mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Add mul, cos, xor, some constants
This commit is contained in:
parent
d173828492
commit
56c41b5690
3 changed files with 195 additions and 9 deletions
|
@ -108,6 +108,38 @@ static AtomicOrdering mapFromLLVMOrdering(LLVMAtomicOrdering Ordering)
|
|||
llvm_unreachable("Invalid LLVMAtomicOrdering value!");
|
||||
}
|
||||
|
||||
typedef unsigned LLVMFastMathFlags;
|
||||
|
||||
enum
|
||||
{
|
||||
LLVMFastMathAllowReassoc = (1 << 0),
|
||||
LLVMFastMathNoNaNs = (1 << 1),
|
||||
LLVMFastMathNoInfs = (1 << 2),
|
||||
LLVMFastMathNoSignedZeros = (1 << 3),
|
||||
LLVMFastMathAllowReciprocal = (1 << 4),
|
||||
LLVMFastMathAllowContract = (1 << 5),
|
||||
LLVMFastMathApproxFunc = (1 << 6),
|
||||
LLVMFastMathNone = 0,
|
||||
LLVMFastMathAll = LLVMFastMathAllowReassoc | LLVMFastMathNoNaNs |
|
||||
LLVMFastMathNoInfs | LLVMFastMathNoSignedZeros |
|
||||
LLVMFastMathAllowReciprocal | LLVMFastMathAllowContract |
|
||||
LLVMFastMathApproxFunc,
|
||||
};
|
||||
|
||||
static FastMathFlags mapFromLLVMFastMathFlags(LLVMFastMathFlags FMF)
|
||||
{
|
||||
FastMathFlags NewFMF;
|
||||
NewFMF.setAllowReassoc((FMF & LLVMFastMathAllowReassoc) != 0);
|
||||
NewFMF.setNoNaNs((FMF & LLVMFastMathNoNaNs) != 0);
|
||||
NewFMF.setNoInfs((FMF & LLVMFastMathNoInfs) != 0);
|
||||
NewFMF.setNoSignedZeros((FMF & LLVMFastMathNoSignedZeros) != 0);
|
||||
NewFMF.setAllowReciprocal((FMF & LLVMFastMathAllowReciprocal) != 0);
|
||||
NewFMF.setAllowContract((FMF & LLVMFastMathAllowContract) != 0);
|
||||
NewFMF.setApproxFunc((FMF & LLVMFastMathApproxFunc) != 0);
|
||||
|
||||
return NewFMF;
|
||||
}
|
||||
|
||||
LLVM_C_EXTERN_C_BEGIN
|
||||
|
||||
LLVMValueRef LLVMZludaBuildAlloca(LLVMBuilderRef B, LLVMTypeRef Ty, unsigned AddrSpace,
|
||||
|
@ -145,4 +177,10 @@ LLVMValueRef LLVMZludaBuildAtomicCmpXchg(LLVMBuilderRef B, LLVMValueRef Ptr,
|
|||
context.getOrInsertSyncScopeID(scope)));
|
||||
}
|
||||
|
||||
void LLVMZludaSetFastMathFlags(LLVMValueRef FPMathInst, LLVMFastMathFlags FMF)
|
||||
{
|
||||
Value *P = unwrap<Value>(FPMathInst);
|
||||
cast<Instruction>(P)->setFastMathFlags(mapFromLLVMFastMathFlags(FMF));
|
||||
}
|
||||
|
||||
LLVM_C_EXTERN_C_END
|
|
@ -1,3 +1,4 @@
|
|||
#![allow(non_upper_case_globals)]
|
||||
use llvm_sys::prelude::*;
|
||||
pub use llvm_sys::*;
|
||||
|
||||
|
@ -23,6 +24,25 @@ pub enum LLVMZludaAtomicRMWBinOp {
|
|||
LLVMZludaAtomicRMWBinOpUDecWrap = 16,
|
||||
}
|
||||
|
||||
// Backport from LLVM 19
|
||||
pub const LLVMZludaFastMathAllowReassoc: ::std::ffi::c_uint = 1 << 0;
|
||||
pub const LLVMZludaFastMathNoNaNs: ::std::ffi::c_uint = 1 << 1;
|
||||
pub const LLVMZludaFastMathNoInfs: ::std::ffi::c_uint = 1 << 2;
|
||||
pub const LLVMZludaFastMathNoSignedZeros: ::std::ffi::c_uint = 1 << 3;
|
||||
pub const LLVMZludaFastMathAllowReciprocal: ::std::ffi::c_uint = 1 << 4;
|
||||
pub const LLVMZludaFastMathAllowContract: ::std::ffi::c_uint = 1 << 5;
|
||||
pub const LLVMZludaFastMathApproxFunc: ::std::ffi::c_uint = 1 << 6;
|
||||
pub const LLVMZludaFastMathNone: ::std::ffi::c_uint = 0;
|
||||
pub const LLVMZludaFastMathAll: ::std::ffi::c_uint = LLVMZludaFastMathAllowReassoc
|
||||
| LLVMZludaFastMathNoNaNs
|
||||
| LLVMZludaFastMathNoInfs
|
||||
| LLVMZludaFastMathNoSignedZeros
|
||||
| LLVMZludaFastMathAllowReciprocal
|
||||
| LLVMZludaFastMathAllowContract
|
||||
| LLVMZludaFastMathApproxFunc;
|
||||
|
||||
pub type LLVMZludaFastMathFlags = std::ffi::c_uint;
|
||||
|
||||
extern "C" {
|
||||
pub fn LLVMZludaBuildAlloca(
|
||||
B: LLVMBuilderRef,
|
||||
|
@ -49,4 +69,6 @@ extern "C" {
|
|||
SuccessOrdering: LLVMAtomicOrdering,
|
||||
FailureOrdering: LLVMAtomicOrdering,
|
||||
) -> LLVMValueRef;
|
||||
|
||||
pub fn LLVMZludaSetFastMathFlags(FPMathInst: LLVMValueRef, FMF: LLVMZludaFastMathFlags);
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
// while with plain LLVM-C it's just:
|
||||
// unsafe { LLVMBuildAdd(builder, lhs, rhs, dst) };
|
||||
|
||||
use std::array::TryFromSliceError;
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
use std::ffi::{CStr, NulError};
|
||||
use std::ops::Deref;
|
||||
|
@ -26,10 +27,7 @@ use std::ptr;
|
|||
use super::*;
|
||||
use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule};
|
||||
use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer;
|
||||
use llvm_zluda::{
|
||||
core::*, LLVMAtomicOrdering, LLVMAtomicRMWBinOp, LLVMZludaAtomicRMWBinOp,
|
||||
LLVMZludaBuildAtomicCmpXchg,
|
||||
};
|
||||
use llvm_zluda::{core::*, *};
|
||||
use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW};
|
||||
use llvm_zluda::{LLVMCallConv, LLVMZludaBuildAlloca};
|
||||
|
||||
|
@ -328,10 +326,76 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
|||
unsafe { LLVMSetAlignment(global, align) };
|
||||
}
|
||||
if !var.array_init.is_empty() {
|
||||
todo!()
|
||||
self.emit_array_init(&var.v_type, &*var.array_init, global);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// TODO: instead of Vec<u8> we should emit a typed initializer
|
||||
fn emit_array_init(
|
||||
&mut self,
|
||||
type_: &ast::Type,
|
||||
array_init: &[u8],
|
||||
global: *mut llvm_zluda::LLVMValue,
|
||||
) -> Result<(), TranslateError> {
|
||||
match type_ {
|
||||
ast::Type::Array(None, scalar, dimensions) => {
|
||||
if dimensions.len() != 1 {
|
||||
todo!()
|
||||
}
|
||||
if dimensions[0] as usize * scalar.size_of() as usize != array_init.len() {
|
||||
return Err(error_unreachable());
|
||||
}
|
||||
let type_ = get_scalar_type(self.context, *scalar);
|
||||
let mut elements = array_init
|
||||
.chunks(scalar.size_of() as usize)
|
||||
.map(|chunk| self.constant_from_bytes(*scalar, chunk, type_))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|_| error_unreachable())?;
|
||||
let initializer =
|
||||
unsafe { LLVMConstArray2(type_, elements.as_mut_ptr(), elements.len() as u64) };
|
||||
unsafe { LLVMSetInitializer(global, initializer) };
|
||||
}
|
||||
_ => todo!(),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn constant_from_bytes(
|
||||
&self,
|
||||
scalar: ast::ScalarType,
|
||||
bytes: &[u8],
|
||||
llvm_type: LLVMTypeRef,
|
||||
) -> Result<LLVMValueRef, TryFromSliceError> {
|
||||
Ok(match scalar {
|
||||
ptx_parser::ScalarType::Pred
|
||||
| ptx_parser::ScalarType::S8
|
||||
| ptx_parser::ScalarType::B8
|
||||
| ptx_parser::ScalarType::U8 => unsafe {
|
||||
LLVMConstInt(llvm_type, u8::from_le_bytes(bytes.try_into()?) as u64, 0)
|
||||
},
|
||||
ptx_parser::ScalarType::S16
|
||||
| ptx_parser::ScalarType::B16
|
||||
| ptx_parser::ScalarType::U16 => unsafe {
|
||||
LLVMConstInt(llvm_type, u16::from_le_bytes(bytes.try_into()?) as u64, 0)
|
||||
},
|
||||
ptx_parser::ScalarType::F16 => todo!(),
|
||||
ptx_parser::ScalarType::BF16 => todo!(),
|
||||
ptx_parser::ScalarType::S32 => todo!(),
|
||||
ptx_parser::ScalarType::U64 => todo!(),
|
||||
ptx_parser::ScalarType::S64 => todo!(),
|
||||
ptx_parser::ScalarType::S16x2 => todo!(),
|
||||
ptx_parser::ScalarType::B32 => todo!(),
|
||||
ptx_parser::ScalarType::F32 => todo!(),
|
||||
ptx_parser::ScalarType::B64 => todo!(),
|
||||
ptx_parser::ScalarType::F64 => todo!(),
|
||||
ptx_parser::ScalarType::B128 => todo!(),
|
||||
ptx_parser::ScalarType::U16x2 => todo!(),
|
||||
ptx_parser::ScalarType::F16x2 => todo!(),
|
||||
ptx_parser::ScalarType::U32 => todo!(),
|
||||
ptx_parser::ScalarType::BF16x2 => todo!(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn get_input_argument_type(
|
||||
|
@ -454,7 +518,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
|||
ast::Instruction::Ld { data, arguments } => self.emit_ld(data, arguments),
|
||||
ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments),
|
||||
ast::Instruction::St { data, arguments } => self.emit_st(data, arguments),
|
||||
ast::Instruction::Mul { data, arguments } => todo!(),
|
||||
ast::Instruction::Mul { data, arguments } => self.emit_mul(data, arguments),
|
||||
ast::Instruction::Setp { data, arguments } => todo!(),
|
||||
ast::Instruction::SetpBool { data, arguments } => todo!(),
|
||||
ast::Instruction::Not { data, arguments } => todo!(),
|
||||
|
@ -483,13 +547,13 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
|||
ast::Instruction::Div { data, arguments } => todo!(),
|
||||
ast::Instruction::Neg { data, arguments } => todo!(),
|
||||
ast::Instruction::Sin { data, arguments } => todo!(),
|
||||
ast::Instruction::Cos { data, arguments } => todo!(),
|
||||
ast::Instruction::Cos { data, arguments } => self.emit_cos(data, arguments),
|
||||
ast::Instruction::Lg2 { data, arguments } => todo!(),
|
||||
ast::Instruction::Ex2 { data, arguments } => todo!(),
|
||||
ast::Instruction::Clz { data, arguments } => self.emit_clz(data, arguments),
|
||||
ast::Instruction::Brev { data, arguments } => self.emit_brev(data, arguments),
|
||||
ast::Instruction::Popc { data, arguments } => todo!(),
|
||||
ast::Instruction::Xor { data, arguments } => todo!(),
|
||||
ast::Instruction::Xor { data, arguments } => self.emit_xor(data, arguments),
|
||||
ast::Instruction::Rem { data, arguments } => todo!(),
|
||||
ast::Instruction::PrmtSlow { arguments } => todo!(),
|
||||
ast::Instruction::Prmt { data, arguments } => todo!(),
|
||||
|
@ -865,6 +929,63 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
|||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_mul(
|
||||
&mut self,
|
||||
data: ast::MulDetails,
|
||||
arguments: ast::MulArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let mul_fn = match data {
|
||||
ast::MulDetails::Integer { type_, control } => match control {
|
||||
ast::MulIntControl::Low => LLVMBuildMul,
|
||||
ast::MulIntControl::High => todo!(),
|
||||
ast::MulIntControl::Wide => todo!(),
|
||||
},
|
||||
ast::MulDetails::Float(arith_float) => LLVMBuildFMul,
|
||||
};
|
||||
let src1 = self.resolver.value(arguments.src1)?;
|
||||
let src2 = self.resolver.value(arguments.src2)?;
|
||||
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||
mul_fn(self.builder, src1, src2, dst)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_cos(
|
||||
&mut self,
|
||||
data: ast::FlushToZero,
|
||||
arguments: ast::CosArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let llvm_fn = c"llvm.cos.f32";
|
||||
let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, llvm_fn.as_ptr()) };
|
||||
let fn_type = get_function_type(
|
||||
self.context,
|
||||
iter::once(&ast::ScalarType::F32.into()),
|
||||
iter::once(Ok(get_scalar_type(self.context, ast::ScalarType::F32))),
|
||||
)?;
|
||||
if fn_ == ptr::null_mut() {
|
||||
fn_ = unsafe { LLVMAddFunction(self.module, llvm_fn.as_ptr(), fn_type) };
|
||||
}
|
||||
let mut src = self.resolver.value(arguments.src)?;
|
||||
let cos = self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||
LLVMBuildCall2(self.builder, fn_type, fn_, &mut src, 1, dst)
|
||||
});
|
||||
unsafe { LLVMZludaSetFastMathFlags(cos, LLVMZludaFastMathApproxFunc) }
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_xor(
|
||||
&mut self,
|
||||
_data: ptx_parser::ScalarType,
|
||||
arguments: ptx_parser::XorArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let src1 = self.resolver.value(arguments.src1)?;
|
||||
let src2 = self.resolver.value(arguments.src2)?;
|
||||
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||
LLVMBuildXor(self.builder, src1, src2, dst)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn get_pointer_type<'ctx>(
|
||||
|
@ -1037,8 +1158,13 @@ impl ResolveIdent {
|
|||
.ok_or_else(|| error_unreachable())
|
||||
}
|
||||
|
||||
fn with_result(&mut self, word: SpirvWord, fn_: impl FnOnce(*const i8) -> LLVMValueRef) {
|
||||
fn with_result(
|
||||
&mut self,
|
||||
word: SpirvWord,
|
||||
fn_: impl FnOnce(*const i8) -> LLVMValueRef,
|
||||
) -> LLVMValueRef {
|
||||
let t = self.get_or_ad_impl(word, |dst| fn_(dst.as_ptr().cast()));
|
||||
self.register(word, t);
|
||||
t
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue