OpenAI/Triton MLIR 第四章: ROCm-triton配置
最近在整理python-based的benchmark代碼,反過來在NV的GPU上又把Triton裝了一遍,發(fā)現(xiàn)Triton的github repo已經(jīng)給出了對應的llvm的commit id以及對應的編譯細節(jié),然后跟著走了一遍,也順利的安裝成功,只需要按照如下方式即可完成NV GPU上的安裝,
1.gitclonehttps://github.com/openai/triton.git; 2.cdtriton; 3.cd$HOME/llvm-project#yourcloneofLLVM. 4.gitcheckout49af6502 5.mkdirbuild 6.cdbuild 7.cmake-GNinja-DCMAKE_BUILD_TYPE=Release-DLLVM_ENABLE_ASSERTIONS=ON../llvm-DLLVM_ENABLE_PROJECTS="mlir;llvm" 8.ninja-j8 exportLLVM_BUILD_DIR=$HOME/llvm-project/build cdLLVM_INCLUDE_DIRS=$LLVM_BUILD_DIR/include LLVM_LIBRARY_DIR=$LLVM_BUILD_DIR/lib LLVM_SYSPATH=$LLVM_BUILD_DIR pipinstall-epython
出現(xiàn)3.0.0說明triton已經(jīng)安裝成功了,裝完triton后一定要安裝Torch,為個人使用的是CUDA 12.1版本,按照下面的命令無腦安裝即可。
pipinstalltorch==2.1.2torchvision==0.16.2torchaudio==2.1.2--index-urlhttps://download.pytorch.org/whl/cu121
NV GPU上triton的安裝和使用其實已經(jīng)輕車熟路了,接下來,讓我們來探索一下AMD GPU上如何安裝和配置triton。
0x00 軟件安裝
關于triton amd的backend,雖然triton的官方將其作為third-party來進行支持,但是我還是推薦大家使用AMD專門維護的一套triton版本,因為在最開始的官方triton的main分支下,開啟 TRITON_CODEGEN_AMD_HIP_BACKEND=1 沒有正確完成編譯。所以找到了
按照對應的安裝流程進行安裝即可,我推薦使用如下命令進行安裝,親測有效
1.gitclonehttps://github.com/ROCmSoftwarePlatform/triton.git 2.cdtriton 3.gitcheckouttriton-mlir
這里已經(jīng)準備好了需要編譯的triton,但是triton后端是基于LLVM的,所以要想借助triton去生成可以跑在對應設備上的代碼,我們還需要對LLVM進行編譯,本教程中將會手動編譯LLVM,當然如果你選擇直接編譯好的LLVM也是沒有問題的。關于LLVM,由于triton是基于b1115f8c這個commit id進行開發(fā)的,那么我們只需要將LLVM clone下來后,checkout到對應的commit id,然后按照如下完整命令進行編譯即可。
1.gitclonehttps://github.com/llvm/llvm-project 2.gitcheckoutb1115f8c 3.cdllvm-project 4.mkdirbuild 5.cdbuild 6.cmake-GNinja-DCMAKE_BUILD_TYPE=Release-DLLVM_ENABLE_ASSERTIONS=ON../llvm-DLLVM_ENABLE_PROJECTS="mlir;llvm" 7.ninja-j8
等LLVM全部裝好后,就可以去將當前這個LLVM的路徑寫入到你的bashrc下
exportPATH=/home/llvm-project/build/bin:$PATH
然后進入到一開始clone下來的triton目錄下進行如下命令
1.cdtriton 2.vimCMakeLists.txt(option(TRITON_BUILD_PYTHON_MODULE"BuildPythonTritonbindings"ON)) 3.mkdirbuild 4.cdbuild 5.cmake.. 6.make-j8
在編譯完全正確后,就會在當前的 build 目錄下產(chǎn)生一個 libtriton.so 文件。那么接下來只要將
libtriton.so 文件移動到 triton/python/triton/_C 目錄下,將 triton 的 python 路徑下入 bashrc
exportTRITON_HOME=/home/Documents/compiler/triton exportPYTHONPATH=$TRITON_HOME/python:${PYTHONPATH}
如果在編譯的過程中出現(xiàn) goolge test 找不到的情況,按照如下命令進行安裝:
1.gitclonehttps://github.com/google/googletest 2.cdgoogletest 3.cmakeCMakeLists.txt 4.make-j8 5.cp./lib/libgtest*.a/usr/lib 6.cdgoogletest 7.cp–ainclude/gtest/usr/include
如果在編譯的過程中出現(xiàn) pybind11 找不到的情況,按照如下命令進行按照:
1.pipinstallpytest 2.gitclonehttps://github.com/pybind/pybind11.git 3.cdpybind11 4.mkdirbuild 5.cdbuild 6.cmake.. 7.makecheck-j8 8.sudomakeinstal
關于 在AMD GPU上的pytorch 一定要去安裝適配 ROCM 版本的 pytorch,由于我的機器使用的是5.6版本的ROCm,所以我的安裝的命令如下,僅供參考:
pip3installtorch==2.1.0torchvision==0.16.0torchaudio==2.1.0--index-url https://download.pytorch.org/whl/rocm5.6
關于 ROCM 版本可以通過如下命令進行查詢:
dpkg-l|greprocm
這里要記住,pytorch在AMD GPU上的使用和在NV GPU上的使用非常相似,也是用.cuda()來指定變量所在位置。
0x01 GEMM代碼示例
全部編譯好后,就可以通過執(zhí)行下面的代碼得到對應的 GEMM 在 AMD 顯卡上針對 Triton和 rocBLAS 的 benchmark 了。
importtorch importtriton importtriton.languageastl importsys importargparse importpytest #`triton.jit`'edfunctionscanbeauto-tunedbyusingthe`triton.autotune`decorator,whichconsumes: #-Alistof`triton.Config`objectsthatdefinedifferentconfigurationsof #meta-parameters(e.g.,`BLOCK_SIZE_M`)andcompilationoptions(e.g.,`num_warps`)totry #-Anauto-tuning*key*whosechangeinvalueswilltriggerevaluationofallthe #providedconfigs @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':256,'BLOCK_SIZE_K':64,'GROUP_SIZE_M':8},num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':256,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':128,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':64,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':128,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':32,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':32,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M':32,'BLOCK_SIZE_N':64,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=5, num_warps=2), ]iftorch.version.hipisNoneelse[ triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':256,'BLOCK_SIZE_K':16,'GROUP_SIZE_M':1,'waves_per_eu':2}, num_warps=4,num_stages=0), triton.Config({'BLOCK_SIZE_M':256,'BLOCK_SIZE_N':256,'BLOCK_SIZE_K':16,'GROUP_SIZE_M':4,'waves_per_eu':2}, num_warps=8,num_stages=0), triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':128,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':1,'waves_per_eu':2}, num_warps=8,num_stages=0), triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':128,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8,'waves_per_eu':3}, num_warps=4,num_stages=0), triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':64,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':1,'waves_per_eu':8}, num_warps=4,num_stages=0), ], key=['M','N','K'], ) @triton.heuristics({ 'EVEN_K':lambdaargs:args['K']%args['BLOCK_SIZE_K']==0, }) @triton.jit defmatmul_kernel( #Pointerstomatrices a_ptr,b_ptr,c_ptr, #Matrixdimensions M,N,K, #Thestridevariablesrepresenthowmuchtoincreasetheptrbywhenmovingby1 #elementinaparticulardimension.E.g.`stride_am`ishowmuchtoincrease`a_ptr` #bytogettheelementonerowdown(AhasMrows). stride_am,stride_ak, stride_bk,stride_bn, stride_cm,stride_cn, #Meta-parameters BLOCK_SIZE_M:tl.constexpr,BLOCK_SIZE_N:tl.constexpr,BLOCK_SIZE_K:tl.constexpr, EVEN_K:tl.constexpr, GROUP_SIZE_M:tl.constexpr, ACTIVATION:tl.constexpr, ): """KernelforcomputingthematmulC=AxB. Ahasshape(M,K),Bhasshape(K,N)andChasshape(M,N) """ #----------------------------------------------------------- #Mapprogramids`pid`totheblockofCitshouldcompute. #ThisisdoneinagroupedorderingtopromoteL2datareuse. #Seeabove`L2CacheOptimizations`sectionfordetails. pid=tl.program_id(axis=0) num_pid_m=tl.cdiv(M,BLOCK_SIZE_M) num_pid_n=tl.cdiv(N,BLOCK_SIZE_N) ifGROUP_SIZE_M==1: pid_m=pid//num_pid_n pid_n=pid%num_pid_n else: num_pid_in_group=GROUP_SIZE_M*num_pid_n group_id=pid//num_pid_in_group first_pid_m=group_id*GROUP_SIZE_M group_size_m=min(num_pid_m-first_pid_m,GROUP_SIZE_M) pid_m=first_pid_m+(pid%group_size_m) pid_n=(pid%num_pid_in_group)//group_size_m #---------------------------------------------------------- #CreatepointersforthefirstblocksofAandB. #WewilladvancethispointeraswemoveintheKdirection #andaccumulate #`a_ptrs`isablockof[BLOCK_SIZE_M,BLOCK_SIZE_K]pointers #`b_ptrs`isablockof[BLOCK_SIZE_K,BLOCK_SIZE_N]pointers #Seeabove`PointerArithmetics`sectionfordetails offs_k=tl.arange(0,BLOCK_SIZE_K) offs_am=(pid_m*BLOCK_SIZE_M+tl.arange(0,BLOCK_SIZE_M))%M offs_bn=(pid_n*BLOCK_SIZE_N+tl.arange(0,BLOCK_SIZE_N))%N a_ptrs=a_ptr+(offs_am[:,None]*stride_am+offs_k[None,:]*stride_ak) b_ptrs=b_ptr+(offs_k[:,None]*stride_bk+offs_bn[None,:]*stride_bn) #----------------------------------------------------------- #IteratetocomputeablockoftheCmatrix. #Weaccumulateintoa`[BLOCK_SIZE_M,BLOCK_SIZE_N]`block #offp32valuesforhigheraccuracy. #`accumulator`willbeconvertedbacktofp16aftertheloop. accumulator=tl.zeros((BLOCK_SIZE_M,BLOCK_SIZE_N),dtype=tl.float32) forkinrange(0,tl.cdiv(K,BLOCK_SIZE_K)): #LoadthenextblockofAandB,generateamaskbycheckingtheKdimension. #Ifitisoutofbounds,setitto0. ifEVEN_K: a=tl.load(a_ptrs) b=tl.load(b_ptrs) else: a=tl.load(a_ptrs,mask=offs_k[None,:]=0,x,0.01*x) #%% #Wecannowcreateaconveniencewrapperfunctionthatonlytakestwoinputtensors, #and(1)checksanyshapeconstraint;(2)allocatestheoutput;(3)launchestheabovekernel. defmatmul(a,b,activation=""): #Checkconstraints. asserta.shape[1]==b.shape[0],"Incompatibledimensions" asserta.is_contiguous(),"MatrixAmustbecontiguous" assertb.is_contiguous(),"MatrixBmustbecontiguous" M,K=a.shape K,N=b.shape #Allocatesoutput. c=torch.empty((M,N),device=a.device,dtype=a.dtype) #1Dlaunchkernelwhereeachblockgetsitsownprogram. grid=lambdaMETA:(triton.cdiv(M,META['BLOCK_SIZE_M'])*triton.cdiv(N,META['BLOCK_SIZE_N']),) matmul_kernel[grid]( a,b,c,# M,N,K,# a.stride(0),a.stride(1),# b.stride(0),b.stride(1),# c.stride(0),c.stride(1),# ACTIVATION=activation# ) returnc #%% #UnitTest #--------- # #Wecantestourcustommatrixmultiplicationoperationagainstanativetorchimplementation(i.e.,cuBLAS). @pytest.mark.parametrize("M,N,K,in_dtype,out_dtype", [(*shape,in_dtype,out_dtype) forshapein[(128,256,32),(128,16,32),(32,128,64), (128,128,64),(64,128,128),(32,128,64), (64,64,32),(32,32,128),(128,128,64), (64,128,128),(512,512,512),(1024,1024,1024)] forin_dtype,out_dtypein[('int8','int8'), ('float16','float16'), ('bfloat16','bfloat16'), ('float16','float32'), ('float32','float32')]] ) deftest_correctness(M,N,K,in_dtype,out_dtype): torch.manual_seed(0) a=torch.randn((M,K),device='cuda',dtype=torch.float16) b=torch.randn((K,N),device='cuda',dtype=torch.float16) triton_output=matmul(a,b) torch_output=torch.matmul(a,b) print(f"triton_output={triton_output}") print(f"torch_output={torch_output}") rtol=0iftorch.version.hipisNoneelse1e-2 iftorch.allclose(triton_output,torch_output,atol=1e-2,rtol=rtol): print("TritonandTorchmatch") else: print("TritonandTorchdiffer") asserttorch.allclose(triton_output,torch_output,atol=1e-2,rtol=rtol) #%% #Benchmark #--------- # #SquareMatrixPerformance #~~~~~~~~~~~~~~~~~~~~~~~~~~ # #WecannowcomparetheperformanceofourkernelagainstthatofcuBLAS.Herewefocusonsquarematrices, #butfeelfreetoarrangethisscriptasyouwishtobenchmarkanyothermatrixshape. globalverbose verbose=False @triton.testing.perf_report( triton.testing.Benchmark( x_names=['M','N','K'],#Argumentnamestouseasanx-axisfortheplot x_vals=[ (1024,1024,1024), (2048,2048,2048), (4096,4096,4096), (8192,8192,8192), (9728,8192,65536) ],#Differentpossiblevaluesfor`x_name` line_arg='provider',#Argumentnamewhosevaluecorrespondstoadifferentlineintheplot #Possiblevaluesfor`line_arg` line_vals=['rocblas','triton'], #Labelnameforthelines line_names=["rocBLAS","Triton"], #Linestyles styles=[('green','-'),('blue','-')], ylabel="TFLOPS",#Labelnameforthey-axis plot_name="matmul-performance",#Namefortheplot,usedalsoasafilenameforsavingtheplot. args={}, )) defbenchmark(M,N,K,provider): a=torch.randn((M,K),device='cuda',dtype=torch.float16) b=torch.randn((K,N),device='cuda',dtype=torch.float16) quantiles=[0.5,0.2,0.8] ifprovider=='rocblas': ms,min_ms,max_ms=triton.testing.do_bench(lambda:torch.matmul(a,b),quantiles=quantiles) ifprovider=='triton': ms,min_ms,max_ms=triton.testing.do_bench(lambda:matmul(a,b),quantiles=quantiles) globalverbose ifverbose: print(f'SIZE:{M},{N},{K}Besttuningconfig:({matmul_kernel.get_best_config()})') perf=lambdams:2*M*N*K*1e-12/(ms*1e-3) returnperf(ms),perf(max_ms),perf(min_ms) defparse_args(): parser=argparse.ArgumentParser( prog="GEMMtutorialexample", allow_abbrev=False, ) parser.add_argument("-v",action='store_true',default=False,help="Printoutthebesttuningconfig") args=parser.parse_args() returnargs defmain(): #assigntoaglobalverbosevartoindicatewhetherprint #besttuningconfig globalverbose args=parse_args() verbose=args.v benchmark.run(show_plots=True,print_data=True) if__name__=='__main__': sys.exit(main())
0x10 GEMM代碼詳細解讀
首先是對于搜索空間的定義,這里
@triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':256,'BLOCK_SIZE_K':64,'GROUP_SIZE_M':8},num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':256,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':128,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':64,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':128,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':32,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':32,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M':32,'BLOCK_SIZE_N':64,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=5, num_warps=2), ]iftorch.version.hipisNoneelse[ triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':256,'BLOCK_SIZE_K':16,'GROUP_SIZE_M':1,'waves_per_eu':2}, num_warps=4,num_stages=0), triton.Config({'BLOCK_SIZE_M':256,'BLOCK_SIZE_N':256,'BLOCK_SIZE_K':16,'GROUP_SIZE_M':4,'waves_per_eu':2}, num_warps=8,num_stages=0), triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':128,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':1,'waves_per_eu':2}, num_warps=8,num_stages=0), triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':128,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8,'waves_per_eu':3}, num_warps=4,num_stages=0), triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':64,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':1,'waves_per_eu':8}, num_warps=4,num_stages=0), ], key=['M','N','K'], )
其中的torch.version.hip走的就是AMD GPU所對應的搜索空間,我們看到其對應的可以tuning的knob,有最常規(guī)的BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M外,還有了一個新的wave_per_eu,我一開始看到這個概念的時候也很陌生,隨后和AMD的技術人員請教了下,總結(jié)下來就是:
AMD GPU由計算單元(CU)組成,這相當于NVIDIA GPU上的流處理器(SM)。在每個CU中,有4個SIMD單元(也稱執(zhí)行引擎或EU)。你可以把SIMD單元看成是一個矢量執(zhí)行單元,它具有執(zhí)行計算所需的一定數(shù)量的寄存器和ALUs。當你發(fā)起一個計算網(wǎng)格時,工作組(相當于NVIDIA GPU上的線程塊)會安排在CU上運行。
在CU中,波前(相當于NVIDIA GPU上的波紋)會安排在SIMD單元上運行。這里提出了occupancy的概念,它表示每個SIMD單元上可同時運行的波前數(shù)。這取決于每個波前需要的資源量和每個SIMD單元的資源量。waves_per_eu參數(shù)重點關注寄存器使用情況。例如,每個SIMD(EU)有512個寄存器。
如果每個波前需要256個寄存器,那么occupancy為2。但如果我們設置waves_per_eu=3,編譯器會試圖將每個波前的寄存器使用量減少到170,這樣occupancy就可以是3了。但是提高waves_per_eu存在寄存器溢出的風險和性能下降。所以增加waves_per_eu可能會增加occupancy,但不一定能提高性能。
然后是具體的kernel定義,這部分的定義其實和NV GPU上的寫法沒有本質(zhì)區(qū)別
@triton.jit defmatmul_kernel( #Pointerstomatrices a_ptr,b_ptr,c_ptr, #Matrixdimensions M,N,K, #Thestridevariablesrepresenthowmuchtoincreasetheptrbywhenmovingby1 #elementinaparticulardimension.E.g.`stride_am`ishowmuchtoincrease`a_ptr` #bytogettheelementonerowdown(AhasMrows). stride_am,stride_ak, stride_bk,stride_bn, stride_cm,stride_cn, #Meta-parameters BLOCK_SIZE_M:tl.constexpr,BLOCK_SIZE_N:tl.constexpr,BLOCK_SIZE_K:tl.constexpr, EVEN_K:tl.constexpr, GROUP_SIZE_M:tl.constexpr, ACTIVATION:tl.constexpr, ): """KernelforcomputingthematmulC=AxB. Ahasshape(M,K),Bhasshape(K,N)andChasshape(M,N) """ #----------------------------------------------------------- #Mapprogramids`pid`totheblockofCitshouldcompute. #ThisisdoneinagroupedorderingtopromoteL2datareuse. #Seeabove`L2CacheOptimizations`sectionfordetails. pid=tl.program_id(axis=0) num_pid_m=tl.cdiv(M,BLOCK_SIZE_M) num_pid_n=tl.cdiv(N,BLOCK_SIZE_N) ifGROUP_SIZE_M==1: pid_m=pid//num_pid_n pid_n=pid%num_pid_n else: num_pid_in_group=GROUP_SIZE_M*num_pid_n group_id=pid//num_pid_in_group first_pid_m=group_id*GROUP_SIZE_M group_size_m=min(num_pid_m-first_pid_m,GROUP_SIZE_M) pid_m=first_pid_m+(pid%group_size_m) pid_n=(pid%num_pid_in_group)//group_size_m #---------------------------------------------------------- #CreatepointersforthefirstblocksofAandB. #WewilladvancethispointeraswemoveintheKdirection #andaccumulate #`a_ptrs`isablockof[BLOCK_SIZE_M,BLOCK_SIZE_K]pointers #`b_ptrs`isablockof[BLOCK_SIZE_K,BLOCK_SIZE_N]pointers #Seeabove`PointerArithmetics`sectionfordetails offs_k=tl.arange(0,BLOCK_SIZE_K) offs_am=(pid_m*BLOCK_SIZE_M+tl.arange(0,BLOCK_SIZE_M))%M offs_bn=(pid_n*BLOCK_SIZE_N+tl.arange(0,BLOCK_SIZE_N))%N a_ptrs=a_ptr+(offs_am[:,None]*stride_am+offs_k[None,:]*stride_ak) b_ptrs=b_ptr+(offs_k[:,None]*stride_bk+offs_bn[None,:]*stride_bn) #----------------------------------------------------------- #IteratetocomputeablockoftheCmatrix. #Weaccumulateintoa`[BLOCK_SIZE_M,BLOCK_SIZE_N]`block #offp32valuesforhigheraccuracy. #`accumulator`willbeconvertedbacktofp16aftertheloop. accumulator=tl.zeros((BLOCK_SIZE_M,BLOCK_SIZE_N),dtype=tl.float32) forkinrange(0,tl.cdiv(K,BLOCK_SIZE_K)): #LoadthenextblockofAandB,generateamaskbycheckingtheKdimension. #Ifitisoutofbounds,setitto0. ifEVEN_K: a=tl.load(a_ptrs) b=tl.load(b_ptrs) else: a=tl.load(a_ptrs,mask=offs_k[None,:]
接下來是單元測試,用來說明triton的輸出結(jié)果和torch的輸出結(jié)果必須是相同的
deftest_correctness(M,N,K,in_dtype,out_dtype): torch.manual_seed(0) a=torch.randn((M,K),device='cuda',dtype=torch.float16) b=torch.randn((K,N),device='cuda',dtype=torch.float16) triton_output=matmul(a,b) torch_output=torch.matmul(a,b) print(f"triton_output={triton_output}") print(f"torch_output={torch_output}") rtol=0iftorch.version.hipisNoneelse1e-2 iftorch.allclose(triton_output,torch_output,atol=1e-2,rtol=rtol): print("TritonandTorchmatch") else: print("TritonandTorchdiffer") asserttorch.allclose(triton_output,torch_output,atol=1e-2,rtol=rtol)
接下來你只需要指定好對應的GEMM的尺寸,我們的默認輸入順序還是以M,N,K為主,剩下都是中規(guī)中局的操作了。
@triton.testing.perf_report( triton.testing.Benchmark( x_names=['M','N','K'],#Argumentnamestouseasanx-axisfortheplot x_vals=[ (1024,1024,1024), (2048,2048,2048), (4096,4096,4096), (8192,8192,8192), (9728,8192,65536) ],#Differentpossiblevaluesfor`x_name` line_arg='provider',#Argumentnamewhosevaluecorrespondstoadifferentlineintheplot #Possiblevaluesfor`line_arg` line_vals=['rocblas','triton'], #Labelnameforthelines line_names=["rocBLAS","Triton"], #Linestyles styles=[('green','-'),('blue','-')], ylabel="TFLOPS",#Labelnameforthey-axis plot_name="matmul-performance",#Namefortheplot,usedalsoasafilenameforsavingtheplot. args={}, )) defbenchmark(M,N,K,provider): a=torch.randn((M,K),device='cuda',dtype=torch.float16) b=torch.randn((K,N),device='cuda',dtype=torch.float16) quantiles=[0.5,0.2,0.8] ifprovider=='rocblas': ms,min_ms,max_ms=triton.testing.do_bench(lambda:torch.matmul(a,b),quantiles=quantiles) ifprovider=='triton': ms,min_ms,max_ms=triton.testing.do_bench(lambda:matmul(a,b),quantiles=quantiles) globalverbose ifverbose: print(f'SIZE:{M},{N},{K}Besttuningconfig:({matmul_kernel.get_best_config()})') perf=lambdams:2*M*N*K*1e-12/(ms*1e-3) returnperf(ms),perf(max_ms),perf(min_ms) defparse_args(): parser=argparse.ArgumentParser( prog="GEMMtutorialexample", allow_abbrev=False, ) parser.add_argument("-v",action='store_true',default=False,help="Printoutthebesttuningconfig") args=parser.parse_args() returnargs defmain(): #assigntoaglobalverbosevartoindicatewhetherprint #besttuningconfig globalverbose args=parse_args() verbose=args.v benchmark.run(show_plots=True,print_data=True) if__name__=='__main__': sys.exit(main())
關于在AMD GPU上更加自動化的GEMM benchmark調(diào)優(yōu)腳本,我們將在后面的章節(jié)中來為大家進行解讀。
審核編輯:劉清
-
amd
+關注
關注
25文章
5466瀏覽量
134072 -
gpu
+關注
關注
28文章
4729瀏覽量
128878 -
Triton
+關注
關注
0文章
16瀏覽量
7033 -
python
+關注
關注
56文章
4792瀏覽量
84621 -
GPU芯片
+關注
關注
1文章
303瀏覽量
5803 -
pytorch
+關注
關注
2文章
807瀏覽量
13195 -
OpenAI
+關注
關注
9文章
1078瀏覽量
6478
原文標題:OpenAI/Triton MLIR 第四章: ROCm-triton配置
文章出處:【微信號:GiantPandaCV,微信公眾號:GiantPandaCV】歡迎添加關注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
相關推薦
評論