diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 54a07aa..cc40410 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -1118,7 +1118,7 @@ impl<'a> MethodEmitContext<'a> { c"llvm.cos.f32", Some(arguments.dst), &ast::ScalarType::F32.into(), - vec![(arguments.src, llvm_f32)], + vec![(self.resolver.value(arguments.src)?, llvm_f32)], )?; unsafe { LLVMZludaSetFastMathFlags(cos, LLVMZludaFastMathApproxFunc) } Ok(()) @@ -1371,7 +1371,7 @@ impl<'a> MethodEmitContext<'a> { c"llvm.sin.f32", Some(arguments.dst), &ast::ScalarType::F32.into(), - vec![(arguments.src, llvm_f32)], + vec![(self.resolver.value(arguments.src)?, llvm_f32)], )?; unsafe { LLVMZludaSetFastMathFlags(sin, LLVMZludaFastMathApproxFunc) } Ok(()) @@ -1382,7 +1382,7 @@ impl<'a> MethodEmitContext<'a> { name: &CStr, dst: Option, return_type: &ast::Type, - arguments: Vec<(SpirvWord, LLVMTypeRef)>, + arguments: Vec<(LLVMValueRef, LLVMTypeRef)>, ) -> Result { let fn_type = get_function_type( self.context, @@ -1393,10 +1393,7 @@ impl<'a> MethodEmitContext<'a> { if fn_ == ptr::null_mut() { fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) }; } - let mut arguments = arguments - .iter() - .map(|(arg, _)| self.resolver.value(*arg)) - .collect::, _>>()?; + let mut arguments = arguments.iter().map(|(arg, _)| *arg).collect::>(); Ok(self.resolver.with_result_option(dst, |dst| unsafe { LLVMBuildCall2( self.builder, @@ -1538,11 +1535,11 @@ impl<'a> MethodEmitContext<'a> { ptx_parser::CvtMode::Bitcast => LLVMBuildBitCast, ptx_parser::CvtMode::SaturateUnsignedToSigned => todo!(), ptx_parser::CvtMode::SaturateSignedToUnsigned => todo!(), - ptx_parser::CvtMode::FPExtend { flush_to_zero } => todo!(), + ptx_parser::CvtMode::FPExtend { flush_to_zero } => LLVMBuildFPExt, ptx_parser::CvtMode::FPTruncate { rounding, flush_to_zero, - } => todo!(), + } => LLVMBuildFPTrunc, ptx_parser::CvtMode::FPRound { integer_rounding, flush_to_zero, @@ -1550,11 +1547,27 @@ impl<'a> MethodEmitContext<'a> { ptx_parser::CvtMode::SignedFromFP { rounding, flush_to_zero, - } => todo!(), + } => { + return self.emit_cvt_float_to_int( + data.from, + data.to, + rounding, + arguments, + "llvm.fptosi.sat", + ) + } ptx_parser::CvtMode::UnsignedFromFP { rounding, flush_to_zero, - } => todo!(), + } => { + return self.emit_cvt_float_to_int( + data.from, + data.to, + rounding, + arguments, + "llvm.fptoui.sat", + ) + } ptx_parser::CvtMode::FPFromSigned(rounding_mode) => todo!(), ptx_parser::CvtMode::FPFromUnsigned(rounding_mode) => todo!(), }; @@ -1565,6 +1578,45 @@ impl<'a> MethodEmitContext<'a> { Ok(()) } + fn emit_cvt_float_to_int( + &mut self, + from: ast::ScalarType, + to: ast::ScalarType, + rounding: ast::RoundingMode, + arguments: ptx_parser::CvtArgs, + llvm_cast: &str, + ) -> Result<(), TranslateError> { + let prefix = match rounding { + ptx_parser::RoundingMode::NearestEven => "llvm.roundeven", + ptx_parser::RoundingMode::Zero => "llvm.trunc", + ptx_parser::RoundingMode::NegativeInf => "llvm.floor", + ptx_parser::RoundingMode::PositiveInf => "llvm.ceil", + }; + let intrinsic = format!("{}.{}\0", prefix, LLVMTypeDisplay(from)); + let rounded_float = self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, + None, + &from.into(), + vec![( + self.resolver.value(arguments.src)?, + get_scalar_type(self.context, from), + )], + )?; + let cast_intrinsic = format!( + "{}.{}.{}\0", + llvm_cast, + LLVMTypeDisplay(to), + LLVMTypeDisplay(from) + ); + self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(cast_intrinsic.as_bytes()) }, + Some(arguments.dst), + &to.into(), + vec![(rounded_float, get_scalar_type(self.context, from))], + )?; + Ok(()) + } + fn emit_rsqrt( &mut self, data: ptx_parser::TypeFtz, @@ -1580,7 +1632,7 @@ impl<'a> MethodEmitContext<'a> { intrinsic, Some(arguments.dst), &data.type_.into(), - vec![(arguments.src, type_)], + vec![(self.resolver.value(arguments.src)?, type_)], )?; Ok(()) } @@ -1601,7 +1653,7 @@ impl<'a> MethodEmitContext<'a> { intrinsic, Some(arguments.dst), &data.type_.into(), - vec![(arguments.src, type_)], + vec![(self.resolver.value(arguments.src)?, type_)], )?; Ok(()) } @@ -1623,7 +1675,7 @@ impl<'a> MethodEmitContext<'a> { intrinsic, Some(arguments.dst), &data.type_.into(), - vec![(arguments.src, type_)], + vec![(self.resolver.value(arguments.src)?, type_)], )?; Ok(()) } @@ -1745,7 +1797,10 @@ impl<'a> MethodEmitContext<'a> { intrinsic, Some(arguments.dst), &data.type_.into(), - vec![(arguments.src, get_scalar_type(self.context, data.type_))], + vec![( + self.resolver.value(arguments.src)?, + get_scalar_type(self.context, data.type_), + )], )?; Ok(()) } @@ -1760,7 +1815,7 @@ impl<'a> MethodEmitContext<'a> { Some(arguments.dst), &ast::ScalarType::F32.into(), vec![( - arguments.src, + self.resolver.value(arguments.src)?, get_scalar_type(self.context, ast::ScalarType::F32.into()), )], )?; @@ -1814,7 +1869,7 @@ impl<'a> MethodEmitContext<'a> { intrinsic, Some(arguments.dst), &type_.into(), - vec![(arguments.src, llvm_type)], + vec![(self.resolver.value(arguments.src)?, llvm_type)], )?; Ok(()) } @@ -1832,13 +1887,16 @@ impl<'a> MethodEmitContext<'a> { } ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.maxnum", }; - let intrinsic = format!("{}.{}\0", llvm_prefix, ScalarTypeInLLVM(data.type_())); + let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_())); let llvm_type = get_scalar_type(self.context, data.type_()); self.emit_intrinsic( unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, Some(arguments.dst), &data.type_().into(), - vec![(arguments.src1, llvm_type), (arguments.src2, llvm_type)], + vec![ + (self.resolver.value(arguments.src1)?, llvm_type), + (self.resolver.value(arguments.src2)?, llvm_type), + ], )?; Ok(()) } @@ -1856,13 +1914,16 @@ impl<'a> MethodEmitContext<'a> { } ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.minnum", }; - let intrinsic = format!("{}.{}\0", llvm_prefix, ScalarTypeInLLVM(data.type_())); + let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_())); let llvm_type = get_scalar_type(self.context, data.type_()); self.emit_intrinsic( unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, Some(arguments.dst), &data.type_().into(), - vec![(arguments.src1, llvm_type), (arguments.src2, llvm_type)], + vec![ + (self.resolver.value(arguments.src1)?, llvm_type), + (self.resolver.value(arguments.src2)?, llvm_type), + ], )?; Ok(()) } @@ -1872,15 +1933,24 @@ impl<'a> MethodEmitContext<'a> { data: ptx_parser::ArithFloat, arguments: ptx_parser::FmaArgs, ) -> Result<(), TranslateError> { - let intrinsic = format!("llvm.fma.{}\0", ScalarTypeInLLVM(data.type_)); + let intrinsic = format!("llvm.fma.{}\0", LLVMTypeDisplay(data.type_)); self.emit_intrinsic( unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, Some(arguments.dst), &data.type_.into(), vec![ - (arguments.src1, get_scalar_type(self.context, data.type_)), - (arguments.src2, get_scalar_type(self.context, data.type_)), - (arguments.src3, get_scalar_type(self.context, data.type_)), + ( + self.resolver.value(arguments.src1)?, + get_scalar_type(self.context, data.type_), + ), + ( + self.resolver.value(arguments.src2)?, + get_scalar_type(self.context, data.type_), + ), + ( + self.resolver.value(arguments.src3)?, + get_scalar_type(self.context, data.type_), + ), ], )?; Ok(()) @@ -2238,9 +2308,9 @@ impl ResolveIdent { } } -struct ScalarTypeInLLVM(ast::ScalarType); +struct LLVMTypeDisplay(ast::ScalarType); -impl std::fmt::Display for ScalarTypeInLLVM { +impl std::fmt::Display for LLVMTypeDisplay { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self.0 { ast::ScalarType::Pred => write!(f, "i1"),