From 8f484d6a5f6322480da9b7aa55efd7697945503b Mon Sep 17 00:00:00 2001 From: Violet Date: Thu, 28 Aug 2025 17:54:07 -0700 Subject: [PATCH] Add support for fp8 to `cvt` (#468) This implements specifically the fp8 conversion instructions needed by llm.c: * `cvt.rn.satfinite{.relu}.f8x2type.f32` * `cvt.rn{.relu}.f16x2.f8x2type` It uses HIP's fp8 and fp16 headers: https://rocm.docs.amd.com/projects/HIP/en/docs-develop/reference/low_fp_types.html#fp8-quarter-precision. --- ptx/lib/zluda_ptx_impl.bc | Bin 12264 -> 15548 bytes ptx/lib/zluda_ptx_impl.cpp | 45 ++++- ptx/src/pass/llvm/emit.rs | 30 ++- ptx/src/pass/llvm/mod.rs | 10 +- ptx/src/pass/mod.rs | 2 + .../replace_instructions_with_functions.rs | 184 +++++++++++++----- ptx/src/test/ll/cvt_rn_f16x2_e4m3x2.ll | 35 ++++ ptx/src/test/ll/cvt_rn_f16x2_e5m2x2.ll | 35 ++++ .../test/ll/cvt_rn_satfinite_e4m3x2_f32.ll | 40 ++++ .../test/ll/cvt_rn_satfinite_e5m2x2_f32.ll | 40 ++++ .../test/spirv_run/cvt_rn_f16x2_e4m3x2.ptx | 23 +++ .../test/spirv_run/cvt_rn_f16x2_e5m2x2.ptx | 23 +++ .../spirv_run/cvt_rn_satfinite_e4m3x2_f32.ptx | 25 +++ .../spirv_run/cvt_rn_satfinite_e5m2x2_f32.ptx | 25 +++ ptx/src/test/spirv_run/mod.rs | 8 + ptx_parser/src/ast.rs | 19 +- ptx_parser/src/lib.rs | 26 ++- 17 files changed, 507 insertions(+), 63 deletions(-) create mode 100644 ptx/src/test/ll/cvt_rn_f16x2_e4m3x2.ll create mode 100644 ptx/src/test/ll/cvt_rn_f16x2_e5m2x2.ll create mode 100644 ptx/src/test/ll/cvt_rn_satfinite_e4m3x2_f32.ll create mode 100644 ptx/src/test/ll/cvt_rn_satfinite_e5m2x2_f32.ll create mode 100644 ptx/src/test/spirv_run/cvt_rn_f16x2_e4m3x2.ptx create mode 100644 ptx/src/test/spirv_run/cvt_rn_f16x2_e5m2x2.ptx create mode 100644 ptx/src/test/spirv_run/cvt_rn_satfinite_e4m3x2_f32.ptx create mode 100644 ptx/src/test/spirv_run/cvt_rn_satfinite_e5m2x2_f32.ptx diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index 853b9e152ba2034b0e8e646bb2c4fc6075e64387..440b232dcb2f228b666170222615d0f7d58cdbbb 100644 GIT binary patch delta 5079 zcmaD6zo&A73R4jGM71nN%ZU@^8ShQpxJb*uonh+1FME`D(iCmZfA)GG2x^R%~2@i_`qu7F54ht9= zKOAIJ;{Esk|NnZ1CA?iyEDV|~j3P@8O6LSzb71gLN$LuR`BVmREh%3xwFvDQOK?fg+TMRsGo259WGcpBhZ{EyV z!BWqAj)6gq4-^z2tiaqBvPovcK}m)KM;Ji@Ae?ew!m0VpMh?sjGweWo5N_a+2D=2K-O@8 z-1oOqAQDYk)7J%2%nA(0K)S%#806!pawcK|3{qeL1_lNPHlCGFY&&!m7))S74LlQW zCtIY|Q{-Ut)}7oUqRU>yroh3*V=;M$h>rLiMvf-N#e7~8ERk)DLdtA>QVj|cCa)*| z7O`Xp%O&$qwiY!M0kfHfxa3$$43Zoe6gb%2JSR7ZCda$i5((T zGMP!zf^Cf>M^oe0$qtfAOubH%QzWe*B2y5)1_lQ1NryE$961^SCQHg>Gai}TDr3R4!EW+?8BwW!4YO1=_-2;u zQ2cNpVcJ83Cc$PV4o8CnPK}e_$%sjB2wafHIVo99Hjm>715cWyn5u-KMuPwg^BR%K z^0H2B=>p9O1!pFw%bLokIPgU9Fd5wykT}%Q$k?I~(Qq)GjghUj!N;7%X(Qw0-Li(* z)VL$7;g)mMPfHNF%*!g+BO#G3(3HT)Cf~~A&ce#fcA3{%Fku7737FyKY>a%34LOsm z<*XTNCvTUtmBZqK=d2PU3`TbaIF7PS7LpItoiCbV@P?Hu<-uB32BV|~Q2ensgo6GWX785L>Rp$kopb6j;p3XpmF@(vY%2(cN&C0K;NVrwNV%%4}@lu(~fP zl7{TQ2n8hfr5rSPAjpB_dTBuggzF0qDj-~YL9)^4_6L>&InD_zn|T=(&zP~CnXIiS z!WcB!LD8^2OF$A7K`_4xuzG9)6-1yM`V{Q$Xt2BG85kH$L=+ShUO65N2xw$#nF6YD zc%KISV`kL|_`}!&lKaQZ`&8b+8gzDj%TH za)p66)yaeNLcn`QQ-^wx9NU!t|K)jA!Rnkud=_2^kZYO3Y^%&*GKG;|eRXIw%koZ${E-$^R^14)C(NYaV++geo^3jfi214 zbby4T8sixulf%4Iq#F+EI2W@V<~=9fa1fkQB_bAnU_0-bzd>RzBZrUqo`ViP<{{3F zOfm-?*gO(sSkwy|Wj=^DJ~CuCn7}CH!G4gH$6SC>;z`pHT}D6iJB`AL4;mc!m}53@ zbvQFHNxWg?xKqE8-Rh3O4JL^@LK$ow4}=;Hu+}e-FflN=)3Aj};*61k=Y&HH;&+(h z47jYk7sMEHS=|w6VV3YQxA1Rp)=W5{%6LvlgISchS>mJvYm%Z%5u?_Fg2N45X$E&2 zxSR~`1iWFA;4$ZT4f1ZI7RP2DZs8|#EC${edL$$qIGP&kMfp3m=RD-rGbsGOFWAfy z1qy%uDmJ?ZtUngCe_@WBz+dHP_eS<7DE#dj4(c$Au!g@l%xA8m&WH&A#xh@fQ26sF z8Jt$|1&4peVcseB4F`3a-B=Ivp0kICzu`yz^Pc}VNc?5y@G<{$7#jW_;PCeVg}(5B zH59hGBXEQj8veeJ@Q+~@We%3S=D?a1+EU8s2MT{-Q1}aj!ha1bDEzBI-fi^b*vu0l z1`huPy|D0?>hz!UP`Jyud32;@;(X3=^g@~FT89hzlP zwD8B8M>o#LBjr$4y{PK_;7rOyf&d2|jZf z35hdC5>5tp3~~|-PABjjEMT5E@t{sUb0CW{o2xXaX`EgVv%rzfBSz>>q05{DFLfLZ z0~$G4%y>e0*b;d)YRnDpDAZ&W9NI3y=E9NEaD~?ZLz9BYboR6cNtOwRKrWO6)ds~Z_3CdL(;AKl z)+`YJz{qDVQ^UZvnTJP+$DF4g#PSj1F&C(HVB5@dN~lLdBv+s*VZqu%g^ltKV0Sh$ z392_VofFbw7VTJ&aMIyuf<{==A&ol+3YrbLG71i9=uBf0WRg@UbZ)rJ+ZZTQ#8rM! zlObz;0kimmCO`89hZfYwDCDqt$Oz@IshtrrU{+0Q5N`98F=^yDE#s^;rBVNkA&Zk% zioqK;MMDM#1}0E^wlP@1A`sL{vIIw;H#qu0y+?5LL87iz-hj!$k#|C5!v|(jj=-i5 z%&dk?0gVkTq8<+KG8hz985rsr7&uucTogFW>%_vgna5RNj=;+TX2HgTI?jqL?k-k$ z1UwQ9P8$duEMQh_JXpZ2H1l9dvk2>DUPqHNLL%E;&AkJ7rp#!RU`P-UXm*G(Ph?|U zVsNCO`9tGDo#rGK^%|=?4mAk|rx(Z^6r9gz&XNEMRS!_8CNQw2v(2iHV_>tsD^PP- zr?DvWz)PJ*qktr~x_}RjjtWjj$1m9C7#tVih-EvR$H2f)sKCIW%)r24#3!KC!hDmf zBVbPh*E#MJ0ec(x?r?J$?rRWw#qIHNe}m{Go{Ej98YGVMtPnicApL{q#>evw^4+`= zMK>FiH}ZNE-DyyL$eU4gzd^lzB40<*s|L-3d@ByUZO~cHf8*1y27|-=8pVGbOg{5x z1amZ6%oLd6%+qM~S>QymRHN-S!55b;8|`~UIkv=u>p=#FR2~L~hl~siLd;Ws$T2W5 zfcilN|RaZbr={J)YuppKs0FRLqLUrAqdKc(J=8kHU^hC+lm-<>Zv*>J{f^=ILb`8-Y2c#idDl#fc>#fyuk|bpeoc1hfDE delta 1949 zcmdl}`67OT3R4xwM71nN(}@%18Lv&;xJW91HC$g^xS$Zs_#%NhiH&Ij19w2K!np>99|_EjY#_Dv442qcd6^nSni#k)J>WId$WmZfBES+b zfh~W-!H8y%68Q#G1{DSd22TkF1_K5L2H~cJ$-*41W)2e^C1)Iz^x?RW&@_Q9%Rn)r zne(U=k4vIY#I*zkuH;0i=Z$I{*#=H_y)3YPj! z4GcdRV8AHBW9f0n&y0)<7-2#w2PV9l()^5>$$=RrbfQ5~RF}O+k^xm2*Bs_IJPiyN zV9FSb5`rV79Be*%lR3n6#3z6h8S{Bbutc^o3MsSsNj4}*m^_*6EoR9Mmix>% zxm(Oo1k9ey&n3rFVvyv(puoY_W!k4`fBtC8O z=VIkxykX8(-}8R{w}G)*DBefAd!(Tz+=whB9Sxsr=$`)L@0N1kdy`60!NOf z#@5Lbq?DLGI!@jqWd#vQo%}^giRGnZK*N4wbX8lVF;xl3C<=UZEI4au z<2dc8KrxRKLoUPQgHpN>Q&K0F$S6Wgj+8Qlh_z0>AfpJf|FeuWL?m^xiL9as$bJDK zW>D(ko^)8F!;z!GZgQqpL-E6bglP{AngpAf zI2;WQd})|$CnqLtV7DNRb5gRJY#zrE2A(uYF;xjejRpZ0<}>`0^W~h_QUsb43N}pM zE@vt$YnWijaiYXIfkT;%O`s`(k?p>CLI6laK}348pn`_hqa+516ayQk;&_B&28PMe z@{aoR1AWXzTnv&DK*7u^Q6p$1BPhZn;bYF@+Hg=w0b&#*o4f>z1q;(A#>uDUtr-(1 zvn$xjVVZtlG$DZDM2WM4h&03GB!xiT`ywd@Z&d~YL63w&2H4kW0lWtd;3hVP za^>C$GMJ1#uf#MhDK?V3neYeED0MJ8Ng}Y z2$JS?1QDF+o0epw*dXmtAn%Yhu{1eVRbjEZN>Sk6qYRuW+hm^?wr zus;2P1gjuI^#c}(3>QF>3=9l=*%%lE z85kJOu`w_RF)%PlLe;_K85kIt7+g1B)Sb*)FT%jUpvK0)0HQ$^2Zs&=gB2SC11|%D zhKajF#bGo|JcNybfro*C0Y<~b6WACSBp4VNU^Gm;f(_!P3N{8Ku;XA76QK@=(J=8v pYzz#d3=9lldh%WqXU^S_pm`v>nbTC0ZSp1S?8)9XYLokIbOAH;Ab diff --git a/ptx/lib/zluda_ptx_impl.cpp b/ptx/lib/zluda_ptx_impl.cpp index 9de6f61..42fa23e 100644 --- a/ptx/lib/zluda_ptx_impl.cpp +++ b/ptx/lib/zluda_ptx_impl.cpp @@ -1,12 +1,13 @@ // Every time this file changes it must te rebuilt, you need `rocm-llvm-dev` and `llvm-17` // `fdenormal-fp-math=dynamic` is required to make functions eligible for inlining -// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | sed '/__hip_cuid/d' | sed 's/external protected/external hidden/g' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc +// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | sed '/__hip_cuid/d' | sed 's/external protected/external hidden/g' | sed 's/trunc nuw/trunc/' | sed 's/trunc nsw/trunc/' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc #include #include #include #include #include +#include #define CONSTANT_SPACE __attribute__((address_space(4))) @@ -476,4 +477,46 @@ typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2))); { return div_f32_part2(x, y, {fma_4, fma_1, fma_3, numerator_scaled_flag}); } + + __device__ static __hip_fp8_storage_t cvt_float_to_fp8(float f, __hip_fp8_interpretation_t interp) + { + const uint32_t bits = reinterpret_cast(f); + const uint8_t sign = (bits & 0x80000000) ? 0x80 : 0x0; + const uint32_t abs = bits & 0x7fffffff; + + const uint32_t min = interp == __HIP_E4M3 ? 0x3A800000 : 0x37000000; + if (abs < min) + { + return sign; // +/- 0 + } + + return __hip_cvt_float_to_fp8(f, __HIP_SATFINITE, interp); + } + + struct Fp8x2 + { + __hip_fp8_storage_t b : 8; + __hip_fp8_storage_t a : 8; + }; + + Fp8x2 FUNC(cvt_rn_satfinite_e4m3x2_f32)(float a, float b) + { + // If built-in support for fp8 formats is added to LLVM IR we should switch to use that. + return {cvt_float_to_fp8(b, __HIP_E4M3), cvt_float_to_fp8(a, __HIP_E4M3)}; + } + + Fp8x2 FUNC(cvt_rn_satfinite_e5m2x2_f32)(float a, float b) + { + return {cvt_float_to_fp8(b, __HIP_E5M2), cvt_float_to_fp8(a, __HIP_E5M2)}; + } + + __half2 FUNC(cvt_rn_f16x2_e4m3x2)(__hip_fp8x2_e4m3 in) + { + return in; + } + + __half2 FUNC(cvt_rn_f16x2_e5m2x2)(__hip_fp8x2_e5m2 in) + { + return in; + } } diff --git a/ptx/src/pass/llvm/emit.rs b/ptx/src/pass/llvm/emit.rs index f6b8ca0..177efa6 100644 --- a/ptx/src/pass/llvm/emit.rs +++ b/ptx/src/pass/llvm/emit.rs @@ -299,7 +299,9 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { }, ptx_parser::ScalarType::S16 | ptx_parser::ScalarType::B16 - | ptx_parser::ScalarType::U16 => unsafe { + | ptx_parser::ScalarType::U16 + | ptx_parser::ScalarType::E4m3x2 + | ptx_parser::ScalarType::E5m2x2 => unsafe { LLVMConstInt(llvm_type, u16::from_le_bytes(bytes.try_into()?) as u64, 0) }, ptx_parser::ScalarType::S32 @@ -1586,6 +1588,26 @@ impl<'a> MethodEmitContext<'a> { data: ptx_parser::CvtDetails, arguments: ptx_parser::CvtArgs, ) -> Result<(), TranslateError> { + // Truncating conversions to FP8 types should be replaced by a function call. + match data { + ptx_parser::CvtDetails { + to: ast::ScalarType::E4m3x2 | ast::ScalarType::E5m2x2, + mode: ast::CvtMode::FPTruncate { .. }, + .. + } => return Err(error_unreachable()), + _ => {} + } + + // Extending conversions from FP8 types should be replaced by a function call. + match data { + ptx_parser::CvtDetails { + from: ast::ScalarType::E4m3x2 | ast::ScalarType::E5m2x2, + mode: ast::CvtMode::FPExtend { .. }, + .. + } => return Err(error_unreachable()), + _ => {} + } + let dst_type = get_scalar_type(self.context, data.to); let llvm_fn = match data.mode { ptx_parser::CvtMode::ZeroExtend => LLVMBuildZExt, @@ -3096,7 +3118,11 @@ impl std::fmt::Display for LLVMTypeDisplay { match self.0 { ast::ScalarType::Pred => write!(f, "i1"), ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => write!(f, "i8"), - ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => write!(f, "i16"), + ast::ScalarType::B16 + | ast::ScalarType::U16 + | ast::ScalarType::S16 + | ast::ScalarType::E4m3x2 + | ast::ScalarType::E5m2x2 => write!(f, "i16"), ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => write!(f, "i32"), ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => write!(f, "i64"), ptx_parser::ScalarType::B128 => write!(f, "i128"), diff --git a/ptx/src/pass/llvm/mod.rs b/ptx/src/pass/llvm/mod.rs index 3513e88..24f790e 100644 --- a/ptx/src/pass/llvm/mod.rs +++ b/ptx/src/pass/llvm/mod.rs @@ -153,9 +153,11 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => unsafe { LLVMInt8TypeInContext(context) }, - ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => unsafe { - LLVMInt16TypeInContext(context) - }, + ast::ScalarType::B16 + | ast::ScalarType::U16 + | ast::ScalarType::S16 + | ast::ScalarType::E4m3x2 + | ast::ScalarType::E5m2x2 => unsafe { LLVMInt16TypeInContext(context) }, ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => unsafe { LLVMInt32TypeInContext(context) }, @@ -169,7 +171,7 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR ast::ScalarType::BF16 => unsafe { LLVMBFloatTypeInContext(context) }, ast::ScalarType::U16x2 => todo!(), ast::ScalarType::S16x2 => todo!(), - ast::ScalarType::F16x2 => todo!(), + ast::ScalarType::F16x2 => unsafe { LLVMVectorType(LLVMHalfTypeInContext(context), 2) }, ast::ScalarType::BF16x2 => todo!(), } } diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 79f5e99..d5edccf 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -995,6 +995,8 @@ fn scalar_to_ptx_name(this: ast::ScalarType) -> &'static str { ast::ScalarType::BF16 => "bf16", ast::ScalarType::BF16x2 => "bf16x2", ast::ScalarType::Pred => "pred", + ast::ScalarType::E4m3x2 => "e4m3x2", + ast::ScalarType::E5m2x2 => "e5m2x2", } } diff --git a/ptx/src/pass/replace_instructions_with_functions.rs b/ptx/src/pass/replace_instructions_with_functions.rs index edcaaa1..4b9b2bb 100644 --- a/ptx/src/pass/replace_instructions_with_functions.rs +++ b/ptx/src/pass/replace_instructions_with_functions.rs @@ -57,6 +57,38 @@ fn run_directive<'input>( }) } +fn get_or_declare_function<'input, S: Into>>( + resolver: &mut GlobalStringIdentResolver2<'input>, + fn_declarations: &mut HashMap< + Cow<'input, str>, + ( + Vec>, + SpirvWord, + Vec>, + ), + rustc_hash::FxBuildHasher, + >, + name: S, + return_arguments: &Vec<(ptx_parser::Type, ptx_parser::StateSpace)>, + input_arguments: &Vec<(ptx_parser::Type, ptx_parser::StateSpace)>, +) -> SpirvWord { + let func = match fn_declarations.entry(name.into()) { + hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1, + hash_map::Entry::Vacant(vacant_entry) => { + let name = vacant_entry.key().clone(); + let full_name = [ZLUDA_PTX_PREFIX, &*name].concat(); + let name = resolver.register_named(Cow::Owned(full_name.clone()), None); + vacant_entry.insert(( + to_variables(resolver, return_arguments), + name, + to_variables(resolver, input_arguments), + )); + name + } + }; + func +} + fn run_statements<'input>( resolver: &mut GlobalStringIdentResolver2<'input>, fn_declarations: &mut FxHashMap< @@ -99,7 +131,7 @@ fn run_statements<'input>( ast::Type::Scalar(ast::ScalarType::U32), ptx_parser::StateSpace::Reg, ))); - let full_name = [ZLUDA_PTX_PREFIX, "shfl_sync_", mode, "_b32_pred"].concat(); + let name = ["shfl_sync_", mode, "_b32_pred"].concat(); let return_arguments = vec![( ast::Type::Vector(2, ast::ScalarType::U32), ptx_parser::StateSpace::Reg, @@ -122,45 +154,19 @@ fn run_statements<'input>( ptx_parser::StateSpace::Reg, ), ]; - let func = match fn_declarations.entry(full_name.into()) { - hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1, - hash_map::Entry::Vacant(vacant_entry) => { - let name = vacant_entry.key().clone(); - let name = resolver.register_named(name, None); - vacant_entry.insert(( - to_variables(resolver, &return_arguments), - name, - to_variables(resolver, &input_arguments), - )); - name - } - }; + let func = get_or_declare_function( + resolver, + fn_declarations, + name, + &return_arguments, + &input_arguments, + ); smallvec![ Statement::Instruction::<_, SpirvWord>(ast::Instruction::Call { data: ptx_parser::CallDetails { uniform: false, - return_arguments: vec![( - ast::Type::Vector(2, ast::ScalarType::U32), - ptx_parser::StateSpace::Reg, - )], - input_arguments: vec![ - ( - ast::Type::Scalar(ast::ScalarType::U32), - ptx_parser::StateSpace::Reg, - ), - ( - ast::Type::Scalar(ast::ScalarType::U32), - ptx_parser::StateSpace::Reg, - ), - ( - ast::Type::Scalar(ast::ScalarType::U32), - ptx_parser::StateSpace::Reg, - ), - ( - ast::Type::Scalar(ast::ScalarType::U32), - ptx_parser::StateSpace::Reg, - ), - ], + return_arguments, + input_arguments }, arguments: ptx_parser::CallArgs { return_arguments: vec![packed_var], @@ -184,6 +190,73 @@ fn run_statements<'input>( arguments: ast::CvtArgs { dst: dst_pred, src: dst_pred_wide, + src2: None, + }, + }) + ] + } + Statement::Instruction(ast::Instruction::Cvt { + data: + ast::CvtDetails { + from: from @ (ast::ScalarType::E4m3x2 | ast::ScalarType::E5m2x2), + to: ast::ScalarType::F16x2, + mode: _, + }, + arguments: + ast::CvtArgs { + dst, + src, + src2: None, + }, + }) => { + let from_str = match from { + ast::ScalarType::E4m3x2 => "e4m3x2", + ast::ScalarType::E5m2x2 => "e5m2x2", + _ => unreachable!(), + }; + let packed_output = resolver.register_unnamed(Some(( + ast::Type::Scalar(ast::ScalarType::B32), + ast::StateSpace::Reg, + ))); + let name = format!("cvt_rn_f16x2_{}", from_str); + let return_arguments = vec![( + ast::Type::Scalar(ast::ScalarType::B32), + ast::StateSpace::Reg, + )]; + let input_arguments = vec![( + ast::Type::Scalar(ast::ScalarType::B16), + ast::StateSpace::Reg, + )]; + let func = get_or_declare_function( + resolver, + fn_declarations, + name, + &return_arguments, + &input_arguments, + ); + smallvec![ + Statement::Instruction::<_, SpirvWord>(ast::Instruction::Call { + data: ptx_parser::CallDetails { + uniform: false, + return_arguments, + input_arguments, + }, + arguments: ptx_parser::CallArgs { + return_arguments: vec![packed_output], + func, + input_arguments: vec![src], + }, + }), + Statement::Instruction(ast::Instruction::Cvt { + data: ast::CvtDetails { + from: ast::ScalarType::B32, + to: ast::ScalarType::F16x2, + mode: ast::CvtMode::Bitcast + }, + arguments: ast::CvtArgs { + dst, + src: packed_output, + src2: None, }, }) ] @@ -335,6 +408,29 @@ fn run_instruction<'input>( i @ ptx_parser::Instruction::Nanosleep { .. } => { to_call(resolver, fn_declarations, "nanosleep_u32".into(), i)? } + i @ ptx_parser::Instruction::Cvt { + data: + ptx_parser::CvtDetails { + from: ast::ScalarType::F32, + to: to @ (ast::ScalarType::E4m3x2 | ast::ScalarType::E5m2x2), + mode: _, + }, + arguments: _, + } => { + let to = match to { + ptx_parser::ScalarType::E4m3x2 => "e4m3x2", + ptx_parser::ScalarType::E5m2x2 => "e5m2x2", + _ => unreachable!(), + }; + // Conversions from f32 to f8 must have two source arguments. + // satfinite is mandatory for conversions to e4m3x2. + to_call( + resolver, + fn_declarations, + format!("cvt_rn_satfinite_{}_f32", to).into(), + i, + )? + } i => i, }) } @@ -373,20 +469,8 @@ fn to_call<'input>( }; Ok::<_, TranslateError>(()) })?; - let fn_name = match fn_declarations.entry(name) { - hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1, - hash_map::Entry::Vacant(vacant_entry) => { - let name = vacant_entry.key().clone(); - let full_name = [ZLUDA_PTX_PREFIX, &*name].concat(); - let name = resolver.register_named(Cow::Owned(full_name.clone()), None); - vacant_entry.insert(( - to_variables(resolver, &data_return), - name, - to_variables(resolver, &data_input), - )); - name - } - }; + let fn_name = + get_or_declare_function(resolver, fn_declarations, name, &data_return, &data_input); Ok(ast::Instruction::Call { data: ptx_parser::CallDetails { uniform: false, diff --git a/ptx/src/test/ll/cvt_rn_f16x2_e4m3x2.ll b/ptx/src/test/ll/cvt_rn_f16x2_e4m3x2.ll new file mode 100644 index 0000000..ffa7ecf --- /dev/null +++ b/ptx/src/test/ll/cvt_rn_f16x2_e4m3x2.ll @@ -0,0 +1,35 @@ +declare hidden i32 @__zluda_ptx_impl_cvt_rn_f16x2_e4m3x2(i16) #0 + +define amdgpu_kernel void @cvt_rn_f16x2_e4m3x2(ptr addrspace(4) byref(i64) %"31", ptr addrspace(4) byref(i64) %"32") #1 { + %"33" = alloca i64, align 8, addrspace(5) + %"34" = alloca i64, align 8, addrspace(5) + %"35" = alloca i16, align 2, addrspace(5) + %"36" = alloca i32, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"30" + +"30": ; preds = %1 + %"37" = load i64, ptr addrspace(4) %"31", align 8 + store i64 %"37", ptr addrspace(5) %"33", align 8 + %"38" = load i64, ptr addrspace(4) %"32", align 8 + store i64 %"38", ptr addrspace(5) %"34", align 8 + %"40" = load i64, ptr addrspace(5) %"33", align 8 + %"45" = inttoptr i64 %"40" to ptr + %"39" = load i16, ptr %"45", align 2 + store i16 %"39", ptr addrspace(5) %"35", align 2 + %"42" = load i16, ptr addrspace(5) %"35", align 2 + %"49" = call i32 @__zluda_ptx_impl_cvt_rn_f16x2_e4m3x2(i16 %"42") + %"46" = bitcast i32 %"49" to <2 x half> + %"41" = bitcast <2 x half> %"46" to i32 + store i32 %"41", ptr addrspace(5) %"36", align 4 + %"43" = load i64, ptr addrspace(5) %"34", align 8 + %"44" = load i32, ptr addrspace(5) %"36", align 4 + %"48" = inttoptr i64 %"43" to ptr + store i32 %"44", ptr %"48", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/ptx/src/test/ll/cvt_rn_f16x2_e5m2x2.ll b/ptx/src/test/ll/cvt_rn_f16x2_e5m2x2.ll new file mode 100644 index 0000000..d63c684 --- /dev/null +++ b/ptx/src/test/ll/cvt_rn_f16x2_e5m2x2.ll @@ -0,0 +1,35 @@ +declare hidden i32 @__zluda_ptx_impl_cvt_rn_f16x2_e5m2x2(i16) #0 + +define amdgpu_kernel void @cvt_rn_f16x2_e5m2x2(ptr addrspace(4) byref(i64) %"31", ptr addrspace(4) byref(i64) %"32") #1 { + %"33" = alloca i64, align 8, addrspace(5) + %"34" = alloca i64, align 8, addrspace(5) + %"35" = alloca i16, align 2, addrspace(5) + %"36" = alloca i32, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"30" + +"30": ; preds = %1 + %"37" = load i64, ptr addrspace(4) %"31", align 8 + store i64 %"37", ptr addrspace(5) %"33", align 8 + %"38" = load i64, ptr addrspace(4) %"32", align 8 + store i64 %"38", ptr addrspace(5) %"34", align 8 + %"40" = load i64, ptr addrspace(5) %"33", align 8 + %"45" = inttoptr i64 %"40" to ptr + %"39" = load i16, ptr %"45", align 2 + store i16 %"39", ptr addrspace(5) %"35", align 2 + %"42" = load i16, ptr addrspace(5) %"35", align 2 + %"49" = call i32 @__zluda_ptx_impl_cvt_rn_f16x2_e5m2x2(i16 %"42") + %"46" = bitcast i32 %"49" to <2 x half> + %"41" = bitcast <2 x half> %"46" to i32 + store i32 %"41", ptr addrspace(5) %"36", align 4 + %"43" = load i64, ptr addrspace(5) %"34", align 8 + %"44" = load i32, ptr addrspace(5) %"36", align 4 + %"48" = inttoptr i64 %"43" to ptr + store i32 %"44", ptr %"48", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/ptx/src/test/ll/cvt_rn_satfinite_e4m3x2_f32.ll b/ptx/src/test/ll/cvt_rn_satfinite_e4m3x2_f32.ll new file mode 100644 index 0000000..eaa932a --- /dev/null +++ b/ptx/src/test/ll/cvt_rn_satfinite_e4m3x2_f32.ll @@ -0,0 +1,40 @@ +declare hidden i16 @__zluda_ptx_impl_cvt_rn_satfinite_e4m3x2_f32(float, float) #0 + +define amdgpu_kernel void @cvt_rn_satfinite_e4m3x2_f32(ptr addrspace(4) byref(i64) %"34", ptr addrspace(4) byref(i64) %"35") #1 { + %"36" = alloca i64, align 8, addrspace(5) + %"37" = alloca i64, align 8, addrspace(5) + %"38" = alloca float, align 4, addrspace(5) + %"39" = alloca float, align 4, addrspace(5) + %"40" = alloca i16, align 2, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"33" + +"33": ; preds = %1 + %"41" = load i64, ptr addrspace(4) %"34", align 8 + store i64 %"41", ptr addrspace(5) %"36", align 8 + %"42" = load i64, ptr addrspace(4) %"35", align 8 + store i64 %"42", ptr addrspace(5) %"37", align 8 + %"44" = load i64, ptr addrspace(5) %"36", align 8 + %"52" = inttoptr i64 %"44" to ptr + %"43" = load float, ptr %"52", align 4 + store float %"43", ptr addrspace(5) %"38", align 4 + %"45" = load i64, ptr addrspace(5) %"36", align 8 + %"53" = inttoptr i64 %"45" to ptr + %"32" = getelementptr inbounds i8, ptr %"53", i64 4 + %"46" = load float, ptr %"32", align 4 + store float %"46", ptr addrspace(5) %"39", align 4 + %"48" = load float, ptr addrspace(5) %"38", align 4 + %"49" = load float, ptr addrspace(5) %"39", align 4 + %"54" = call i16 @__zluda_ptx_impl_cvt_rn_satfinite_e4m3x2_f32(float %"48", float %"49") + store i16 %"54", ptr addrspace(5) %"40", align 2 + %"50" = load i64, ptr addrspace(5) %"37", align 8 + %"51" = load i16, ptr addrspace(5) %"40", align 2 + %"55" = inttoptr i64 %"50" to ptr + store i16 %"51", ptr %"55", align 2 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/ptx/src/test/ll/cvt_rn_satfinite_e5m2x2_f32.ll b/ptx/src/test/ll/cvt_rn_satfinite_e5m2x2_f32.ll new file mode 100644 index 0000000..bec74d3 --- /dev/null +++ b/ptx/src/test/ll/cvt_rn_satfinite_e5m2x2_f32.ll @@ -0,0 +1,40 @@ +declare hidden i16 @__zluda_ptx_impl_cvt_rn_satfinite_e5m2x2_f32(float, float) #0 + +define amdgpu_kernel void @cvt_rn_satfinite_e5m2x2_f32(ptr addrspace(4) byref(i64) %"34", ptr addrspace(4) byref(i64) %"35") #1 { + %"36" = alloca i64, align 8, addrspace(5) + %"37" = alloca i64, align 8, addrspace(5) + %"38" = alloca float, align 4, addrspace(5) + %"39" = alloca float, align 4, addrspace(5) + %"40" = alloca i16, align 2, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"33" + +"33": ; preds = %1 + %"41" = load i64, ptr addrspace(4) %"34", align 8 + store i64 %"41", ptr addrspace(5) %"36", align 8 + %"42" = load i64, ptr addrspace(4) %"35", align 8 + store i64 %"42", ptr addrspace(5) %"37", align 8 + %"44" = load i64, ptr addrspace(5) %"36", align 8 + %"52" = inttoptr i64 %"44" to ptr + %"43" = load float, ptr %"52", align 4 + store float %"43", ptr addrspace(5) %"38", align 4 + %"45" = load i64, ptr addrspace(5) %"36", align 8 + %"53" = inttoptr i64 %"45" to ptr + %"32" = getelementptr inbounds i8, ptr %"53", i64 4 + %"46" = load float, ptr %"32", align 4 + store float %"46", ptr addrspace(5) %"39", align 4 + %"48" = load float, ptr addrspace(5) %"38", align 4 + %"49" = load float, ptr addrspace(5) %"39", align 4 + %"54" = call i16 @__zluda_ptx_impl_cvt_rn_satfinite_e5m2x2_f32(float %"48", float %"49") + store i16 %"54", ptr addrspace(5) %"40", align 2 + %"50" = load i64, ptr addrspace(5) %"37", align 8 + %"51" = load i16, ptr addrspace(5) %"40", align 2 + %"55" = inttoptr i64 %"50" to ptr + store i16 %"51", ptr %"55", align 2 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/ptx/src/test/spirv_run/cvt_rn_f16x2_e4m3x2.ptx b/ptx/src/test/spirv_run/cvt_rn_f16x2_e4m3x2.ptx new file mode 100644 index 0000000..946c498 --- /dev/null +++ b/ptx/src/test/spirv_run/cvt_rn_f16x2_e4m3x2.ptx @@ -0,0 +1,23 @@ +.version 7.8 +.target sm_90 +.address_size 64 + +.visible .entry cvt_rn_f16x2_e4m3x2( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .b16 in; + .reg .b32 result; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.b16 in, [in_addr]; + + cvt.rn.f16x2.e4m3x2 result, in; + st.b32 [out_addr], result; + ret; +} diff --git a/ptx/src/test/spirv_run/cvt_rn_f16x2_e5m2x2.ptx b/ptx/src/test/spirv_run/cvt_rn_f16x2_e5m2x2.ptx new file mode 100644 index 0000000..7dcaee0 --- /dev/null +++ b/ptx/src/test/spirv_run/cvt_rn_f16x2_e5m2x2.ptx @@ -0,0 +1,23 @@ +.version 7.8 +.target sm_90 +.address_size 64 + +.visible .entry cvt_rn_f16x2_e5m2x2( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .b16 in; + .reg .b32 result; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.b16 in, [in_addr]; + + cvt.rn.f16x2.e5m2x2 result, in; + st.b32 [out_addr], result; + ret; +} diff --git a/ptx/src/test/spirv_run/cvt_rn_satfinite_e4m3x2_f32.ptx b/ptx/src/test/spirv_run/cvt_rn_satfinite_e4m3x2_f32.ptx new file mode 100644 index 0000000..8a470cf --- /dev/null +++ b/ptx/src/test/spirv_run/cvt_rn_satfinite_e4m3x2_f32.ptx @@ -0,0 +1,25 @@ +.version 7.8 +.target sm_90 +.address_size 64 + +.visible .entry cvt_rn_satfinite_e4m3x2_f32( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .f32 in_a; + .reg .f32 in_b; + .reg .b16 result; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.f32 in_a, [in_addr]; + ld.f32 in_b, [in_addr + 4]; + + cvt.rn.satfinite.e4m3x2.f32 result, in_a, in_b; + st.b16 [out_addr], result; + ret; +} diff --git a/ptx/src/test/spirv_run/cvt_rn_satfinite_e5m2x2_f32.ptx b/ptx/src/test/spirv_run/cvt_rn_satfinite_e5m2x2_f32.ptx new file mode 100644 index 0000000..0e5dc8e --- /dev/null +++ b/ptx/src/test/spirv_run/cvt_rn_satfinite_e5m2x2_f32.ptx @@ -0,0 +1,25 @@ +.version 7.8 +.target sm_90 +.address_size 64 + +.visible .entry cvt_rn_satfinite_e5m2x2_f32( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .f32 in_a; + .reg .f32 in_b; + .reg .b16 result; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.f32 in_a, [in_addr]; + ld.f32 in_b, [in_addr + 4]; + + cvt.rn.satfinite.e5m2x2.f32 result, in_a, in_b; + st.b16 [out_addr], result; + ret; +} diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index ca412be..28a8112 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -185,6 +185,14 @@ test_ptx!(cvt_rni, [9.5f32, 10.5f32], [10f32, 10f32]); test_ptx!(cvt_rzi, [-13.8f32, 12.9f32], [-13f32, 12f32]); test_ptx!(cvt_s32_f32, [-13.8f32, 12.9f32], [-13i32, 13i32]); test_ptx!(cvt_rni_u16_f32, [0x477FFF80u32], [65535u16]); +test_ptx!(cvt_rn_satfinite_e4m3x2_f32, [0.40625, 12.9f32], [0x2D55u16]); +test_ptx!( + cvt_rn_satfinite_e5m2x2_f32, + [0.375, -5256.6f32], + [0x36EDu16] +); +test_ptx!(cvt_rn_f16x2_e4m3x2, [0x2D55u16], [0x36804a80u32]); +test_ptx!(cvt_rn_f16x2_e5m2x2, [0x36EDu16], [0x3600ED00u32]); test_ptx!(clz, [0b00000101_00101101_00010011_10101011u32], [5u32]); test_ptx!(popc, [0b10111100_10010010_01001001_10001010u32], [14u32]); test_ptx!( diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index f198795..bee81ac 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -233,6 +233,11 @@ ptx_parser_macros::generate_instruction_type!( type: { Type::Scalar(data.from) }, relaxed_type_check: true, }, + src2: { + repr: Option, + type: { Type::Scalar(data.from) }, + relaxed_type_check: true, + }, } }, Cvta { @@ -1047,7 +1052,9 @@ impl ScalarType { | ScalarType::S16 | ScalarType::B16 | ScalarType::F16 - | ScalarType::BF16 => 2, + | ScalarType::BF16 + | ScalarType::E4m3x2 + | ScalarType::E5m2x2 => 2, ScalarType::U32 | ScalarType::S32 | ScalarType::B32 @@ -1069,7 +1076,9 @@ impl ScalarType { | ScalarType::S16 | ScalarType::B16 | ScalarType::F16 - | ScalarType::BF16 => Layout::new::(), + | ScalarType::BF16 + | ScalarType::E4m3x2 + | ScalarType::E5m2x2 => Layout::new::(), ScalarType::U32 | ScalarType::S32 | ScalarType::B32 @@ -1110,6 +1119,8 @@ impl ScalarType { ScalarType::F64 => ScalarKind::Float, ScalarType::BF16 => ScalarKind::Float, ScalarType::BF16x2 => ScalarKind::Float, + ScalarType::E4m3x2 => ScalarKind::Float, + ScalarType::E5m2x2 => ScalarKind::Float, ScalarType::Pred => ScalarKind::Pred, } } @@ -1884,7 +1895,9 @@ impl CvtDetails { saturate, }, Ordering::Greater => { - if rounding.is_some() { + if rounding.is_some() + && !(src == ScalarType::E4m3x2 || src == ScalarType::E5m2x2) + { errors.push(PtxError::SyntaxError( "should not have rounding mode when dst is larger than src in cvt" .to_string(), diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 9c08f95..ea458f6 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -2370,7 +2370,7 @@ derive_parser!( // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt cvt{.ifrnd}{.ftz}{.sat}.dtype.atype d, a => { let data = ast::CvtDetails::new(&mut state.errors, ifrnd, ftz, sat, dtype, atype); - let arguments = ast::CvtArgs { dst: d, src: a }; + let arguments = ast::CvtArgs { dst: d, src: a, src2: None }; ast::Instruction::Cvt { data, arguments } @@ -2381,18 +2381,38 @@ derive_parser!( // cvt.frnd2{.relu}{.satfinite}.bf16x2.f32 d, a, b; // cvt.rna{.satfinite}.tf32.f32 d, a; // cvt.frnd2{.relu}.tf32.f32 d, a; - // cvt.rn.satfinite{.relu}.f8x2type.f32 d, a, b; + cvt.rn.satfinite{.relu}.f8x2type.f32 d, a, b => { + if relu { + state.errors.push(PtxError::Todo); + } + let data = ast::CvtDetails::new(&mut state.errors, Some(rn), false, false, f8x2type, ScalarType::F32); + ast::Instruction::Cvt { + data, + arguments: ast::CvtArgs { dst: d, src: a, src2: Some(b) } + } + } // cvt.rn.satfinite{.relu}.f8x2type.f16x2 d, a; - // cvt.rn.{.relu}.f16x2.f8x2type d, a; + cvt.rn{.relu}.f16x2.f8x2type d, a => { + if relu { + state.errors.push(PtxError::Todo); + } + let data = ast::CvtDetails::new(&mut state.errors, Some(rn), false, false, ScalarType::F16x2, f8x2type); + ast::Instruction::Cvt { + data, + arguments: ast::CvtArgs { dst: d, src: a, src2: None } + } + } .ifrnd: RawRoundingMode = { .rn, .rz, .rm, .rp, .rni, .rzi, .rmi, .rpi }; .frnd2: RawRoundingMode = { .rn, .rz }; + RawRoundingMode = { .rn }; .dtype: ScalarType = { .u8, .u16, .u32, .u64, .s8, .s16, .s32, .s64, .bf16, .f16, .f32, .f64 }; .atype: ScalarType = { .u8, .u16, .u32, .u64, .s8, .s16, .s32, .s64, .bf16, .f16, .f32, .f64 }; + .f8x2type: ScalarType = { .e4m3x2, .e5m2x2 }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl shl.type d, a, b => { ast::Instruction::Shl { data: type_, arguments: ShlArgs { dst: d, src1: a, src2: b } }