diff --git a/elf.o b/elf.o new file mode 100644 index 0000000..3095bf4 Binary files /dev/null and b/elf.o differ diff --git a/ptx/src/pass/expand_operands.rs b/ptx/src/pass/expand_operands.rs index b21c343..5296391 100644 --- a/ptx/src/pass/expand_operands.rs +++ b/ptx/src/pass/expand_operands.rs @@ -230,15 +230,33 @@ impl<'a, 'input> FlattenArguments<'a, 'input> { fn vec_pack( &mut self, - vector_elements: Vec, + vector_elements: Vec>, type_space: Option<(&ast::Type, ast::StateSpace)>, is_dst: bool, relaxed_type_check: bool, ) -> Result { let (width, scalar_t, state_space) = match type_space { Some((ast::Type::Vector(width, scalar_t), space)) => (*width, *scalar_t, space), + Some((ast::Type::Scalar(scalar_t), space)) + if scalar_t.kind() == ast::ScalarKind::Bit => + { + let type_ = + ast::ScalarType::from_size(scalar_t.size_of() / (vector_elements.len() as u8)) + .ok_or_else(|| error_mismatched_type())?; + (vector_elements.len() as u8, type_, space) + } _ => return Err(error_mismatched_type()), }; + let vector_elements = vector_elements + .into_iter() + .map(|element| match element { + ast::RegOrImmediate::Reg(name) => self.reg(name), + ast::RegOrImmediate::Imm(immediate_value) => self.immediate( + immediate_value, + Some((&ast::Type::Scalar(scalar_t), state_space)), + ), + }) + .collect::, _>>()?; let temporary_vector = self .resolver .register_unnamed(Some((ast::Type::Vector(width, scalar_t), state_space))); diff --git a/ptx/src/pass/fix_special_registers.rs b/ptx/src/pass/fix_special_registers.rs index 70b468d..2b90b0f 100644 --- a/ptx/src/pass/fix_special_registers.rs +++ b/ptx/src/pass/fix_special_registers.rs @@ -198,10 +198,15 @@ pub fn map_operand( Some(ident) => ast::ParsedOperand::Reg(ident), None => ast::ParsedOperand::VecMember(ident, member), }, - ast::ParsedOperand::VecPack(idents) => ast::ParsedOperand::VecPack( - idents + ast::ParsedOperand::VecPack(elements) => ast::ParsedOperand::VecPack( + elements .into_iter() - .map(|ident| Ok(fn_(ident, None)?.unwrap_or(ident))) + .map(|element| match element { + ast::RegOrImmediate::Reg(ident) => { + Ok(ast::RegOrImmediate::Reg(fn_(ident, None)?.unwrap_or(ident))) + } + ast::RegOrImmediate::Imm(imm) => Ok(ast::RegOrImmediate::Imm(imm)), + }) .collect::, _>>()?, ), }) diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 74bb53c..f743fad 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -613,6 +613,12 @@ struct ConstantDefinition { pub value: ast::ImmediateValue, } +impl std::fmt::Display for ConstantDefinition { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "zluda.constant{} {}", self.typ, self.value) + } +} + pub struct PtrAccess { underlying_type: ast::Type, state_space: ast::StateSpace, @@ -629,6 +635,22 @@ struct RepackVectorDetails { relaxed_type_check: bool, } +impl std::fmt::Display for RepackVectorDetails { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let extract = if self.is_extract { + ".extract" + } else { + ".composite" + }; + let relaxed = if self.relaxed_type_check { + ".relaxed" + } else { + "" + }; + write!(f, "zluda.repack_vector{}{}{}", extract, relaxed, self.typ) + } +} + struct FunctionPointerDetails { dst: SpirvWord, src: SpirvWord, diff --git a/ptx/src/pass/test/expand_operands/mod.rs b/ptx/src/pass/test/expand_operands/mod.rs new file mode 100644 index 0000000..20efae8 --- /dev/null +++ b/ptx/src/pass/test/expand_operands/mod.rs @@ -0,0 +1,22 @@ +use crate::pass::{test::directive2_vec_to_string, *}; + +use super::test_pass; + +macro_rules! test_expand_operands { + ($test_name:ident) => { + test_pass!(run_expand_operands, $test_name); + }; +} + +fn run_expand_operands(ptx: ptx_parser::Module) -> String { + // We run the minimal number of passes required to produce the input expected by expand_operands + let mut flat_resolver = GlobalStringIdentResolver2::new(SpirvWord(1)); + let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); + let directives = normalize_identifiers2::run(&mut scoped_resolver, ptx.directives).unwrap(); + let directives = normalize_predicates2::run(&mut flat_resolver, directives).unwrap(); + let directives = expand_operands::run(&mut flat_resolver, directives).unwrap(); + directive2_vec_to_string(&flat_resolver, directives) +} + +test_expand_operands!(vector_operand); +test_expand_operands!(vector_operand_convert); diff --git a/ptx/src/pass/test/expand_operands/vector_operand.ptx b/ptx/src/pass/test/expand_operands/vector_operand.ptx new file mode 100644 index 0000000..e5cfdd9 --- /dev/null +++ b/ptx/src/pass/test/expand_operands/vector_operand.ptx @@ -0,0 +1,23 @@ +.version 6.5 +.target sm_60 +.address_size 64 + +.func (.reg .v2.b16 output) default ( + .reg .b16 input +) +{ + mov.v2.b16 output, {0x5678, input}; + ret; +} + +// %%% output %%% + +.func (.reg .v2 .b16 %2) %1 ( + .reg .b16 %3 +) +{ + .b16.reg %4 = zluda.constant.b16 22136; + .v2.b16.reg %5 = zluda.repack_vector.composite.b16 %4, %3; + mov.v2.b16 %2, %5; + ret; +} diff --git a/ptx/src/pass/test/expand_operands/vector_operand_convert.ptx b/ptx/src/pass/test/expand_operands/vector_operand_convert.ptx new file mode 100644 index 0000000..1c94806 --- /dev/null +++ b/ptx/src/pass/test/expand_operands/vector_operand_convert.ptx @@ -0,0 +1,23 @@ +.version 6.5 +.target sm_60 +.address_size 64 + +.func (.reg .b32 output) default ( + .reg .b16 input +) +{ + mov.b32 output, {0x5678, input}; + ret; +} + +// %%% output %%% + +.func (.reg .b32 %2) %1 ( + .reg .b16 %3 +) +{ + .b16.reg %4 = zluda.constant.b16 22136; + .v2.b16.reg %5 = zluda.repack_vector.composite.b16 %4, %3; + mov.b32 %2, %5; + ret; +} diff --git a/ptx/src/pass/test/insert_implicit_conversions/default_reg_b32_reg_v2_b16.ptx b/ptx/src/pass/test/insert_implicit_conversions/default_reg_b32_reg_v2_b16.ptx new file mode 100644 index 0000000..f334b83 --- /dev/null +++ b/ptx/src/pass/test/insert_implicit_conversions/default_reg_b32_reg_v2_b16.ptx @@ -0,0 +1,22 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.func (.reg .b32 output) default ( + .reg .v2.b16 input +) +{ + mov.b32 output, input; + ret; +} + +// %%% output %%% + +.func (.reg .b32 %2) %1 ( + .reg .v2 .b16 %3 +) +{ + .b32.reg %4 = zluda.convert_implicit.default.reg.b32.reg.v2.b16 %3; + mov.b32 %2, %4; + ret; +} diff --git a/ptx/src/pass/test/insert_implicit_conversions/mod.rs b/ptx/src/pass/test/insert_implicit_conversions/mod.rs index 1fb7a54..1b5fffb 100644 --- a/ptx/src/pass/test/insert_implicit_conversions/mod.rs +++ b/ptx/src/pass/test/insert_implicit_conversions/mod.rs @@ -21,3 +21,4 @@ fn run_insert_implicit_conversions(ptx: ptx_parser::Module) -> String { test_insert_implicit_conversions!(default); test_insert_implicit_conversions!(default_reg_b32_reg_f16x2); +test_insert_implicit_conversions!(default_reg_b32_reg_v2_b16); diff --git a/ptx/src/pass/test/mod.rs b/ptx/src/pass/test/mod.rs index 3a9ef1f..4014842 100644 --- a/ptx/src/pass/test/mod.rs +++ b/ptx/src/pass/test/mod.rs @@ -6,6 +6,7 @@ use std::{ path::Path, }; +mod expand_operands; mod insert_implicit_conversions; #[macro_export] @@ -202,6 +203,8 @@ fn statement_to_string( Statement::Variable(var) => format!("{}", var), Statement::Instruction(instr) => format!("{}", instr), Statement::Conversion(conv) => format!("{}", conv), + Statement::Constant(constant) => format!("{}", constant), + Statement::RepackVector(repack) => format!("{}", repack), _ => todo!(), }; let mut args_formatter = StatementFormatter::new(resolver); diff --git a/ptx/src/test/ll/vector_operand.ll b/ptx/src/test/ll/vector_operand.ll new file mode 100644 index 0000000..564d0da --- /dev/null +++ b/ptx/src/test/ll/vector_operand.ll @@ -0,0 +1,31 @@ +define amdgpu_kernel void @vector_operand(ptr addrspace(4) byref(i64) %"36", ptr addrspace(4) byref(i64) %"37") #0 { + %"38" = alloca i64, align 8, addrspace(5) + %"39" = alloca i64, align 8, addrspace(5) + %"40" = alloca i16, align 2, addrspace(5) + %"41" = alloca i32, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"35" + +"35": ; preds = %1 + %"42" = load i64, ptr addrspace(4) %"36", align 8 + store i64 %"42", ptr addrspace(5) %"38", align 8 + %"43" = load i64, ptr addrspace(4) %"37", align 8 + store i64 %"43", ptr addrspace(5) %"39", align 8 + %"45" = load i64, ptr addrspace(5) %"38", align 8 + %"50" = inttoptr i64 %"45" to ptr + %"44" = load i16, ptr %"50", align 2 + store i16 %"44", ptr addrspace(5) %"40", align 2 + %"46" = load i16, ptr addrspace(5) %"40", align 2 + %"34" = insertelement <2 x i16> , i16 %"46", i8 1 + %"51" = bitcast <2 x i16> %"34" to i32 + store i32 %"51", ptr addrspace(5) %"41", align 4 + %"48" = load i64, ptr addrspace(5) %"39", align 8 + %"49" = load i32, ptr addrspace(5) %"41", align 4 + %"52" = inttoptr i64 %"48" to ptr + store i32 %"49", ptr %"52", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/ptx/src/test/operands.ptx b/ptx/src/test/operands.ptx index 67c59f5..e074251 100644 --- a/ptx/src/test/operands.ptx +++ b/ptx/src/test/operands.ptx @@ -7,6 +7,7 @@ ) { .reg .u32 %reg<10>; + .reg .b16 %reg_16; .reg .u64 %reg_64; .reg .pred p; .reg .pred q; @@ -30,4 +31,7 @@ // vector index - only supported by mov (maybe: ld, st, tex) mov.u32 %reg0, %ntid.x; + + // vector operand + mov.u32 %reg0, {0, %reg_16}; } diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 1caf560..f413a23 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -137,6 +137,7 @@ test_ptx!( [0x1_00_00_00_00_00_00i64] ); test_ptx!(vector_extract, [1u8, 2u8, 3u8, 4u8], [3u8, 4u8, 1u8, 2u8]); +test_ptx!(vector_operand, [0x1234u16], [0x12345678]); test_ptx!(shr, [-2i32], [-1i32]); test_ptx!(shr_oob, [-32768i16], [-1i16]); test_ptx!(or, [1u64, 2u64], [3u64]); diff --git a/ptx/src/test/spirv_run/vector_operand.ptx b/ptx/src/test/spirv_run/vector_operand.ptx new file mode 100644 index 0000000..a83eeae --- /dev/null +++ b/ptx/src/test/spirv_run/vector_operand.ptx @@ -0,0 +1,25 @@ +.version 6.5 +.target sm_60 +.address_size 64 + +.visible .entry vector_operand( + .param .u64 input_p, + .param .u64 output_p +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + + .reg .b16 in; + .reg .b32 out; + + ld.param.u64 in_addr, [input_p]; + ld.param.u64 out_addr, [output_p]; + + ld.b16 in, [in_addr]; + + mov.b32 out, {0x5678, in}; + + st.b32 [out_addr], out; + ret; +} \ No newline at end of file diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 8285cd3..38d4aed 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -808,7 +808,17 @@ where ), ParsedOperand::VecPack(vec) => ParsedOperand::VecPack( vec.into_iter() - .map(|ident| (self)(ident, type_space, is_dst, relaxed_type_check)) + .map(|reg_or_immediate| { + Ok(match reg_or_immediate { + RegOrImmediate::Reg(ident) => RegOrImmediate::Reg((self)( + ident, + type_space, + is_dst, + relaxed_type_check, + )?), + RegOrImmediate::Imm(imm) => RegOrImmediate::Imm(imm), + }) + }) .collect::, _>>()?, ), }) @@ -1005,6 +1015,7 @@ impl std::fmt::Display for Type { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Type::Scalar(scalar_type) => write!(f, "{}", scalar_type), + Type::Vector(count, scalar_type) => write!(f, ".v{}{}", count, scalar_type), _ => todo!(), } } @@ -1063,6 +1074,15 @@ impl Type { } impl ScalarType { + pub fn from_size(size: u8) -> Option { + Some(match size { + 1 => ScalarType::B8, + 2 => ScalarType::B16, + 4 => ScalarType::B32, + 16 => ScalarType::B128, + _ => return None, + }) + } pub fn size_of(self) -> u8 { match self { ScalarType::U8 | ScalarType::S8 | ScalarType::B8 => 1, @@ -1225,7 +1245,7 @@ pub enum ParsedOperand { RegOffset(Ident, i32), Imm(ImmediateValue), VecMember(Ident, u8), - VecPack(Vec), + VecPack(Vec>), } impl ParsedOperand { diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 680a39d..b482fa5 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -1283,12 +1283,18 @@ impl ast::ParsedOperand { } fn vector_operand<'a, 'input>( stream: &mut PtxParser<'a, 'input>, - ) -> PResult> { - let (_, r1, _, r2) = (Token::LBrace, ident, Token::Comma, ident).parse_next(stream)?; + ) -> PResult>> { + let (_, r1, _, r2) = ( + Token::LBrace, + reg_or_immediate, + Token::Comma, + reg_or_immediate, + ) + .parse_next(stream)?; // TODO: parse .v8 literals dispatch! {any; (Token::RBrace, _) => empty.map(|_| vec![r1, r2]), - (Token::Comma, _) => (ident, Token::Comma, ident, Token::RBrace).map(|(r3, _, r4, _)| vec![r1, r2, r3, r4]), + (Token::Comma, _) => (reg_or_immediate, Token::Comma, reg_or_immediate, Token::RBrace).map(|(r3, _, r4, _)| vec![r1, r2, r3, r4]), _ => fail } .parse_next(stream) @@ -1325,7 +1331,7 @@ pub enum PtxError<'input> { #[from] source: TokenError, }, - #[error("{0}")] + #[error("Context error: {0}")] Parser(ContextError), #[error("")] Todo,