From 82b5cef0bd03fd395dd213ea8386c26d16671894 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sat, 15 May 2021 15:58:11 +0200 Subject: [PATCH] Carry state space with pointer --- ptx/src/ast.rs | 41 ++++++++++- ptx/src/ptx.lalrpop | 6 +- ptx/src/translate.rs | 167 +++++++++++++++++-------------------------- 3 files changed, 108 insertions(+), 106 deletions(-) diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index e45a6fb..e49e489 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -108,10 +108,49 @@ pub type ParsedFunction<'a> = Function<'a, &'a str, Statement OpTypeInt Scalar(ScalarType), + // .param.v2.b32 foo; + // -> OpTypeVector Vector(ScalarType, u8), + // .param.b32 foo[4]; + // -> OpTypeArray Array(ScalarType, Vec), - Pointer(ScalarType), + /* + Variables of this type almost never exist in the original .ptx and are + usually artificially created. Some examples below: + - extern pointers to the .shared memory in the form: + .extern .shared .b32 shared_mem[]; + which we first parse as + .extern .shared .b32 shared_mem; + and then convert to an additional function parameter: + .param .ptr<.b32.shared> shared_mem; + and do a load at the start of the function (and renames inside fn): + .reg .ptr<.b32.shared> temp; + ld.param.ptr<.b32.shared> temp, [shared_mem]; + note, we don't support non-.shared extern pointers, because there's + zero use for them in the ptxas + - artifical pointers created by stateful conversion, which work + similiarly to the above + - function parameters: + foobar(.param .align 4 .b8 numbers[]) + which get parsed to + foobar(.param .align 4 .b8 numbers) + and then converted to + foobar(.reg .align 4 .ptr<.b8.param> numbers) + - ld/st with offset: + .reg.b32 x; + .param.b64 arg0; + st.param.b32 [arg0+4], x; + Yes, this code is legal and actually emitted by the NV compiler! + We convert the st to: + .reg ptr<.b64.param> temp = ptr_offset(arg0, 4); + st.param.b32 [temp], x; + */ + // .reg ptr<.b64.param> + // -> OpTypePointer Function + Pointer(ScalarType, StateSpace), } #[derive(PartialEq, Eq, Hash, Clone, Copy)] diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 78ebf1d..2253f85 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -624,9 +624,9 @@ ModuleVariable: ast::Variable<&'input str> = { return Err(ParseError::User { error: ast::PtxError::NonExternPointer }); } if space == ".global" { - (ast::Type::Pointer(t), ast::StateSpace::Global, Vec::new()) + (ast::Type::Scalar(t), ast::StateSpace::Global, Vec::new()) } else { - (ast::Type::Pointer(t), ast::StateSpace::Shared, Vec::new()) + (ast::Type::Scalar(t), ast::StateSpace::Shared, Vec::new()) } } }; @@ -648,7 +648,7 @@ ParamVariable: (Option, Vec, ast::Type, &'input str) = { (ast::Type::Array(t, dimensions), init) } ast::ArrayOrPointer::Pointer => { - (ast::Type::Pointer(t), Vec::new()) + (ast::Type::Scalar(t), Vec::new()) } }; (align, array_init, v_type, name) diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 88ef51b..ea6451e 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -56,33 +56,20 @@ enum SpirvType { } impl SpirvType { - fn new(t: ast::Type, decl_space: ast::StateSpace) -> Self { + fn new(t: ast::Type) -> Self { match t { ast::Type::Scalar(t) => SpirvType::Base(t.into()), ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len), ast::Type::Array(t, len) => SpirvType::Array(t.into(), len), - ast::Type::Pointer(pointer_t) => { - let spirv_space = match decl_space { - ast::StateSpace::Reg | ast::StateSpace::Param | ast::StateSpace::Local => { - spirv::StorageClass::Private - } - ast::StateSpace::Global => spirv::StorageClass::CrossWorkgroup, - ast::StateSpace::Const => spirv::StorageClass::UniformConstant, - ast::StateSpace::Shared => spirv::StorageClass::Workgroup, - ast::StateSpace::Generic => spirv::StorageClass::Generic, - ast::StateSpace::Sreg => spirv::StorageClass::Input, - }; - SpirvType::Pointer(Box::new(SpirvType::Base(pointer_t.into())), spirv_space) - } + ast::Type::Pointer(pointer_t, space) => SpirvType::Pointer( + Box::new(SpirvType::Base(pointer_t.into())), + space.to_spirv(), + ), } } - fn pointer_to( - t: ast::Type, - inner_space: ast::StateSpace, - outer_space: spirv::StorageClass, - ) -> Self { - let key = Self::new(t, inner_space); + fn pointer_to(t: ast::Type, outer_space: spirv::StorageClass) -> Self { + let key = Self::new(t); SpirvType::Pointer(Box::new(key), outer_space) } } @@ -394,7 +381,7 @@ impl TypeWordMap { b.constant_composite(result_type, None, components.into_iter()) } }, - ast::Type::Pointer(typ) => return Err(error_unreachable()), + ast::Type::Pointer(..) => return Err(error_unreachable()), }) } @@ -453,7 +440,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result, _>>()?; let must_link_ptx_impl = ptx_impl_imports.len() > 0; - let directives = ptx_impl_imports + let mut directives = ptx_impl_imports .into_iter() .map(|(_, v)| v) .chain(directives.into_iter()) @@ -461,7 +448,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result(m: &mut MultiHashMap, transformation has a semantical meaning - we emit additional "OpFunctionParameter ..." with type "OpTypePointer Workgroup ...") */ +/* fn convert_dynamic_shared_memory_usage<'input>( module: Vec>, new_id: &mut impl FnMut() -> spirv::Word, @@ -819,7 +807,7 @@ fn convert_dynamic_shared_memory_usage<'input>( ast::Variable { name: shared_id_param, align: None, - v_type: ast::Type::Pointer(ast::ScalarType::B8), + v_type: ast::Type::Pointer(ast::ScalarType::B8, new_todo!()), state_space: ast::StateSpace::Shared, array_init: Vec::new(), } @@ -937,6 +925,7 @@ fn get_callers_of_extern_shared_single<'a>( } } } +*/ type DenormCountMap = HashMap; @@ -1031,11 +1020,7 @@ fn emit_builtins( for (reg, id) in id_defs.special_registers.builtins() { let result_type = map.get_or_add( builder, - SpirvType::pointer_to( - reg.get_type(), - ast::StateSpace::Reg, - spirv::StorageClass::Input, - ), + SpirvType::pointer_to(reg.get_type(), spirv::StorageClass::Input), ); builder.variable(result_type, Some(id), spirv::StorageClass::Input, None); builder.decorate( @@ -1144,10 +1129,7 @@ fn emit_function_header<'a>( } */ for input in &func_decl.input_arguments { - let result_type = map.get_or_add( - builder, - SpirvType::new(input.v_type.clone(), input.state_space), - ); + let result_type = map.get_or_add(builder, SpirvType::new(input.v_type.clone())); builder.function_parameter(Some(input.name), result_type)?; } Ok(fn_id) @@ -1753,8 +1735,8 @@ fn to_ptx_impl_atomic_call( input_arguments: vec![ ast::Variable { align: None, - v_type: ast::Type::Pointer(typ), - state_space: ptr_space, + v_type: ast::Type::Pointer(typ, ptr_space), + state_space: ast::StateSpace::Reg, name: id_defs.register_intermediate(None), array_init: Vec::new(), }, @@ -1791,7 +1773,11 @@ fn to_ptx_impl_atomic_call( func: fn_id, ret_params: vec![(arg.dst, ast::Type::Scalar(scalar_typ), ast::StateSpace::Reg)], param_list: vec![ - (arg.src1, ast::Type::Pointer(typ), ptr_space), + ( + arg.src1, + ast::Type::Pointer(typ, ptr_space), + ast::StateSpace::Reg, + ), ( arg.src2, ast::Type::Scalar(scalar_typ), @@ -2629,8 +2615,8 @@ fn insert_implicit_conversions( is_dst: false, sema: ArgumentSemantics::PhysicalPointer, }, - typ: &ast::Type::Pointer(underlying_type), - state_space, + typ: &ast::Type::Pointer(underlying_type, state_space), + state_space: new_todo!(), stmt_ctor: |new_ptr_src| { Statement::PtrAccess(PtrAccess { underlying_type, @@ -2758,10 +2744,10 @@ fn get_function_type( builder, spirv_input .iter() - .map(|var| SpirvType::new(var.v_type.clone(), var.state_space)), + .map(|var| SpirvType::new(var.v_type.clone())), spirv_output .iter() - .map(|var| SpirvType::new(var.v_type.clone(), var.state_space)), + .map(|var| SpirvType::new(var.v_type.clone())), ) } @@ -2790,7 +2776,7 @@ fn emit_function_body_ops( Statement::Call(call) => { let (result_type, result_id) = match &*call.ret_params { [(id, typ, space)] => ( - map.get_or_add(builder, SpirvType::new(typ.clone(), *space)), + map.get_or_add(builder, SpirvType::new(typ.clone())), Some(*id), ), [] => (map.void(), None), @@ -2922,10 +2908,8 @@ fn emit_function_body_ops( if data.qualifier != ast::LdStQualifier::Weak { todo!() } - let result_type = map.get_or_add( - builder, - SpirvType::new(ast::Type::from(data.typ.clone()), data.state_space), - ); + let result_type = + map.get_or_add(builder, SpirvType::new(ast::Type::from(data.typ.clone()))); builder.load( result_type, Some(arg.dst), @@ -2956,10 +2940,8 @@ fn emit_function_body_ops( // SPIR-V does not support ret as guaranteed-converged ast::Instruction::Ret(_) => builder.ret()?, ast::Instruction::Mov(d, arg) => { - let result_type = map.get_or_add( - builder, - SpirvType::new(ast::Type::from(d.typ.clone()), ast::StateSpace::Reg), - ); + let result_type = + map.get_or_add(builder, SpirvType::new(ast::Type::from(d.typ.clone()))); builder.copy_object(result_type, Some(arg.dst), arg.src)?; } ast::Instruction::Mul(mul, arg) => match mul { @@ -3000,8 +2982,7 @@ fn emit_function_body_ops( ast::Instruction::Shl(t, a) => { let full_type = ast::Type::Scalar(*t); let size_of = full_type.size_of(); - let result_type = - map.get_or_add(builder, SpirvType::new(full_type, ast::StateSpace::Reg)); + let result_type = map.get_or_add(builder, SpirvType::new(full_type)); let offset_src = insert_shift_hack(builder, map, a.src2, size_of)?; builder.shift_left_logical(result_type, Some(a.dst), a.src1, offset_src)?; } @@ -3265,7 +3246,6 @@ fn emit_function_body_ops( builder, SpirvType::pointer_to( details.typ.clone(), - details.state_space, spirv::StorageClass::Function, ), ); @@ -3297,11 +3277,11 @@ fn emit_function_body_ops( }) => { let u8_pointer = map.get_or_add( builder, - SpirvType::new(ast::Type::Pointer(ast::ScalarType::U8), *state_space), + SpirvType::new(ast::Type::Pointer(ast::ScalarType::U8, *state_space)), ); let result_type = map.get_or_add( builder, - SpirvType::new(ast::Type::Pointer(*underlying_type), *state_space), + SpirvType::new(ast::Type::Pointer(*underlying_type, *state_space)), ); let ptr_src_u8 = builder.bitcast(u8_pointer, None, *ptr_src)?; let temp = builder.in_bounds_ptr_access_chain( @@ -3596,15 +3576,12 @@ fn emit_variable( &*var.array_init, )?) } else if must_init { - let type_id = map.get_or_add(builder, SpirvType::new(var.v_type.clone(), var.state_space)); + let type_id = map.get_or_add(builder, SpirvType::new(var.v_type.clone())); Some(builder.constant_null(type_id, None)) } else { None }; - let ptr_type_id = map.get_or_add( - builder, - SpirvType::pointer_to(var.v_type.clone(), var.state_space, st_class), - ); + let ptr_type_id = map.get_or_add(builder, SpirvType::pointer_to(var.v_type.clone(), st_class)); builder.variable(ptr_type_id, Some(var.name), st_class, initalizer); if let Some(align) = var.align { builder.decorate( @@ -3742,10 +3719,7 @@ fn emit_min( ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min, ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin, }; - let inst_type = map.get_or_add( - builder, - SpirvType::new(desc.get_type(), ast::StateSpace::Reg), - ); + let inst_type = map.get_or_add(builder, SpirvType::new(desc.get_type())); builder.ext_inst( inst_type, Some(arg.dst), @@ -3770,10 +3744,7 @@ fn emit_max( ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max, ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax, }; - let inst_type = map.get_or_add( - builder, - SpirvType::new(desc.get_type(), ast::StateSpace::Reg), - ); + let inst_type = map.get_or_add(builder, SpirvType::new(desc.get_type())); builder.ext_inst( inst_type, Some(arg.dst), @@ -4255,14 +4226,13 @@ fn emit_implicit_conversion( (_, _, ConversionKind::BitToPtr) => { let dst_type = map.get_or_add( builder, - SpirvType::pointer_to(cv.to_type.clone(), cv.from_space, cv.to_space.to_spirv()), + SpirvType::pointer_to(cv.to_type.clone(), cv.to_space.to_spirv()), ); builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?; } (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => { if from_parts.width == to_parts.width { - let dst_type = - map.get_or_add(builder, SpirvType::new(cv.to_type.clone(), cv.to_space)); + let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); if from_parts.scalar_kind != ast::ScalarKind::Float && to_parts.scalar_kind != ast::ScalarKind::Float { @@ -4275,13 +4245,10 @@ fn emit_implicit_conversion( // This block is safe because it's illegal to implictly convert between floating point values let same_width_bit_type = map.get_or_add( builder, - SpirvType::new( - ast::Type::from_parts(TypeParts { - scalar_kind: ast::ScalarKind::Bit, - ..from_parts - }), - cv.from_space, - ), + SpirvType::new(ast::Type::from_parts(TypeParts { + scalar_kind: ast::ScalarKind::Bit, + ..from_parts + })), ); let same_width_bit_value = builder.bitcast(same_width_bit_type, None, cv.src)?; let wide_bit_type = ast::Type::from_parts(TypeParts { @@ -4289,7 +4256,7 @@ fn emit_implicit_conversion( ..to_parts }); let wide_bit_type_spirv = - map.get_or_add(builder, SpirvType::new(wide_bit_type.clone(), cv.to_space)); + map.get_or_add(builder, SpirvType::new(wide_bit_type.clone())); if to_parts.scalar_kind == ast::ScalarKind::Unsigned || to_parts.scalar_kind == ast::ScalarKind::Bit { @@ -4323,15 +4290,13 @@ fn emit_implicit_conversion( } } (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::SignExtend) => { - let result_type = - map.get_or_add(builder, SpirvType::new(cv.to_type.clone(), cv.to_space)); + let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); builder.s_convert(result_type, Some(cv.dst), cv.src)?; } (TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default) | (TypeKind::Scalar, TypeKind::Array, ConversionKind::Default) | (TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => { - let into_type = - map.get_or_add(builder, SpirvType::new(cv.to_type.clone(), cv.to_space)); + let into_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); builder.bitcast(into_type, Some(cv.dst), cv.src)?; } (_, _, ConversionKind::PtrToPtr { spirv_ptr }) => { @@ -4339,12 +4304,12 @@ fn emit_implicit_conversion( map.get_or_add( builder, SpirvType::Pointer( - Box::new(SpirvType::new(cv.to_type.clone(), cv.to_space)), + Box::new(SpirvType::new(cv.to_type.clone())), spirv::StorageClass::Function, ), ) } else { - map.get_or_add(builder, SpirvType::new(cv.to_type.clone(), cv.to_space)) + map.get_or_add(builder, SpirvType::new(cv.to_type.clone())) }; builder.bitcast(result_type, Some(cv.dst), cv.src)?; } @@ -4358,18 +4323,14 @@ fn emit_load_var( map: &mut TypeWordMap, details: &LoadVarDetails, ) -> Result<(), TranslateError> { - let result_type = map.get_or_add( - builder, - SpirvType::new(details.typ.clone(), details.state_space), - ); + let result_type = map.get_or_add(builder, SpirvType::new(details.typ.clone())); match details.member_index { Some((index, Some(width))) => { let vector_type = match details.typ { ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width), _ => return Err(TranslateError::MismatchedType), }; - let vector_type_spirv = - map.get_or_add(builder, SpirvType::new(vector_type, details.state_space)); + let vector_type_spirv = map.get_or_add(builder, SpirvType::new(vector_type)); let vector_temp = builder.load( vector_type_spirv, None, @@ -4387,11 +4348,7 @@ fn emit_load_var( Some((index, None)) => { let result_ptr_type = map.get_or_add( builder, - SpirvType::pointer_to( - details.typ.clone(), - details.state_space, - spirv::StorageClass::Function, - ), + SpirvType::pointer_to(details.typ.clone(), spirv::StorageClass::Function), ); let index_spirv = map.get_or_add_constant( builder, @@ -5661,7 +5618,7 @@ impl> PtrAccess

{ ast::StateSpace::Reg => new_todo!(), ast::StateSpace::Sreg => new_todo!(), }; - let ptr_type = ast::Type::Pointer(self.underlying_type.clone()); + let ptr_type = ast::Type::Pointer(self.underlying_type.clone(), new_todo!()); let new_dst = visitor.id( ArgumentDescriptor { op: self.dst, @@ -6231,24 +6188,28 @@ impl ast::Type { match self { ast::Type::Scalar(scalar) => TypeParts { kind: TypeKind::Scalar, + state_space: ast::StateSpace::Reg, scalar_kind: scalar.kind(), width: scalar.size_of(), components: Vec::new(), }, ast::Type::Vector(scalar, components) => TypeParts { kind: TypeKind::Vector, + state_space: ast::StateSpace::Reg, scalar_kind: scalar.kind(), width: scalar.size_of(), components: vec![*components as u32], }, ast::Type::Array(scalar, components) => TypeParts { kind: TypeKind::Array, + state_space: ast::StateSpace::Reg, scalar_kind: scalar.kind(), width: scalar.size_of(), components: components.clone(), }, - ast::Type::Pointer(scalar) => TypeParts { - kind: TypeKind::PointerScalar, + ast::Type::Pointer(scalar, space) => TypeParts { + kind: TypeKind::Pointer, + state_space: *space, scalar_kind: scalar.kind(), width: scalar.size_of(), components: Vec::new(), @@ -6269,9 +6230,10 @@ impl ast::Type { ast::ScalarType::from_parts(t.width, t.scalar_kind), t.components, ), - TypeKind::PointerScalar => { - ast::Type::Pointer(ast::ScalarType::from_parts(t.width, t.scalar_kind)) - } + TypeKind::Pointer => ast::Type::Pointer( + ast::ScalarType::from_parts(t.width, t.scalar_kind), + t.state_space, + ), } } @@ -6292,6 +6254,7 @@ struct TypeParts { kind: TypeKind, scalar_kind: ast::ScalarKind, width: u8, + state_space: ast::StateSpace, components: Vec, } @@ -6300,7 +6263,7 @@ enum TypeKind { Scalar, Vector, Array, - PointerScalar, + Pointer, } impl ast::Instruction {