不分割成token,直接從字節中高效學習,Mamba原來還能這樣用
機器之心報道
編輯:張倩
在定義語言模型時,通常會使用一種基本分詞方法,把句子分爲詞(word)、子詞(subword)或字符(character)。其中,子詞分詞法一直是最受歡迎的選擇,因爲它在訓練效率和處理詞彙表外單詞的能力之間實現了自然的折中。然而,一些研究指出了子詞分詞法的問題,如對錯別字、拼寫和大小寫變化以及形態變化缺乏穩健性。
因此,有些研究人員另闢蹊徑,採用了一種使用字節序列的方法,即從原始數據到預測的端到端映射,中間不進行任何分詞。與子詞模型相比,基於字節級的語言模型能夠更容易地在不同的書寫形式和形態變化之間進行泛化。當然,將文本建模爲字節意味着生成的序列要比對應的子詞長得多。如此一來,效率的提升就要依靠架構的改進來實現了。
自迴歸 Transformer 在語言建模中占主導地位,但效率問題尤爲突出:計算成本隨序列長度呈二次方增長,因此對長(字節)序列的擴展能力很差。研究人員壓縮了 Transformer 的內部表示,以便處理長序列,例如開發了長度感知建模方法,在這種方法中,token 組在中間層內合併。最近,Yu 等人 [2023] 提出了 MegaByte Transformer,它使用固定大小的字節片段作爲子詞的模擬壓縮形式。因此,MegaByte 可以降低計算成本。不過,這可能還不是最好的方法。
在一份新論文中,來自康奈爾大學的研究者介紹了一種高效、簡單的字節級語言模型 MambaByte。該模型對最近推出的 Mamba 架構進行了直接改造。Mamba 建立在狀態空間模型(SSM)開創的方法基礎上,引入了對文本等離散數據更有效的選擇機制,並提供了高效的 GPU 實現。作者的簡單觀察結果是,使用 Mamba(不做修改)可以緩解語言建模中的主要計算瓶頸,從而消除 patching 並有效利用可用的計算資源。
他們在實驗中將 MambaByte 與 Transformers、SSM 和 MegaByte(patching)架構進行了比較,這些架構都是在固定參數和固定計算設置下,並在多個長篇文本數據集上進行比較的。圖 1 總結了他們的主要發現。
與字節級 Transformers 相比,MambaByte 能更快地實現更好的性能,計算效率也明顯更高。作者還考慮了無 token 語言模型與現有最先進的子詞模型相比的可行性。在這方面,他們發現 MambaByte 與各種子詞基線模型相比具有競爭力,但它能處理更長的序列。研究結果表明,MambaByte 是現有依賴分詞器( tokenizer)的模型的有力替代品,有望用來促進端到端學習。
背景:選擇性狀態空間序列模型
SSM 通過一階微分方程對隱藏狀態的跨時間演變進行建模。線性時不變(time-invariant) SSM 在幾種模態的深度學習中顯示出了良好的效果。然而,Mamba 作者 Gu 和 Dao 最近認爲,這些方法的恆定動態缺乏隱藏狀態中依賴輸入的上下文選擇,而這可能是語言建模等任務所必需的。爲此,他們提出了 Mamba,該方法將給定輸入 x (t) ∈ R、隱藏狀態 h (t) ∈ R^n 和輸出 y (t) ∈ R 在時間 t 的時變連續狀態動態定義爲:
其參數爲對角時不變系統矩陣 A∈R^(n×n),以及隨時間變化的輸入和輸出矩陣 B (t)∈R^(n×1) 和 C (t)∈R^(1×n)。
要對字節等離散時間序列建模,必須通過離散化來逼近 (1) 中的連續時間動態。這就產生了離散時間隱態 recurrence,每個時間步都有新矩陣 A、B 和 C,即
請注意,(2) 類似於循環神經網絡的線性版本,可以在語言模型生成過程中以這種循環形式應用。離散化要求每個輸入位置都有一個時間步,即 ∆[k],對應於 的 x [k] = x (t_k)。然後就可以根據 ∆[k] 計算出離散時間矩陣 A、B 和 C。圖 2 展示了 Mamba 如何爲離散序列建模。
在 Mamba 中,SSM 項是輸入選擇性的,即 B、C 和 ∆ 被定義爲輸入 x [k]∈R^d 的函數:
其中 W_B ∈ R^(n×d)(C 的定義類似),W_∆ ∈ R^(d×r) 和 W_R ∈ R^(r×d)(對於某個 r ≪d)是可學習的權重,而 softplus 則確保正向性。請注意,對於每個輸入維度 d,SSM 參數 A、B 和 C 都是相同的,但時間步數 ∆ 是不同的;這導致每個時間步數 k 的隱藏狀態大小爲 n × d。
Mamba 將這個 SSM 層嵌入到一個完整的神經網絡語言模型中。具體來說,該模型採用了一系列門控層,其靈感來源於之前的門控 SSM。圖 3 顯示了將 SSM 層與門控神經網絡相結合的 Mamba 架構。
實驗結果
表 2 顯示了每個數據集的每字節比特數(BPB)。在本實驗中,MegaByte758M+262M 和 MambaByte 模型使用相同的每字節 FLOP 數(見表 1)。作者發現,在所有數據集上,MambaByte 的性能始終優於 MegaByte。此外,作者注意到,由於資金限制,他們無法對 MambaByte 進行完整的 80B 字節訓練,但 MambaByte 在計算量和訓練數據減少 63% 的情況下仍優於 MegaByte。此外,MambaByte-353M 還優於字節級 Transformer 和 PerceiverAR。
在如此少的訓練步驟中,MambaByte 爲什麼比一個大得多的模型表現得更好?圖 1 通過觀察參數數量相同的模型進一步探討了這種關係。圖中顯示,對於參數大小相同的 MegaByte 模型,輸入 patching 較少的模型表現更好,但在計算歸一化後,它們的表現類似。事實上,全長的 Transformer 雖然在絕對意義上速度較慢,但在計算歸一化後,其性能也與 MegaByte 相似。相比之下,改用 Mamba 架構可以顯著提高計算使用率和模型性能。
根據這些發現,表 3 比較了這些模型在 PG19 數據集上的較大版本。在這個實驗中,作者將 MambaByte-972M 與 MegaByte-1.3B+350M 和其他字節級模型以及幾個 SOTA 子詞模型進行了比較。他們發現,MambaByte-972M 即使只訓練了 150B 字節,其性能也優於所有字節級模型,並與子詞模型相比具有競爭力。
文本生成。Transformer 模型中的自迴歸推理需要緩存整個上下文,這會大大影響生成速度。MambaByte 不存在這一瓶頸,因爲它每層只保留一個隨時間變化的隱藏狀態,因此每生成一步的時間是恆定的。表 4 比較了 MambaByte-972M 和 MambaByte-1.6B 與 MegaByte-1.3B+350M 在 A100 80GB PCIe GPU 上的文本生成速度。雖然 MegaByte 通過 patching 大大降低了生成成本,但他們觀察到 MambaByte 由於使用了循環生成,在參數相似設置下速度達到了前者的 2.6 倍。