mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Add float-to-int cvt
This commit is contained in:
parent
3105674618
commit
002a19354a
1 changed files with 97 additions and 27 deletions
|
@ -1118,7 +1118,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
c"llvm.cos.f32",
|
||||
Some(arguments.dst),
|
||||
&ast::ScalarType::F32.into(),
|
||||
vec![(arguments.src, llvm_f32)],
|
||||
vec![(self.resolver.value(arguments.src)?, llvm_f32)],
|
||||
)?;
|
||||
unsafe { LLVMZludaSetFastMathFlags(cos, LLVMZludaFastMathApproxFunc) }
|
||||
Ok(())
|
||||
|
@ -1371,7 +1371,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
c"llvm.sin.f32",
|
||||
Some(arguments.dst),
|
||||
&ast::ScalarType::F32.into(),
|
||||
vec![(arguments.src, llvm_f32)],
|
||||
vec![(self.resolver.value(arguments.src)?, llvm_f32)],
|
||||
)?;
|
||||
unsafe { LLVMZludaSetFastMathFlags(sin, LLVMZludaFastMathApproxFunc) }
|
||||
Ok(())
|
||||
|
@ -1382,7 +1382,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
name: &CStr,
|
||||
dst: Option<SpirvWord>,
|
||||
return_type: &ast::Type,
|
||||
arguments: Vec<(SpirvWord, LLVMTypeRef)>,
|
||||
arguments: Vec<(LLVMValueRef, LLVMTypeRef)>,
|
||||
) -> Result<LLVMValueRef, TranslateError> {
|
||||
let fn_type = get_function_type(
|
||||
self.context,
|
||||
|
@ -1393,10 +1393,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
if fn_ == ptr::null_mut() {
|
||||
fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
|
||||
}
|
||||
let mut arguments = arguments
|
||||
.iter()
|
||||
.map(|(arg, _)| self.resolver.value(*arg))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let mut arguments = arguments.iter().map(|(arg, _)| *arg).collect::<Vec<_>>();
|
||||
Ok(self.resolver.with_result_option(dst, |dst| unsafe {
|
||||
LLVMBuildCall2(
|
||||
self.builder,
|
||||
|
@ -1538,11 +1535,11 @@ impl<'a> MethodEmitContext<'a> {
|
|||
ptx_parser::CvtMode::Bitcast => LLVMBuildBitCast,
|
||||
ptx_parser::CvtMode::SaturateUnsignedToSigned => todo!(),
|
||||
ptx_parser::CvtMode::SaturateSignedToUnsigned => todo!(),
|
||||
ptx_parser::CvtMode::FPExtend { flush_to_zero } => todo!(),
|
||||
ptx_parser::CvtMode::FPExtend { flush_to_zero } => LLVMBuildFPExt,
|
||||
ptx_parser::CvtMode::FPTruncate {
|
||||
rounding,
|
||||
flush_to_zero,
|
||||
} => todo!(),
|
||||
} => LLVMBuildFPTrunc,
|
||||
ptx_parser::CvtMode::FPRound {
|
||||
integer_rounding,
|
||||
flush_to_zero,
|
||||
|
@ -1550,11 +1547,27 @@ impl<'a> MethodEmitContext<'a> {
|
|||
ptx_parser::CvtMode::SignedFromFP {
|
||||
rounding,
|
||||
flush_to_zero,
|
||||
} => todo!(),
|
||||
} => {
|
||||
return self.emit_cvt_float_to_int(
|
||||
data.from,
|
||||
data.to,
|
||||
rounding,
|
||||
arguments,
|
||||
"llvm.fptosi.sat",
|
||||
)
|
||||
}
|
||||
ptx_parser::CvtMode::UnsignedFromFP {
|
||||
rounding,
|
||||
flush_to_zero,
|
||||
} => todo!(),
|
||||
} => {
|
||||
return self.emit_cvt_float_to_int(
|
||||
data.from,
|
||||
data.to,
|
||||
rounding,
|
||||
arguments,
|
||||
"llvm.fptoui.sat",
|
||||
)
|
||||
}
|
||||
ptx_parser::CvtMode::FPFromSigned(rounding_mode) => todo!(),
|
||||
ptx_parser::CvtMode::FPFromUnsigned(rounding_mode) => todo!(),
|
||||
};
|
||||
|
@ -1565,6 +1578,45 @@ impl<'a> MethodEmitContext<'a> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_cvt_float_to_int(
|
||||
&mut self,
|
||||
from: ast::ScalarType,
|
||||
to: ast::ScalarType,
|
||||
rounding: ast::RoundingMode,
|
||||
arguments: ptx_parser::CvtArgs<SpirvWord>,
|
||||
llvm_cast: &str,
|
||||
) -> Result<(), TranslateError> {
|
||||
let prefix = match rounding {
|
||||
ptx_parser::RoundingMode::NearestEven => "llvm.roundeven",
|
||||
ptx_parser::RoundingMode::Zero => "llvm.trunc",
|
||||
ptx_parser::RoundingMode::NegativeInf => "llvm.floor",
|
||||
ptx_parser::RoundingMode::PositiveInf => "llvm.ceil",
|
||||
};
|
||||
let intrinsic = format!("{}.{}\0", prefix, LLVMTypeDisplay(from));
|
||||
let rounded_float = self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
||||
None,
|
||||
&from.into(),
|
||||
vec![(
|
||||
self.resolver.value(arguments.src)?,
|
||||
get_scalar_type(self.context, from),
|
||||
)],
|
||||
)?;
|
||||
let cast_intrinsic = format!(
|
||||
"{}.{}.{}\0",
|
||||
llvm_cast,
|
||||
LLVMTypeDisplay(to),
|
||||
LLVMTypeDisplay(from)
|
||||
);
|
||||
self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(cast_intrinsic.as_bytes()) },
|
||||
Some(arguments.dst),
|
||||
&to.into(),
|
||||
vec![(rounded_float, get_scalar_type(self.context, from))],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_rsqrt(
|
||||
&mut self,
|
||||
data: ptx_parser::TypeFtz,
|
||||
|
@ -1580,7 +1632,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
intrinsic,
|
||||
Some(arguments.dst),
|
||||
&data.type_.into(),
|
||||
vec![(arguments.src, type_)],
|
||||
vec![(self.resolver.value(arguments.src)?, type_)],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1601,7 +1653,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
intrinsic,
|
||||
Some(arguments.dst),
|
||||
&data.type_.into(),
|
||||
vec![(arguments.src, type_)],
|
||||
vec![(self.resolver.value(arguments.src)?, type_)],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1623,7 +1675,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
intrinsic,
|
||||
Some(arguments.dst),
|
||||
&data.type_.into(),
|
||||
vec![(arguments.src, type_)],
|
||||
vec![(self.resolver.value(arguments.src)?, type_)],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1745,7 +1797,10 @@ impl<'a> MethodEmitContext<'a> {
|
|||
intrinsic,
|
||||
Some(arguments.dst),
|
||||
&data.type_.into(),
|
||||
vec![(arguments.src, get_scalar_type(self.context, data.type_))],
|
||||
vec![(
|
||||
self.resolver.value(arguments.src)?,
|
||||
get_scalar_type(self.context, data.type_),
|
||||
)],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1760,7 +1815,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
Some(arguments.dst),
|
||||
&ast::ScalarType::F32.into(),
|
||||
vec![(
|
||||
arguments.src,
|
||||
self.resolver.value(arguments.src)?,
|
||||
get_scalar_type(self.context, ast::ScalarType::F32.into()),
|
||||
)],
|
||||
)?;
|
||||
|
@ -1814,7 +1869,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
intrinsic,
|
||||
Some(arguments.dst),
|
||||
&type_.into(),
|
||||
vec![(arguments.src, llvm_type)],
|
||||
vec![(self.resolver.value(arguments.src)?, llvm_type)],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1832,13 +1887,16 @@ impl<'a> MethodEmitContext<'a> {
|
|||
}
|
||||
ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.maxnum",
|
||||
};
|
||||
let intrinsic = format!("{}.{}\0", llvm_prefix, ScalarTypeInLLVM(data.type_()));
|
||||
let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_()));
|
||||
let llvm_type = get_scalar_type(self.context, data.type_());
|
||||
self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
||||
Some(arguments.dst),
|
||||
&data.type_().into(),
|
||||
vec![(arguments.src1, llvm_type), (arguments.src2, llvm_type)],
|
||||
vec![
|
||||
(self.resolver.value(arguments.src1)?, llvm_type),
|
||||
(self.resolver.value(arguments.src2)?, llvm_type),
|
||||
],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1856,13 +1914,16 @@ impl<'a> MethodEmitContext<'a> {
|
|||
}
|
||||
ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.minnum",
|
||||
};
|
||||
let intrinsic = format!("{}.{}\0", llvm_prefix, ScalarTypeInLLVM(data.type_()));
|
||||
let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_()));
|
||||
let llvm_type = get_scalar_type(self.context, data.type_());
|
||||
self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
||||
Some(arguments.dst),
|
||||
&data.type_().into(),
|
||||
vec![(arguments.src1, llvm_type), (arguments.src2, llvm_type)],
|
||||
vec![
|
||||
(self.resolver.value(arguments.src1)?, llvm_type),
|
||||
(self.resolver.value(arguments.src2)?, llvm_type),
|
||||
],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1872,15 +1933,24 @@ impl<'a> MethodEmitContext<'a> {
|
|||
data: ptx_parser::ArithFloat,
|
||||
arguments: ptx_parser::FmaArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let intrinsic = format!("llvm.fma.{}\0", ScalarTypeInLLVM(data.type_));
|
||||
let intrinsic = format!("llvm.fma.{}\0", LLVMTypeDisplay(data.type_));
|
||||
self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
||||
Some(arguments.dst),
|
||||
&data.type_.into(),
|
||||
vec![
|
||||
(arguments.src1, get_scalar_type(self.context, data.type_)),
|
||||
(arguments.src2, get_scalar_type(self.context, data.type_)),
|
||||
(arguments.src3, get_scalar_type(self.context, data.type_)),
|
||||
(
|
||||
self.resolver.value(arguments.src1)?,
|
||||
get_scalar_type(self.context, data.type_),
|
||||
),
|
||||
(
|
||||
self.resolver.value(arguments.src2)?,
|
||||
get_scalar_type(self.context, data.type_),
|
||||
),
|
||||
(
|
||||
self.resolver.value(arguments.src3)?,
|
||||
get_scalar_type(self.context, data.type_),
|
||||
),
|
||||
],
|
||||
)?;
|
||||
Ok(())
|
||||
|
@ -2238,9 +2308,9 @@ impl ResolveIdent {
|
|||
}
|
||||
}
|
||||
|
||||
struct ScalarTypeInLLVM(ast::ScalarType);
|
||||
struct LLVMTypeDisplay(ast::ScalarType);
|
||||
|
||||
impl std::fmt::Display for ScalarTypeInLLVM {
|
||||
impl std::fmt::Display for LLVMTypeDisplay {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self.0 {
|
||||
ast::ScalarType::Pred => write!(f, "i1"),
|
||||
|
|
Loading…
Add table
Reference in a new issue