RM新时代网站-首页

0
  • 聊天消息
  • 系統(tǒng)消息
  • 評論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

FlashAttenion-V3: Flash Decoding詳解

jf_pmFSk4VX ? 來源:GiantPandaCV ? 2023-10-31 16:18 ? 次閱讀

Flash Attention V1和V2的作者又推出了Flash Decoding,真是太強了!

Flash-Decoding借鑒了FlashAttention的優(yōu)點,將并行化維度擴展到keys/values序列長度。這種方法幾乎不收序列長度影響(這對LLM模型能力很重要),可以充分利用GPU,即使在batch size較小時(inference特點),也可以極大提高了encoding速度。

相關(guān)背景知識先推薦閱讀:

FlashAttention圖解(如何加速Attention)

FlashAttention2詳解(性能比FlashAttention提升200%)

Motivation

最近,像ChatGPT或Llama這樣的LLM模型受到了空前的關(guān)注。然而,它們的運行成本卻非常高昂。雖然單次回復(fù)的成本約為0.01美元(例如在AWS 8塊A100上運行幾秒鐘),但是當(dāng)擴展到數(shù)十億用戶的多次交互時,成本會迅速上升。而且一些場景的成本更高,例如代碼自動補全,因為只要用戶輸入一個新字符就會執(zhí)行。由于LLM應(yīng)用非常廣泛且還在迅速增長,即使稍微提升其運行效率也會產(chǎn)生巨大的收益。

LLM inference(或稱為decoding)是一個迭代的過程:預(yù)測的tokens是逐個生成的。如果生成的句子有N個單詞,那么模型需要進行N次forward。一個常用的優(yōu)化技巧是KV Cache,該方法緩存了之前forward的一些中間結(jié)果,節(jié)約了大部分運算(如MatMul),但是attention操作是個例外。隨著輸出tokens長度增加,attention操作的復(fù)雜度也極具上升。

然而我們希望LLM能處理長上下文。增加了上下文長度,LLM可以輸出更長的文檔、跟蹤更長的對話,甚至在編寫代碼之前處理整個代碼庫。例如,2022年大多數(shù)LLM的上下文長度最多為2k(如GPT-3),但現(xiàn)在LLM上下文長度可以擴展到32k(Llama-2-32k),甚至最近達(dá)到了100k(CodeLlama)。在這種情況下,attention操作在推理過程中占據(jù)了相當(dāng)大的時間比例。此外,當(dāng)batch size增加時,即使在相對較小的上下文中,attention操作也可能成為瓶頸。這是因為該操作需要對內(nèi)存的訪問會隨著batch size增加而增加,而模型中其他操作只和模型大小相關(guān)。

因此,本文提出了Flash-Decoding,可以推理過程中顯著加速attention操作(例如長序列生成速度提高8倍)。其主要思想是最大化并行加載keys和values的效率,通過重新縮放組合得到正確結(jié)果。

Multi-head attention for decoding

在decoding過程中,每個生成的新token需要與先前的tokens合并后,才能繼續(xù)執(zhí)行attention操作,即936fb5aa-77c1-11ee-939d-92fbcf53809c.png。Attention操作在訓(xùn)練過程的瓶頸主要卡在訪問內(nèi)存讀寫中間結(jié)果(例如93895640-77c1-11ee-939d-92fbcf53809c.png)的帶寬,相關(guān)加速方案可以參考FlashAttention和FlashAttention2。

然而,上述優(yōu)化不適合直接應(yīng)用于推理過程。因為在訓(xùn)練過程中,F(xiàn)lashAttention對batch size和query length進行了并行化加速。而在推理過程中,query length通常為1,這意味著如果batch size小于GPU上的SM數(shù)量(例如A100上有108個SMs),那么整個計算過程只使用了GPU的一小部分!特別是當(dāng)上下文較長時,通常會減小batch size來適應(yīng)GPU內(nèi)存。例如batch size = 1時,F(xiàn)lashAttention對GPU利用率小于1%!

下面展示了FlashAttention的計算示意圖,該示例將keys和values分為了2個block:

93a173e2-77c1-11ee-939d-92fbcf53809c.png

FlashAttention示意圖

對應(yīng)的計算公式:

