From 42bad8fcc22d3fd66bcdbfea7ce9a41268772e50 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 16 Sep 2020 00:20:49 +0200 Subject: [PATCH] Fix st/ld offsets implement abs and fix remaining bugs from vectorAdd generation --- ptx/src/ast.rs | 1 + ptx/src/test/spirv_run/ld_st_offset.ptx | 23 ++ ptx/src/test/spirv_run/ld_st_offset.spvtxt | 58 +++++ ptx/src/test/spirv_run/mod.rs | 1 + ptx/src/translate.rs | 257 ++++++++++++++------- 5 files changed, 258 insertions(+), 82 deletions(-) create mode 100644 ptx/src/test/spirv_run/ld_st_offset.ptx create mode 100644 ptx/src/test/spirv_run/ld_st_offset.spvtxt diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index da37ee3..9214944 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -338,6 +338,7 @@ pub struct MovVectorDetails { pub typ: MovVectorType, pub length: u8, } +#[derive(Copy, Clone)] pub struct AbsDetails { pub flush_to_zero: bool, pub typ: ScalarType, diff --git a/ptx/src/test/spirv_run/ld_st_offset.ptx b/ptx/src/test/spirv_run/ld_st_offset.ptx new file mode 100644 index 0000000..60cba13 --- /dev/null +++ b/ptx/src/test/spirv_run/ld_st_offset.ptx @@ -0,0 +1,23 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry ld_st_offset( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u32 temp1; + .reg .u32 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u32 temp1, [in_addr]; + ld.u32 temp2, [in_addr+4]; + st.u32 [out_addr], temp2; + st.u32 [out_addr+4], temp1; + ret; +} \ No newline at end of file diff --git a/ptx/src/test/spirv_run/ld_st_offset.spvtxt b/ptx/src/test/spirv_run/ld_st_offset.spvtxt new file mode 100644 index 0000000..f08e05f --- /dev/null +++ b/ptx/src/test/spirv_run/ld_st_offset.spvtxt @@ -0,0 +1,58 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %34 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "ld_st_offset" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %37 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %uint = OpTypeInt 32 0 +%_ptr_Function_uint = OpTypePointer Function %uint +%_ptr_Generic_uint = OpTypePointer Generic %uint + %ulong_4 = OpConstant %ulong 4 + %ulong_4_0 = OpConstant %ulong 4 + %1 = OpFunction %void None %37 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %32 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_uint Function + %7 = OpVariable %_ptr_Function_uint Function + OpStore %2 %8 + OpStore %3 %9 + %11 = OpLoad %ulong %2 + %10 = OpCopyObject %ulong %11 + OpStore %4 %10 + %13 = OpLoad %ulong %3 + %12 = OpCopyObject %ulong %13 + OpStore %5 %12 + %15 = OpLoad %ulong %4 + %26 = OpConvertUToPtr %_ptr_Generic_uint %15 + %14 = OpLoad %uint %26 + OpStore %6 %14 + %17 = OpLoad %ulong %4 + %27 = OpCopyObject %ulong %17 + %23 = OpIAdd %ulong %27 %ulong_4 + %28 = OpConvertUToPtr %_ptr_Generic_uint %23 + %16 = OpLoad %uint %28 + OpStore %7 %16 + %18 = OpLoad %ulong %5 + %19 = OpLoad %uint %7 + %29 = OpConvertUToPtr %_ptr_Generic_uint %18 + OpStore %29 %19 + %20 = OpLoad %ulong %5 + %21 = OpLoad %uint %6 + %30 = OpCopyObject %ulong %20 + %25 = OpIAdd %ulong %30 %ulong_4_0 + %31 = OpConvertUToPtr %_ptr_Generic_uint %25 + OpStore %31 %21 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index fd50d3c..f1c3194 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -58,6 +58,7 @@ test_ptx!(block, [1u64], [2u64]); test_ptx!(local_align, [1u64], [1u64]); test_ptx!(call, [1u64], [2u64]); test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]); +test_ptx!(ld_st_offset, [1u32, 2u32], [2u32, 1u32]); struct DisplayError { err: T, diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 3ac5222..f5c0ecb 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -606,9 +606,9 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>( ) { let mut post_statements = Vec::new(); let new_statement = stmt.visit_variable(&mut |desc: ArgumentDescriptor, _| { - let id_type = match (id_def.get_type(desc.op), desc.is_pointer) { - (Some(t), false) => t, - (Some(_), true) => ast::Type::Scalar(ast::ScalarType::B64), + let id_type = match (id_def.get_type(desc.op), desc.sema) { + (Some(t), ArgumentSemantics::ParamPtr) | (Some(t), ArgumentSemantics::Default) => t, + (Some(t), ArgumentSemantics::Ptr) => ast::Type::Scalar(ast::ScalarType::B64), (None, _) => return desc.op, }; let generated_id = id_def.new_id(Some(id_type)); @@ -725,33 +725,71 @@ impl<'a, 'b> ArgumentMapVisitor id } ast::Operand::RegOffset(reg, offset) => { - let scalar_t = if let ast::Type::Scalar(scalar) = typ { - scalar - } else { - todo!() - }; - let id_constant_stmt = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t))); - self.func.push(Statement::Constant(ConstantDefinition { - dst: id_constant_stmt, - typ: scalar_t, - value: offset as i128, - })); - let result_id = self.id_def.new_id(Some(typ)); - let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!()); - self.func.push(Statement::Instruction( - ast::Instruction::::Add( - ast::AddDetails::Int(ast::AddIntDesc { - typ: int_type, - saturate: false, - }), - ast::Arg3 { - dst: result_id, - src1: reg, - src2: id_constant_stmt, - }, - ), - )); - result_id + match desc.sema { + ArgumentSemantics::Default => { + let scalar_t = if let ast::Type::Scalar(scalar) = typ { + scalar + } else { + todo!() + }; + let id_constant_stmt = + self.id_def.new_id(Some(ast::Type::Scalar(scalar_t))); + let result_id = self.id_def.new_id(Some(typ)); + self.func.push(Statement::Constant(ConstantDefinition { + dst: id_constant_stmt, + typ: scalar_t, + value: offset as i128, + })); + let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!()); + self.func.push(Statement::Instruction( + ast::Instruction::::Add( + ast::AddDetails::Int(ast::AddIntDesc { + typ: int_type, + saturate: false, + }), + ast::Arg3 { + dst: result_id, + src1: reg, + src2: id_constant_stmt, + }, + ), + )); + result_id + } + ArgumentSemantics::Ptr => { + let scalar_t = ast::ScalarType::U64; + let id_constant_stmt = + self.id_def.new_id(Some(ast::Type::Scalar(scalar_t))); + let result_id = self.id_def.new_id(Some(typ)); + self.func.push(Statement::Constant(ConstantDefinition { + dst: id_constant_stmt, + typ: scalar_t, + value: offset as i128, + })); + let int_type = ast::IntType::U64; + self.func.push(Statement::Instruction( + ast::Instruction::::Add( + ast::AddDetails::Int(ast::AddIntDesc { + typ: int_type, + saturate: false, + }), + ast::Arg3 { + dst: result_id, + src1: reg, + src2: id_constant_stmt, + }, + ), + )); + result_id + } + ArgumentSemantics::ParamPtr => { + if offset == 0 { + return reg; + } + // Will be needed for arrays + todo!() + } + } } } } @@ -977,8 +1015,9 @@ fn emit_function_body_ops( let (result_type, result_id) = match &*call.ret_params { [(id, typ)] => ( map.get_or_add(builder, SpirvType::from(ast::Type::from(*typ))), - *id, + Some(*id), ), + [] => (map.void(), None), _ => todo!(), }; let arg_list = call @@ -986,7 +1025,7 @@ fn emit_function_body_ops( .iter() .map(|(id, _)| *id) .collect::>(); - builder.function_call(result_type, Some(result_id), call.func, arg_list)?; + builder.function_call(result_type, result_id, call.func, arg_list)?; } Statement::Variable(ast::Variable { align, @@ -1047,7 +1086,7 @@ fn emit_function_body_ops( builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?; } Statement::Instruction(inst) => match inst { - ast::Instruction::Abs(_, _) => todo!(), + ast::Instruction::Abs(d, arg) => emit_abs(builder, map, opencl, d, arg)?, ast::Instruction::Call(_) => unreachable!(), // SPIR-V does not support marking jumps as guaranteed-converged ast::Instruction::Bra(_, arg) => { @@ -1098,10 +1137,8 @@ fn emit_function_body_ops( ast::MulDetails::Float(_) => todo!(), }, ast::Instruction::Add(add, arg) => match add { - ast::AddDetails::Int(ref desc) => { - emit_add_int(builder, map, desc, arg)?; - } - ast::AddDetails::Float(_) => todo!(), + ast::AddDetails::Int(ref desc) => emit_add_int(builder, map, desc, arg)?, + ast::AddDetails::Float(desc) => emit_add_float(builder, map, desc, arg)?, }, ast::Instruction::Setp(setp, arg) => { if arg.dst2.is_some() { @@ -1184,6 +1221,21 @@ fn emit_function_body_ops( Ok(()) } +fn emit_add_float( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + desc: &ast::AddFloatDesc, + arg: &ast::Arg3, +) -> Result<(), dr::Error> { + if desc.flush_to_zero { + todo!() + } + let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); + builder.f_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?; + emit_rounding_decoration(builder, arg.dst, desc.rounding); + Ok(()) +} + fn emit_cvt( builder: &mut dr::Builder, map: &mut TypeWordMap, @@ -1415,6 +1467,30 @@ fn emit_mul_int( Ok(()) } +fn emit_abs( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + d: &ast::AbsDetails, + arg: &ast::Arg2, +) -> Result<(), dr::Error> { + let scalar_t = ast::ScalarType::from(d.typ); + let result_type = map.get_or_add(builder, SpirvType::from(scalar_t)); + let cl_abs = if scalar_t.kind() == ScalarKind::Signed { + spirv::CLOp::s_abs + } else { + spirv::CLOp::fabs + }; + builder.ext_inst( + result_type, + Some(arg.dst), + opencl, + cl_abs as spirv::Word, + [arg.src], + )?; + Ok(()) +} + fn emit_add_int( builder: &mut dr::Builder, map: &mut TypeWordMap, @@ -1821,7 +1897,7 @@ impl> ResolvedCall { ArgumentDescriptor { op: id, is_dst: true, - is_pointer: false, + sema: ArgumentSemantics::Default, }, Some(typ.into()), ); @@ -1832,7 +1908,7 @@ impl> ResolvedCall { ArgumentDescriptor { op: self.func, is_dst: false, - is_pointer: false, + sema: ArgumentSemantics::Default, }, None, ); @@ -1844,7 +1920,7 @@ impl> ResolvedCall { ArgumentDescriptor { op: id, is_dst: false, - is_pointer: false, + sema: ArgumentSemantics::Default, }, typ.into(), ); @@ -2024,7 +2100,14 @@ where struct ArgumentDescriptor { op: Op, is_dst: bool, - is_pointer: bool, + sema: ArgumentSemantics, +} + +#[derive(Copy, Clone, PartialEq, Eq)] +enum ArgumentSemantics { + Default, + Ptr, + ParamPtr, } impl ArgumentDescriptor { @@ -2032,7 +2115,7 @@ impl ArgumentDescriptor { ArgumentDescriptor { op: u, is_dst: self.is_dst, - is_pointer: self.is_pointer, + sema: self.sema, } } } @@ -2046,13 +2129,15 @@ impl ast::Instruction { ast::Instruction::MovVector(t, a) => { ast::Instruction::MovVector(t, a.map(visitor, (t.typ, t.length))) } - ast::Instruction::Abs(_, _) => todo!(), + ast::Instruction::Abs(d, arg) => { + ast::Instruction::Abs(d, arg.map(visitor, ast::Type::Scalar(d.typ))) + } // Call instruction is converted to a call statement early on ast::Instruction::Call(_) => unreachable!(), ast::Instruction::Ld(d, a) => { let inst_type = d.typ; - let src_is_pointer = d.state_space != ast::LdStateSpace::Param; - ast::Instruction::Ld(d, a.map_ld(visitor, inst_type, src_is_pointer)) + let is_param = d.state_space == ast::LdStateSpace::Param; + ast::Instruction::Ld(d, a.map_ld(visitor, inst_type, is_param)) } ast::Instruction::Mov(mov_type, a) => { ast::Instruction::Mov(mov_type, a.map(visitor, mov_type.into())) @@ -2100,8 +2185,8 @@ impl ast::Instruction { } ast::Instruction::St(d, a) => { let inst_type = d.typ; - let param_space = d.state_space == ast::StStateSpace::Param; - ast::Instruction::St(d, a.map(visitor, inst_type, param_space)) + let is_param = d.state_space == ast::StStateSpace::Param; + ast::Instruction::St(d, a.map(visitor, inst_type, is_param)) } ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)), ast::Instruction::Ret(d) => ast::Instruction::Ret(d), @@ -2350,7 +2435,7 @@ impl ast::Arg1 { ArgumentDescriptor { op: self.src, is_dst: false, - is_pointer: false, + sema: ArgumentSemantics::Default, }, t, ), @@ -2369,7 +2454,7 @@ impl ast::Arg2 { ArgumentDescriptor { op: self.dst, is_dst: true, - is_pointer: false, + sema: ArgumentSemantics::Default, }, Some(t), ), @@ -2377,7 +2462,7 @@ impl ast::Arg2 { ArgumentDescriptor { op: self.src, is_dst: false, - is_pointer: false, + sema: ArgumentSemantics::Default, }, t, ), @@ -2388,14 +2473,14 @@ impl ast::Arg2 { self, visitor: &mut V, t: ast::Type, - is_src_pointer: bool, + is_param: bool, ) -> ast::Arg2 { ast::Arg2 { dst: visitor.variable( ArgumentDescriptor { op: self.dst, is_dst: true, - is_pointer: false, + sema: ArgumentSemantics::Default, }, Some(t), ), @@ -2403,7 +2488,11 @@ impl ast::Arg2 { ArgumentDescriptor { op: self.src, is_dst: false, - is_pointer: is_src_pointer, + sema: if is_param { + ArgumentSemantics::ParamPtr + } else { + ArgumentSemantics::Ptr + }, }, t, ), @@ -2421,7 +2510,7 @@ impl ast::Arg2 { ArgumentDescriptor { op: self.dst, is_dst: true, - is_pointer: false, + sema: ArgumentSemantics::Default, }, Some(dst_t), ), @@ -2429,7 +2518,7 @@ impl ast::Arg2 { ArgumentDescriptor { op: self.src, is_dst: false, - is_pointer: false, + sema: ArgumentSemantics::Default, }, src_t, ), @@ -2442,14 +2531,18 @@ impl ast::Arg2St { self, visitor: &mut V, t: ast::Type, - param_space: bool, + is_param: bool, ) -> ast::Arg2St { ast::Arg2St { src1: visitor.operand( ArgumentDescriptor { op: self.src1, - is_dst: param_space, - is_pointer: !param_space, + is_dst: is_param, + sema: if is_param { + ArgumentSemantics::ParamPtr + } else { + ArgumentSemantics::Ptr + }, }, t, ), @@ -2457,7 +2550,7 @@ impl ast::Arg2St { ArgumentDescriptor { op: self.src2, is_dst: false, - is_pointer: false, + sema: ArgumentSemantics::Default, }, t, ), @@ -2486,7 +2579,7 @@ impl ast::Arg2Vec { ArgumentDescriptor { op: dst, is_dst: true, - is_pointer: false, + sema: ArgumentSemantics::Default, }, Some(ast::Type::Scalar(scalar_type.into())), ), @@ -2496,7 +2589,7 @@ impl ast::Arg2Vec { ArgumentDescriptor { op: composite_src, is_dst: false, - is_pointer: false, + sema: ArgumentSemantics::Default, }, Some(ast::Type::Scalar(scalar_type.into())), ), @@ -2504,7 +2597,7 @@ impl ast::Arg2Vec { ArgumentDescriptor { op: scalar_src, is_dst: false, - is_pointer: false, + sema: ArgumentSemantics::Default, }, Some(ast::Type::Scalar(scalar_type.into())), ), @@ -2514,7 +2607,7 @@ impl ast::Arg2Vec { ArgumentDescriptor { op: dst, is_dst: true, - is_pointer: false, + sema: ArgumentSemantics::Default, }, Some(ast::Type::Scalar(scalar_type.into())), ), @@ -2522,7 +2615,7 @@ impl ast::Arg2Vec { ArgumentDescriptor { op: src, is_dst: false, - is_pointer: false, + sema: ArgumentSemantics::Default, }, (scalar_type, vec_len), ), @@ -2533,7 +2626,7 @@ impl ast::Arg2Vec { ArgumentDescriptor { op: dst, is_dst: true, - is_pointer: false, + sema: ArgumentSemantics::Default, }, Some(ast::Type::Scalar(scalar_type.into())), ), @@ -2543,7 +2636,7 @@ impl ast::Arg2Vec { ArgumentDescriptor { op: composite_src, is_dst: false, - is_pointer: false, + sema: ArgumentSemantics::Default, }, Some(ast::Type::Scalar(scalar_type.into())), ), @@ -2551,7 +2644,7 @@ impl ast::Arg2Vec { ArgumentDescriptor { op: src, is_dst: false, - is_pointer: false, + sema: ArgumentSemantics::Default, }, (scalar_type, vec_len), ), @@ -2571,7 +2664,7 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.dst, is_dst: true, - is_pointer: false, + sema: ArgumentSemantics::Default, }, Some(t), ), @@ -2579,7 +2672,7 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.src1, is_dst: false, - is_pointer: false, + sema: ArgumentSemantics::Default, }, t, ), @@ -2587,7 +2680,7 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.src2, is_dst: false, - is_pointer: false, + sema: ArgumentSemantics::Default, }, t, ), @@ -2604,7 +2697,7 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.dst, is_dst: true, - is_pointer: false, + sema: ArgumentSemantics::Default, }, Some(t), ), @@ -2612,7 +2705,7 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.src1, is_dst: false, - is_pointer: false, + sema: ArgumentSemantics::Default, }, t, ), @@ -2620,7 +2713,7 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.src2, is_dst: false, - is_pointer: false, + sema: ArgumentSemantics::Default, }, ast::Type::Scalar(ast::ScalarType::U32), ), @@ -2639,7 +2732,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.dst1, is_dst: true, - is_pointer: false, + sema: ArgumentSemantics::Default, }, Some(ast::Type::Scalar(ast::ScalarType::Pred)), ), @@ -2648,7 +2741,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: dst2, is_dst: true, - is_pointer: false, + sema: ArgumentSemantics::Default, }, Some(ast::Type::Scalar(ast::ScalarType::Pred)), ) @@ -2657,7 +2750,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src1, is_dst: false, - is_pointer: false, + sema: ArgumentSemantics::Default, }, t, ), @@ -2665,7 +2758,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src2, is_dst: false, - is_pointer: false, + sema: ArgumentSemantics::Default, }, t, ), @@ -2684,7 +2777,7 @@ impl ast::Arg5 { ArgumentDescriptor { op: self.dst1, is_dst: true, - is_pointer: false, + sema: ArgumentSemantics::Default, }, Some(ast::Type::Scalar(ast::ScalarType::Pred)), ), @@ -2693,7 +2786,7 @@ impl ast::Arg5 { ArgumentDescriptor { op: dst2, is_dst: true, - is_pointer: false, + sema: ArgumentSemantics::Default, }, Some(ast::Type::Scalar(ast::ScalarType::Pred)), ) @@ -2702,7 +2795,7 @@ impl ast::Arg5 { ArgumentDescriptor { op: self.src1, is_dst: false, - is_pointer: false, + sema: ArgumentSemantics::Default, }, t, ), @@ -2710,7 +2803,7 @@ impl ast::Arg5 { ArgumentDescriptor { op: self.src2, is_dst: false, - is_pointer: false, + sema: ArgumentSemantics::Default, }, t, ), @@ -2718,7 +2811,7 @@ impl ast::Arg5 { ArgumentDescriptor { op: self.src3, is_dst: false, - is_pointer: false, + sema: ArgumentSemantics::Default, }, ast::Type::Scalar(ast::ScalarType::Pred), ),