diff --git a/ptx/src/pass/insert_implicit_conversions2.rs b/ptx/src/pass/insert_implicit_conversions2.rs index b2d3161..bd1675c 100644 --- a/ptx/src/pass/insert_implicit_conversions2.rs +++ b/ptx/src/pass/insert_implicit_conversions2.rs @@ -186,7 +186,7 @@ fn default_implicit_conversion_space( ast::Type::Scalar(ast::ScalarType::B32) | ast::Type::Scalar(ast::ScalarType::U32) | ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space { - ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => { + ast::StateSpace::Local | ast::StateSpace::Shared => { Ok(Some(ConversionKind::BitToPtr)) } _ => Err(error_mismatched_type()), diff --git a/ptx/src/pass/llvm/emit.rs b/ptx/src/pass/llvm/emit.rs index d0e826c..bcc82e4 100644 --- a/ptx/src/pass/llvm/emit.rs +++ b/ptx/src/pass/llvm/emit.rs @@ -693,6 +693,7 @@ impl<'a> MethodEmitContext<'a> { } } (ast::Type::Vector(..), ast::Type::Scalar(..)) + | (ast::Type::Scalar(..), ast::Type::Vector(..)) | (ast::Type::Scalar(..), ast::Type::Array(..)) | (ast::Type::Array(..), ast::Type::Scalar(..)) => { let dst_type = get_type(self.context, to_type)?; @@ -701,7 +702,7 @@ impl<'a> MethodEmitContext<'a> { }); Ok(()) } - _ => todo!(), + _ => return Err(error_todo()), } } @@ -2409,7 +2410,7 @@ impl<'a> MethodEmitContext<'a> { (control >> 12) & 0b1111, ]; if components.iter().any(|&c| c > 7) { - return Err(TranslateError::Todo("".to_string())); + return Err(error_todo()); } let u32_type = get_scalar_type(self.context, ast::ScalarType::U32); let v4u8_type = get_type(self.context, &ast::Type::Vector(4, ast::ScalarType::U8))?; diff --git a/ptx/src/pass/llvm/mod.rs b/ptx/src/pass/llvm/mod.rs index 40781fc..5e5705c 100644 --- a/ptx/src/pass/llvm/mod.rs +++ b/ptx/src/pass/llvm/mod.rs @@ -169,8 +169,9 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR ast::ScalarType::F32 => unsafe { LLVMFloatTypeInContext(context) }, ast::ScalarType::F64 => unsafe { LLVMDoubleTypeInContext(context) }, ast::ScalarType::BF16 => unsafe { LLVMBFloatTypeInContext(context) }, - ast::ScalarType::U16x2 => todo!(), - ast::ScalarType::S16x2 => todo!(), + ast::ScalarType::U16x2 | ast::ScalarType::S16x2 => unsafe { + LLVMVectorType(LLVMInt16TypeInContext(context), 2) + }, ast::ScalarType::F16x2 => unsafe { LLVMVectorType(LLVMHalfTypeInContext(context), 2) }, ast::ScalarType::BF16x2 => unsafe { LLVMVectorType(LLVMBFloatTypeInContext(context), 2) }, } @@ -180,14 +181,14 @@ fn get_state_space(space: ast::StateSpace) -> Result { match space { ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE), ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE), - ast::StateSpace::Param => Err(TranslateError::Todo("".to_string())), + ast::StateSpace::Param => Err(error_todo()), ast::StateSpace::ParamEntry => Ok(CONSTANT_ADDRESS_SPACE), - ast::StateSpace::ParamFunc => Err(TranslateError::Todo("".to_string())), + ast::StateSpace::ParamFunc => Err(error_todo()), ast::StateSpace::Local => Ok(PRIVATE_ADDRESS_SPACE), ast::StateSpace::Global => Ok(GLOBAL_ADDRESS_SPACE), ast::StateSpace::Const => Ok(CONSTANT_ADDRESS_SPACE), ast::StateSpace::Shared => Ok(SHARED_ADDRESS_SPACE), - ast::StateSpace::SharedCta => Err(TranslateError::Todo("".to_string())), - ast::StateSpace::SharedCluster => Err(TranslateError::Todo("".to_string())), + ast::StateSpace::SharedCta => Err(error_todo()), + ast::StateSpace::SharedCluster => Err(error_todo()), } } diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index bd72ca7..2ede637 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -339,6 +339,7 @@ test_ptx!( [0x8e2da590u32, 0xedeaee14, 0x248a9f70], [613065134u32] ); +test_ptx!(param_is_addressable, [0xDEAD], [0u64]); test_ptx!(assertfail); // TODO: not yet supported diff --git a/ptx/src/test/spirv_run/param_is_addressable.ptx b/ptx/src/test/spirv_run/param_is_addressable.ptx new file mode 100644 index 0000000..8d394b3 --- /dev/null +++ b/ptx/src/test/spirv_run/param_is_addressable.ptx @@ -0,0 +1,22 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry param_is_addressable( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .b64 temp; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + mov.b64 temp, input; + ld.param.b64 temp, [temp]; + sub.u64 temp, temp, in_addr; + st.u64 [out_addr], temp; + ret; +}