From 40bdb83e6b80c169e9ab38e332dc3d633e8b0066 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Mon, 26 Oct 2020 01:49:25 +0100 Subject: [PATCH] Support float constants --- ptx/src/ast.rs | 56 ++++--- ptx/src/ptx.lalrpop | 158 ++++++++++++++---- ptx/src/test/spirv_run/constant_f32.ptx | 21 +++ ptx/src/test/spirv_run/constant_f32.spvtxt | 57 +++++++ ptx/src/test/spirv_run/constant_negative.ptx | 21 +++ .../test/spirv_run/constant_negative.spvtxt | 56 +++++++ ptx/src/test/spirv_run/mod.rs | 2 + ptx/src/translate.rs | 96 ++++++++--- 8 files changed, 385 insertions(+), 82 deletions(-) create mode 100644 ptx/src/test/spirv_run/constant_f32.ptx create mode 100644 ptx/src/test/spirv_run/constant_f32.spvtxt create mode 100644 ptx/src/test/spirv_run/constant_negative.ptx create mode 100644 ptx/src/test/spirv_run/constant_negative.spvtxt diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index b045a83..d858d06 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -621,17 +621,25 @@ pub struct Arg5 { pub src3: P::Operand, } +#[derive(Copy, Clone)] +pub enum ImmediateValue { + U64(u64), + S64(i64), + F32(f32), + F64(f64), +} + #[derive(Copy, Clone)] pub enum Operand { Reg(ID), RegOffset(ID, i32), - Imm(u32), + Imm(ImmediateValue), } #[derive(Copy, Clone)] pub enum CallOperand { Reg(ID), - Imm(u32), + Imm(ImmediateValue), } pub enum IdOrVector { @@ -642,7 +650,7 @@ pub enum IdOrVector { pub enum OperandOrVector { Reg(ID), RegOffset(ID, i32), - Imm(u32), + Imm(ImmediateValue), Vec(Vec), } @@ -1028,7 +1036,7 @@ pub struct MinMaxFloat { } pub enum NumsOrArrays<'a> { - Nums(Vec<&'a str>), + Nums(Vec<(&'a str, u32)>), Arrays(Vec>), } @@ -1076,8 +1084,8 @@ impl<'a> NumsOrArrays<'a> { if vec.len() > *dim as usize { return Err(PtxError::ZeroDimensionArray); } - for (idx, val) in vec.iter().enumerate() { - Self::parse_and_copy_single(t, idx, val, result)?; + for (idx, (val, radix)) in vec.iter().enumerate() { + Self::parse_and_copy_single(t, idx, val, *radix, result)?; } } NumsOrArrays::Arrays(_) => return Err(PtxError::ZeroDimensionArray), @@ -1107,42 +1115,43 @@ impl<'a> NumsOrArrays<'a> { t: SizedScalarType, idx: usize, str_val: &str, + radix: u32, output: &mut [u8], ) -> Result<(), PtxError> { match t { SizedScalarType::B8 | SizedScalarType::U8 => { - Self::parse_and_copy_single_t::(idx, str_val, output)?; + Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } SizedScalarType::B16 | SizedScalarType::U16 => { - Self::parse_and_copy_single_t::(idx, str_val, output)?; + Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } SizedScalarType::B32 | SizedScalarType::U32 => { - Self::parse_and_copy_single_t::(idx, str_val, output)?; + Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } SizedScalarType::B64 | SizedScalarType::U64 => { - Self::parse_and_copy_single_t::(idx, str_val, output)?; + Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } SizedScalarType::S8 => { - Self::parse_and_copy_single_t::(idx, str_val, output)?; + Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } SizedScalarType::S16 => { - Self::parse_and_copy_single_t::(idx, str_val, output)?; + Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } SizedScalarType::S32 => { - Self::parse_and_copy_single_t::(idx, str_val, output)?; + Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } SizedScalarType::S64 => { - Self::parse_and_copy_single_t::(idx, str_val, output)?; + Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } SizedScalarType::F16 => { - Self::parse_and_copy_single_t::(idx, str_val, output)?; + Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } SizedScalarType::F16x2 => todo!(), SizedScalarType::F32 => { - Self::parse_and_copy_single_t::(idx, str_val, output)?; + Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } SizedScalarType::F64 => { - Self::parse_and_copy_single_t::(idx, str_val, output)?; + Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } } Ok(()) @@ -1151,6 +1160,7 @@ impl<'a> NumsOrArrays<'a> { fn parse_and_copy_single_t( idx: usize, str_val: &str, + _radix: u32, // TODO: use this to properly support hex literals output: &mut [u8], ) -> Result<(), PtxError> where @@ -1200,8 +1210,8 @@ mod tests { #[test] fn array_auto_sizes_0_dimension() { let inp = NumsOrArrays::Arrays(vec![ - NumsOrArrays::Nums(vec!["1", "2"]), - NumsOrArrays::Nums(vec!["3", "4"]), + NumsOrArrays::Nums(vec![("1", 10), ("2", 10)]), + NumsOrArrays::Nums(vec![("3", 10), ("4", 10)]), ]); let mut dimensions = vec![0u32, 2]; assert_eq!( @@ -1214,8 +1224,8 @@ mod tests { #[test] fn array_fails_wrong_structure() { let inp = NumsOrArrays::Arrays(vec![ - NumsOrArrays::Nums(vec!["1", "2"]), - NumsOrArrays::Arrays(vec![NumsOrArrays::Nums(vec!["1"])]), + NumsOrArrays::Nums(vec![("1", 10), ("2", 10)]), + NumsOrArrays::Arrays(vec![NumsOrArrays::Nums(vec![("1", 10)])]), ]); let mut dimensions = vec![0u32, 2]; assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err()); @@ -1224,8 +1234,8 @@ mod tests { #[test] fn array_fails_too_long_component() { let inp = NumsOrArrays::Arrays(vec![ - NumsOrArrays::Nums(vec!["1", "2", "3"]), - NumsOrArrays::Nums(vec!["4", "5"]), + NumsOrArrays::Nums(vec![("1", 10), ("2", 10), ("3", 10)]), + NumsOrArrays::Nums(vec![("4", 10), ("5", 10)]), ]); let mut dimensions = vec![0u32, 2]; assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err()); diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 163a233..d445baa 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -15,12 +15,16 @@ match { r"\s+" => { }, r"//[^\n\r]*[\n\r]*" => { }, r"/\*([^\*]*\*+[^\*/])*([^\*]*\*+|[^\*])*\*/" => { }, - r"-?[?:0x]?[0-9]+" => Num, + r"0[fF][0-9a-zA-Z]{8}" => F32NumToken, + r"0[dD][0-9a-zA-Z]{16}" => F64NumToken, + r"0[xX][0-9a-zA-Z]+U?" => HexNumToken, + r"[0-9]+U?" => DecimalNumToken, r#""[^"]*""# => String, r"[0-9]+\.[0-9]+" => VersionNumber, "!", "(", ")", "+", + "-", ",", ".", ":", @@ -181,6 +185,74 @@ ExtendedID : &'input str = { ID } +NumToken: (&'input str, u32, bool) = { + => { + if s.ends_with('U') { + (&s[2..s.len() - 1], 16, true) + } else { + (&s[2..], 16, false) + } + }, + => { + let radix = if s.starts_with('0') { 8 } else { 10 }; + if s.ends_with('U') { + (&s[..s.len() - 1], radix, true) + } else { + (s, radix, false) + } + } +} + +F32Num: f32 = { + =>? { + match u32::from_str_radix(&s[2..], 16) { + Ok(x) => Ok(unsafe { std::mem::transmute::<_, f32>(x) }), + Err(err) => Err(ParseError::User { error: ast::PtxError::from(err) }) + } + + } +} + +F64Num: f64 = { + =>? { + match u64::from_str_radix(&s[2..], 16) { + Ok(x) => Ok(unsafe { std::mem::transmute::<_, f64>(x) }), + Err(err) => Err(ParseError::User { error: ast::PtxError::from(err) }) + } + } +} + +U8Num: u8 = { + =>? { + let (text, radix, _) = x; + match u8::from_str_radix(text, radix) { + Ok(x) => Ok(x), + Err(err) => Err(ParseError::User { error: ast::PtxError::from(err) }) + } + } +} + +U32Num: u32 = { + =>? { + let (text, radix, _) = x; + match u32::from_str_radix(text, radix) { + Ok(x) => Ok(x), + Err(err) => Err(ParseError::User { error: ast::PtxError::from(err) }) + } + } +} + +// TODO: handle negative number properly +S32Num: i32 = { + =>? { + let (text, radix, _) = x; + match i32::from_str_radix(text, radix) { + Ok(x) => Ok(if sign.is_some() { -x } else { x }), + Err(err) => Err(ParseError::User { error: ast::PtxError::from(err) }) + } + } +} + pub Module: ast::Module<'input> = { Target => { ast::Module { version: v, directives: without_none(d) } @@ -218,7 +290,7 @@ Directive: Option>> = { }; AddressSize = { - ".address_size" Num + ".address_size" U8Num }; Function: ast::Function<'input, &'input str, ast::Statement>> = { @@ -328,7 +400,7 @@ DebugDirective: () = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#debugging-directives-loc DebugLocation = { - ".loc" Num Num Num + ".loc" U32Num U32Num U32Num }; Label: &'input str = { @@ -336,10 +408,7 @@ Label: &'input str = { }; Align: u32 = { - ".align" => { - let align = a.parse::(); - align.unwrap_with(errors) - } + ".align" => x }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names @@ -348,10 +417,7 @@ MultiVariable: ast::MultiVariable<&'input str> = { } VariableParam: u32 = { - "<" ">" => { - let size = n.parse::(); - size.unwrap_with(errors) - } + "<" ">" => n } Variable: ast::Variable = { @@ -1239,29 +1305,51 @@ ArithFloat: ast::ArithFloat = { Operand: ast::Operand<&'input str> = { => ast::Operand::Reg(r), - "+" => { - let offset = o.parse::(); - let offset = offset.unwrap_with(errors); - ast::Operand::RegOffset(r, offset) - }, - // TODO: start parsing whole constants sub-language: - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#constants - => { - let offset = o.parse::(); - let offset = offset.unwrap_with(errors); - ast::Operand::Imm(offset) - } + "+" => ast::Operand::RegOffset(r, offset), + => ast::Operand::Imm(x) }; CallOperand: ast::CallOperand<&'input str> = { => ast::CallOperand::Reg(r), - => { - let offset = o.parse::(); - let offset = offset.unwrap_with(errors); - ast::CallOperand::Imm(offset) - } + => ast::CallOperand::Imm(x) }; +// TODO: start parsing whole constants sub-language: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#constants +ImmediateValue: ast::ImmediateValue = { + // TODO: treat negation correctly + =>? { + let (num, radix, is_unsigned) = x; + if neg.is_some() { + match i64::from_str_radix(num, radix) { + Ok(x) => Ok(ast::ImmediateValue::S64(-x)), + Err(err) => Err(ParseError::User { error: ast::PtxError::ParseInt(err) }) + } + } else if is_unsigned { + match u64::from_str_radix(num, radix) { + Ok(x) => Ok(ast::ImmediateValue::U64(x)), + Err(err) => Err(ParseError::User { error: ast::PtxError::ParseInt(err) }) + } + } else { + match i64::from_str_radix(num, radix) { + Ok(x) => Ok(ast::ImmediateValue::S64(x)), + Err(_) => { + match u64::from_str_radix(num, radix) { + Ok(x) => Ok(ast::ImmediateValue::U64(x)), + Err(err) => Err(ParseError::User { error: ast::PtxError::ParseInt(err) }) + } + } + } + } + }, + => { + ast::ImmediateValue::F32(f) + }, + => { + ast::ImmediateValue::F64(f) + } +} + Arg1: ast::Arg1> = { => ast::Arg1{<>} }; @@ -1332,7 +1420,7 @@ VectorPrefix: u8 = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#debugging-directives-file File = { - ".file" Num String ("," Num "," Num)? + ".file" U32Num String ("," U32Num "," U32Num)? }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#debugging-directives-section @@ -1341,11 +1429,11 @@ Section = { }; SectionDwarfLines: () = { - BitType Comma, + BitType Comma, ".b32" SectionLabel, ".b64" SectionLabel, - ".b32" SectionLabel "+" Num, - ".b64" SectionLabel "+" Num, + ".b32" SectionLabel "+" U32Num, + ".b64" SectionLabel "+" U32Num, }; SectionLabel = { @@ -1409,9 +1497,7 @@ ArrayEmptyDimension = { } ArrayDimension: u32 = { - "[" "]" =>? { - str::parse::(n).map_err(|e| ParseError::User { error: ast::PtxError::ParseInt(e) }) - } + "[" "]" => n, } ArrayInitializer: ast::NumsOrArrays<'input> = { @@ -1424,7 +1510,7 @@ NumsOrArraysBracket: ast::NumsOrArrays<'input> = { NumsOrArrays: ast::NumsOrArrays<'input> = { > => ast::NumsOrArrays::Arrays(n), - > => ast::NumsOrArrays::Nums(n), + > => ast::NumsOrArrays::Nums(n.into_iter().map(|(x,radix,_)| (x, radix)).collect()), } Comma: Vec = { diff --git a/ptx/src/test/spirv_run/constant_f32.ptx b/ptx/src/test/spirv_run/constant_f32.ptx new file mode 100644 index 0000000..8894658 --- /dev/null +++ b/ptx/src/test/spirv_run/constant_f32.ptx @@ -0,0 +1,21 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry constant_f32( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .f32 temp; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.f32 temp, [in_addr]; + mul.f32 temp, temp, 0f3f000000; // 0.5 + st.f32 [out_addr], temp; + ret; +} diff --git a/ptx/src/test/spirv_run/constant_f32.spvtxt b/ptx/src/test/spirv_run/constant_f32.spvtxt new file mode 100644 index 0000000..905bec4 --- /dev/null +++ b/ptx/src/test/spirv_run/constant_f32.spvtxt @@ -0,0 +1,57 @@ +; SPIR-V +; Version: 1.3 +; Generator: rspirv +; Bound: 32 +OpCapability GenericPointer +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpCapability Int8 +OpCapability Int16 +OpCapability Int64 +OpCapability Float16 +OpCapability Float64 +OpCapability FunctionFloatControlINTEL +OpExtension "SPV_INTEL_float_controls2" +%24 = OpExtInstImport "OpenCL.std" +OpMemoryModel Physical64 OpenCL +OpEntryPoint Kernel %1 "constant_f32" +OpDecorate %1 FunctionDenormModeINTEL 32 Preserve +%25 = OpTypeVoid +%26 = OpTypeInt 64 0 +%27 = OpTypeFunction %25 %26 %26 +%28 = OpTypePointer Function %26 +%29 = OpTypeFloat 32 +%30 = OpTypePointer Function %29 +%31 = OpTypePointer Generic %29 +%19 = OpConstant %29 0.5 +%1 = OpFunction %25 None %27 +%7 = OpFunctionParameter %26 +%8 = OpFunctionParameter %26 +%22 = OpLabel +%2 = OpVariable %28 Function +%3 = OpVariable %28 Function +%4 = OpVariable %28 Function +%5 = OpVariable %28 Function +%6 = OpVariable %30 Function +OpStore %2 %7 +OpStore %3 %8 +%10 = OpLoad %26 %2 +%9 = OpCopyObject %26 %10 +OpStore %4 %9 +%12 = OpLoad %26 %3 +%11 = OpCopyObject %26 %12 +OpStore %5 %11 +%14 = OpLoad %26 %4 +%20 = OpConvertUToPtr %31 %14 +%13 = OpLoad %29 %20 +OpStore %6 %13 +%16 = OpLoad %29 %6 +%15 = OpFMul %29 %16 %19 +OpStore %6 %15 +%17 = OpLoad %26 %5 +%18 = OpLoad %29 %6 +%21 = OpConvertUToPtr %31 %17 +OpStore %21 %18 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/ptx/src/test/spirv_run/constant_negative.ptx b/ptx/src/test/spirv_run/constant_negative.ptx new file mode 100644 index 0000000..c723c38 --- /dev/null +++ b/ptx/src/test/spirv_run/constant_negative.ptx @@ -0,0 +1,21 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry constant_negative( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .s32 temp; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.s32 temp, [in_addr]; + mul.lo.s32 temp, temp, -1; + st.s32 [out_addr], temp; + ret; +} diff --git a/ptx/src/test/spirv_run/constant_negative.spvtxt b/ptx/src/test/spirv_run/constant_negative.spvtxt new file mode 100644 index 0000000..39e5d19 --- /dev/null +++ b/ptx/src/test/spirv_run/constant_negative.spvtxt @@ -0,0 +1,56 @@ +; SPIR-V +; Version: 1.3 +; Generator: rspirv +; Bound: 32 +OpCapability GenericPointer +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpCapability Int8 +OpCapability Int16 +OpCapability Int64 +OpCapability Float16 +OpCapability Float64 +OpCapability FunctionFloatControlINTEL +OpExtension "SPV_INTEL_float_controls2" +%24 = OpExtInstImport "OpenCL.std" +OpMemoryModel Physical64 OpenCL +OpEntryPoint Kernel %1 "constant_negative" +%25 = OpTypeVoid +%26 = OpTypeInt 64 0 +%27 = OpTypeFunction %25 %26 %26 +%28 = OpTypePointer Function %26 +%29 = OpTypeInt 32 0 +%30 = OpTypePointer Function %29 +%31 = OpTypePointer Generic %29 +%19 = OpConstant %29 4294967295 +%1 = OpFunction %25 None %27 +%7 = OpFunctionParameter %26 +%8 = OpFunctionParameter %26 +%22 = OpLabel +%2 = OpVariable %28 Function +%3 = OpVariable %28 Function +%4 = OpVariable %28 Function +%5 = OpVariable %28 Function +%6 = OpVariable %30 Function +OpStore %2 %7 +OpStore %3 %8 +%10 = OpLoad %26 %2 +%9 = OpCopyObject %26 %10 +OpStore %4 %9 +%12 = OpLoad %26 %3 +%11 = OpCopyObject %26 %12 +OpStore %5 %11 +%14 = OpLoad %26 %4 +%20 = OpConvertUToPtr %31 %14 +%13 = OpLoad %29 %20 +OpStore %6 %13 +%16 = OpLoad %29 %6 +%15 = OpIMul %29 %16 %19 +OpStore %6 %15 +%17 = OpLoad %26 %5 +%18 = OpLoad %29 %6 +%21 = OpConvertUToPtr %31 %17 +OpStore %21 %18 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 658d2ef..40acd46 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -87,6 +87,8 @@ test_ptx!(rcp, [2f32], [0.5f32]); // TODO: mul_ftz fails because IGC does not yet handle SPV_INTEL_float_controls2 // test_ptx!(mul_ftz, [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], [0u32]); test_ptx!(mul_non_ftz, [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], [0b1_00000000_01000000000000000000000u32]); +test_ptx!(constant_f32, [10f32], [5f32]); +test_ptx!(constant_negative, [-101i32], [101i32]); struct DisplayError { err: T, diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 20b5159..c0ff8f0 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1681,7 +1681,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { self.func.push(Statement::Constant(ConstantDefinition { dst: id_constant_stmt, typ: ast::ScalarType::from_parts(width, kind), - value: -(offset as i64), + value: ast::ImmediateValue::S64(-(offset as i64)), })); self.func.push(Statement::Instruction( ast::Instruction::::Sub( @@ -1697,7 +1697,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { self.func.push(Statement::Constant(ConstantDefinition { dst: id_constant_stmt, typ: ast::ScalarType::from_parts(width, kind), - value: offset as i64, + value: ast::ImmediateValue::S64(offset as i64), })); self.func.push(Statement::Instruction( ast::Instruction::::Add( @@ -1724,7 +1724,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { fn immediate( &mut self, - desc: ArgumentDescriptor, + desc: ArgumentDescriptor, typ: &ast::Type, ) -> Result { let scalar_t = if let ast::Type::Scalar(scalar) = typ { @@ -1736,7 +1736,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { self.func.push(Statement::Constant(ConstantDefinition { dst: id, typ: scalar_t, - value: desc.op as i64, + value: desc.op, })); Ok(id) } @@ -2081,32 +2081,82 @@ fn emit_function_body_ops( } Statement::Constant(cnst) => { let typ_id = map.get_or_add_scalar(builder, cnst.typ); - match cnst.typ { - ast::ScalarType::B8 | ast::ScalarType::U8 => { - builder.constant_u32(typ_id, Some(cnst.dst), cnst.value as u8 as u32); + match (cnst.typ, cnst.value) { + (ast::ScalarType::B8, ast::ImmediateValue::U64(value)) + | (ast::ScalarType::U8, ast::ImmediateValue::U64(value)) => { + builder.constant_u32(typ_id, Some(cnst.dst), value as u8 as u32); } - ast::ScalarType::B16 | ast::ScalarType::U16 => { - builder.constant_u32(typ_id, Some(cnst.dst), cnst.value as u16 as u32); + (ast::ScalarType::B16, ast::ImmediateValue::U64(value)) + | (ast::ScalarType::U16, ast::ImmediateValue::U64(value)) => { + builder.constant_u32(typ_id, Some(cnst.dst), value as u16 as u32); } - ast::ScalarType::B32 | ast::ScalarType::U32 => { - builder.constant_u32(typ_id, Some(cnst.dst), cnst.value as u32); + (ast::ScalarType::B32, ast::ImmediateValue::U64(value)) + | (ast::ScalarType::U32, ast::ImmediateValue::U64(value)) => { + builder.constant_u32(typ_id, Some(cnst.dst), value as u32); } - ast::ScalarType::B64 | ast::ScalarType::U64 => { - builder.constant_u64(typ_id, Some(cnst.dst), cnst.value as u64); + (ast::ScalarType::B64, ast::ImmediateValue::U64(value)) + | (ast::ScalarType::U64, ast::ImmediateValue::U64(value)) => { + builder.constant_u64(typ_id, Some(cnst.dst), value); } - ast::ScalarType::S8 => { - builder.constant_u32(typ_id, Some(cnst.dst), cnst.value as i8 as u32); + (ast::ScalarType::S8, ast::ImmediateValue::U64(value)) => { + builder.constant_u32(typ_id, Some(cnst.dst), value as i8 as u32); } - ast::ScalarType::S16 => { - builder.constant_u32(typ_id, Some(cnst.dst), cnst.value as i16 as u32); + (ast::ScalarType::S16, ast::ImmediateValue::U64(value)) => { + builder.constant_u32(typ_id, Some(cnst.dst), value as i16 as u32); } - ast::ScalarType::S32 => { - builder.constant_u32(typ_id, Some(cnst.dst), cnst.value as i32 as u32); + (ast::ScalarType::S32, ast::ImmediateValue::U64(value)) => { + builder.constant_u32(typ_id, Some(cnst.dst), value as i32 as u32); } - ast::ScalarType::S64 => { - builder.constant_u64(typ_id, Some(cnst.dst), cnst.value as i64 as u64); + (ast::ScalarType::S64, ast::ImmediateValue::U64(value)) => { + builder.constant_u64(typ_id, Some(cnst.dst), value as i64 as u64); } - _ => unreachable!(), + (ast::ScalarType::B8, ast::ImmediateValue::S64(value)) + | (ast::ScalarType::U8, ast::ImmediateValue::S64(value)) => { + builder.constant_u32(typ_id, Some(cnst.dst), value as u8 as u32); + } + (ast::ScalarType::B16, ast::ImmediateValue::S64(value)) + | (ast::ScalarType::U16, ast::ImmediateValue::S64(value)) => { + builder.constant_u32(typ_id, Some(cnst.dst), value as u16 as u32); + } + (ast::ScalarType::B32, ast::ImmediateValue::S64(value)) + | (ast::ScalarType::U32, ast::ImmediateValue::S64(value)) => { + builder.constant_u32(typ_id, Some(cnst.dst), value as u32); + } + (ast::ScalarType::B64, ast::ImmediateValue::S64(value)) + | (ast::ScalarType::U64, ast::ImmediateValue::S64(value)) => { + builder.constant_u64(typ_id, Some(cnst.dst), value as u64); + } + (ast::ScalarType::S8, ast::ImmediateValue::S64(value)) => { + builder.constant_u32(typ_id, Some(cnst.dst), value as i8 as u32); + } + (ast::ScalarType::S16, ast::ImmediateValue::S64(value)) => { + builder.constant_u32(typ_id, Some(cnst.dst), value as i16 as u32); + } + (ast::ScalarType::S32, ast::ImmediateValue::S64(value)) => { + builder.constant_u32(typ_id, Some(cnst.dst), value as i32 as u32); + } + (ast::ScalarType::S64, ast::ImmediateValue::S64(value)) => { + builder.constant_u64(typ_id, Some(cnst.dst), value as u64); + } + (ast::ScalarType::F16, ast::ImmediateValue::F32(value)) => { + builder.constant_f32(typ_id, Some(cnst.dst), f16::from_f32(value).to_f32()); + } + (ast::ScalarType::F32, ast::ImmediateValue::F32(value)) => { + builder.constant_f32(typ_id, Some(cnst.dst), value); + } + (ast::ScalarType::F64, ast::ImmediateValue::F32(value)) => { + builder.constant_f64(typ_id, Some(cnst.dst), value as f64); + } + (ast::ScalarType::F16, ast::ImmediateValue::F64(value)) => { + builder.constant_f32(typ_id, Some(cnst.dst), f16::from_f64(value).to_f32()); + } + (ast::ScalarType::F32, ast::ImmediateValue::F64(value)) => { + builder.constant_f32(typ_id, Some(cnst.dst), value as f32); + } + (ast::ScalarType::F64, ast::ImmediateValue::F64(value)) => { + builder.constant_f64(typ_id, Some(cnst.dst), value); + } + _ => return Err(TranslateError::MismatchedType), } } Statement::Conversion(cv) => emit_implicit_conversion(builder, map, cv)?, @@ -4371,7 +4421,7 @@ impl VisitVariableExpanded for CompositeRead { struct ConstantDefinition { pub dst: spirv::Word, pub typ: ast::ScalarType, - pub value: i64, + pub value: ast::ImmediateValue, } struct BrachCondition {