0x0. 前言
這篇文章來解析一下Megaton-LM涉及到的一個(gè)優(yōu)化gradient_accumulation_fusion。這里fusion的意思是在gemm接口中會(huì)將當(dāng)前的結(jié)果累加到先前計(jì)算的梯度上,所有這些都在一個(gè)操作中完成,可以避免多次訪問global memory提升算子的帶寬。下面解析一下這個(gè)優(yōu)化的調(diào)度邏輯和cuda實(shí)現(xiàn)。
https://github.com/BBuf/how-to-optim-algorithm-in-cuda 這個(gè)倉庫整理了一些cuda優(yōu)化相關(guān)鏈接以及大模型訓(xùn)練推理相關(guān)的知識(shí)鏈接(large-language-model-note子目錄下),歡迎查看。
0x1. 調(diào)度邏輯解析
gradient_accumulation_fusion的調(diào)度邏輯是和LinearWithGradAccumulationAndAsyncCommunication這個(gè)類的實(shí)現(xiàn)有關(guān)的,LinearWithGradAccumulationAndAsyncCommunication 這個(gè)類又被包了一層變成 linear_with_grad_accumulation_and_async_allreduce 這個(gè)函數(shù),這個(gè)函數(shù)又給RowParallelLinear和ColumnParallelLinear這兩個(gè)實(shí)現(xiàn)模型并行的Linear類使用。
下面解析一下linear_with_grad_accumulation_and_async_allreduce這個(gè)函數(shù)(https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/layers.py#L356-L446):
#這部分定義了一個(gè)函數(shù),名為linear_with_grad_accumulation_and_async_allreduce, #它接收七個(gè)參數(shù):輸入張量、權(quán)重張量、一個(gè)可選的偏置張量和3個(gè)布爾標(biāo)志。 deflinear_with_grad_accumulation_and_async_allreduce( input:torch.Tensor, weight:torch.Tensor, bias:Optional[torch.Tensor], gradient_accumulation_fusion:bool, async_grad_allreduce:bool, sequence_parallel_enabled:bool, )->torch.Tensor: """帶有反向傳播的異步通信和梯度累積融合的線性層實(shí)現(xiàn). 此函數(shù)提供了一個(gè)選項(xiàng),可以將反向傳播計(jì)算的結(jié)果累積到一個(gè)現(xiàn)有的梯度緩沖區(qū)中, 從而避免在梯度計(jì)算后進(jìn)行額外的加法核操作。 此外,輸入梯度的張量并行allreduce可以與權(quán)重梯度的計(jì)算異步進(jìn)行。 在使用序列并行的情況下,輸入梯度的reducescatter與權(quán)重梯度的計(jì)算異步進(jìn)行。 使用此模塊需要環(huán)境變量CUDA_DEVICE_MAX_CONNECTIONS=1。代碼中有一些集合操作, 應(yīng)該在計(jì)算核之前調(diào)度,以使通信與計(jì)算重疊,這對于加速是必要的,但對于正確性則不是必要的, 因此調(diào)度器不會(huì)強(qiáng)制這種排序。將CUDA_DEVICE_MAX_CONNECTIONS設(shè)置為1會(huì)強(qiáng)制按照它們被調(diào)用的順序調(diào)度內(nèi)核。 Arguments: input(torch.Tensorrequired):輸入,類似torch.nn.functional.linear weight(torch.Tensorrequired):權(quán)重,類似torch.nn.functional.linear bias(torch.Tensoroptional):偏置,類似torch.nn.functional.linear gradient_accumulation_fusion(boolrequired):執(zhí)行梯度累積融合, 需要自定義的CUDA擴(kuò)展模塊fused_weight_gradient_mlp_cuda。 要使用gradient_accumulation_fusion,你必須使用--cpp_ext和--cuda_ext安裝APEX。 例如:"pipinstall--global-option="--cpp_ext"--global-option="--cuda_ext." 注意,此擴(kuò)展要求CUDA版本大于或等于11。否則,你必須關(guān)閉梯度累積融合。 async_grad_allreduce(boolrequired):異步地與權(quán)重梯度的計(jì)算進(jìn)行輸入梯度的allreduce。 如果sequence_parallel_enabled為True,這必須為False,因?yàn)椴粓?zhí)行allreduce。 sequence_parallel_enabled(boolrequired):表示使用了序列并行, 因此在前向傳播中,輸入是addgather后的,在反向傳播中,輸入梯度是reducescatter后的。 """ #這部分創(chuàng)建了一個(gè)名為args的列表,它基本上是函數(shù)輸入?yún)?shù)的集合。 args=[ input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce, sequence_parallel_enabled, ] #這部分檢查是否已經(jīng)發(fā)出警告。函數(shù)使用一個(gè)類級(jí)別變量warned來記住是否已經(jīng)向用戶顯示了警告。 ifnotlinear_with_grad_accumulation_and_async_allreduce.warned: #這部分檢查環(huán)境變量CUDA_DEVICE_MAX_CONNECTIONS是否設(shè)置為"1"。 #如果沒有,并且滿足某些條件(sequence_parallel_enabled或async_grad_allreduce), #它會(huì)發(fā)出警告。然后將warned標(biāo)志設(shè)置為True,以便不會(huì)重復(fù)發(fā)出此警告。 ifos.environ.get('CUDA_DEVICE_MAX_CONNECTIONS')!="1": ifsequence_parallel_enabled: warnings.warn( "Whenusingsequenceparallelismitisrecommendedtosetthe" "environmentvariableCUDA_DEVICE_MAX_CONNECTIONSto1for" "maximumspeedup") linear_with_grad_accumulation_and_async_allreduce.warned=True ifasync_grad_allreduce: warnings.warn( "Whenusingasyncgradallreduceitisrecommendedtosetthe" "environmentvariableCUDA_DEVICE_MAX_CONNECTIONSto1for" "maximumspeedup") linear_with_grad_accumulation_and_async_allreduce.warned=True #最后,函數(shù)調(diào)用另一個(gè)名為LinearWithGradAccumulationAndAsyncCommunication的類并返回其結(jié)果。 returnLinearWithGradAccumulationAndAsyncCommunication.apply(*args) #在函數(shù)外部,初始化屬性warned為False。這用于檢查是否已經(jīng)向用戶發(fā)出警告。 linear_with_grad_accumulation_and_async_allreduce.warned=False
解著解析一下LinearWithGradAccumulationAndAsyncCommunication這個(gè)類的實(shí)現(xiàn)(https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/layers.py#L232):
#這定義了一個(gè)名為LinearWithGradAccumulationAndAsyncCommunication的類, #該類繼承自torch.autograd.Function。 classLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): """Seelinear_with_grad_accumulation_and_async_allreduce""" #使用兩個(gè)裝飾器標(biāo)記forward方法。其中@staticmethod表示這是一個(gè)靜態(tài)方法, #而@custom_fwd是一個(gè)自定義裝飾器,用于特定的前向傳播操作。 @staticmethod @custom_fwd defforward( ctx, input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce, sequence_parallel, ): #使用上下文對象ctx保存輸入和權(quán)重,以便在后向傳播中使用。 ctx.save_for_backward(input,weight) #在上下文對象ctx中存儲(chǔ)其他變量和標(biāo)志。 ctx.use_bias=biasisnotNone ctx.gradient_accumulation_fusion=gradient_accumulation_fusion ctx.async_grad_allreduce=async_grad_allreduce ctx.sequence_parallel=sequence_parallel #如果啟用了序列并行,則進(jìn)行以下操作: ifsequence_parallel: #獲取模型并行的world_size(通常是參與并行處理的GPU數(shù)量)。 world_size=get_tensor_model_parallel_world_size() #更改輸入的第一個(gè)維度以考慮模型并行的全部大小。 dim_size=list(input.size()) dim_size[0]=dim_size[0]*world_size #收集所有GPU上的輸入。 all_gather_buffer=get_global_memory_buffer().get_tensor(dim_size,input.dtype,"mpu") torch.distributed._all_gather_base( all_gather_buffer,input,group=get_tensor_model_parallel_group() ) #更新total_input為收集的數(shù)據(jù)。 total_input=all_gather_buffer else: #如果不使用序列并行,則total_input僅僅是傳入的輸入。 total_input=input #對total_input和weight的轉(zhuǎn)置進(jìn)行矩陣乘法以計(jì)算輸出。 output=torch.matmul(total_input,weight.t()) #如果提供了偏置,則將其添加到輸出中 ifbiasisnotNone: output=output+bias returnoutput @staticmethod @custom_bwd defbackward(ctx,grad_output): #從上下文對象中恢復(fù)前向傳播保存的張量。 input,weight=ctx.saved_tensors #從上下文對象中恢復(fù)偏置使用的信息。 use_bias=ctx.use_bias #如果啟用了序列并行,要如何獲取完整的輸入數(shù)據(jù)。 #它通過分布式的_all_gather_base函數(shù)來異步地聚集所有輸入。 ifctx.sequence_parallel: world_size=get_tensor_model_parallel_world_size() dim_size=list(input.size()) dim_size[0]=dim_size[0]*world_size all_gather_buffer=get_global_memory_buffer().get_tensor(dim_size,input.dtype,"mpu") handle=torch.distributed._all_gather_base( all_gather_buffer,input,group=get_tensor_model_parallel_group(),async_op=True ) #HerewerelyonCUDA_DEVICE_MAX_CONNECTIONS=1toensurethatthe #gatherisscheduledbeforetheinputgradientcomputation total_input=all_gather_buffer #如果沒有啟用序列并行,那么完整的輸入就是原始輸入。 else: total_input=input #通過矩陣乘法計(jì)算關(guān)于輸入的梯度。 grad_input=grad_output.matmul(weight) #如果啟用了序列并行,則等待所有聚集操作完成。 ifctx.sequence_parallel: handle.wait() #Doinggather+slicingduringtheNeMoforwardpasscanmakethistensor #notbecontiguous.PyTorchonlychecksifthetensoriscontiguous,andonly #clonesitifit'snotcontiguous: #https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761 #這些是注釋,提到在NeMo的前向傳遞中,執(zhí)行g(shù)ather和slicing操作可能會(huì)導(dǎo)致grad_output張量 #不是連續(xù)的。PyTorch只檢查張量是否是連續(xù)的,并且只在不連續(xù)時(shí)克隆它。 grad_output=grad_output.contiguous()#確保grad_output是連續(xù)的 #Convertthetensorshapesto2Dforexecutioncompatibility #將grad_output張量的形狀轉(zhuǎn)化為2D,以確保兼容性。 grad_output=grad_output.view( grad_output.shape[0]*grad_output.shape[1],grad_output.shape[2] ) #同樣地,將total_input張量也轉(zhuǎn)化為2D。 total_input=total_input.view( total_input.shape[0]*total_input.shape[1],total_input.shape[2] ) #如果啟用了異步的梯度all-reduce,執(zhí)行該操作。這是一個(gè)分布式操作,用于聚合所有工作節(jié)點(diǎn)上的梯度。 ifctx.async_grad_allreduce: #Asynchronousall-reduce handle=torch.distributed.all_reduce( grad_input,group=get_tensor_model_parallel_group(),async_op=True ) #HerewerelyonCUDA_DEVICE_MAX_CONNECTIONS=1toensurethatthe #all-reduceisscheduledbeforetheweightgradientcomputation #如果啟用了序列并行,則不應(yīng)該在此處啟用異步all-reduce(由assert語句確保)。 #接著,創(chuàng)建一個(gè)新的sub_grad_input張量,并執(zhí)行一個(gè)reduce_scatter操作。 #這是一個(gè)分布式操作,它會(huì)將輸入的梯度從所有工作節(jié)點(diǎn)上聚合到一個(gè)工作節(jié)點(diǎn)上。 ifctx.sequence_parallel: assertnotctx.async_grad_allreduce dim_size=list(input.size()) sub_grad_input=torch.empty( dim_size,dtype=input.dtype,device=torch.cuda.current_device(),requires_grad=False ) #reduce_scatter handle=torch.distributed._reduce_scatter_base( sub_grad_input,grad_input,group=get_tensor_model_parallel_group(),async_op=True ) #HerewerelyonCUDA_DEVICE_MAX_CONNECTIONS=1toensurethatthe #reducescatterisscheduledbeforetheweightgradientcomputation #根據(jù)是否啟用了梯度累積融合,使用特定的CUDA操作或標(biāo)準(zhǔn)的矩陣乘法來計(jì)算權(quán)重的梯度。 #這個(gè)條件檢查是否啟用了梯度累積融合。梯度累積通常在小批量訓(xùn)練中用于累積梯度以在較大的有效批量上更新模型。 ifctx.gradient_accumulation_fusion: ifweight.main_grad.dtype==torch.float32: fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32( total_input,grad_output,weight.main_grad ) elifweight.main_grad.dtypein(torch.float16,torch.bfloat16): fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16( total_input,grad_output,weight.main_grad ) else: raiseRuntimeError("Unsupportedgradienttypeforgradientaccumulationfusion") #在梯度累積融合的情況下,設(shè)置grad_weight為None, #這意味著梯度已經(jīng)在前面的CUDA函數(shù)中直接更新了(weight.main_grad),所以在這里沒有返回值。 grad_weight=None else: grad_weight=grad_output.t().matmul(total_input) #如果使用偏置,則計(jì)算關(guān)于偏置的梯度。 grad_bias=grad_output.sum(dim=0)ifuse_biaselseNone #如果啟用了序列并行,等待上述操作完成,并返回計(jì)算得到的梯度。 ifctx.sequence_parallel: handle.wait() returnsub_grad_input,grad_weight,grad_bias,None,None,None #如果啟用了異步all-reduce,等待all-reduce操作完成。 ifctx.async_grad_allreduce: handle.wait() returngrad_input,grad_weight,grad_bias,None,None,None
可以看到gradient_accumulation_fusion這個(gè)優(yōu)化作用于Linear層中對weight求梯度的時(shí)候,調(diào)用了apex庫提供的2個(gè)fuse cuda kernel原地更新了weight的梯度。
0x2. fused_weight_gradient_mlp_cuda 實(shí)現(xiàn)
fused_weight_gradient_mlp_cuda接口分別為float32和float16/bfloat16提供了2個(gè)cuda kernel實(shí)現(xiàn),我們先看一下上層的接口。(https://github.com/NVIDIA/apex/blob/master/csrc/megatron/fused_weight_gradient_dense.cpp)
//定義了一個(gè)名為wgrad_gemm_accum_fp32_cuda_stub的函數(shù)原型。這是一個(gè)CUDAC++函數(shù), //用于處理float32數(shù)據(jù)類型的權(quán)重梯度累積。該函數(shù)接受三個(gè)at::Tensor參數(shù): //input_2d,d_output_2d,和d_weight。 voidwgrad_gemm_accum_fp32_cuda_stub( at::Tensor&input_2d, at::Tensor&d_output_2d, at::Tensor&d_weight ); //定義了一個(gè)名為wgrad_gemm_accum_fp16_cuda_stub的函數(shù)原型,與上面的函數(shù)類似, //但它是為float16數(shù)據(jù)類型設(shè)計(jì)的。 voidwgrad_gemm_accum_fp16_cuda_stub( at::Tensor&input_2d, at::Tensor&d_output_2d, at::Tensor&d_weight ); PYBIND11_MODULE(TORCH_EXTENSION_NAME,m){ m.def("wgrad_gemm_accum_fp32",&wgrad_gemm_accum_fp32_cuda_stub,"wgradgemmaccuminfp32"); m.def("wgrad_gemm_accum_fp16",&wgrad_gemm_accum_fp16_cuda_stub,"wgradgemmaccuminfp16"); }
接下來解析一下wgrad_gemm_accum_fp32這個(gè)kernel,對應(yīng) https://github.com/NVIDIA/apex/blob/master/csrc/megatron/fused_weight_gradient_dense_cuda.cu 這個(gè)文件。
//這個(gè)函數(shù)是一個(gè)封裝了NVIDIAcuBLAS庫中的cublasGemmEx函數(shù)的C++函數(shù), //專門用于執(zhí)行BFloat16(BF16)的矩陣乘法(GEMM)操作。 //函數(shù)的名稱為gemmex_wrapper,它的設(shè)計(jì)意圖是提供一個(gè)簡單的接口, //使得PyTorch可以方便地利用cuBLAS中的高效GEMM操作,特別是當(dāng)使用BFloat16數(shù)據(jù)類型時(shí)。 //BF16TensorcorewrapperaroundcublasGEMMEx voidgemmex_wrapper( cublasHandle_thandle,//cuBLAS庫的句柄,用于管理cuBLAS調(diào)用。 cublasOperation_ttransa, cublasOperation_ttransb,//這兩個(gè)參數(shù)描述了兩個(gè)輸入矩陣A和B是否需要轉(zhuǎn)置。 //定義了矩陣A,B和輸出矩陣C的維度。具體來說,矩陣A的維度為mxk, //矩陣B的維度為kxn,輸出矩陣C的維度為mxn。 intm, intn, intk, constfloat*alpha,//標(biāo)量系數(shù),用于計(jì)算alpha*A*B。 at::BFloat16*A,//輸入矩陣A,它們都是BFloat16數(shù)據(jù)類型。 intlda,//這個(gè)參數(shù)是矩陣A的leadingdim,通常與矩陣的行數(shù)相同。 at::BFloat16*B, intldb, constfloat*beta,//標(biāo)量系數(shù),用于計(jì)算beta*C。 float*C,//輸出矩陣C,它是float數(shù)據(jù)類型。 intldc){//矩陣C的leading維度,通常與矩陣C的行數(shù)相同。 //使用TORCH_CUDABLAS_CHECK宏調(diào)用了cublasGemmEx函數(shù)。這是cuBLAS庫中用于執(zhí)行混合精度矩陣乘法的函數(shù)。 //cublasGemmEx函數(shù)的參數(shù)主要用于描述輸入和輸出矩陣的屬性,以及要執(zhí)行的具體操作。 //在這里,輸入矩陣A和B都是BFloat16數(shù)據(jù)類型,而輸出矩陣C是float數(shù)據(jù)類型。 //CUDA_R_16BF和CUDA_R_32F是枚舉值,用于描述矩陣的數(shù)據(jù)類型。 //CUBLAS_GEMM_DEFAULT_TENSOR_OP是一個(gè)枚舉值,指示cuBLAS使用默認(rèn)的TensorCore操作來執(zhí)行GEMM。 TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, transa, transb, m, n, k, alpha, A, CUDA_R_16BF, lda, B, CUDA_R_16BF, ldb, beta, C, CUDA_R_32F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } //類似上面的函數(shù),用于執(zhí)行FP16的矩陣乘法 //FP16TensorcorewrapperaroundcublasGEMMEx voidgemmex_wrapper( cublasHandle_thandle, cublasOperation_ttransa, cublasOperation_ttransb, intm, intn, intk, constfloat*alpha, at::Half*A, intlda, at::Half*B, intldb, constfloat*beta, float*C, intldc){ TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, transa, transb, m, n, k, alpha, A, CUDA_R_16F, lda, B, CUDA_R_16F, ldb, beta, C, CUDA_R_32F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } //類似上面的函數(shù),用于執(zhí)行FP32的矩陣乘法 //FP32wrapperaroundcublasGEMMEx voidgemmex_wrapper( cublasHandle_thandle, cublasOperation_ttransa, cublasOperation_ttransb, intm, intn, intk, constfloat*alpha, float*A, intlda, float*B, intldb, constfloat*beta, float*C, intldc){ TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, transa, transb, m, n, k, alpha, A, CUDA_R_32F, lda, B, CUDA_R_32F, ldb, beta, C, CUDA_R_32F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } //這個(gè)函數(shù)wgrad_gemm_accum_fp32_cuda是一個(gè)模板函數(shù),用于在CUDA上執(zhí)行累加的權(quán)重梯度計(jì)算(矩陣乘法)。 //它使用了前面提到的gemmex_wrapper函數(shù),該函數(shù)是NVIDIAcuBLAS庫中的cublasGemmEx函數(shù)的封裝, //用于執(zhí)行高效的矩陣乘法。 template voidwgrad_gemm_accum_fp32_cuda(T*input,T*d_output,float*d_weight,intin_dim,inthidden_dim,intout_dim){ //獲取當(dāng)前CUDAcuBLAS句柄。 cublasHandle_thandle=at::getCurrentCUDABlasHandle(); //獲取CUDAStream。 cudaStream_tstream; //從cuBLAS句柄獲取當(dāng)前CUDA流。 cublasGetStream(handle,&stream); //定義矩陣乘法的標(biāo)量系數(shù),用于計(jì)算alpha*A*B+beta*C。 constfloatalpha=1.0; constfloatbeta=1.0; //使用CUBLAS_OP_N和CUBLAS_OP_T作為參數(shù),表示輸入矩陣不需要轉(zhuǎn)置,但d_output矩陣需要轉(zhuǎn)置。 //使用輸入矩陣input和輸出矩陣的梯度d_output作為輸入,將結(jié)果存儲(chǔ)在權(quán)重梯度d_weight中。 gemmex_wrapper( handle, CUBLAS_OP_N, CUBLAS_OP_T, in_dim, out_dim, hidden_dim, &alpha, input, in_dim, d_output, out_dim, &beta, d_weight, in_dim); } //這是為數(shù)據(jù)類型at::Half(即半精度浮點(diǎn)型,也稱為FP16)顯式實(shí)例化的wgrad_gemm_accum_fp32_cuda函數(shù)。 //使用此數(shù)據(jù)類型的版本,可以進(jìn)行更快速的計(jì)算,尤其是在支持FP16計(jì)算的硬件上。 templatevoidwgrad_gemm_accum_fp32_cuda(at::Half*input,at::Half*d_output,float*d_weight,intin_dim,inthidden_dim,intout_dim); templatevoidwgrad_gemm_accum_fp32_cuda(at::BFloat16*input,at::BFloat16*d_output,float*d_weight,intin_dim,inthidden_dim,intout_dim); templatevoidwgrad_gemm_accum_fp32_cuda(float*input,float*d_output,float*d_weight,intin_dim,inthidden_dim,intout_dim); //這個(gè)函數(shù)名為wgrad_gemm_accum_fp32_cuda_stub,從名字中可以看出這是一個(gè)為CUDA定義的存根函數(shù)。 //它處理輸入的張量,調(diào)整它們的維度,然后調(diào)用對應(yīng)的CUDA模板函數(shù)來完成具體的操作。 voidwgrad_gemm_accum_fp32_cuda_stub( at::Tensor&input, at::Tensor&d_output, at::Tensor&d_weight ){ at::Tensorinput_2d,d_output_2d; //inputtensor:collapsetothefirstdim autoin_sizes=input.sizes(); //如果input張量的維度大于2,它將最后一個(gè)維度以外的所有維度折疊為第一個(gè)維度, //使其成為一個(gè)2D張量input_2d。否則,它將使用原始input張量。 if(input.dim()>2){ input_2d=input.view({-1,in_sizes[in_sizes.size()-1]}); }else{ input_2d=input; } //d_outputtensor:collapsetothefirstdim //類似地,如果d_output張量的維度大于2,它也會(huì)進(jìn)行同樣的維度轉(zhuǎn)換。 //否則,它會(huì)使用原始的d_output張量。 autod_out_sizes=d_output.sizes(); if(d_output.dim()>2){ d_output_2d=d_output.view({-1,d_out_sizes[d_out_sizes.size()-1]}); }else{ d_output_2d=d_output; } //hidden_dim是input_2d的第一個(gè)維度的大小。 constinthidden_dim=input_2d.size(0); //in_dim是input_2d的第二個(gè)維度的大小。 constintin_dim=input_2d.size(1); //out_dim是d_weight的第一個(gè)維度的大小。 constintout_dim=d_weight.size(0); //使用DISPATCH_FLOAT_HALF_AND_BFLOAT宏來基于input_2d的數(shù)據(jù)類型調(diào)用相應(yīng)的函數(shù)。 //這意味著,根據(jù)輸入數(shù)據(jù)的數(shù)據(jù)類型(浮點(diǎn)、半精度或BFloat16), //它將選擇正確的版本的wgrad_gemm_accum_fp32_cuda函數(shù)進(jìn)行調(diào)用。 DISPATCH_FLOAT_HALF_AND_BFLOAT(input_2d.scalar_type(),0,"wgrad_gemm_accum_fp32", wgrad_gemm_accum_fp32_cuda( input_2d.data_ptr(), d_output_2d.data_ptr(), d_weight.data_ptr(), in_dim, hidden_dim, out_dim); ); }
注意,在Kernel中這里會(huì)將當(dāng)前的結(jié)果累加到先前計(jì)算的梯度上,所有這些都在一個(gè)操作中完成,這是fuse的思想,可以避免多次訪問global memory提升算子的帶寬。
審核編輯:彭菁
-
邏輯
+關(guān)注
關(guān)注
2文章
833瀏覽量
29464 -
異步通信
+關(guān)注
關(guān)注
1文章
57瀏覽量
10124 -
函數(shù)
+關(guān)注
關(guān)注
3文章
4327瀏覽量
62569 -
大模型
+關(guān)注
關(guān)注
2文章
2423瀏覽量
2640
原文標(biāo)題:0x3. 總結(jié)
文章出處:【微信號(hào):GiantPandaCV,微信公眾號(hào):GiantPandaCV】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
相關(guān)推薦
評論