From c8b88f4483eaf5ee68cd9306ca57dfaa5f7d0ce0 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 11 Oct 2024 16:27:36 +0200 Subject: [PATCH] Implement div --- ptx/src/pass/emit_llvm.rs | 184 +++++++++++++++++++++++++++++++++++++- 1 file changed, 182 insertions(+), 2 deletions(-) diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 7c6cbb7..15177bc 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -18,6 +18,12 @@ // while with plain LLVM-C it's just: // unsafe { LLVMBuildAdd(builder, lhs, rhs, dst) }; +// AMDGPU LLVM backend support for llvm.experimental.constrained.* is incomplete. +// Emitting @llvm.experimental.constrained.fdiv.f32(...) makes LLVm fail with +// "LLVM ERROR: unsupported libcall legalization". Running with "-mllvm -print-before-all" +// shows it fails inside amdgpu-isel. You can get a little bit furthr with "-mllvm -global-isel", +// but it will too fail similarly, but with "unable to legalize instruction" + use std::array::TryFromSliceError; use std::convert::TryInto; use std::ffi::{CStr, NulError}; @@ -534,7 +540,7 @@ impl<'a> MethodEmitContext<'a> { ast::Instruction::Bar { .. } => todo!(), ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments), ast::Instruction::AtomCas { data, arguments } => self.emit_atom_cas(data, arguments), - ast::Instruction::Div { .. } => todo!(), + ast::Instruction::Div { data, arguments } => self.emit_div(data, arguments), ast::Instruction::Neg { .. } => todo!(), ast::Instruction::Sin { .. } => todo!(), ast::Instruction::Cos { data, arguments } => self.emit_cos(data, arguments), @@ -626,7 +632,7 @@ impl<'a> MethodEmitContext<'a> { }); Ok(()) } else { - let conversion_fn = if from_type.kind() == ast::ScalarKind::Signed + let _conversion_fn = if from_type.kind() == ast::ScalarKind::Signed && to_type.kind() == ast::ScalarKind::Signed { if to_type.size_of() >= from_type.size_of() { @@ -1086,6 +1092,147 @@ impl<'a> MethodEmitContext<'a> { } Ok(()) } + + fn emit_div( + &mut self, + data: ptx_parser::DivDetails, + arguments: ptx_parser::DivArgs, + ) -> Result<(), TranslateError> { + let integer_div = match data { + ptx_parser::DivDetails::Unsigned(_) => LLVMBuildUDiv, + ptx_parser::DivDetails::Signed(_) => LLVMBuildSDiv, + ptx_parser::DivDetails::Float(float_div) => { + return self.emit_div_float(float_div, arguments) + } + }; + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + integer_div(self.builder, src1, src2, dst) + }); + Ok(()) + } + + fn emit_div_float( + &mut self, + float_div: ptx_parser::DivFloatDetails, + arguments: ptx_parser::DivArgs, + ) -> Result<(), TranslateError> { + let builder = self.builder; + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + let _rnd = match float_div.kind { + ptx_parser::DivFloatKind::Approx => ast::RoundingMode::NearestEven, + ptx_parser::DivFloatKind::ApproxFull => ast::RoundingMode::NearestEven, + ptx_parser::DivFloatKind::Rounding(rounding_mode) => rounding_mode, + }; + let approx = match float_div.kind { + ptx_parser::DivFloatKind::Approx => { + LLVMZludaFastMathAllowReciprocal | LLVMZludaFastMathApproxFunc + } + ptx_parser::DivFloatKind::ApproxFull => LLVMZludaFastMathNone, + ptx_parser::DivFloatKind::Rounding(_) => LLVMZludaFastMathNone, + }; + let fdiv = self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildFDiv(builder, src1, src2, dst) + }); + unsafe { LLVMZludaSetFastMathFlags(fdiv, approx) }; + if let ptx_parser::DivFloatKind::ApproxFull = float_div.kind { + // https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-div: + // div.full.f32 implements a relatively fast, full-range approximation that scales + // operands to achieve better accuracy, but is not fully IEEE 754 compliant and does not + // support rounding modifiers. The maximum ulp error is 2 across the full range of + // inputs. + // https://llvm.org/docs/LangRef.html#fpmath-metadata + let fpmath_value = + unsafe { LLVMConstReal(get_scalar_type(self.context, ast::ScalarType::F32), 2.0) }; + let fpmath_value = unsafe { LLVMValueAsMetadata(fpmath_value) }; + let mut md_node_content = [fpmath_value]; + let md_node = unsafe { + LLVMMDNodeInContext2( + self.context, + md_node_content.as_mut_ptr(), + md_node_content.len(), + ) + }; + let md_node = unsafe { LLVMMetadataAsValue(self.context, md_node) }; + let kind = unsafe { + LLVMGetMDKindIDInContext( + self.context, + "fpmath".as_ptr().cast(), + "fpmath".len() as u32, + ) + }; + unsafe { LLVMSetMetadata(fdiv, kind, md_node) }; + } + Ok(()) + } + + /* + // Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding` + // Should be available in LLVM 19 + fn with_rounding(&mut self, rnd: ast::RoundingMode, fn_: impl FnOnce(&mut Self) -> T) -> T { + let mut u32_type = get_scalar_type(self.context, ast::ScalarType::U32); + let void_type = unsafe { LLVMVoidTypeInContext(self.context) }; + let get_rounding = c"llvm.get.rounding"; + let get_rounding_fn_type = unsafe { LLVMFunctionType(u32_type, ptr::null_mut(), 0, 0) }; + let mut get_rounding_fn = + unsafe { LLVMGetNamedFunction(self.module, get_rounding.as_ptr()) }; + if get_rounding_fn == ptr::null_mut() { + get_rounding_fn = unsafe { + LLVMAddFunction(self.module, get_rounding.as_ptr(), get_rounding_fn_type) + }; + } + let set_rounding = c"llvm.set.rounding"; + let set_rounding_fn_type = unsafe { LLVMFunctionType(void_type, &mut u32_type, 1, 0) }; + let mut set_rounding_fn = + unsafe { LLVMGetNamedFunction(self.module, set_rounding.as_ptr()) }; + if set_rounding_fn == ptr::null_mut() { + set_rounding_fn = unsafe { + LLVMAddFunction(self.module, set_rounding.as_ptr(), set_rounding_fn_type) + }; + } + let mut preserved_rounding_mode = unsafe { + LLVMBuildCall2( + self.builder, + get_rounding_fn_type, + get_rounding_fn, + ptr::null_mut(), + 0, + LLVM_UNNAMED.as_ptr(), + ) + }; + let mut requested_rounding = unsafe { + LLVMConstInt( + get_scalar_type(self.context, ast::ScalarType::B32), + rounding_to_llvm(rnd) as u64, + 0, + ) + }; + unsafe { + LLVMBuildCall2( + self.builder, + set_rounding_fn_type, + set_rounding_fn, + &mut requested_rounding, + 1, + LLVM_UNNAMED.as_ptr(), + ) + }; + let result = fn_(self); + unsafe { + LLVMBuildCall2( + self.builder, + set_rounding_fn_type, + set_rounding_fn, + &mut preserved_rounding_mode, + 1, + LLVM_UNNAMED.as_ptr(), + ) + }; + result + } + */ } fn get_pointer_type<'ctx>( @@ -1279,3 +1426,36 @@ impl ResolveIdent { } } } + +/* +struct ScalarTypeInLLVM(ast::ScalarType); + +impl std::fmt::Display for ScalarTypeInLLVM { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.0 { + ast::ScalarType::Pred => write!(f, "i1"), + ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => write!(f, "i8"), + ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => write!(f, "i16"), + ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => write!(f, "i32"), + ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => write!(f, "i64"), + ptx_parser::ScalarType::B128 => write!(f, "i128"), + ast::ScalarType::F16 => write!(f, "f16"), + ptx_parser::ScalarType::BF16 => write!(f, "bfloat"), + ast::ScalarType::F32 => write!(f, "f32"), + ast::ScalarType::F64 => write!(f, "f64"), + ptx_parser::ScalarType::S16x2 | ptx_parser::ScalarType::U16x2 => write!(f, "v2i16"), + ast::ScalarType::F16x2 => write!(f, "v2f16"), + ptx_parser::ScalarType::BF16x2 => write!(f, "v2bfloat"), + } + } +} + +fn rounding_to_llvm(this: ast::RoundingMode) -> u32 { + match this { + ptx_parser::RoundingMode::Zero => 0, + ptx_parser::RoundingMode::NearestEven => 1, + ptx_parser::RoundingMode::PositiveInf => 2, + ptx_parser::RoundingMode::NegativeInf => 3, + } +} +*/