寫在前面:近來筆者偶然間接觸了一個深度學習框架 OneFlow,所以這段時間主要在閱讀 OneFlow 框架的 cuda 源碼。官方源碼基于不同場景分三種方式實現(xiàn) Softmax,本文主要介紹其中一種的實現(xiàn)過程,即 Warp 級別 Softmax,適用于矩陣寬度不超過 1024 的情況。
1 Softmax
Softmax 操作是深度學習模型中最常用的操作之一。在深度學習的多分類任務中,最后一層通常是一個 Softmax 操作將 logits 映射成概率,然后結合交叉熵求損失。另外還有一些場景會用到 Softmax 做一個歸一化操作,比如 Transformer 結構中 query 和 key 矩陣相乘并縮放后會執(zhí)行一個 Softmax 操作,這一步的意義是求出 query 和 key 中每一項的兩兩相似度,具體筆者在另一篇文章有詳述——【ASR】基于DFCNN-CTC模型的語音識別系統(tǒng)(二)
圖1 Scaled Dot-Product Attention 結構示意圖
深度學習框架中的所有算子底層都對應著 GPU上的 CUDA kernel function,Softmax 操作也不例外。Softmax 作為一個被廣泛使用的算子,其 CUDA Kernel 的實現(xiàn)會影響很多網(wǎng)絡最終的訓練速度。那么如何實現(xiàn)一個高效的 Softmax CUDA Kernel?本文將會介紹 OneFlow 中優(yōu)化的 Softmax CUDA Kernel 的技巧,在這之前我們先來看一下 Softmax 的計算公式。
定義 x 是一個 n 維向量,其 Softmax 輸出 y 也是一個 n 維向量,那么有如下計算公式:
從上面的公式可以發(fā)現(xiàn)一個問題,當 為一個較大的正數(shù)時,取指數(shù)后 將會非常大,從而導致數(shù)值溢出,如何解決這個問題呢?
一般的處理方法是,讓每個分量去減掉向量的最大值,這樣可以保證取指數(shù)后的結果必然在 0~1 之間,可以有效避免數(shù)值溢出。處理后的公式如下:
根據(jù)公式可以看出,要執(zhí)行 Softmax 計算,需要實現(xiàn) 5 個業(yè)務邏輯:reduceMax、broadcastSub、exp、reduceSum、broadcastDiv。下面筆者將對源碼中的計算技巧進行解讀,有興趣的讀者可以下載源碼來閱讀(https://github.com/BBuf/how-to-optim-algorithm-in-cuda/blob/master/softmax/oneflow_softmax.cu)。
2 三種實現(xiàn)方式
Softmax 函數(shù)的輸入形狀為:(num_rows, num_cols),num_cols 的變化會對有效帶寬產(chǎn)生影響。因為,沒有一種通用的優(yōu)化方法可以實現(xiàn)在所有 num_cols 的情況下都是傳輸最優(yōu)的。所以,在 OneFlow 中采用分段函數(shù)優(yōu)化 SoftmaxKernel:對于不同 num_cols 范圍,選擇不同的實現(xiàn),以期在所有情況下都能達到較高的有效帶寬。
針對不同的 Softmax 場景,OneFlow 提供了三種實現(xiàn),分段對 Softmax kernel 進行優(yōu)化:
(1) 一個 Warp 處理一行的計算,適用于 num_cols <= 1024 情況
(2) 一個 Block 處理一行的計算,借助 Shared Memory 保存中間結果數(shù)據(jù),適用于需要的 Shared Memory 資源滿足 Kernel Launch 的可啟動條件的情況,在本測試環(huán)境中是 1024 < num_cols <= 4096。
(3) 一個 Block 處理一行的計算,不使用 Shared Memory,重復讀輸入 x,適用于不支持(1)、(2)的情況。
分段處理邏輯在 DispatchSoftmax 函數(shù)中體現(xiàn),主體代碼如下:
if (cols < 1024) { return DispatchSoftmaxWarpImpl( stream, load, store, rows, cols); } else { bool dispatch_smem_impl_success; { cudaError_t err = TryDispatchSoftmaxBlockSMemImpl( stream, load, store, rows, cols, &dispatch_smem_impl_success); if (err != cudaSuccess) { return err; } } if (!dispatch_smem_impl_success) { return DispatchSoftmaxBlockUncachedImpl( stream, load, store, rows, cols); } return cudaSuccess; } ,>,>,>
3 WarpSoftmax
3.1 數(shù)據(jù) Pack 提升訪問帶寬
在筆者上一篇文章【CUDA編程】OneFlow Element-Wise 算子源碼解讀中詳細地介紹了如何進行向量化讀寫,有興趣的讀者可以移步,這里我們先看源碼。
template struct GetPackType { using type = typename std::aligned_storage::type; }; template using PackType = typename GetPackType::type; template union Pack { static_assert(sizeof(PackType) == sizeof(T) * N, ""); __device__ Pack() { // do nothing } PackType storage; T elem[N]; }; ,>,>,>
oneflow 利用 union 共享空間的特性實現(xiàn)了一個 Pack 類型,細心的讀者可能會發(fā)現(xiàn),跟 elementwise.cuh 源碼相比,這里少了一個 Packed 類,這是因為 elementwise.cuh 實現(xiàn)的時間晚于 softmax.cuh??赡芸紤]到 Pack 后類型的內(nèi)存對齊特性,重新定義了 Packed 類,并聲明了內(nèi)存對齊值為 pack_size * sizeof(T)。
接下來定義了兩個代表輸入和輸出的數(shù)據(jù)結構 DirectLoad 和 DirectStore,分別實現(xiàn)了 load 和 store 兩個函數(shù)用來把讀取和寫入一個 pack 的數(shù)據(jù)。使用 DirectLoad 和 DirectStore 有兩個好處:
可以在CUDA Kernel中只關心計算類型ComputeType,而不用關心具體的數(shù)據(jù)類型T。
只需要加幾行代碼就可以快速支持 Softmax 和其他 Kernel Fuse,減少帶寬需求,提升整體性能。
/** * @brief 定義了輸入的數(shù)據(jù)結構 * * @tparam SRC 輸入數(shù)據(jù)的類型 * @tparam DST 計算數(shù)據(jù)的類型,ComputeType */ template struct DirectLoad { /** * @brief Construct a new Direct Load object * * @param src 輸入的數(shù)據(jù)源 * @param row_size num of elements per row */ DirectLoad(const SRC* src, int64_t row_size) : src(src), row_size(row_size) {} /** * @brief 從數(shù)據(jù)源 load 一個 pack 數(shù)據(jù)到 dst * * @tparam N pack_size * @param dst * @param row 數(shù)據(jù)源的第 row 行 * @param col 數(shù)據(jù)源的第 col 列 * @return __device__ */ template __device__ void load(DST* dst, int64_t row, int64_t col) const { Pack pack; const int64_t offset = (row * row_size + col) / N; // pack 偏移量 pack.storage = *(reinterpret_cast*>(src) + offset); #pragma unroll for (int i = 0; i < N; ++i) { dst[i] = static_cast(pack.elem[i]); } } const SRC* src; int64_t row_size; }; template struct DirectStore { DirectStore(DST* dst, int64_t row_size) : dst(dst), row_size(row_size) {} template __device__ void store(const SRC* src, int64_t row, int64_t col) { Pack pack; const int64_t offset = (row * row_size + col) / N; #pragma unroll for (int i = 0; i < N; ++i) { pack.elem[i] = static_cast(src[i]); } *(reinterpret_cast*>(dst) + offset) = pack.storage; } DST* dst; int64_t row_size; }; ,>,>
3.2 調(diào)用鏈
針對 WarpSoftmax 這個分支,對源碼中函數(shù)的調(diào)用關系梳理后如下:
DispatchSoftmaxWarpImpl ->DispatchSoftmaxWarpImplPackSize ->DispatchSoftmaxWarpImplCols ->DispatchSoftmaxWarpImplPadding ->LaunchSoftmaxWarpImpl ->SoftmaxWarpImpl(kernel)
接下來將從上到下逐個解讀其實現(xiàn)細節(jié)。
3.3 DispatchSoftmaxWarpImpl
該函數(shù)被 DispatchSoftmax 函數(shù)調(diào)用,其內(nèi)部邏輯非常簡單,實例化了一個 DispatchSoftmaxWarpImplPackSize 類并調(diào)用了其重載的()函數(shù),所有的參數(shù)都是透傳,沒有其他邏輯。
template inline cudaError_t DispatchSoftmaxWarpImpl(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) { return DispatchSoftmaxWarpImplPackSize()(stream, load, store, rows, cols); } ,>
3.4 DispatchSoftmaxWarpImplPackSize
顧名思義,pack_size 參數(shù)是在這個結構體內(nèi)部確定的。該結構體內(nèi)部重載了一個小括號運算符,其函數(shù)內(nèi)部只做了一件事,對矩陣的列數(shù)進行判斷,如果是偶數(shù),pack_size 取 2,否則取 1。
template struct DispatchSoftmaxWarpImplPackSize { cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) { if (cols % 2 == 0) { return DispatchSoftmaxWarpImplCols(stream, load, store, rows, cols); } else { return DispatchSoftmaxWarpImplCols(stream, load, store, rows, cols); } } }; ,>,>
筆者讀到這里不禁產(chǎn)生了疑問,前面說過數(shù)據(jù) Pack 后可以提升 GPU 訪問帶寬,但是在該函數(shù)中 pack_size 最大也只能取到 2,在前面的文章中筆者提到過在 cuda 中最大支持一次 128 bit的讀寫,意味著針對 float 類型 pack_size 最大可以取 4,對 half 類型甚至可以取 8。所以帶著這個疑問筆者咨詢了官方源碼的作者俊丞大佬,答曰可以取更大的 pack_size,這里是考慮到更多的特化會導致編譯時間過長所以只實現(xiàn)了 2 個模板。獲得解答后,筆者自行實現(xiàn)了一個 pack_size = 4 的模板,然后經(jīng)過實測(矩陣大小為 1024*1024, 32*16)發(fā)現(xiàn), pack_size 取 4 和取 2 相比幾乎沒有提升。。。倒是取 2 相比取 1 有 6% 的提升。猜測可能是 pack_size 影響了 DispatchSoftmaxWarpImplCols 這個 kernel 的啟動參數(shù),所以間接影響了性能,這里官方肯定做過一系列測試。。。
3.5 DispatchSoftmaxWarpImplCols
DispatchSoftmaxWarpImplCols 函數(shù)代碼比較長,讀起來稍顯晦澀,要理解它的實現(xiàn)邏輯,我們可以換個思路,看它的目的是什么,然后倒推它的實現(xiàn)過程。很顯然,該函數(shù)在最后調(diào)用了 DispatchSoftmaxWarpImplPadding 函數(shù),那么我們就來看被調(diào)用的函數(shù)需要哪些參數(shù),DispatchSoftmaxWarpImplCols 的作用就是確定這些參數(shù)。讀了 DispatchSoftmaxWarpImplPadding 的參數(shù)列表我們可以發(fā)現(xiàn),有三個重要參數(shù)需要傳入:cols_per_thread, thread_group_width, rows_per_access,這里先對這三個參數(shù)做一個解釋:
cols_per_thread:每個線程處理的元素列數(shù)
thread_group_width:線程組的大小,一個線程組要處理整行的數(shù)據(jù)
rows_per_access:每個線程組一次處理的行數(shù)
函數(shù)體內(nèi)主要是針對 cols 的大小做了分支,前后代碼有一個分水嶺,即 cols <= 32 * pack_size,可以分開來看。
當 cols <= 32 * pack_size 時,thread_group_width 取 2 的 n 次冪,從 1 到 32 一直判斷,如果 cols <= (thread_group_width)*pack_size 那么 thread_group_width 就取當前的值。cols_per_thread 取 pack_size,就是說當前一個線程只處理一個 Pack 寬度的數(shù)據(jù),這時候數(shù)據(jù)量也比較小,所以對 rows 也做了一層判斷,如果 rows 是偶數(shù),那么 rows_per_access 取 2,每個線程一次處理 2 行數(shù)據(jù),否則一次只處理 1 行。
當 cols > 32 * pack_size 時,這種屬于數(shù)據(jù)量比較大的情況,所以 thread_group_width 直接取能取到的最大值 32,即 Warp 的大小。每個線程也要處理多個 Pack,cols_per_thread 取 pack_size 的整數(shù)倍,直到 32 * cols_per_thread = 1024,一直判斷 cols <= 32 * cols_per_thread,如果滿足條件,cols_per_thread 就取當前值。對于 rows_per_access 參數(shù),直接取 1,即每個線程一次只處理 1 行數(shù)據(jù)。
至此函數(shù)邏輯就介紹完了,這個函數(shù)里有兩個宏,不熟悉 C++ 的讀者讀起來可能沒那么順暢,這里推薦一個網(wǎng)站(https://cppinsights.io/),從編譯器的角度將 C++ 源碼展開顯示,對閱讀泛型編程和宏這類代碼很有幫助。
template typename std::enable_if::type DispatchSoftmaxWarpImplCols( cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) { if (cols <= 0) { return cudaErrorInvalidValue; } #define DEFINE_ONE_ELIF(thread_group_width) else if (cols <= (thread_group_width)*pack_size) { if (rows % 2 == 0) { return DispatchSoftmaxWarpImplPadding(stream, load, store, rows, cols); } else { return DispatchSoftmaxWarpImplPadding(stream, load, store, rows, cols); } } DEFINE_ONE_ELIF(1) DEFINE_ONE_ELIF(2) DEFINE_ONE_ELIF(4) DEFINE_ONE_ELIF(8) DEFINE_ONE_ELIF(16) DEFINE_ONE_ELIF(32) #undef DEFINE_ONE_ELIF #define DEFINE_ONE_ELIF(col) else if (cols <= (col)*kWarpSize) { return DispatchSoftmaxWarpImplPadding(stream, load, store, rows, cols); } DEFINE_ONE_ELIF(4) DEFINE_ONE_ELIF(6) DEFINE_ONE_ELIF(8) DEFINE_ONE_ELIF(10) DEFINE_ONE_ELIF(12) DEFINE_ONE_ELIF(14) DEFINE_ONE_ELIF(16) DEFINE_ONE_ELIF(18) DEFINE_ONE_ELIF(20) DEFINE_ONE_ELIF(22) DEFINE_ONE_ELIF(24) DEFINE_ONE_ELIF(26) DEFINE_ONE_ELIF(28) DEFINE_ONE_ELIF(30) DEFINE_ONE_ELIF(32) #undef DEFINE_ONE_ELIF else { return cudaErrorInvalidValue; } } ,>,>,>
3.6 DispatchSoftmaxWarpImplPadding
顧名思義,這個函數(shù)內(nèi)部的邏輯跟 padding 相關,實際上這個函數(shù)只做了一件事,當 cols == cols_per_thread * thread_group_width 時說明矩陣列數(shù)能被線程組均分,這時候不需要 padding,否則需要 padding。
template inline cudaError_t DispatchSoftmaxWarpImplPadding(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) { if (cols == cols_per_thread * thread_group_width) { return LaunchSoftmaxWarpImpl( stream, load, store, rows, cols); } else { return LaunchSoftmaxWarpImpl( stream, load, store, rows, cols); } } ,>,>
3.7 LaunchSoftmaxWarpImpl
該函數(shù)是核函數(shù)的啟動函數(shù),函數(shù)內(nèi)主要是確定 block_size、num_blocks 這兩個參數(shù)。這兩個參數(shù)的確定筆者在上一篇文章【CUDA編程】OneFlow Element-Wise 算子源碼解讀中有詳細介紹,有興趣的讀者可以移步,這里不再贅述。
函數(shù)中定義了一個 block_dim 對象,從初始化參數(shù)可以看出這是一個二維的 block,寬是 thread_group_width,高取 thread_groups_per_block。從核函數(shù)啟動參數(shù) grid_dim_x 可以看出網(wǎng)格是一維的,由此我們可以確定 cuda 線程網(wǎng)格的形狀。這里筆者給出示意圖如下。
圖2 線程網(wǎng)格示意圖
template inline cudaError_t LaunchSoftmaxWarpImpl(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) { constexpr int block_size = 128; constexpr int waves = 32; static_assert(block_size % thread_group_width == 0, ""); constexpr int thread_groups_per_block = block_size / thread_group_width; dim3 block_dim(thread_group_width, thread_groups_per_block); const int64_t num_blocks = (rows / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block; int grid_dim_x; { cudaError_t err = GetNumBlocks(block_size, num_blocks, waves, &grid_dim_x); if (err != cudaSuccess) { return err; } } SoftmaxWarpImpl <<>>(load, store, rows, cols); return cudaPeekAtLastError(); } ,>,>
3.8 核函數(shù) SoftmaxWarpImpl
接下來就是 WarpSoftmax 的核函數(shù) SoftmaxWarpImpl,該函數(shù)體內(nèi)部實現(xiàn)了 Softmax 的核心計算邏輯。
template __global__ void SoftmaxWarpImpl(LOAD load, STORE store, const int64_t rows, const int64_t cols) { static_assert(cols_per_thread % pack_size == 0, ""); // 確保每個thread處理的元素個數(shù)正好被完全pack static_assert(thread_group_width <= kWarpSize, ""); // 處理元素的線程組的寬度需要小于等于kWarpSize,并且需要被kWarpSize整除 static_assert(kWarpSize % thread_group_width == 0, ""); constexpr int num_packs = cols_per_thread / pack_size; // 每個線程處理的 pack 的數(shù)目,即每個線程需要處理的元素個數(shù) / pack_size assert(cols <= cols_per_thread * thread_group_width); // 確保一個thread group 能處理的列數(shù)大于等于一行 ComputeType buf[rows_per_access][cols_per_thread]; // 聲明寄存器大小,這是一個二維數(shù)組 const int global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y; // 當前warp的全局index const int num_global_thread_group = gridDim.x * blockDim.y; // warp的總數(shù)量 const int lane_id = threadIdx.x; // warp內(nèi)的線程id const int64_t step = num_global_thread_group * rows_per_access; // 處理的行數(shù)步長 // for 循環(huán)的開始為 row = 全局的線程組id * 每個線程組一次處理的行數(shù),結束為總行數(shù) for (int64_t row = global_thread_group_id * rows_per_access; row < rows; row += step) { // 寄存器中開辟一塊內(nèi)存記錄當前線程組處理的每一行的最大值 ComputeType thread_max[rows_per_access]; // 對每一行的循環(huán) #pragma unroll for (int row_id = 0; row_id < rows_per_access; ++row_id) { // 把當前行最小值初始化為 -inf thread_max[row_id] = -Inf(); // 獲取第 row_id 行的指針 ComputeType* row_buf = buf[row_id]; #pragma unroll for (int pack_id = 0; pack_id < num_packs; ++pack_id) { const int pack_offset = pack_id * pack_size; // 相鄰的線程讀取相鄰的pack,也就是說同一個線程處理的相鄰pack間間隔是thread_group_width*pack_size const int col = (pack_id * thread_group_width + lane_id) * pack_size; if (!padding || col < cols) { // 使用 obj.template 調(diào)用函數(shù)模板防止歧義,load 一個 pack 的數(shù)據(jù)到寄存器 load.template load(row_buf + pack_offset, row + row_id, col); #pragma unroll for (int i = 0; i < pack_size; ++i) { thread_max[row_id] = max(thread_max[row_id], row_buf[pack_offset + i]); } } else { // 需要 padding 且 col > cols,這種情況對于第 col 列的數(shù)據(jù)直接將 row_buf 賦最新小值,不影響 thread_max 計算即可 #pragma unroll for (int i = 0; i < pack_size; ++i) { row_buf[pack_offset + i] = -Inf(); } } } } // 記錄屬于同一個warp的線程組的每一行的最大值,也就是需要進行一次warpReduce max ComputeType warp_max[rows_per_access]; #pragma unroll for (int row_id = 0; row_id < rows_per_access; ++row_id) { // 通過線程束洗牌函數(shù)對一個線程組內(nèi)的所有線程的 thread_max 求規(guī)約得到一個線程組處理的每一行的最大值 warp_max[row_id] = WarpAllReduce(thread_max[row_id]); } // 記錄當前線程組處理的每一行的sum ComputeType thread_sum[rows_per_access]; #pragma unroll for (int row_id = 0; row_id < rows_per_access; ++row_id) { thread_sum[row_id] = 0; ComputeType* row_buf = buf[row_id]; #pragma unroll for (int i = 0; i < cols_per_thread; ++i) { if (algorithm == Algorithm::kSoftmax) { row_buf[i] = Exp(row_buf[i] - warp_max[row_id]); thread_sum[row_id] += row_buf[i]; } else if (algorithm == Algorithm::kLogSoftmax) { row_buf[i] -= warp_max[row_id]; thread_sum[row_id] += Exp(row_buf[i]); } else { __trap(); // 內(nèi)核的執(zhí)行被中止并在主機程序中引發(fā)中斷。 } } } ComputeType warp_sum[rows_per_access]; #pragma unroll for (int row_id = 0; row_id < rows_per_access; ++row_id) { warp_sum[row_id] = WarpAllReduce(thread_sum[row_id]); } #pragma unroll for (int row_id = 0; row_id < rows_per_access; ++row_id) { ComputeType* row_buf = buf[row_id]; #pragma unroll for (int i = 0; i < cols_per_thread; ++i) { if (algorithm == Algorithm::kSoftmax) { row_buf[i] = Div(row_buf[i], warp_sum[row_id]); } else if (algorithm == Algorithm::kLogSoftmax) { row_buf[i] -= Log(warp_sum[row_id]); } else { __trap(); } } #pragma unroll for (int i = 0; i < num_packs; ++i) { const int col = (i * thread_group_width + lane_id) * pack_size; if (!padding || col < cols) { store.template store(row_buf + i * pack_size, row + row_id, col); } } } } } ,>,>
具體代碼如上,在解讀之前,需要先介紹一下幾個重要參數(shù)的意義。
algorithm:代表所使用的的算法,有 Algorithm::kSoftmax 和 Algorithm::kLogSoftmax。
global_thread_group_id:當前線程組的全局索引
lane_id:當前線程在線程組內(nèi)的索引
首先在核函數(shù)內(nèi)部做了幾個編譯期斷言操作,確保核函數(shù)能夠正常啟動。然后在寄存器中定義了一個二維數(shù)組 buf[rows_per_access][cols_per_thread] 用來存儲矩陣中的數(shù)據(jù),我們知道,寄存器中的變量只能對當前線程可見,每個線程中都有一個變量 buf,但是存儲的值可以不同,這里是為了減少對全局內(nèi)存的讀取,所以給每個線程都定義一個寄存器變量用于存儲該線程處理的矩陣元素。
接著是一個 Grip-loop 的循環(huán),因為有可能矩陣行數(shù)過大導致前面求 num_blocks 的時候是根據(jù)硬件參數(shù)選取的,這時候每個線程不止處理一次,所以循環(huán)步長設置為網(wǎng)格大小。Grip-loop 內(nèi)部定義了一個寄存器變量 thread_max[rows_per_access],這個數(shù)組用來存儲當前線程處理的元素中的每一行的最大值。接下來就是一個 reduceMax 操作。
(1)reduceMax
如圖 2,每個線程處理了多個 Pack 的數(shù)據(jù),求最大值需要兩層循環(huán)。第一層循環(huán)中把一個 Pack 的矩陣元素 load 到 buf 數(shù)組中,這里主要是要理解 col 變量的含義,結合圖 2 的示意圖不難理解,相鄰的線程讀取相鄰的 Pack 的目的是讓一個線程束中各線程單次訪問的數(shù)據(jù)在內(nèi)存中相鄰,這是一個合并訪問的概念,目的是提升訪問效率。第二層循環(huán)中對單個 Pack 中的元素求最大值存到 thread_max 中。
注意,這時候 thread_max 中存的只是每個線程內(nèi)部處理的元素的最大值,但是 reduceMax 操作要獲取的是矩陣每一行的最大值,由于 WarpSoftmax 的應用范圍就是一個線程組處理一行數(shù)據(jù),所以再對線程組內(nèi)所有的 thread_max 求最大值即可。前面說過,每個線程內(nèi)部都有一個 thread_max 變量,對這些變量求最大值,必然要在線程間進行通信,源碼中使用了 WarpAllReduce() 函數(shù)完成了這一操作得到了矩陣每一行的最大值 warp_max,核心就是利用了線程束洗牌指令 __shfl_xor_sync 完成了一個束內(nèi)折半規(guī)約操作,筆者之前在另一篇文章也有介紹:【CUDA編程】CUDA編程中的并行規(guī)約問題。有興趣的讀者可以去 cuda 官網(wǎng)詳細了解一下束內(nèi)洗牌指令的用法,當然了這里也可以直接使用共享內(nèi)存存儲數(shù)據(jù),我們知道共享內(nèi)存在整個 block 都是可見的,也就不需要使用束內(nèi)通信,但是從訪問性能出發(fā),共享內(nèi)存是不如寄存器快的,所以 oneflow 選擇了寄存器。,>
template class ReductionOp, typename T, int thread_group_width = kWarpSize> __inline__ __device__ T WarpAllReduce(T val) { for (int mask = thread_group_width / 2; mask > 0; mask /= 2) { val = ReductionOp()(val, __shfl_xor_sync(0xffffffff, val, mask)); } return val; }
(2)reduceSum
接下來就是 reduceSum 操作,這里源碼提供了兩種算法: Algorithm::kSoftmax 和 Algorithm::kLogSoftmax。kSoftmax 就是公式(2)中的計算公式,kLogSoftmax 計算的是 計算公式如下:
reduceSum 的計算思路和 reduceMax 相同,先在寄存器定義一個變量 thread_sum 然后求出各個線程內(nèi)的指數(shù)和,最后束內(nèi)規(guī)約求每一行的指數(shù)和 warp_sum。
broadcastSub、exp、broadcastDiv 這三個操作比較簡單,其邏輯就直接包含在兩個規(guī)約操作的實現(xiàn)代碼里,這里不再贅述,至此 WarpSoftmax 源碼解讀完畢,有興趣的讀者可以自行嘗試。調(diào)用時可以將矩陣 cols 限制在 1024 以內(nèi)調(diào)用 DispatchSoftmax 函數(shù),也可以直接調(diào)用 DispatchSoftmaxWarpImpl 函數(shù)。
4 小結
總結一下 WarpSoftmax 源碼中的一些值得注意的內(nèi)容。
數(shù)據(jù) Pack 可以有效地提升訪問帶寬,pack_size 可以根據(jù) cuda 中最大支持一次 128 bit 的讀寫來確定最大值。
WarpSoftmax 的核心就是束內(nèi)規(guī)約,利用了束內(nèi)線程可互相訪問寄存器的特性提高效率,但受制于單個線程可使用的寄存器大小,所以 WarpSoftmax 不適用于矩陣列數(shù)比較大的場景。
源碼中對于 pack_size 和 row_per_access 的確定都比較簡單粗暴,可以進行更細致的處理。
審核編輯:湯梓紅
-
源碼
+關注
關注
8文章
639瀏覽量
29185 -
模型
+關注
關注
1文章
3226瀏覽量
48806 -
深度學習
+關注
關注
73文章
5500瀏覽量
121111 -
OneFlow
+關注
關注
0文章
9瀏覽量
8802
原文標題:【CUDA編程】OneFlow Softmax 算子源碼解讀之WarpSoftmax
文章出處:【微信號:GiantPandaCV,微信公眾號:GiantPandaCV】歡迎添加關注!文章轉載請注明出處。
發(fā)布評論請先 登錄
相關推薦
評論