93b5acae-77c1-11ee-939d-92fbcf53809c.png

FlashAttention示意圖對應(yīng)的計算公式

注意93bdf760-77c1-11ee-939d-92fbcf53809c.png的計算過程依賴93c63aba-77c1-11ee-939d-92fbcf53809c.png,從下圖也可以看出,F(xiàn)lashAttention是按順序更新output的,其實當(dāng)時我在看FlashAttention這篇文章時就覺得這個順序操作可以優(yōu)化的,因為反正都要rescale,不如最后統(tǒng)一rescale,沒必要等之前block計算完(為了獲取上一個block的max值)

93d525ac-77c1-11ee-939d-92fbcf53809c.jpg

flashattention計算過程

A faster attention for decoding: Flash-Decoding

上面提到FlashAttention對batch size和query length進行了并行化加速,F(xiàn)lash-Decoding在此基礎(chǔ)上增加了一個新的并行化維度:keys/values的序列長度。即使batch size很小,但只要上下文足夠長,它就可以充分利用GPU。與FlashAttention類似,F(xiàn)lash-Decoding幾乎不用額外存儲大量數(shù)據(jù)到全局內(nèi)存中,從而減少了內(nèi)存開銷。

93e66074-77c1-11ee-939d-92fbcf53809c.gif

flashdecoding計算過程

Flash Decoding主要包含以下三個步驟(可以結(jié)合上圖來看):

將keys和values分成較小的block

使用FlashAttention并行計算query與每個block的注意力(這是和FlashAttention最大的區(qū)別)。對于每個block的每行(因為一行是一個特征維度),F(xiàn)lash Decoding會額外記錄attention values的log-sum-exp(標(biāo)量值,用于第3步進行rescale)

對所有output blocks進行reduction得到最終的output,需要用log-sum-exp值來重新調(diào)整每個塊的貢獻(xiàn)

實際應(yīng)用中,第1步中的數(shù)據(jù)分塊不涉及GPU操作(因為不需要在物理上分開),只需要對第2步和第3步執(zhí)行單獨的kernels。雖然最終的reduction操作會引入一些額外的計算,但在總體上,F(xiàn)lash-Decoding通過增加并行化的方式取得了更高的效率。

Benchmarks on CodeLlama 34B

作者對CodeLLaMa-34b的decoding throughput進行了基準(zhǔn)測試。該模型與Llama 2具有相同的架構(gòu)。作者在各種序列長度(從512到64k)上測試了decoding速度,并比較了多種attention計算方法:

PyTorch:使用純PyTorch primitives運行注意力計算(不使用FlashAttention)。

FlashAttention v2(v2.2之前的版本)。

FasterTransformer:使用FasterTransformer attention kernel

Flash-Decoding

將從內(nèi)存中讀取整個模型和KV Cache所需的時間作為上限

940efbf6-77c1-11ee-939d-92fbcf53809c.png

Untitled

從上圖可以看出,F(xiàn)lash-Decoding在處理非常大的序列時速度可以提高8倍,并且比其他方法具有更好的可擴展性。所有方法在處理small prompts時表現(xiàn)相似,但隨著序列長度從512增加到64k,其他方法的性能都變差了,而Flash-Decoding對序列長度的增加并不敏感(下圖也是很好的證明)

9422813a-77c1-11ee-939d-92fbcf53809c.png

micro-benchmark on A100

Using Flash-Decoding

作者還通了Flash-Decoding使用方式:

基于FlashAttention package ,從版本2.2開始。

xFormers,在版本0.0.22中提供了xformers.ops.memory_efficient_attention模塊

作者也提供了LLaMa v2/CodeLLaMa的repo1和xFormers repo2。此外,作者還提供了一個針對LLaMa v1/v2的最小示例。

個人總結(jié)

Flash-Decoding對LLM在GPU上inference進行了顯著加速(尤其是batch size較小時),并且在處理長序列時具有更好的可擴展性。

聲明:本文內(nèi)容及配圖由入駐作者撰寫或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問題,請聯(lián)系本站處理。 舉報投訴
  • gpu
    gpu
    +關(guān)注

    關(guān)注

    28

    文章

    4729

    瀏覽量

    128890
  • 模型
    +關(guān)注

    關(guān)注

    1

    文章

    3226

    瀏覽量

    48807
  • LLM
    LLM
    +關(guān)注

    關(guān)注

    0

    文章

    286

    瀏覽量

    327

