Add mul, cos, xor, some constants

This commit is contained in:
Andrzej Janik 2024-10-06 03:25:02 +02:00
parent d173828492
commit 56c41b5690
3 changed files with 195 additions and 9 deletions

View file

@ -108,6 +108,38 @@ static AtomicOrdering mapFromLLVMOrdering(LLVMAtomicOrdering Ordering)
llvm_unreachable("Invalid LLVMAtomicOrdering value!");
}
typedef unsigned LLVMFastMathFlags;
enum
{
LLVMFastMathAllowReassoc = (1 << 0),
LLVMFastMathNoNaNs = (1 << 1),
LLVMFastMathNoInfs = (1 << 2),
LLVMFastMathNoSignedZeros = (1 << 3),
LLVMFastMathAllowReciprocal = (1 << 4),
LLVMFastMathAllowContract = (1 << 5),
LLVMFastMathApproxFunc = (1 << 6),
LLVMFastMathNone = 0,
LLVMFastMathAll = LLVMFastMathAllowReassoc | LLVMFastMathNoNaNs |
LLVMFastMathNoInfs | LLVMFastMathNoSignedZeros |
LLVMFastMathAllowReciprocal | LLVMFastMathAllowContract |
LLVMFastMathApproxFunc,
};
static FastMathFlags mapFromLLVMFastMathFlags(LLVMFastMathFlags FMF)
{
FastMathFlags NewFMF;
NewFMF.setAllowReassoc((FMF & LLVMFastMathAllowReassoc) != 0);
NewFMF.setNoNaNs((FMF & LLVMFastMathNoNaNs) != 0);
NewFMF.setNoInfs((FMF & LLVMFastMathNoInfs) != 0);
NewFMF.setNoSignedZeros((FMF & LLVMFastMathNoSignedZeros) != 0);
NewFMF.setAllowReciprocal((FMF & LLVMFastMathAllowReciprocal) != 0);
NewFMF.setAllowContract((FMF & LLVMFastMathAllowContract) != 0);
NewFMF.setApproxFunc((FMF & LLVMFastMathApproxFunc) != 0);
return NewFMF;
}
LLVM_C_EXTERN_C_BEGIN
LLVMValueRef LLVMZludaBuildAlloca(LLVMBuilderRef B, LLVMTypeRef Ty, unsigned AddrSpace,
@ -145,4 +177,10 @@ LLVMValueRef LLVMZludaBuildAtomicCmpXchg(LLVMBuilderRef B, LLVMValueRef Ptr,
context.getOrInsertSyncScopeID(scope)));
}
void LLVMZludaSetFastMathFlags(LLVMValueRef FPMathInst, LLVMFastMathFlags FMF)
{
Value *P = unwrap<Value>(FPMathInst);
cast<Instruction>(P)->setFastMathFlags(mapFromLLVMFastMathFlags(FMF));
}
LLVM_C_EXTERN_C_END

View file

