diff --git a/elf.o b/elf.o new file mode 100644 index 0000000..3095bf4 Binary files /dev/null and b/elf.o differ diff --git a/ptx/src/pass/expand_operands.rs b/ptx/src/pass/expand_operands.rs index b21c343..5296391 100644 --- a/ptx/src/pass/expand_operands.rs +++ b/ptx/src/pass/expand_operands.rs @@ -230,15 +230,33 @@ impl<'a, 'input> FlattenArguments<'a, 'input> { fn vec_pack( &mut self, - vector_elements: Vec, + vector_elements: Vec>, type_space: Option<(&ast::Type, ast::StateSpace)>, is_dst: bool, relaxed_type_check: bool, ) -> Result { let (width, scalar_t, state_space) = match type_space { Some((ast::Type::Vector(width, scalar_t), space)) => (*width, *scalar_t, space), + Some((ast::Type::Scalar(scalar_t), space)) + if scalar_t.kind() == ast::ScalarKind::Bit => + { + let type_ = + ast::ScalarType::from_size(scalar_t.size_of() / (vector_elements.len() as u8)) + .ok_or_else(|| error_mismatched_type())?; + (vector_elements.len() as u8, type_, space) + } _ => return Err(error_mismatched_type()), }; + let vector_elements = vector_elements + .into_iter() + .map(|element| match element { + ast::RegOrImmediate::Reg(name) => self.reg(name), + ast::RegOrImmediate::Imm(immediate_value) => self.immediate( + immediate_value, + Some((&ast::Type::Scalar(scalar_t), state_space)), + ), + }) + .collect::, _>>()?; let temporary_vector = self .resolver .register_unnamed(Some((ast::Type::Vector(width, scalar_t), state_space))); diff --git a/ptx/src/pass/fix_special_registers.rs b/ptx/src/pass/fix_special_registers.rs index 70b468d..2b90b0f 100644 --- a/ptx/src/pass/fix_special_registers.rs +++ b/ptx/src/pass/fix_special_registers.rs @@ -198,10 +198,15 @@ pub fn map_operand( Some(ident) => ast::ParsedOperand::Reg(ident), None => ast::ParsedOperand::VecMember(ident, member), }, - ast::ParsedOperand::VecPack(idents) => ast::ParsedOperand::VecPack( - idents + ast::ParsedOperand::VecPack(elements) => ast::ParsedOperand::VecPack( + elements .into_iter() - .map(|ident| Ok(fn_(ident, None)?.unwrap_or(ident))) + .map(|element| match element { + ast::RegOrImmediate::Reg(ident) => { + Ok(ast::RegOrImmediate::Reg(fn_(ident, None)?.unwrap_or(ident))) + } + ast::RegOrImmediate::Imm(imm) => Ok(ast::RegOrImmediate::Imm(imm)), + }) .collect::, _>>()?, ), }) diff --git a/ptx/src/pass/insert_implicit_conversions2.rs b/ptx/src/pass/insert_implicit_conversions2.rs index 5b0fd3b..b2d3161 100644 --- a/ptx/src/pass/insert_implicit_conversions2.rs +++ b/ptx/src/pass/insert_implicit_conversions2.rs @@ -208,7 +208,7 @@ fn default_implicit_conversion_type( if should_bitcast(instruction_type, operand_type) { Ok(Some(ConversionKind::Default)) } else { - Err(TranslateError::MismatchedType) + Err(error_mismatched_type()) } } else { Ok(Some(ConversionKind::PtrToPtr)) @@ -264,14 +264,14 @@ pub(crate) fn should_convert_relaxed_dst_wrapper( (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { if operand_space != instruction_space { - return Err(TranslateError::MismatchedType); + return Err(error_mismatched_type()); } if operand_type == instruction_type { return Ok(None); } match should_convert_relaxed_dst(operand_type, instruction_type) { conv @ Some(_) => Ok(conv), - None => Err(TranslateError::MismatchedType), + None => Err(error_mismatched_type()), } } diff --git a/ptx/src/pass/llvm/emit.rs b/ptx/src/pass/llvm/emit.rs index 2403d90..d3b81cd 100644 --- a/ptx/src/pass/llvm/emit.rs +++ b/ptx/src/pass/llvm/emit.rs @@ -24,8 +24,6 @@ // shows it fails inside amdgpu-isel. You can get a little bit furthr with "-mllvm -global-isel", // but it will too fail similarly, but with "unable to legalize instruction" -use std::array::TryFromSliceError; -use std::convert::TryInto; use std::ffi::{CStr, NulError}; use std::{i8, ptr, u64}; @@ -249,79 +247,51 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { unsafe { LLVMSetAlignment(global, align) }; } if !var.array_init.is_empty() { - self.emit_array_init(&var.v_type, &*var.array_init, global)?; + let initializer = self.get_array_init(&var.v_type, &*var.array_init)?; + unsafe { LLVMSetInitializer(global, initializer) }; } Ok(()) } - // TODO: instead of Vec we should emit a typed initializer - fn emit_array_init( - &mut self, + fn get_array_init( + &self, type_: &ast::Type, - array_init: &[u8], - global: *mut llvm_zluda::LLVMValue, - ) -> Result<(), TranslateError> { - match type_ { + array_init: &[ast::RegOrImmediate], + ) -> Result<*mut LLVMValue, TranslateError> { + let initializer = match type_ { ast::Type::Array(None, scalar, dimensions) => { if dimensions.len() != 1 { todo!() } - if dimensions[0] as usize * scalar.size_of() as usize != array_init.len() { + if dimensions[0] as usize != array_init.len() { return Err(error_unreachable()); } let type_ = get_scalar_type(self.context, *scalar); let mut elements = array_init - .chunks(scalar.size_of() as usize) - .map(|chunk| self.constant_from_bytes(*scalar, chunk, type_)) - .collect::, _>>() - .map_err(|_| error_unreachable())?; - let initializer = - unsafe { LLVMConstArray2(type_, elements.as_mut_ptr(), elements.len() as u64) }; - unsafe { LLVMSetInitializer(global, initializer) }; + .iter() + .map(|elem| match elem { + ast::RegOrImmediate::Reg(reg) => { + Ok(unsafe { LLVMConstPtrToInt(self.resolver.value(*reg)?, type_) }) + } + ast::RegOrImmediate::Imm(imm) => { + Ok(get_immediate_value(self.context, scalar, imm)) + } + }) + .collect::, _>>()?; + unsafe { LLVMConstArray2(type_, elements.as_mut_ptr(), elements.len() as u64) } } - _ => todo!(), - } - Ok(()) - } - - fn constant_from_bytes( - &self, - scalar: ast::ScalarType, - bytes: &[u8], - llvm_type: LLVMTypeRef, - ) -> Result { - Ok(match scalar { - ptx_parser::ScalarType::Pred - | ptx_parser::ScalarType::S8 - | ptx_parser::ScalarType::B8 - | ptx_parser::ScalarType::U8 => unsafe { - LLVMConstInt(llvm_type, u8::from_le_bytes(bytes.try_into()?) as u64, 0) - }, - ptx_parser::ScalarType::S16 - | ptx_parser::ScalarType::B16 - | ptx_parser::ScalarType::U16 - | ptx_parser::ScalarType::E4m3x2 - | ptx_parser::ScalarType::E5m2x2 => unsafe { - LLVMConstInt(llvm_type, u16::from_le_bytes(bytes.try_into()?) as u64, 0) - }, - ptx_parser::ScalarType::S32 - | ptx_parser::ScalarType::B32 - | ptx_parser::ScalarType::U32 => unsafe { - LLVMConstInt(llvm_type, u32::from_le_bytes(bytes.try_into()?) as u64, 0) - }, - ptx_parser::ScalarType::F16 => todo!(), - ptx_parser::ScalarType::BF16 => todo!(), - ptx_parser::ScalarType::U64 => todo!(), - ptx_parser::ScalarType::S64 => todo!(), - ptx_parser::ScalarType::S16x2 => todo!(), - ptx_parser::ScalarType::F32 => todo!(), - ptx_parser::ScalarType::B64 => todo!(), - ptx_parser::ScalarType::F64 => todo!(), - ptx_parser::ScalarType::B128 => todo!(), - ptx_parser::ScalarType::U16x2 => todo!(), - ptx_parser::ScalarType::F16x2 => todo!(), - ptx_parser::ScalarType::BF16x2 => todo!(), - }) + ast::Type::Scalar(scalar) => { + let initializer = match array_init { + [ast::RegOrImmediate::Imm(init)] => init, + _ => return Err(error_mismatched_type()), + }; + get_immediate_value(self.context, scalar, initializer) + } + _ => { + todo!() + } + }; + Ok(initializer) } fn emit_fn_attribute(&self, llvm_object: LLVMValueRef, key: &str, value: &str) { @@ -338,6 +308,20 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { } } +fn get_immediate_value( + context: LLVMContextRef, + scalar_type: &ast::ScalarType, + imm: &ast::ImmediateValue, +) -> *mut LLVMValue { + let type_ = get_scalar_type(context, *scalar_type); + match imm { + ast::ImmediateValue::U64(x) => unsafe { LLVMConstInt(type_, *x, 0) }, + ast::ImmediateValue::S64(x) => unsafe { LLVMConstInt(type_, *x as u64, 0) }, + ast::ImmediateValue::F32(x) => unsafe { LLVMConstReal(type_, *x as f64) }, + ast::ImmediateValue::F64(x) => unsafe { LLVMConstReal(type_, *x) }, + } +} + fn llvm_ftz(ftz: bool) -> &'static str { if ftz { "preserve-sign" @@ -404,7 +388,6 @@ impl<'a> MethodEmitContext<'a> { Statement::VectorWrite(vector_write) => self.emit_vector_write(vector_write)?, Statement::SetMode(mode_reg) => self.emit_set_mode(mode_reg)?, Statement::FpSaturate { dst, src, type_ } => self.emit_fp_saturate(type_, dst, src)?, - // No-op Statement::FpModeRequired { .. } => {} }) } @@ -445,7 +428,7 @@ impl<'a> MethodEmitContext<'a> { unsafe { LLVMSetAlignment(alloca, align) }; } if !var.array_init.is_empty() { - todo!() + return Err(error_unreachable()); } Ok(()) } @@ -722,13 +705,7 @@ impl<'a> MethodEmitContext<'a> { } fn emit_constant(&mut self, constant: ConstantDefinition) -> Result<(), TranslateError> { - let type_ = get_scalar_type(self.context, constant.typ); - let value = match constant.value { - ast::ImmediateValue::U64(x) => unsafe { LLVMConstInt(type_, x, 0) }, - ast::ImmediateValue::S64(x) => unsafe { LLVMConstInt(type_, x as u64, 0) }, - ast::ImmediateValue::F32(x) => unsafe { LLVMConstReal(type_, x as f64) }, - ast::ImmediateValue::F64(x) => unsafe { LLVMConstReal(type_, x) }, - }; + let value = get_immediate_value(self.context, &constant.typ, &constant.value); self.resolver.register(constant.dst, value); Ok(()) } @@ -3133,7 +3110,7 @@ impl std::fmt::Display for LLVMTypeDisplay { ast::ScalarType::F64 => write!(f, "f64"), ptx_parser::ScalarType::S16x2 | ptx_parser::ScalarType::U16x2 => write!(f, "v2i16"), ast::ScalarType::F16x2 => write!(f, "v2f16"), - ptx_parser::ScalarType::BF16x2 => write!(f, "v2bfloat"), + ptx_parser::ScalarType::BF16x2 => write!(f, "v2bf16"), } } } diff --git a/ptx/src/pass/llvm/mod.rs b/ptx/src/pass/llvm/mod.rs index 24f790e..40781fc 100644 --- a/ptx/src/pass/llvm/mod.rs +++ b/ptx/src/pass/llvm/mod.rs @@ -172,7 +172,7 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR ast::ScalarType::U16x2 => todo!(), ast::ScalarType::S16x2 => todo!(), ast::ScalarType::F16x2 => unsafe { LLVMVectorType(LLVMHalfTypeInContext(context), 2) }, - ast::ScalarType::BF16x2 => todo!(), + ast::ScalarType::BF16x2 => unsafe { LLVMVectorType(LLVMBFloatTypeInContext(context), 2) }, } } diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 39bebb0..f743fad 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -224,7 +224,7 @@ fn error_unknown_symbol>(symbol: T) -> TranslateError { #[cfg(debug_assertions)] fn error_mismatched_type() -> TranslateError { - panic!() + panic!("Mismatched type") } #[cfg(not(debug_assertions))] @@ -613,6 +613,12 @@ struct ConstantDefinition { pub value: ast::ImmediateValue, } +impl std::fmt::Display for ConstantDefinition { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "zluda.constant{} {}", self.typ, self.value) + } +} + pub struct PtrAccess { underlying_type: ast::Type, state_space: ast::StateSpace, @@ -629,6 +635,22 @@ struct RepackVectorDetails { relaxed_type_check: bool, } +impl std::fmt::Display for RepackVectorDetails { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let extract = if self.is_extract { + ".extract" + } else { + ".composite" + }; + let relaxed = if self.relaxed_type_check { + ".relaxed" + } else { + "" + }; + write!(f, "zluda.repack_vector{}{}{}", extract, relaxed, self.typ) + } +} + struct FunctionPointerDetails { dst: SpirvWord, src: SpirvWord, diff --git a/ptx/src/pass/normalize_identifiers2.rs b/ptx/src/pass/normalize_identifiers2.rs index 05045b7..901f628 100644 --- a/ptx/src/pass/normalize_identifiers2.rs +++ b/ptx/src/pass/normalize_identifiers2.rs @@ -92,7 +92,7 @@ fn run_variable<'input, 'b>( align: variable.align, v_type: variable.v_type, state_space: variable.state_space, - array_init: variable.array_init, + array_init: run_array_init(resolver, &variable.array_init)?, }) } @@ -171,7 +171,7 @@ fn run_multivariable<'input, 'b>( v_type: variable.var.v_type.clone(), state_space: variable.var.state_space, name: ident, - array_init: variable.var.array_init.clone(), + array_init: run_array_init(resolver, &variable.var.array_init)?, })); } } @@ -186,9 +186,22 @@ fn run_multivariable<'input, 'b>( v_type: variable.var.v_type.clone(), state_space: variable.var.state_space, name: ident, - array_init: variable.var.array_init.clone(), + array_init: run_array_init(resolver, &variable.var.array_init)?, })); } } Ok(()) } + +fn run_array_init<'input, 'b>( + resolver: &mut ScopedResolver<'input, 'b>, + array_init: &[ast::RegOrImmediate<&'input str>], +) -> Result>, TranslateError> { + Ok(array_init + .iter() + .map(|elem| match elem { + ast::RegOrImmediate::Reg(name) => Ok(ast::RegOrImmediate::Reg(resolver.get(name)?)), + ast::RegOrImmediate::Imm(imm) => Ok(ast::RegOrImmediate::Imm(*imm)), + }) + .collect::, _>>()?) +} diff --git a/ptx/src/pass/test/expand_operands/mod.rs b/ptx/src/pass/test/expand_operands/mod.rs new file mode 100644 index 0000000..20efae8 --- /dev/null +++ b/ptx/src/pass/test/expand_operands/mod.rs @@ -0,0 +1,22 @@ +use crate::pass::{test::directive2_vec_to_string, *}; + +use super::test_pass; + +macro_rules! test_expand_operands { + ($test_name:ident) => { + test_pass!(run_expand_operands, $test_name); + }; +} + +fn run_expand_operands(ptx: ptx_parser::Module) -> String { + // We run the minimal number of passes required to produce the input expected by expand_operands + let mut flat_resolver = GlobalStringIdentResolver2::new(SpirvWord(1)); + let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); + let directives = normalize_identifiers2::run(&mut scoped_resolver, ptx.directives).unwrap(); + let directives = normalize_predicates2::run(&mut flat_resolver, directives).unwrap(); + let directives = expand_operands::run(&mut flat_resolver, directives).unwrap(); + directive2_vec_to_string(&flat_resolver, directives) +} + +test_expand_operands!(vector_operand); +test_expand_operands!(vector_operand_convert); diff --git a/ptx/src/pass/test/expand_operands/vector_operand.ptx b/ptx/src/pass/test/expand_operands/vector_operand.ptx new file mode 100644 index 0000000..e5cfdd9 --- /dev/null +++ b/ptx/src/pass/test/expand_operands/vector_operand.ptx @@ -0,0 +1,23 @@ +.version 6.5 +.target sm_60 +.address_size 64 + +.func (.reg .v2.b16 output) default ( + .reg .b16 input +) +{ + mov.v2.b16 output, {0x5678, input}; + ret; +} + +// %%% output %%% + +.func (.reg .v2 .b16 %2) %1 ( + .reg .b16 %3 +) +{ + .b16.reg %4 = zluda.constant.b16 22136; + .v2.b16.reg %5 = zluda.repack_vector.composite.b16 %4, %3; + mov.v2.b16 %2, %5; + ret; +} diff --git a/ptx/src/pass/test/expand_operands/vector_operand_convert.ptx b/ptx/src/pass/test/expand_operands/vector_operand_convert.ptx new file mode 100644 index 0000000..1c94806 --- /dev/null +++ b/ptx/src/pass/test/expand_operands/vector_operand_convert.ptx @@ -0,0 +1,23 @@ +.version 6.5 +.target sm_60 +.address_size 64 + +.func (.reg .b32 output) default ( + .reg .b16 input +) +{ + mov.b32 output, {0x5678, input}; + ret; +} + +// %%% output %%% + +.func (.reg .b32 %2) %1 ( + .reg .b16 %3 +) +{ + .b16.reg %4 = zluda.constant.b16 22136; + .v2.b16.reg %5 = zluda.repack_vector.composite.b16 %4, %3; + mov.b32 %2, %5; + ret; +} diff --git a/ptx/src/pass/test/insert_implicit_conversions/default_reg_b32_reg_v2_b16.ptx b/ptx/src/pass/test/insert_implicit_conversions/default_reg_b32_reg_v2_b16.ptx new file mode 100644 index 0000000..f334b83 --- /dev/null +++ b/ptx/src/pass/test/insert_implicit_conversions/default_reg_b32_reg_v2_b16.ptx @@ -0,0 +1,22 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.func (.reg .b32 output) default ( + .reg .v2.b16 input +) +{ + mov.b32 output, input; + ret; +} + +// %%% output %%% + +.func (.reg .b32 %2) %1 ( + .reg .v2 .b16 %3 +) +{ + .b32.reg %4 = zluda.convert_implicit.default.reg.b32.reg.v2.b16 %3; + mov.b32 %2, %4; + ret; +} diff --git a/ptx/src/pass/test/insert_implicit_conversions/mod.rs b/ptx/src/pass/test/insert_implicit_conversions/mod.rs index 1fb7a54..1b5fffb 100644 --- a/ptx/src/pass/test/insert_implicit_conversions/mod.rs +++ b/ptx/src/pass/test/insert_implicit_conversions/mod.rs @@ -21,3 +21,4 @@ fn run_insert_implicit_conversions(ptx: ptx_parser::Module) -> String { test_insert_implicit_conversions!(default); test_insert_implicit_conversions!(default_reg_b32_reg_f16x2); +test_insert_implicit_conversions!(default_reg_b32_reg_v2_b16); diff --git a/ptx/src/pass/test/mod.rs b/ptx/src/pass/test/mod.rs index 3a9ef1f..4014842 100644 --- a/ptx/src/pass/test/mod.rs +++ b/ptx/src/pass/test/mod.rs @@ -6,6 +6,7 @@ use std::{ path::Path, }; +mod expand_operands; mod insert_implicit_conversions; #[macro_export] @@ -202,6 +203,8 @@ fn statement_to_string( Statement::Variable(var) => format!("{}", var), Statement::Instruction(instr) => format!("{}", instr), Statement::Conversion(conv) => format!("{}", conv), + Statement::Constant(constant) => format!("{}", constant), + Statement::RepackVector(repack) => format!("{}", repack), _ => todo!(), }; let mut args_formatter = StatementFormatter::new(resolver); diff --git a/ptx/src/test/ll/const_ident.ll b/ptx/src/test/ll/const_ident.ll new file mode 100644 index 0000000..c927ef8 --- /dev/null +++ b/ptx/src/test/ll/const_ident.ll @@ -0,0 +1,55 @@ +@x = addrspace(4) global i64 1 +@y = addrspace(4) global [4 x i64] [i64 4, i64 5, i64 6, i64 0] +@constparams = addrspace(4) global [4 x i64] [i64 ptrtoint (ptr addrspace(4) @x to i64), i64 ptrtoint (ptr addrspace(4) @y to i64)] + +define amdgpu_kernel void @const_ident(ptr addrspace(4) byref(i64) %"49", ptr addrspace(4) byref(i64) %"50") #0 { + %"51" = alloca i64, align 8, addrspace(5) + %"52" = alloca i64, align 8, addrspace(5) + %"53" = alloca i64, align 8, addrspace(5) + %"54" = alloca i64, align 8, addrspace(5) + %"55" = alloca i64, align 8, addrspace(5) + %"56" = alloca i64, align 8, addrspace(5) + %"57" = alloca i64, align 8, addrspace(5) + %"58" = alloca i64, align 8, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"48" + +"48": ; preds = %1 + %"59" = load i64, ptr addrspace(4) %"49", align 8 + store i64 %"59", ptr addrspace(5) %"51", align 8 + %"60" = load i64, ptr addrspace(4) %"50", align 8 + store i64 %"60", ptr addrspace(5) %"52", align 8 + store i64 ptrtoint (ptr addrspace(4) @x to i64), ptr addrspace(5) %"53", align 8 + store i64 ptrtoint (ptr addrspace(4) @y to i64), ptr addrspace(5) %"54", align 8 + %"63" = load i64, ptr addrspace(4) @constparams, align 8 + store i64 %"63", ptr addrspace(5) %"55", align 8 + %"64" = load i64, ptr addrspace(4) getelementptr inbounds (i8, ptr addrspace(4) @constparams, i64 8), align 8 + store i64 %"64", ptr addrspace(5) %"56", align 8 + %"66" = load i64, ptr addrspace(5) %"53", align 8 + %"67" = load i64, ptr addrspace(5) %"55", align 8 + %"65" = xor i64 %"66", %"67" + store i64 %"65", ptr addrspace(5) %"57", align 8 + %"69" = load i64, ptr addrspace(5) %"54", align 8 + %"70" = load i64, ptr addrspace(5) %"56", align 8 + %"68" = xor i64 %"69", %"70" + store i64 %"68", ptr addrspace(5) %"58", align 8 + %"71" = load i64, ptr addrspace(5) %"52", align 8 + %"72" = load i64, ptr addrspace(5) %"57", align 8 + %"85" = inttoptr i64 %"71" to ptr + store i64 %"72", ptr %"85", align 8 + %"73" = load i64, ptr addrspace(5) %"52", align 8 + %"87" = inttoptr i64 %"73" to ptr + %"45" = getelementptr inbounds i8, ptr %"87", i64 8 + %"74" = load i64, ptr addrspace(5) %"58", align 8 + store i64 %"74", ptr %"45", align 8 + %"75" = load i64, ptr addrspace(5) %"52", align 8 + %"89" = inttoptr i64 %"75" to ptr + %"47" = getelementptr inbounds i8, ptr %"89", i64 8 + %"76" = load i64, ptr addrspace(5) %"58", align 8 + store i64 %"76", ptr %"47", align 8 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/ptx/src/test/ll/fma_bf16x2.ll b/ptx/src/test/ll/fma_bf16x2.ll new file mode 100644 index 0000000..ff7a638 --- /dev/null +++ b/ptx/src/test/ll/fma_bf16x2.ll @@ -0,0 +1,51 @@ +define amdgpu_kernel void @fma_bf16x2(ptr addrspace(4) byref(i64) %"39", ptr addrspace(4) byref(i64) %"40") #0 { + %"41" = alloca i64, align 8, addrspace(5) + %"42" = alloca i64, align 8, addrspace(5) + %"43" = alloca i32, align 4, addrspace(5) + %"44" = alloca i32, align 4, addrspace(5) + %"45" = alloca i32, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"38" + +"38": ; preds = %1 + %"46" = load i64, ptr addrspace(4) %"39", align 8 + store i64 %"46", ptr addrspace(5) %"41", align 8 + %"47" = load i64, ptr addrspace(4) %"40", align 8 + store i64 %"47", ptr addrspace(5) %"42", align 8 + %"49" = load i64, ptr addrspace(5) %"41", align 8 + %"60" = inttoptr i64 %"49" to ptr + %"48" = load i32, ptr %"60", align 4 + store i32 %"48", ptr addrspace(5) %"43", align 4 + %"50" = load i64, ptr addrspace(5) %"41", align 8 + %"61" = inttoptr i64 %"50" to ptr + %"35" = getelementptr inbounds i8, ptr %"61", i64 4 + %"51" = load i32, ptr %"35", align 4 + store i32 %"51", ptr addrspace(5) %"44", align 4 + %"52" = load i64, ptr addrspace(5) %"41", align 8 + %"62" = inttoptr i64 %"52" to ptr + %"37" = getelementptr inbounds i8, ptr %"62", i64 8 + %"53" = load i32, ptr %"37", align 4 + store i32 %"53", ptr addrspace(5) %"45", align 4 + %"55" = load i32, ptr addrspace(5) %"43", align 4 + %"56" = load i32, ptr addrspace(5) %"44", align 4 + %"57" = load i32, ptr addrspace(5) %"45", align 4 + %"64" = bitcast i32 %"55" to <2 x bfloat> + %"65" = bitcast i32 %"56" to <2 x bfloat> + %"66" = bitcast i32 %"57" to <2 x bfloat> + %"63" = call <2 x bfloat> @llvm.fma.v2bf16(<2 x bfloat> %"64", <2 x bfloat> %"65", <2 x bfloat> %"66") + %"54" = bitcast <2 x bfloat> %"63" to i32 + store i32 %"54", ptr addrspace(5) %"43", align 4 + %"58" = load i64, ptr addrspace(5) %"42", align 8 + %"59" = load i32, ptr addrspace(5) %"43", align 4 + %"67" = inttoptr i64 %"58" to ptr + store i32 %"59", ptr %"67", align 4 + ret void +} + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare <2 x bfloat> @llvm.fma.v2bf16(<2 x bfloat>, <2 x bfloat>, <2 x bfloat>) #1 + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } diff --git a/ptx/src/test/ll/global_array_f32.ll b/ptx/src/test/ll/global_array_f32.ll new file mode 100644 index 0000000..201a754 --- /dev/null +++ b/ptx/src/test/ll/global_array_f32.ll @@ -0,0 +1,28 @@ +@foobar = addrspace(1) global [4 x float] [float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00] + +define amdgpu_kernel void @global_array_f32(ptr addrspace(4) byref(i64) %"36", ptr addrspace(4) byref(i64) %"37") #0 { + %"38" = alloca i64, align 8, addrspace(5) + %"39" = alloca i64, align 8, addrspace(5) + %"40" = alloca float, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"35" + +"35": ; preds = %1 + store i64 ptrtoint (ptr addrspace(1) @foobar to i64), ptr addrspace(5) %"38", align 8 + %"42" = load i64, ptr addrspace(4) %"37", align 8 + store i64 %"42", ptr addrspace(5) %"39", align 8 + %"43" = load i64, ptr addrspace(5) %"38", align 8 + %"48" = inttoptr i64 %"43" to ptr addrspace(1) + %"34" = getelementptr inbounds i8, ptr addrspace(1) %"48", i64 4 + %"44" = load float, ptr addrspace(1) %"34", align 4 + store float %"44", ptr addrspace(5) %"40", align 4 + %"45" = load i64, ptr addrspace(5) %"39", align 8 + %"46" = load float, ptr addrspace(5) %"40", align 4 + %"49" = inttoptr i64 %"45" to ptr + store float %"46", ptr %"49", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/ptx/src/test/ll/vector_operand.ll b/ptx/src/test/ll/vector_operand.ll new file mode 100644 index 0000000..564d0da --- /dev/null +++ b/ptx/src/test/ll/vector_operand.ll @@ -0,0 +1,31 @@ +define amdgpu_kernel void @vector_operand(ptr addrspace(4) byref(i64) %"36", ptr addrspace(4) byref(i64) %"37") #0 { + %"38" = alloca i64, align 8, addrspace(5) + %"39" = alloca i64, align 8, addrspace(5) + %"40" = alloca i16, align 2, addrspace(5) + %"41" = alloca i32, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"35" + +"35": ; preds = %1 + %"42" = load i64, ptr addrspace(4) %"36", align 8 + store i64 %"42", ptr addrspace(5) %"38", align 8 + %"43" = load i64, ptr addrspace(4) %"37", align 8 + store i64 %"43", ptr addrspace(5) %"39", align 8 + %"45" = load i64, ptr addrspace(5) %"38", align 8 + %"50" = inttoptr i64 %"45" to ptr + %"44" = load i16, ptr %"50", align 2 + store i16 %"44", ptr addrspace(5) %"40", align 2 + %"46" = load i16, ptr addrspace(5) %"40", align 2 + %"34" = insertelement <2 x i16> , i16 %"46", i8 1 + %"51" = bitcast <2 x i16> %"34" to i32 + store i32 %"51", ptr addrspace(5) %"41", align 4 + %"48" = load i64, ptr addrspace(5) %"39", align 8 + %"49" = load i32, ptr addrspace(5) %"41", align 4 + %"52" = inttoptr i64 %"48" to ptr + store i32 %"49", ptr %"52", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/ptx/src/test/operands.ptx b/ptx/src/test/operands.ptx index 67c59f5..e074251 100644 --- a/ptx/src/test/operands.ptx +++ b/ptx/src/test/operands.ptx @@ -7,6 +7,7 @@ ) { .reg .u32 %reg<10>; + .reg .b16 %reg_16; .reg .u64 %reg_64; .reg .pred p; .reg .pred q; @@ -30,4 +31,7 @@ // vector index - only supported by mov (maybe: ld, st, tex) mov.u32 %reg0, %ntid.x; + + // vector operand + mov.u32 %reg0, {0, %reg_16}; } diff --git a/ptx/src/test/spirv_run/const_ident.ptx b/ptx/src/test/spirv_run/const_ident.ptx new file mode 100644 index 0000000..9e50e8e --- /dev/null +++ b/ptx/src/test/spirv_run/const_ident.ptx @@ -0,0 +1,39 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.const .u64 x = 1; +.const .u64 y[4] = {4,5,6}; +.const .u64 constparams[2] = { x, y }; + +.visible .entry const_ident( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 x_addr; + .reg .u64 y_addr; + .reg .u64 constparams_0; + .reg .u64 constparams_1; + .reg .b64 x_equal; + .reg .b64 y_equal; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + mov.u64 x_addr, x; + mov.u64 y_addr, y; + + ld.const.u64 constparams_0, [constparams+0]; + ld.const.u64 constparams_1, [constparams+8]; + + xor.b64 x_equal, x_addr, constparams_0; + xor.b64 y_equal, y_addr, constparams_1; + + st.u64 [out_addr], x_equal; + st.u64 [out_addr+8], y_equal; + st.u64 [out_addr+8], y_equal; + ret; +} diff --git a/ptx/src/test/spirv_run/fma_bf16x2.ptx b/ptx/src/test/spirv_run/fma_bf16x2.ptx new file mode 100644 index 0000000..ac112bd --- /dev/null +++ b/ptx/src/test/spirv_run/fma_bf16x2.ptx @@ -0,0 +1,25 @@ +.version 7.0 +.target sm_80 +.address_size 64 + +.visible .entry fma_bf16x2( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .b32 temp1; + .reg .b32 temp2; + .reg .b32 temp3; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.b32 temp1, [in_addr]; + ld.b32 temp2, [in_addr+4]; + ld.b32 temp3, [in_addr+8]; + fma.rn.bf16x2 temp1, temp1, temp2, temp3; + st.b32 [out_addr], temp1; + ret; +} diff --git a/ptx/src/test/spirv_run/global_array_f32.ptx b/ptx/src/test/spirv_run/global_array_f32.ptx new file mode 100644 index 0000000..070fdac --- /dev/null +++ b/ptx/src/test/spirv_run/global_array_f32.ptx @@ -0,0 +1,22 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.global .f32 foobar[4] = {0f3f800000}; + +.visible .entry global_array_f32( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .f32 temp; + + mov.u64 in_addr, foobar; + ld.param.u64 out_addr, [output]; + + ld.global.f32 temp, [in_addr+4]; + st.f32 [out_addr], temp; + ret; +} \ No newline at end of file diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 39e58a5..f413a23 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -137,6 +137,7 @@ test_ptx!( [0x1_00_00_00_00_00_00i64] ); test_ptx!(vector_extract, [1u8, 2u8, 3u8, 4u8], [3u8, 4u8, 1u8, 2u8]); +test_ptx!(vector_operand, [0x1234u16], [0x12345678]); test_ptx!(shr, [-2i32], [-1i32]); test_ptx!(shr_oob, [-32768i16], [-1i16]); test_ptx!(or, [1u64, 2u64], [3u64]); @@ -144,6 +145,7 @@ test_ptx!(sub, [2u64], [1u64]); test_ptx!(min, [555i32, 444i32], [444i32]); test_ptx!(max, [555i32, 444i32], [555i32]); test_ptx!(global_array, [0xDEADu32], [1u32]); +test_ptx!(global_array_f32, [0x0], [0f32]); test_ptx!(extern_shared, [127u64], [127u64]); test_ptx!(extern_shared_call, [121u64], [123u64]); test_ptx!(rcp, [2f32], [0.5f32]); @@ -166,6 +168,11 @@ test_ptx!(and, [6u32, 3u32], [2u32]); test_ptx!(selp, [100u16, 200u16], [200u16]); test_ptx!(selp_true, [100u16, 200u16], [100u16]); test_ptx!(fma, [2f32, 3f32, 5f32], [11f32]); +test_ptx!( + fma_bf16x2, + [0x40004040, 0x40404080, 0x40A04040], + [0x41304170] +); test_ptx!(shared_variable, [513u64], [513u64]); test_ptx!(shared_ptr_32, [513u64], [513u64]); test_ptx!(atom_cas, [91u32, 91u32], [91u32, 100u32]); @@ -256,6 +263,7 @@ test_ptx!( test_ptx!(non_scalar_ptr_offset, [1u32, 2u32, 3u32, 4u32], [7u32]); test_ptx!(stateful_neg_offset, [1237518u64], [1237518u64]); test_ptx!(const, [0u16], [10u16, 20, 30, 40]); +test_ptx!(const_ident, [0u16], [0u64, 0u64]); test_ptx!(cvt_s16_s8, [0x139231C2u32], [0xFFFFFFC2u32]); test_ptx!(cvt_f64_f32, [0.125f32], [0.125f64]); test_ptx!(prmt, [0x70c507d6u32, 0x6fbd4b5cu32], [0x6fbdd65cu32]); diff --git a/ptx/src/test/spirv_run/vector_operand.ptx b/ptx/src/test/spirv_run/vector_operand.ptx new file mode 100644 index 0000000..a83eeae --- /dev/null +++ b/ptx/src/test/spirv_run/vector_operand.ptx @@ -0,0 +1,25 @@ +.version 6.5 +.target sm_60 +.address_size 64 + +.visible .entry vector_operand( + .param .u64 input_p, + .param .u64 output_p +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + + .reg .b16 in; + .reg .b32 out; + + ld.param.u64 in_addr, [input_p]; + ld.param.u64 out_addr, [output_p]; + + ld.b16 in, [in_addr]; + + mov.b32 out, {0x5678, in}; + + st.b32 [out_addr], out; + ret; +} \ No newline at end of file diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index e186653..0570eb4 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -7,6 +7,7 @@ use crate::{ ShuffleMode, VoteMode, }; use bitflags::bitflags; +use derive_more::Display; use std::{alloc::Layout, cmp::Ordering, fmt::Write, num::NonZeroU8}; pub enum Statement { @@ -807,7 +808,17 @@ where ), ParsedOperand::VecPack(vec) => ParsedOperand::VecPack( vec.into_iter() - .map(|ident| (self)(ident, type_space, is_dst, relaxed_type_check)) + .map(|reg_or_immediate| { + Ok(match reg_or_immediate { + RegOrImmediate::Reg(ident) => RegOrImmediate::Reg((self)( + ident, + type_space, + is_dst, + relaxed_type_check, + )?), + RegOrImmediate::Imm(imm) => RegOrImmediate::Imm(imm), + }) + }) .collect::, _>>()?, ), }) @@ -948,7 +959,7 @@ pub struct Variable { pub v_type: Type, pub state_space: StateSpace, pub name: ID, - pub array_init: Vec, + pub array_init: Vec>, } impl std::fmt::Display for Variable { @@ -1004,6 +1015,7 @@ impl std::fmt::Display for Type { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Type::Scalar(scalar_type) => write!(f, "{}", scalar_type), + Type::Vector(count, scalar_type) => write!(f, ".v{}{}", count, scalar_type), _ => todo!(), } } @@ -1062,6 +1074,15 @@ impl Type { } impl ScalarType { + pub fn from_size(size: u8) -> Option { + Some(match size { + 1 => ScalarType::B8, + 2 => ScalarType::B16, + 4 => ScalarType::B32, + 16 => ScalarType::B128, + _ => return None, + }) + } pub fn size_of(self) -> u8 { match self { ScalarType::U8 | ScalarType::S8 | ScalarType::B8 => 1, @@ -1212,13 +1233,19 @@ pub struct ShfDetails { pub mode: FunnelShiftMode, } +#[derive(Clone, Copy, Display)] +pub enum RegOrImmediate { + Reg(Ident), + Imm(ImmediateValue), +} + #[derive(Clone)] pub enum ParsedOperand { Reg(Ident), RegOffset(Ident, i32), Imm(ImmediateValue), VecMember(Ident, u8), - VecPack(Vec), + VecPack(Vec>), } impl ParsedOperand { diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 6201e43..29b348a 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -332,6 +332,19 @@ fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult( + stream: &mut PtxParser<'a, 'input>, +) -> PResult> { + trace( + "reg_or_immediate", + alt(( + immediate_value.map(|imm| ast::RegOrImmediate::Imm(imm)), + ident.map(|id| ast::RegOrImmediate::Reg(id)), + )), + ) + .parse_next(stream) +} + pub fn parse_for_errors<'input>(text: &'input str) -> Vec> { let (tokens, mut errors) = lex_with_span_unchecked(text); let parse_result = { @@ -964,9 +977,9 @@ fn multi_variable<'a, 'input: 'a>( let initializer = match state_space { StateSpace::Global | StateSpace::Const => match array_dimensions { Some(ref mut dimensions) => { - opt(array_initializer(vector, type_, dimensions)).parse_next(stream)? + opt(array_initializer(type_, vector, dimensions)).parse_next(stream)? } - None => opt(value_initializer(vector, type_)).parse_next(stream)?, + None => opt(value_initializer(vector)).parse_next(stream)?, }, _ => None, }; @@ -990,10 +1003,10 @@ fn multi_variable<'a, 'input: 'a>( } fn array_initializer<'b, 'a: 'b, 'input: 'a>( - vector: Option, type_: ScalarType, + vector: Option, array_dimensions: &'b mut Vec, -) -> impl Parser, Vec, ContextError> + 'b { +) -> impl Parser, Vec>, ContextError> + 'b { trace( "array_initializer", move |stream: &mut PtxParser<'a, 'input>| { @@ -1007,15 +1020,24 @@ fn array_initializer<'b, 'a: 'b, 'input: 'a>( Token::LBrace, separated::<_, (), (), _, _, _, _>( 0..=array_dimensions[0] as usize, - single_value_append(&mut result, type_), + single_value_append(&mut result), Token::Comma, ), Token::RBrace, ) .parse_next(stream)?; // pad with zeros - let result_size = type_.size_of() as usize * array_dimensions[0] as usize; - result.extend(iter::repeat(0u8).take(result_size - result.len())); + let result_size = array_dimensions[0] as usize; + let default = match type_.kind() { + ScalarKind::Bit | ScalarKind::Unsigned | ScalarKind::Pred => { + ast::ImmediateValue::U64(0) + } + ScalarKind::Signed => ast::ImmediateValue::S64(0), + ScalarKind::Float => ast::ImmediateValue::F64(0.0), + }; + result.extend( + iter::repeat(ast::RegOrImmediate::Imm(default)).take(result_size - result.len()), + ); Ok(result) }, ) @@ -1023,8 +1045,7 @@ fn array_initializer<'b, 'a: 'b, 'input: 'a>( fn value_initializer<'a, 'input: 'a>( vector: Option, - type_: ScalarType, -) -> impl Parser, Vec, ContextError> { +) -> impl Parser, Vec>, ContextError> { trace( "value_initializer", move |stream: &mut PtxParser<'a, 'input>| { @@ -1034,77 +1055,20 @@ fn value_initializer<'a, 'input: 'a>( if vector.is_some() { return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)); } - single_value_append(&mut result, type_).parse_next(stream)?; + single_value_append(&mut result).parse_next(stream)?; Ok(result) }, ) } fn single_value_append<'b, 'a: 'b, 'input: 'a>( - accumulator: &'b mut Vec, - type_: ScalarType, + accumulator: &'b mut Vec>, ) -> impl Parser, (), ContextError> + 'b { trace( "single_value_append", move |stream: &mut PtxParser<'a, 'input>| { - let value = immediate_value.parse_next(stream)?; - match (type_, value) { - (ScalarType::U8 | ScalarType::B8, ImmediateValue::U64(x)) => { - accumulator.extend_from_slice(&(x as u8).to_le_bytes()) - } - (ScalarType::U8 | ScalarType::B8, ImmediateValue::S64(x)) => { - accumulator.extend_from_slice(&(x as u8).to_le_bytes()) - } - (ScalarType::U16 | ScalarType::B16, ImmediateValue::U64(x)) => { - accumulator.extend_from_slice(&(x as u16).to_le_bytes()) - } - (ScalarType::U16 | ScalarType::B16, ImmediateValue::S64(x)) => { - accumulator.extend_from_slice(&(x as u16).to_le_bytes()) - } - (ScalarType::U32 | ScalarType::B32, ImmediateValue::U64(x)) => { - accumulator.extend_from_slice(&(x as u32).to_le_bytes()) - } - (ScalarType::U32 | ScalarType::B32, ImmediateValue::S64(x)) => { - accumulator.extend_from_slice(&(x as u32).to_le_bytes()) - } - (ScalarType::U64 | ScalarType::B64, ImmediateValue::U64(x)) => { - accumulator.extend_from_slice(&(x as u64).to_le_bytes()) - } - (ScalarType::U64 | ScalarType::B64, ImmediateValue::S64(x)) => { - accumulator.extend_from_slice(&(x as u64).to_le_bytes()) - } - (ScalarType::S8, ImmediateValue::U64(x)) => { - accumulator.extend_from_slice(&(x as i8).to_le_bytes()) - } - (ScalarType::S8, ImmediateValue::S64(x)) => { - accumulator.extend_from_slice(&(x as i8).to_le_bytes()) - } - (ScalarType::S16, ImmediateValue::U64(x)) => { - accumulator.extend_from_slice(&(x as i16).to_le_bytes()) - } - (ScalarType::S16, ImmediateValue::S64(x)) => { - accumulator.extend_from_slice(&(x as i16).to_le_bytes()) - } - (ScalarType::S32, ImmediateValue::U64(x)) => { - accumulator.extend_from_slice(&(x as i32).to_le_bytes()) - } - (ScalarType::S32, ImmediateValue::S64(x)) => { - accumulator.extend_from_slice(&(x as i32).to_le_bytes()) - } - (ScalarType::S64, ImmediateValue::U64(x)) => { - accumulator.extend_from_slice(&(x as i64).to_le_bytes()) - } - (ScalarType::S64, ImmediateValue::S64(x)) => { - accumulator.extend_from_slice(&(x as i64).to_le_bytes()) - } - (ScalarType::F32, ImmediateValue::F32(x)) => { - accumulator.extend_from_slice(&x.to_le_bytes()) - } - (ScalarType::F64, ImmediateValue::F64(x)) => { - accumulator.extend_from_slice(&x.to_le_bytes()) - } - _ => return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)), - } + let value = reg_or_immediate.parse_next(stream)?; + accumulator.push(value); Ok(()) }, ) @@ -1379,12 +1343,18 @@ impl ast::ParsedOperand { } fn vector_operand<'a, 'input>( stream: &mut PtxParser<'a, 'input>, - ) -> PResult> { - let (_, r1, _, r2) = (Token::LBrace, ident, Token::Comma, ident).parse_next(stream)?; + ) -> PResult>> { + let (_, r1, _, r2) = ( + Token::LBrace, + reg_or_immediate, + Token::Comma, + reg_or_immediate, + ) + .parse_next(stream)?; // TODO: parse .v8 literals dispatch! {any; (Token::RBrace, _) => empty.map(|_| vec![r1, r2]), - (Token::Comma, _) => (ident, Token::Comma, ident, Token::RBrace).map(|(r3, _, r4, _)| vec![r1, r2, r3, r4]), + (Token::Comma, _) => (reg_or_immediate, Token::Comma, reg_or_immediate, Token::RBrace).map(|(r3, _, r4, _)| vec![r1, r2, r3, r4]), _ => fail } .parse_next(stream) @@ -1421,7 +1391,7 @@ pub enum PtxError<'input> { #[from] source: TokenError, }, - #[error("{0}")] + #[error("Context error: {0}")] Parser(ContextError), #[error("")] Todo, @@ -2772,14 +2742,30 @@ derive_parser!( arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c } } } + .rnd: RawRoundingMode = { .rn }; + ScalarType = { .f16 }; //fma.rnd{.ftz}{.sat}.f16x2 d, a, b, c; //fma.rnd{.ftz}.relu.f16 d, a, b, c; //fma.rnd{.ftz}.relu.f16x2 d, a, b, c; //fma.rnd{.relu}.bf16 d, a, b, c; - //fma.rnd{.relu}.bf16x2 d, a, b, c; - //fma.rnd.oob.{relu}.type d, a, b, c; + fma.rnd{.relu}.bf16x2 d, a, b, c => { + if relu { + state.errors.push(PtxError::Todo); + } + ast::Instruction::Fma { + data: ast::ArithFloat { + type_: bf16x2, + rounding: rnd.into(), + flush_to_zero: None, + saturate: false, + is_fusable: false + }, + arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c } + } + } .rnd: RawRoundingMode = { .rn }; - ScalarType = { .f16 }; + ScalarType = { .bf16x2 }; + //fma.rnd.oob.{relu}.type d, a, b, c; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-sub // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sub