diff --git a/llvm_zluda/src/lib.cpp b/llvm_zluda/src/lib.cpp index cff2e31..073dba7 100644 --- a/llvm_zluda/src/lib.cpp +++ b/llvm_zluda/src/lib.cpp @@ -108,6 +108,38 @@ static AtomicOrdering mapFromLLVMOrdering(LLVMAtomicOrdering Ordering) llvm_unreachable("Invalid LLVMAtomicOrdering value!"); } +typedef unsigned LLVMFastMathFlags; + +enum +{ + LLVMFastMathAllowReassoc = (1 << 0), + LLVMFastMathNoNaNs = (1 << 1), + LLVMFastMathNoInfs = (1 << 2), + LLVMFastMathNoSignedZeros = (1 << 3), + LLVMFastMathAllowReciprocal = (1 << 4), + LLVMFastMathAllowContract = (1 << 5), + LLVMFastMathApproxFunc = (1 << 6), + LLVMFastMathNone = 0, + LLVMFastMathAll = LLVMFastMathAllowReassoc | LLVMFastMathNoNaNs | + LLVMFastMathNoInfs | LLVMFastMathNoSignedZeros | + LLVMFastMathAllowReciprocal | LLVMFastMathAllowContract | + LLVMFastMathApproxFunc, +}; + +static FastMathFlags mapFromLLVMFastMathFlags(LLVMFastMathFlags FMF) +{ + FastMathFlags NewFMF; + NewFMF.setAllowReassoc((FMF & LLVMFastMathAllowReassoc) != 0); + NewFMF.setNoNaNs((FMF & LLVMFastMathNoNaNs) != 0); + NewFMF.setNoInfs((FMF & LLVMFastMathNoInfs) != 0); + NewFMF.setNoSignedZeros((FMF & LLVMFastMathNoSignedZeros) != 0); + NewFMF.setAllowReciprocal((FMF & LLVMFastMathAllowReciprocal) != 0); + NewFMF.setAllowContract((FMF & LLVMFastMathAllowContract) != 0); + NewFMF.setApproxFunc((FMF & LLVMFastMathApproxFunc) != 0); + + return NewFMF; +} + LLVM_C_EXTERN_C_BEGIN LLVMValueRef LLVMZludaBuildAlloca(LLVMBuilderRef B, LLVMTypeRef Ty, unsigned AddrSpace, @@ -145,4 +177,10 @@ LLVMValueRef LLVMZludaBuildAtomicCmpXchg(LLVMBuilderRef B, LLVMValueRef Ptr, context.getOrInsertSyncScopeID(scope))); } +void LLVMZludaSetFastMathFlags(LLVMValueRef FPMathInst, LLVMFastMathFlags FMF) +{ + Value *P = unwrap(FPMathInst); + cast(P)->setFastMathFlags(mapFromLLVMFastMathFlags(FMF)); +} + LLVM_C_EXTERN_C_END \ No newline at end of file diff --git a/llvm_zluda/src/lib.rs b/llvm_zluda/src/lib.rs index 3d941fa..afcfd89 100644 --- a/llvm_zluda/src/lib.rs +++ b/llvm_zluda/src/lib.rs @@ -1,3 +1,4 @@ +#![allow(non_upper_case_globals)] use llvm_sys::prelude::*; pub use llvm_sys::*; @@ -23,6 +24,25 @@ pub enum LLVMZludaAtomicRMWBinOp { LLVMZludaAtomicRMWBinOpUDecWrap = 16, } +// Backport from LLVM 19 +pub const LLVMZludaFastMathAllowReassoc: ::std::ffi::c_uint = 1 << 0; +pub const LLVMZludaFastMathNoNaNs: ::std::ffi::c_uint = 1 << 1; +pub const LLVMZludaFastMathNoInfs: ::std::ffi::c_uint = 1 << 2; +pub const LLVMZludaFastMathNoSignedZeros: ::std::ffi::c_uint = 1 << 3; +pub const LLVMZludaFastMathAllowReciprocal: ::std::ffi::c_uint = 1 << 4; +pub const LLVMZludaFastMathAllowContract: ::std::ffi::c_uint = 1 << 5; +pub const LLVMZludaFastMathApproxFunc: ::std::ffi::c_uint = 1 << 6; +pub const LLVMZludaFastMathNone: ::std::ffi::c_uint = 0; +pub const LLVMZludaFastMathAll: ::std::ffi::c_uint = LLVMZludaFastMathAllowReassoc + | LLVMZludaFastMathNoNaNs + | LLVMZludaFastMathNoInfs + | LLVMZludaFastMathNoSignedZeros + | LLVMZludaFastMathAllowReciprocal + | LLVMZludaFastMathAllowContract + | LLVMZludaFastMathApproxFunc; + +pub type LLVMZludaFastMathFlags = std::ffi::c_uint; + extern "C" { pub fn LLVMZludaBuildAlloca( B: LLVMBuilderRef, @@ -49,4 +69,6 @@ extern "C" { SuccessOrdering: LLVMAtomicOrdering, FailureOrdering: LLVMAtomicOrdering, ) -> LLVMValueRef; + + pub fn LLVMZludaSetFastMathFlags(FPMathInst: LLVMValueRef, FMF: LLVMZludaFastMathFlags); } diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 5271157..a2b2638 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -18,6 +18,7 @@ // while with plain LLVM-C it's just: // unsafe { LLVMBuildAdd(builder, lhs, rhs, dst) }; +use std::array::TryFromSliceError; use std::convert::{TryFrom, TryInto}; use std::ffi::{CStr, NulError}; use std::ops::Deref; @@ -26,10 +27,7 @@ use std::ptr; use super::*; use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule}; use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer; -use llvm_zluda::{ - core::*, LLVMAtomicOrdering, LLVMAtomicRMWBinOp, LLVMZludaAtomicRMWBinOp, - LLVMZludaBuildAtomicCmpXchg, -}; +use llvm_zluda::{core::*, *}; use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW}; use llvm_zluda::{LLVMCallConv, LLVMZludaBuildAlloca}; @@ -328,10 +326,76 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { unsafe { LLVMSetAlignment(global, align) }; } if !var.array_init.is_empty() { - todo!() + self.emit_array_init(&var.v_type, &*var.array_init, global); } Ok(()) } + + // TODO: instead of Vec we should emit a typed initializer + fn emit_array_init( + &mut self, + type_: &ast::Type, + array_init: &[u8], + global: *mut llvm_zluda::LLVMValue, + ) -> Result<(), TranslateError> { + match type_ { + ast::Type::Array(None, scalar, dimensions) => { + if dimensions.len() != 1 { + todo!() + } + if dimensions[0] as usize * scalar.size_of() as usize != array_init.len() { + return Err(error_unreachable()); + } + let type_ = get_scalar_type(self.context, *scalar); + let mut elements = array_init + .chunks(scalar.size_of() as usize) + .map(|chunk| self.constant_from_bytes(*scalar, chunk, type_)) + .collect::, _>>() + .map_err(|_| error_unreachable())?; + let initializer = + unsafe { LLVMConstArray2(type_, elements.as_mut_ptr(), elements.len() as u64) }; + unsafe { LLVMSetInitializer(global, initializer) }; + } + _ => todo!(), + } + Ok(()) + } + + fn constant_from_bytes( + &self, + scalar: ast::ScalarType, + bytes: &[u8], + llvm_type: LLVMTypeRef, + ) -> Result { + Ok(match scalar { + ptx_parser::ScalarType::Pred + | ptx_parser::ScalarType::S8 + | ptx_parser::ScalarType::B8 + | ptx_parser::ScalarType::U8 => unsafe { + LLVMConstInt(llvm_type, u8::from_le_bytes(bytes.try_into()?) as u64, 0) + }, + ptx_parser::ScalarType::S16 + | ptx_parser::ScalarType::B16 + | ptx_parser::ScalarType::U16 => unsafe { + LLVMConstInt(llvm_type, u16::from_le_bytes(bytes.try_into()?) as u64, 0) + }, + ptx_parser::ScalarType::F16 => todo!(), + ptx_parser::ScalarType::BF16 => todo!(), + ptx_parser::ScalarType::S32 => todo!(), + ptx_parser::ScalarType::U64 => todo!(), + ptx_parser::ScalarType::S64 => todo!(), + ptx_parser::ScalarType::S16x2 => todo!(), + ptx_parser::ScalarType::B32 => todo!(), + ptx_parser::ScalarType::F32 => todo!(), + ptx_parser::ScalarType::B64 => todo!(), + ptx_parser::ScalarType::F64 => todo!(), + ptx_parser::ScalarType::B128 => todo!(), + ptx_parser::ScalarType::U16x2 => todo!(), + ptx_parser::ScalarType::F16x2 => todo!(), + ptx_parser::ScalarType::U32 => todo!(), + ptx_parser::ScalarType::BF16x2 => todo!(), + }) + } } fn get_input_argument_type( @@ -454,7 +518,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { ast::Instruction::Ld { data, arguments } => self.emit_ld(data, arguments), ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments), ast::Instruction::St { data, arguments } => self.emit_st(data, arguments), - ast::Instruction::Mul { data, arguments } => todo!(), + ast::Instruction::Mul { data, arguments } => self.emit_mul(data, arguments), ast::Instruction::Setp { data, arguments } => todo!(), ast::Instruction::SetpBool { data, arguments } => todo!(), ast::Instruction::Not { data, arguments } => todo!(), @@ -483,13 +547,13 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { ast::Instruction::Div { data, arguments } => todo!(), ast::Instruction::Neg { data, arguments } => todo!(), ast::Instruction::Sin { data, arguments } => todo!(), - ast::Instruction::Cos { data, arguments } => todo!(), + ast::Instruction::Cos { data, arguments } => self.emit_cos(data, arguments), ast::Instruction::Lg2 { data, arguments } => todo!(), ast::Instruction::Ex2 { data, arguments } => todo!(), ast::Instruction::Clz { data, arguments } => self.emit_clz(data, arguments), ast::Instruction::Brev { data, arguments } => self.emit_brev(data, arguments), ast::Instruction::Popc { data, arguments } => todo!(), - ast::Instruction::Xor { data, arguments } => todo!(), + ast::Instruction::Xor { data, arguments } => self.emit_xor(data, arguments), ast::Instruction::Rem { data, arguments } => todo!(), ast::Instruction::PrmtSlow { arguments } => todo!(), ast::Instruction::Prmt { data, arguments } => todo!(), @@ -865,6 +929,63 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { }); Ok(()) } + + fn emit_mul( + &mut self, + data: ast::MulDetails, + arguments: ast::MulArgs, + ) -> Result<(), TranslateError> { + let mul_fn = match data { + ast::MulDetails::Integer { type_, control } => match control { + ast::MulIntControl::Low => LLVMBuildMul, + ast::MulIntControl::High => todo!(), + ast::MulIntControl::Wide => todo!(), + }, + ast::MulDetails::Float(arith_float) => LLVMBuildFMul, + }; + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + mul_fn(self.builder, src1, src2, dst) + }); + Ok(()) + } + + fn emit_cos( + &mut self, + data: ast::FlushToZero, + arguments: ast::CosArgs, + ) -> Result<(), TranslateError> { + let llvm_fn = c"llvm.cos.f32"; + let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, llvm_fn.as_ptr()) }; + let fn_type = get_function_type( + self.context, + iter::once(&ast::ScalarType::F32.into()), + iter::once(Ok(get_scalar_type(self.context, ast::ScalarType::F32))), + )?; + if fn_ == ptr::null_mut() { + fn_ = unsafe { LLVMAddFunction(self.module, llvm_fn.as_ptr(), fn_type) }; + } + let mut src = self.resolver.value(arguments.src)?; + let cos = self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildCall2(self.builder, fn_type, fn_, &mut src, 1, dst) + }); + unsafe { LLVMZludaSetFastMathFlags(cos, LLVMZludaFastMathApproxFunc) } + Ok(()) + } + + fn emit_xor( + &mut self, + _data: ptx_parser::ScalarType, + arguments: ptx_parser::XorArgs, + ) -> Result<(), TranslateError> { + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildXor(self.builder, src1, src2, dst) + }); + Ok(()) + } } fn get_pointer_type<'ctx>( @@ -1037,8 +1158,13 @@ impl ResolveIdent { .ok_or_else(|| error_unreachable()) } - fn with_result(&mut self, word: SpirvWord, fn_: impl FnOnce(*const i8) -> LLVMValueRef) { + fn with_result( + &mut self, + word: SpirvWord, + fn_: impl FnOnce(*const i8) -> LLVMValueRef, + ) -> LLVMValueRef { let t = self.get_or_ad_impl(word, |dst| fn_(dst.as_ptr().cast())); self.register(word, t); + t } }