@ -1,3 +1,4 @@
#![allow(non_upper_case_globals)]
use llvm_sys::prelude::*;
pub use llvm_sys::*;
@ -23,6 +24,25 @@ pub enum LLVMZludaAtomicRMWBinOp {
LLVMZludaAtomicRMWBinOpUDecWrap = 16,
}
// Backport from LLVM 19
pub const LLVMZludaFastMathAllowReassoc: ::std::ffi::c_uint = 1 << 0;
pub const LLVMZludaFastMathNoNaNs: ::std::ffi::c_uint = 1 << 1;
pub const LLVMZludaFastMathNoInfs: ::std::ffi::c_uint = 1 << 2;
pub const LLVMZludaFastMathNoSignedZeros: ::std::ffi::c_uint = 1 << 3;
pub const LLVMZludaFastMathAllowReciprocal: ::std::ffi::c_uint = 1 << 4;
pub const LLVMZludaFastMathAllowContract: ::std::ffi::c_uint = 1 << 5;
pub const LLVMZludaFastMathApproxFunc: ::std::ffi::c_uint = 1 << 6;
pub const LLVMZludaFastMathNone: ::std::ffi::c_uint = 0;
pub const LLVMZludaFastMathAll: ::std::ffi::c_uint = LLVMZludaFastMathAllowReassoc
| LLVMZludaFastMathNoNaNs
| LLVMZludaFastMathNoInfs
| LLVMZludaFastMathNoSignedZeros
| LLVMZludaFastMathAllowReciprocal
| LLVMZludaFastMathAllowContract
| LLVMZludaFastMathApproxFunc;
pub type LLVMZludaFastMathFlags = std::ffi::c_uint;
extern "C" {
pub fn LLVMZludaBuildAlloca(
B: LLVMBuilderRef,
@ -49,4 +69,6 @@ extern "C" {
SuccessOrdering: LLVMAtomicOrdering,
FailureOrdering: LLVMAtomicOrdering,
) -> LLVMValueRef;
pub fn LLVMZludaSetFastMathFlags(FPMathInst: LLVMValueRef, FMF: LLVMZludaFastMathFlags);
}

View file

@ -18,6 +18,7 @@
// while with plain LLVM-C it's just:
// unsafe { LLVMBuildAdd(builder, lhs, rhs, dst) };
use std::array::TryFromSliceError;
use std::convert::{TryFrom, TryInto};
use std::ffi::{CStr, NulError};
use std::ops::Deref;
@ -26,10 +27,7 @@ use std::ptr;
use super::*;
use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule};
use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer;
use llvm_zluda::{
core::*, LLVMAtomicOrdering, LLVMAtomicRMWBinOp, LLVMZludaAtomicRMWBinOp,
LLVMZludaBuildAtomicCmpXchg,
};
use llvm_zluda::{core::*, *};
use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW};
use llvm_zluda::{LLVMCallConv, LLVMZludaBuildAlloca};
@ -328,10 +326,76 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
unsafe { LLVMSetAlignment(global, align) };
}
if !var.array_init.is_empty() {
todo!()
self.emit_array_init(&var.v_type, &*var.array_init, global);
}
Ok(())
}
// TODO: instead of Vec<u8> we should emit a typed initializer
fn emit_array_init(
&mut self,
type_: &ast::Type,
array_init: &[u8],
global: *mut llvm_zluda::LLVMValue,
) -> Result<(), TranslateError> {
match type_ {
ast::Type::Array(None, scalar, dimensions) => {
if dimensions.len() != 1 {
todo!()
}
if dimensions[0] as usize * scalar.size_of() as usize != array_init.len() {
return Err(error_unreachable());
}
let type_ = get_scalar_type(self.context, *scalar);
let mut elements = array_init
.chunks(scalar.size_of() as usize)
.map(|chunk| self.constant_from_bytes(*scalar, chunk, type_))
.collect::<Result<Vec<_>, _>>()
.map_err(|_| error_unreachable())?;
let initializer =
unsafe { LLVMConstArray2(type_, elements.as_mut_ptr(), elements.len() as u64) };
unsafe { LLVMSetInitializer(global, initializer) };
}
_ => todo!(),
}
Ok(())
}
fn constant_from_bytes(
&self,
scalar: ast::ScalarType,
bytes: &[u8],
llvm_type: LLVMTypeRef,
) -> Result<LLVMValueRef, TryFromSliceError> {
Ok(match scalar {
ptx_parser::ScalarType::Pred
| ptx_parser::ScalarType::S8
| ptx_parser::ScalarType::B8
| ptx_parser::ScalarType::U8 => unsafe {
LLVMConstInt(llvm_type, u8::from_le_bytes(bytes.try_into()?) as u64, 0)
},
ptx_parser::ScalarType::S16
| ptx_parser::ScalarType::B16
| ptx_parser::ScalarType::U16 => unsafe {
LLVMConstInt(llvm_type, u16::from_le_bytes(bytes.try_into()?) as u64, 0)
},
ptx_parser::ScalarType::F16 => todo!(),
ptx_parser::ScalarType::BF16 => todo!(),
ptx_parser::ScalarType::S32 => todo!(),
ptx_parser::ScalarType::U64 => todo!(),
ptx_parser::ScalarType::S64 => todo!(),
ptx_parser::ScalarType::S16x2 => todo!(),
ptx_parser::ScalarType::B32 => todo!(),
ptx_parser::ScalarType::F32 => todo!(),
ptx_parser::ScalarType::B64 => todo!(),
ptx_parser::ScalarType::F64 => todo!(),
ptx_parser::ScalarType::B128 => todo!(),
ptx_parser::ScalarType::U16x2 => todo!(),
ptx_parser::ScalarType::F16x2 => todo!(),
ptx_parser::ScalarType::U32 => todo!(),
ptx_parser::ScalarType::BF16x2 => todo!(),
})
}
}
fn get_input_argument_type(
@ -454,7 +518,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
ast::Instruction::Ld { data, arguments } => self.emit_ld(data, arguments),
ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments),
ast::Instruction::St { data, arguments } => self.emit_st(data, arguments),
ast::Instruction::Mul { data, arguments } => todo!(),
ast::Instruction::Mul { data, arguments } => self.emit_mul(data, arguments),
ast::Instruction::Setp { data, arguments } => todo!(),
ast::Instruction::SetpBool { data, arguments } => todo!(),
ast::Instruction::Not { data, arguments } => todo!(),
@ -483,13 +547,13 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
ast::Instruction::Div { data, arguments } => todo!(),
ast::Instruction::Neg { data, arguments } => todo!(),
ast::Instruction::Sin { data, arguments } => todo!(),
ast::Instruction::Cos { data, arguments } => todo!(),
ast::Instruction::Cos { data, arguments } => self.emit_cos(data, arguments),
ast::Instruction::Lg2 { data, arguments } => todo!(),
ast::Instruction::Ex2 { data, arguments } => todo!(),
ast::Instruction::Clz { data, arguments } => self.emit_clz(data, arguments),
ast::Instruction::Brev { data, arguments } => self.emit_brev(data, arguments),
ast::Instruction::Popc { data, arguments } => todo!(),
ast::Instruction::Xor { data, arguments } => todo!(),
ast::Instruction::Xor { data, arguments } => self.emit_xor(data, arguments),
ast::Instruction::Rem { data, arguments } => todo!(),
ast::Instruction::PrmtSlow { arguments } => todo!(),
ast::Instruction::Prmt { data, arguments } => todo!(),
@ -865,6 +929,63 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
});
Ok(())
}
fn emit_mul(
&mut self,
data: ast::MulDetails,
arguments: ast::MulArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let mul_fn = match data {
ast::MulDetails::Integer { type_, control } => match control {
ast::MulIntControl::Low => LLVMBuildMul,
ast::MulIntControl::High => todo!(),
ast::MulIntControl::Wide => todo!(),
},
ast::MulDetails::Float(arith_float) => LLVMBuildFMul,
};
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
self.resolver.with_result(arguments.dst, |dst| unsafe {
mul_fn(self.builder, src1, src2, dst)
});
Ok(())
}
fn emit_cos(
&mut self,
data: ast::FlushToZero,
arguments: ast::CosArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let llvm_fn = c"llvm.cos.f32";
let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, llvm_fn.as_ptr()) };
let fn_type = get_function_type(
self.context,
iter::once(&ast::ScalarType::F32.into()),
iter::once(Ok(get_scalar_type(self.context, ast::ScalarType::F32))),
)?;
if fn_ == ptr::null_mut() {
fn_ = unsafe { LLVMAddFunction(self.module, llvm_fn.as_ptr(), fn_type) };
}
let mut src = self.resolver.value(arguments.src)?;
let cos = self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildCall2(self.builder, fn_type, fn_, &mut src, 1, dst)
});
unsafe { LLVMZludaSetFastMathFlags(cos, LLVMZludaFastMathApproxFunc) }
Ok(())
}
fn emit_xor(
&mut self,
_data: ptx_parser::ScalarType,
arguments: ptx_parser::XorArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildXor(self.builder, src1, src2, dst)
});
Ok(())
}
}
fn get_pointer_type<'ctx>(
@ -1037,8 +1158,13 @@ impl ResolveIdent {
.ok_or_else(|| error_unreachable())
}
fn with_result(&mut self, word: SpirvWord, fn_: impl FnOnce(*const i8) -> LLVMValueRef) {
fn with_result(
&mut self,
word: SpirvWord,
fn_: impl FnOnce(*const i8) -> LLVMValueRef,
) -> LLVMValueRef {
let t = self.get_or_ad_impl(word, |dst| fn_(dst.as_ptr().cast()));
self.register(word, t);
t
}
}