Add float-to-int cvt

This commit is contained in:
Andrzej Janik 2024-10-15 19:16:11 +02:00
parent 3105674618
commit 002a19354a

View file

@ -1118,7 +1118,7 @@ impl<'a> MethodEmitContext<'a> {
c"llvm.cos.f32",
Some(arguments.dst),
&ast::ScalarType::F32.into(),
vec![(arguments.src, llvm_f32)],
vec![(self.resolver.value(arguments.src)?, llvm_f32)],
)?;
unsafe { LLVMZludaSetFastMathFlags(cos, LLVMZludaFastMathApproxFunc) }
Ok(())
@ -1371,7 +1371,7 @@ impl<'a> MethodEmitContext<'a> {
c"llvm.sin.f32",
Some(arguments.dst),
&ast::ScalarType::F32.into(),
vec![(arguments.src, llvm_f32)],
vec![(self.resolver.value(arguments.src)?, llvm_f32)],
)?;
unsafe { LLVMZludaSetFastMathFlags(sin, LLVMZludaFastMathApproxFunc) }
Ok(())
@ -1382,7 +1382,7 @@ impl<'a> MethodEmitContext<'a> {
name: &CStr,
dst: Option<SpirvWord>,
return_type: &ast::Type,
arguments: Vec<(SpirvWord, LLVMTypeRef)>,
arguments: Vec<(LLVMValueRef, LLVMTypeRef)>,
) -> Result<LLVMValueRef, TranslateError> {
let fn_type = get_function_type(
self.context,
@ -1393,10 +1393,7 @@ impl<'a> MethodEmitContext<'a> {
if fn_ == ptr::null_mut() {
fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
}
let mut arguments = arguments
.iter()
.map(|(arg, _)| self.resolver.value(*arg))
.collect::<Result<Vec<_>, _>>()?;
let mut arguments = arguments.iter().map(|(arg, _)| *arg).collect::<Vec<_>>();
Ok(self.resolver.with_result_option(dst, |dst| unsafe {
LLVMBuildCall2(
self.builder,
@ -1538,11 +1535,11 @@ impl<'a> MethodEmitContext<'a> {
ptx_parser::CvtMode::Bitcast => LLVMBuildBitCast,
ptx_parser::CvtMode::SaturateUnsignedToSigned => todo!(),
ptx_parser::CvtMode::SaturateSignedToUnsigned => todo!(),
ptx_parser::CvtMode::FPExtend { flush_to_zero } => todo!(),
ptx_parser::CvtMode::FPExtend { flush_to_zero } => LLVMBuildFPExt,
ptx_parser::CvtMode::FPTruncate {
rounding,
flush_to_zero,
} => todo!(),
} => LLVMBuildFPTrunc,
ptx_parser::CvtMode::FPRound {
integer_rounding,
flush_to_zero,
@ -1550,11 +1547,27 @@ impl<'a> MethodEmitContext<'a> {
ptx_parser::CvtMode::SignedFromFP {
rounding,
flush_to_zero,
} => todo!(),
} => {
return self.emit_cvt_float_to_int(
data.from,
data.to,
rounding,
arguments,
"llvm.fptosi.sat",
)
}
ptx_parser::CvtMode::UnsignedFromFP {
rounding,
flush_to_zero,
} => todo!(),
} => {
return self.emit_cvt_float_to_int(
data.from,
data.to,
rounding,
arguments,
"llvm.fptoui.sat",
)
}
ptx_parser::CvtMode::FPFromSigned(rounding_mode) => todo!(),
ptx_parser::CvtMode::FPFromUnsigned(rounding_mode) => todo!(),
};
@ -1565,6 +1578,45 @@ impl<'a> MethodEmitContext<'a> {
Ok(())
}
fn emit_cvt_float_to_int(
&mut self,
from: ast::ScalarType,
to: ast::ScalarType,
rounding: ast::RoundingMode,
arguments: ptx_parser::CvtArgs<SpirvWord>,
llvm_cast: &str,
) -> Result<(), TranslateError> {
let prefix = match rounding {
ptx_parser::RoundingMode::NearestEven => "llvm.roundeven",
ptx_parser::RoundingMode::Zero => "llvm.trunc",
ptx_parser::RoundingMode::NegativeInf => "llvm.floor",
ptx_parser::RoundingMode::PositiveInf => "llvm.ceil",
};
let intrinsic = format!("{}.{}\0", prefix, LLVMTypeDisplay(from));
let rounded_float = self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
None,
&from.into(),
vec![(
self.resolver.value(arguments.src)?,
get_scalar_type(self.context, from),
)],
)?;
let cast_intrinsic = format!(
"{}.{}.{}\0",
llvm_cast,
LLVMTypeDisplay(to),
LLVMTypeDisplay(from)
);
self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(cast_intrinsic.as_bytes()) },
Some(arguments.dst),
&to.into(),
vec![(rounded_float, get_scalar_type(self.context, from))],
)?;
Ok(())
}
fn emit_rsqrt(
&mut self,
data: ptx_parser::TypeFtz,
@ -1580,7 +1632,7 @@ impl<'a> MethodEmitContext<'a> {
intrinsic,
Some(arguments.dst),
&data.type_.into(),
vec![(arguments.src, type_)],
vec![(self.resolver.value(arguments.src)?, type_)],
)?;
Ok(())
}
@ -1601,7 +1653,7 @@ impl<'a> MethodEmitContext<'a> {
intrinsic,
Some(arguments.dst),
&data.type_.into(),
vec![(arguments.src, type_)],
vec![(self.resolver.value(arguments.src)?, type_)],
)?;
Ok(())
}
@ -1623,7 +1675,7 @@ impl<'a> MethodEmitContext<'a> {
intrinsic,
Some(arguments.dst),
&data.type_.into(),
vec![(arguments.src, type_)],
vec![(self.resolver.value(arguments.src)?, type_)],
)?;
Ok(())
}
@ -1745,7 +1797,10 @@ impl<'a> MethodEmitContext<'a> {
intrinsic,
Some(arguments.dst),
&data.type_.into(),
vec![(arguments.src, get_scalar_type(self.context, data.type_))],
vec![(
self.resolver.value(arguments.src)?,
get_scalar_type(self.context, data.type_),
)],
)?;
Ok(())
}
@ -1760,7 +1815,7 @@ impl<'a> MethodEmitContext<'a> {
Some(arguments.dst),
&ast::ScalarType::F32.into(),
vec![(
arguments.src,
self.resolver.value(arguments.src)?,
get_scalar_type(self.context, ast::ScalarType::F32.into()),
)],
)?;
@ -1814,7 +1869,7 @@ impl<'a> MethodEmitContext<'a> {
intrinsic,
Some(arguments.dst),
&type_.into(),
vec![(arguments.src, llvm_type)],
vec![(self.resolver.value(arguments.src)?, llvm_type)],
)?;
Ok(())
}
@ -1832,13 +1887,16 @@ impl<'a> MethodEmitContext<'a> {
}
ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.maxnum",
};
let intrinsic = format!("{}.{}\0", llvm_prefix, ScalarTypeInLLVM(data.type_()));
let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_()));
let llvm_type = get_scalar_type(self.context, data.type_());
self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
Some(arguments.dst),
&data.type_().into(),
vec![(arguments.src1, llvm_type), (arguments.src2, llvm_type)],
vec![
(self.resolver.value(arguments.src1)?, llvm_type),
(self.resolver.value(arguments.src2)?, llvm_type),
],
)?;
Ok(())
}
@ -1856,13 +1914,16 @@ impl<'a> MethodEmitContext<'a> {
}
ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.minnum",
};
let intrinsic = format!("{}.{}\0", llvm_prefix, ScalarTypeInLLVM(data.type_()));
let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_()));
let llvm_type = get_scalar_type(self.context, data.type_());
self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
Some(arguments.dst),
&data.type_().into(),
vec![(arguments.src1, llvm_type), (arguments.src2, llvm_type)],
vec![
(self.resolver.value(arguments.src1)?, llvm_type),
(self.resolver.value(arguments.src2)?, llvm_type),
],
)?;
Ok(())
}
@ -1872,15 +1933,24 @@ impl<'a> MethodEmitContext<'a> {
data: ptx_parser::ArithFloat,
arguments: ptx_parser::FmaArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let intrinsic = format!("llvm.fma.{}\0", ScalarTypeInLLVM(data.type_));
let intrinsic = format!("llvm.fma.{}\0", LLVMTypeDisplay(data.type_));
self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
Some(arguments.dst),
&data.type_.into(),
vec![
(arguments.src1, get_scalar_type(self.context, data.type_)),
(arguments.src2, get_scalar_type(self.context, data.type_)),
(arguments.src3, get_scalar_type(self.context, data.type_)),
(
self.resolver.value(arguments.src1)?,
get_scalar_type(self.context, data.type_),
),
(
self.resolver.value(arguments.src2)?,
get_scalar_type(self.context, data.type_),
),
(
self.resolver.value(arguments.src3)?,
get_scalar_type(self.context, data.type_),
),
],
)?;
Ok(())
@ -2238,9 +2308,9 @@ impl ResolveIdent {
}
}
struct ScalarTypeInLLVM(ast::ScalarType);
struct LLVMTypeDisplay(ast::ScalarType);
impl std::fmt::Display for ScalarTypeInLLVM {
impl std::fmt::Display for LLVMTypeDisplay {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.0 {
ast::ScalarType::Pred => write!(f, "i1"),