原文標(biāo)題:FlashAttenion-V3: Flash Decoding詳解

文章出處:【微信號:GiantPandaCV,微信公眾號:GiantPandaCV】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。

收藏 人收藏

    評論

    相關(guān)推薦

    Flash基本操作——Flash基礎(chǔ)(1)#多媒體技術(shù)

    FlaSh
    未來加油dz
    發(fā)布于 :2023年05月24日 10:43:53

    Flash基本操作——Flash工具1(3)#多媒體技術(shù)

    FlaSh
    未來加油dz
    發(fā)布于 :2023年05月24日 10:46:17

    Flash基本操作——Flash工具2(3)#多媒體技術(shù)

    FlaSh
    未來加油dz
    發(fā)布于 :2023年05月24日 10:48:11

    Flash基本操作——Flash工具3(1)#多媒體技術(shù)

    FlaSh
    未來加油dz
    發(fā)布于 :2023年05月24日 10:49:01

    Flash基本操作——Flash工具3(2)#多媒體技術(shù)

    FlaSh
    未來加油dz
    發(fā)布于 :2023年05月24日 10:49:44

    Flash基本操作——Flash工具3(3)#多媒體技術(shù)

    FlaSh
    未來加油dz
    發(fā)布于 :2023年05月24日 10:50:22

    Necessary to disable "Above 4G Decoding" for View with vGPU?

    /grid-vgpu-deployment-guide.pdf 在第17頁,它為幾個服務(wù)器制造商提供了BIOS建議。 它建議禁用SuperMicro的“Above 4G Decoding”。 對于Dom0為32位
    發(fā)表于 09-04 15:36

    3~25V與10安3~15V電壓可調(diào)電壓電路原理圖詳解

    3~25V與10安3~15V電壓可調(diào)電壓電路原理圖詳解
    發(fā)表于 04-16 20:47

    模電Flash動畫詳解

    模電Flash動畫詳解,一共有161個!
    發(fā)表于 09-27 08:15

    Flash Magic V2.45

    Flash Magic V2.45 Flash Magic V2.45軟件
    發(fā)表于 05-10 11:24 ?8次下載

    基于MSP430功能模塊詳解系列之——FLASH存儲器

    基于MSP430功能模塊詳解系列之——FLASH存儲器
    發(fā)表于 10-12 15:27 ?11次下載
    基于MSP430功能模塊<b class='flag-5'>詳解</b>系列之——<b class='flag-5'>FLASH</b>存儲器

    MP3-FLASH-16P 使用說明書 V1.0

    藍(lán)板MP3-FLASH-16P使用說明書 V1.0 MP3-FLASH-16P 是一個提供串口的語音模塊,完美的集成了 MP3、WAV 的硬解碼。同時軟件支持工業(yè)級別的串口通信協(xié)議,以
    發(fā)表于 11-28 14:08 ?24次下載

    【轉(zhuǎn)載】keil將程序裝入外部FLASH詳解

    【轉(zhuǎn)載】keil將程序裝入外部FLASH詳解
    發(fā)表于 12-01 20:21 ?14次下載
    【轉(zhuǎn)載】keil將程序裝入外部<b class='flag-5'>FLASH</b><b class='flag-5'>詳解</b>

    開源軟件-Morse_Encoding_Decoding摩斯密碼工具

    ./oschina_soft/Morse_Encoding_Decoding.zip
    發(fā)表于 06-28 11:52 ?1次下載
    開源軟件-Morse_Encoding_<b class='flag-5'>Decoding</b>摩斯密碼工具

    瑞薩Flash程序員V3 發(fā)布說明

    電子發(fā)燒友網(wǎng)站提供《瑞薩Flash程序員V3 發(fā)布說明.pdf》資料免費下載
    發(fā)表于 02-19 09:37 ?1次下載
    瑞薩<b class='flag-5'>Flash</b>程序員<b class='flag-5'>V3</b> 發(fā)布說明
    RM新时代网站-首页