From 6f2944d9be9e5cdf54fe4f52c948161cc10eb94d Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Tue, 15 Oct 2024 04:42:44 +0200 Subject: [PATCH] Add or, mad, fma, min, max, selp, lg2, ex2, popc, rem --- ptx/src/pass/emit_llvm.rs | 309 +++++++++++++++++++++++++++++++++++--- ptx/src/pass/mod.rs | 10 ++ 2 files changed, 298 insertions(+), 21 deletions(-) diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 1784745..209840f 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -518,7 +518,7 @@ impl<'a> MethodEmitContext<'a> { ast::Instruction::Setp { data, arguments } => self.emit_setp(data, arguments), ast::Instruction::SetpBool { .. } => todo!(), ast::Instruction::Not { data, arguments } => self.emit_not(data, arguments), - ast::Instruction::Or { .. } => todo!(), + ast::Instruction::Or { data, arguments } => self.emit_or(data, arguments), ast::Instruction::And { arguments, .. } => self.emit_and(arguments), ast::Instruction::Bra { arguments } => self.emit_bra(arguments), ast::Instruction::Call { data, arguments } => self.emit_call(data, arguments), @@ -528,15 +528,15 @@ impl<'a> MethodEmitContext<'a> { ast::Instruction::Ret { data } => Ok(self.emit_ret(data)), ast::Instruction::Cvta { data, arguments } => self.emit_cvta(data, arguments), ast::Instruction::Abs { .. } => todo!(), - ast::Instruction::Mad { .. } => todo!(), - ast::Instruction::Fma { .. } => todo!(), + ast::Instruction::Mad { data, arguments } => self.emit_mad(data, arguments), + ast::Instruction::Fma { data, arguments } => self.emit_fma(data, arguments), ast::Instruction::Sub { data, arguments } => self.emit_sub(data, arguments), - ast::Instruction::Min { .. } => todo!(), - ast::Instruction::Max { .. } => todo!(), + ast::Instruction::Min { data, arguments } => self.emit_min(data, arguments), + ast::Instruction::Max { data, arguments } => self.emit_max(data, arguments), ast::Instruction::Rcp { data, arguments } => self.emit_rcp(data, arguments), ast::Instruction::Sqrt { data, arguments } => self.emit_sqrt(data, arguments), ast::Instruction::Rsqrt { data, arguments } => self.emit_rsqrt(data, arguments), - ast::Instruction::Selp { .. } => todo!(), + ast::Instruction::Selp { data, arguments } => self.emit_selp(data, arguments), ast::Instruction::Bar { .. } => todo!(), ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments), ast::Instruction::AtomCas { data, arguments } => self.emit_atom_cas(data, arguments), @@ -544,13 +544,13 @@ impl<'a> MethodEmitContext<'a> { ast::Instruction::Neg { data, arguments } => self.emit_neg(data, arguments), ast::Instruction::Sin { data, arguments } => self.emit_sin(data, arguments), ast::Instruction::Cos { data, arguments } => self.emit_cos(data, arguments), - ast::Instruction::Lg2 { .. } => todo!(), - ast::Instruction::Ex2 { .. } => todo!(), + ast::Instruction::Lg2 { data, arguments } => self.emit_lg2(data, arguments), + ast::Instruction::Ex2 { data, arguments } => self.emit_ex2(data, arguments), ast::Instruction::Clz { data, arguments } => self.emit_clz(data, arguments), ast::Instruction::Brev { data, arguments } => self.emit_brev(data, arguments), - ast::Instruction::Popc { .. } => todo!(), + ast::Instruction::Popc { data, arguments } => self.emit_popc(data, arguments), ast::Instruction::Xor { data, arguments } => self.emit_xor(data, arguments), - ast::Instruction::Rem { .. } => todo!(), + ast::Instruction::Rem { data, arguments } => self.emit_rem(data, arguments), ast::Instruction::PrmtSlow { .. } => todo!(), ast::Instruction::Prmt { .. } => todo!(), ast::Instruction::Membar { .. } => todo!(), @@ -664,7 +664,14 @@ impl<'a> MethodEmitContext<'a> { _ => todo!(), } } - ConversionKind::SignExtend => todo!(), + ConversionKind::SignExtend => { + let src = self.resolver.value(conversion.src)?; + let type_ = get_type(self.context, &conversion.to_type)?; + self.resolver.with_result(conversion.dst, |dst| unsafe { + LLVMBuildSExt(builder, src, type_, dst) + }); + Ok(()) + } ConversionKind::BitToPtr => { let src = self.resolver.value(conversion.src)?; let type_ = get_pointer_type(self.context, conversion.to_space)?; @@ -986,20 +993,82 @@ impl<'a> MethodEmitContext<'a> { data: ast::MulDetails, arguments: ast::MulArgs, ) -> Result<(), TranslateError> { + self.emit_mul_impl(data, Some(arguments.dst), arguments.src1, arguments.src2)?; + Ok(()) + } + + fn emit_mul_impl( + &mut self, + data: ast::MulDetails, + dst: Option, + src1: SpirvWord, + src2: SpirvWord, + ) -> Result { let mul_fn = match data { - ast::MulDetails::Integer { control, .. } => match control { + ast::MulDetails::Integer { control, type_ } => match control { ast::MulIntControl::Low => LLVMBuildMul, - ast::MulIntControl::High => todo!(), - ast::MulIntControl::Wide => todo!(), + ast::MulIntControl::High => return self.emit_mul_high(type_, dst, src1, src2), + ast::MulIntControl::Wide => { + return Ok(self.emit_mul_wide_impl(type_, dst, src1, src2)?.1) + } }, ast::MulDetails::Float(..) => LLVMBuildFMul, }; - let src1 = self.resolver.value(arguments.src1)?; - let src2 = self.resolver.value(arguments.src2)?; - self.resolver.with_result(arguments.dst, |dst| unsafe { - mul_fn(self.builder, src1, src2, dst) - }); - Ok(()) + let src1 = self.resolver.value(src1)?; + let src2 = self.resolver.value(src2)?; + Ok(self + .resolver + .with_result_option(dst, |dst| unsafe { mul_fn(self.builder, src1, src2, dst) })) + } + + fn emit_mul_high( + &mut self, + type_: ptx_parser::ScalarType, + dst: Option, + src1: SpirvWord, + src2: SpirvWord, + ) -> Result { + let (wide_type, wide_value) = self.emit_mul_wide_impl(type_, None, src1, src2)?; + let shift_constant = + unsafe { LLVMConstInt(wide_type, (type_.layout().size() * 8) as u64, 0) }; + let shifted = unsafe { + LLVMBuildLShr( + self.builder, + wide_value, + shift_constant, + LLVM_UNNAMED.as_ptr(), + ) + }; + let narrow_type = get_scalar_type(self.context, type_); + Ok(self.resolver.with_result_option(dst, |dst| unsafe { + LLVMBuildTrunc(self.builder, shifted, narrow_type, dst) + })) + } + + fn emit_mul_wide_impl( + &mut self, + type_: ptx_parser::ScalarType, + dst: Option, + src1: SpirvWord, + src2: SpirvWord, + ) -> Result<(LLVMTypeRef, LLVMValueRef), TranslateError> { + let src1 = self.resolver.value(src1)?; + let src2 = self.resolver.value(src2)?; + let wide_type = + unsafe { LLVMIntTypeInContext(self.context, (type_.layout().size() * 8 * 2) as u32) }; + let llvm_cast = match type_.kind() { + ptx_parser::ScalarKind::Signed => LLVMBuildSExt, + ptx_parser::ScalarKind::Unsigned => LLVMBuildZExt, + _ => return Err(error_unreachable()), + }; + let src1 = unsafe { llvm_cast(self.builder, src1, wide_type, LLVM_UNNAMED.as_ptr()) }; + let src2 = unsafe { llvm_cast(self.builder, src2, wide_type, LLVM_UNNAMED.as_ptr()) }; + Ok(( + wide_type, + self.resolver.with_result_option(dst, |dst| unsafe { + LLVMBuildMul(self.builder, src1, src2, dst) + }), + )) } fn emit_cos( @@ -1018,6 +1087,19 @@ impl<'a> MethodEmitContext<'a> { Ok(()) } + fn emit_or( + &mut self, + _data: ptx_parser::ScalarType, + arguments: ptx_parser::OrArgs, + ) -> Result<(), TranslateError> { + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildOr(self.builder, src1, src2, dst) + }); + Ok(()) + } + fn emit_xor( &mut self, _data: ptx_parser::ScalarType, @@ -1612,6 +1694,191 @@ impl<'a> MethodEmitContext<'a> { Ok(()) } + fn emit_ex2( + &mut self, + data: ptx_parser::TypeFtz, + arguments: ptx_parser::Ex2Args, + ) -> Result<(), TranslateError> { + let intrinsic = match data.type_ { + ast::ScalarType::F16 => c"llvm.amdgcn.exp2.f16", + ast::ScalarType::F32 => c"llvm.amdgcn.exp2.f32", + _ => return Err(error_unreachable()), + }; + self.emit_intrinsic( + intrinsic, + Some(arguments.dst), + &data.type_.into(), + vec![(arguments.src, get_scalar_type(self.context, data.type_))], + )?; + Ok(()) + } + + fn emit_lg2( + &mut self, + _data: ptx_parser::FlushToZero, + arguments: ptx_parser::Lg2Args, + ) -> Result<(), TranslateError> { + self.emit_intrinsic( + c"llvm.amdgcn.log.f32", + Some(arguments.dst), + &ast::ScalarType::F32.into(), + vec![( + arguments.src, + get_scalar_type(self.context, ast::ScalarType::F32.into()), + )], + )?; + Ok(()) + } + + fn emit_selp( + &mut self, + _data: ptx_parser::ScalarType, + arguments: ptx_parser::SelpArgs, + ) -> Result<(), TranslateError> { + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + let src3 = self.resolver.value(arguments.src3)?; + self.resolver.with_result(arguments.dst, |dst_name| unsafe { + LLVMBuildSelect(self.builder, src3, src1, src2, dst_name) + }); + Ok(()) + } + + fn emit_rem( + &mut self, + data: ptx_parser::ScalarType, + arguments: ptx_parser::RemArgs, + ) -> Result<(), TranslateError> { + let llvm_fn = match data.kind() { + ptx_parser::ScalarKind::Unsigned => LLVMBuildURem, + ptx_parser::ScalarKind::Signed => LLVMBuildSRem, + _ => return Err(error_unreachable()), + }; + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + self.resolver.with_result(arguments.dst, |dst_name| unsafe { + llvm_fn(self.builder, src1, src2, dst_name) + }); + Ok(()) + } + + fn emit_popc( + &mut self, + type_: ptx_parser::ScalarType, + arguments: ptx_parser::PopcArgs, + ) -> Result<(), TranslateError> { + let intrinsic = match type_ { + ast::ScalarType::B32 => c"llvm.ctpop.i32", + ast::ScalarType::B64 => c"llvm.ctpop.i64", + _ => return Err(error_unreachable()), + }; + let llvm_type = get_scalar_type(self.context, type_); + self.emit_intrinsic( + intrinsic, + Some(arguments.dst), + &type_.into(), + vec![(arguments.src, llvm_type)], + )?; + Ok(()) + } + + fn emit_min( + &mut self, + data: ptx_parser::MinMaxDetails, + arguments: ptx_parser::MinArgs, + ) -> Result<(), TranslateError> { + let llvm_prefix = match data { + ptx_parser::MinMaxDetails::Signed(..) => "llvm.smin", + ptx_parser::MinMaxDetails::Unsigned(..) => "llvm.umin", + ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => { + return Err(error_todo()) + } + ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.maxnum", + }; + let intrinsic = format!("{}.{}\0", llvm_prefix, ScalarTypeInLLVM(data.type_())); + let llvm_type = get_scalar_type(self.context, data.type_()); + self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, + Some(arguments.dst), + &data.type_().into(), + vec![(arguments.src1, llvm_type), (arguments.src2, llvm_type)], + )?; + Ok(()) + } + + fn emit_max( + &mut self, + data: ptx_parser::MinMaxDetails, + arguments: ptx_parser::MaxArgs, + ) -> Result<(), TranslateError> { + let llvm_prefix = match data { + ptx_parser::MinMaxDetails::Signed(..) => "llvm.smax", + ptx_parser::MinMaxDetails::Unsigned(..) => "llvm.umax", + ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => { + return Err(error_todo()) + } + ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.minnum", + }; + let intrinsic = format!("{}.{}\0", llvm_prefix, ScalarTypeInLLVM(data.type_())); + let llvm_type = get_scalar_type(self.context, data.type_()); + self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, + Some(arguments.dst), + &data.type_().into(), + vec![(arguments.src1, llvm_type), (arguments.src2, llvm_type)], + )?; + Ok(()) + } + + fn emit_fma( + &mut self, + data: ptx_parser::ArithFloat, + arguments: ptx_parser::FmaArgs, + ) -> Result<(), TranslateError> { + let intrinsic = format!("llvm.fma.{}\0", ScalarTypeInLLVM(data.type_)); + self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, + Some(arguments.dst), + &data.type_.into(), + vec![ + (arguments.src1, get_scalar_type(self.context, data.type_)), + (arguments.src2, get_scalar_type(self.context, data.type_)), + (arguments.src3, get_scalar_type(self.context, data.type_)), + ], + )?; + Ok(()) + } + + fn emit_mad( + &mut self, + data: ptx_parser::MadDetails, + arguments: ptx_parser::MadArgs, + ) -> Result<(), TranslateError> { + let mul_control = match data { + ptx_parser::MadDetails::Float(mad_float) => { + return self.emit_fma( + mad_float, + ast::FmaArgs { + dst: arguments.dst, + src1: arguments.src1, + src2: arguments.src2, + src3: arguments.src3, + }, + ) + } + ptx_parser::MadDetails::Integer { saturate: true, .. } => return Err(error_todo()), + ptx_parser::MadDetails::Integer { type_, control, .. } => { + ast::MulDetails::Integer { control, type_ } + } + }; + let temp = self.emit_mul_impl(mul_control, None, arguments.src1, arguments.src2)?; + let src3 = self.resolver.value(arguments.src3)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildAdd(self.builder, temp, src3, dst) + }); + Ok(()) + } + /* // Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding` // Should be available in LLVM 19 @@ -1870,7 +2137,6 @@ impl ResolveIdent { } } -/* struct ScalarTypeInLLVM(ast::ScalarType); impl std::fmt::Display for ScalarTypeInLLVM { @@ -1893,6 +2159,7 @@ impl std::fmt::Display for ScalarTypeInLLVM { } } +/* fn rounding_to_llvm(this: ast::RoundingMode) -> u32 { match this { ptx_parser::RoundingMode::Zero => 0, diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 65292eb..ef131b4 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -149,6 +149,16 @@ fn error_unreachable() -> TranslateError { TranslateError::Unreachable } +#[cfg(debug_assertions)] +fn error_todo() -> TranslateError { + unreachable!() +} + +#[cfg(not(debug_assertions))] +fn error_todo() -> TranslateError { + TranslateError::Todo +} + #[cfg(debug_assertions)] fn error_unknown_symbol() -> TranslateError { panic!()