ZLUDA/ptx/src/pass/emit_llvm.rs
2024-12-10 21:48:10 +01:00

2509 lines
97 KiB
Rust

// We use Raw LLVM-C bindings here because using inkwell is just not worth it.
// Specifically the issue is with builder functions. We maintain the mapping
// between ZLUDA identifiers and LLVM values. When using inkwell, LLVM values
// are kept as instances `AnyValueEnum`. Now look at the signature of
// `Builder::build_int_add(...)`:
// pub fn build_int_add<T: IntMathValue<'ctx>>(&self, lhs: T, rhs: T, name: &str, ) -> Result<T, BuilderError>
// At this point both lhs and rhs are `AnyValueEnum`. To call
// `build_int_add(...)` we would have to do something like this:
// if let (Ok(lhs), Ok(rhs)) = (lhs.as_int(), rhs.as_int()) {
// builder.build_int_add(lhs, rhs, dst)?;
// } else if let (Ok(lhs), Ok(rhs)) = (lhs.as_pointer(), rhs.as_pointer()) {
// builder.build_int_add(lhs, rhs, dst)?;
// } else if let (Ok(lhs), Ok(rhs)) = (lhs.as_vector(), rhs.as_vector()) {
// builder.build_int_add(lhs, rhs, dst)?;
// } else {
// return Err(error_unrachable());
// }
// while with plain LLVM-C it's just:
// unsafe { LLVMBuildAdd(builder, lhs, rhs, dst) };
// AMDGPU LLVM backend support for llvm.experimental.constrained.* is incomplete.
// Emitting @llvm.experimental.constrained.fdiv.f32(...) makes LLVm fail with
// "LLVM ERROR: unsupported libcall legalization". Running with "-mllvm -print-before-all"
// shows it fails inside amdgpu-isel. You can get a little bit furthr with "-mllvm -global-isel",
// but it will too fail similarly, but with "unable to legalize instruction"
use std::array::TryFromSliceError;
use std::convert::TryInto;
use std::ffi::{CStr, NulError};
use std::ops::Deref;
use std::{i8, ptr};
use super::*;
use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule};
use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer;
use llvm_zluda::{core::*, *};
use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW};
use llvm_zluda::{LLVMCallConv, LLVMZludaBuildAlloca};
const LLVM_UNNAMED: &CStr = c"";
// https://llvm.org/docs/AMDGPUUsage.html#address-spaces
const GENERIC_ADDRESS_SPACE: u32 = 0;
const GLOBAL_ADDRESS_SPACE: u32 = 1;
const SHARED_ADDRESS_SPACE: u32 = 3;
const CONSTANT_ADDRESS_SPACE: u32 = 4;
const PRIVATE_ADDRESS_SPACE: u32 = 5;
struct Context(LLVMContextRef);
impl Context {
fn new() -> Self {
Self(unsafe { LLVMContextCreate() })
}
fn get(&self) -> LLVMContextRef {
self.0
}
}
impl Drop for Context {
fn drop(&mut self) {
unsafe {
LLVMContextDispose(self.0);
}
}
}
struct Module(LLVMModuleRef);
impl Module {
fn new(ctx: &Context, name: &CStr) -> Self {
Self(unsafe { LLVMModuleCreateWithNameInContext(name.as_ptr(), ctx.get()) })
}
fn get(&self) -> LLVMModuleRef {
self.0
}
fn verify(&self) -> Result<(), Message> {
let mut err = ptr::null_mut();
let error = unsafe {
LLVMVerifyModule(
self.get(),
LLVMVerifierFailureAction::LLVMReturnStatusAction,
&mut err,
)
};
if error == 1 && err != ptr::null_mut() {
Err(Message(unsafe { CStr::from_ptr(err) }))
} else {
Ok(())
}
}
fn write_bitcode_to_memory(&self) -> MemoryBuffer {
let memory_buffer = unsafe { LLVMWriteBitcodeToMemoryBuffer(self.get()) };
MemoryBuffer(memory_buffer)
}
}
impl Drop for Module {
fn drop(&mut self) {
unsafe {
LLVMDisposeModule(self.0);
}
}
}
struct Builder(LLVMBuilderRef);
impl Builder {
fn new(ctx: &Context) -> Self {
Self::new_raw(ctx.get())
}
fn new_raw(ctx: LLVMContextRef) -> Self {
Self(unsafe { LLVMCreateBuilderInContext(ctx) })
}
fn get(&self) -> LLVMBuilderRef {
self.0
}
}
impl Drop for Builder {
fn drop(&mut self) {
unsafe {
LLVMDisposeBuilder(self.0);
}
}
}
struct Message(&'static CStr);
impl Drop for Message {
fn drop(&mut self) {
unsafe {
LLVMDisposeMessage(self.0.as_ptr().cast_mut());
}
}
}
impl std::fmt::Debug for Message {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Debug::fmt(&self.0, f)
}
}
pub struct MemoryBuffer(LLVMMemoryBufferRef);
impl Drop for MemoryBuffer {
fn drop(&mut self) {
unsafe {
LLVMDisposeMemoryBuffer(self.0);
}
}
}
impl Deref for MemoryBuffer {
type Target = [u8];
fn deref(&self) -> &Self::Target {
let data = unsafe { LLVMGetBufferStart(self.0) };
let len = unsafe { LLVMGetBufferSize(self.0) };
unsafe { std::slice::from_raw_parts(data.cast(), len) }
}
}
pub(super) fn run<'input>(
id_defs: GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<MemoryBuffer, TranslateError> {
let context = Context::new();
let module = Module::new(&context, LLVM_UNNAMED);
let mut emit_ctx = ModuleEmitContext::new(&context, &module, &id_defs);
for directive in directives {
match directive {
Directive2::Variable(linking, variable) => emit_ctx.emit_global(linking, variable)?,
Directive2::Method(method) => emit_ctx.emit_method(method)?,
}
}
if let Err(err) = module.verify() {
panic!("{:?}", err);
}
Ok(module.write_bitcode_to_memory())
}
struct ModuleEmitContext<'a, 'input> {
context: LLVMContextRef,
module: LLVMModuleRef,
builder: Builder,
id_defs: &'a GlobalStringIdentResolver2<'input>,
resolver: ResolveIdent,
}
impl<'a, 'input> ModuleEmitContext<'a, 'input> {
fn new(
context: &Context,
module: &Module,
id_defs: &'a GlobalStringIdentResolver2<'input>,
) -> Self {
ModuleEmitContext {
context: context.get(),
module: module.get(),
builder: Builder::new(context),
id_defs,
resolver: ResolveIdent::new(&id_defs),
}
}
fn kernel_call_convention() -> u32 {
LLVMCallConv::LLVMAMDGPUKERNELCallConv as u32
}
fn func_call_convention() -> u32 {
LLVMCallConv::LLVMCCallConv as u32
}
fn emit_method(
&mut self,
method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<(), TranslateError> {
let func_decl = method.func_decl;
let name = method
.import_as
.as_deref()
.or_else(|| match func_decl.name {
ast::MethodName::Kernel(name) => Some(name),
ast::MethodName::Func(id) => self.id_defs.ident_map[&id].name.as_deref(),
})
.ok_or_else(|| error_unreachable())?;
let name = CString::new(name).map_err(|_| error_unreachable())?;
let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) };
if fn_ == ptr::null_mut() {
let fn_type = get_function_type(
self.context,
func_decl.return_arguments.iter().map(|v| &v.v_type),
func_decl
.input_arguments
.iter()
.map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)),
)?;
fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
self.emit_fn_attribute(fn_, "amdgpu-unsafe-fp-atomics", "true");
self.emit_fn_attribute(fn_, "uniform-work-group-size", "true");
self.emit_fn_attribute(fn_, "no-trapping-math", "true");
}
if let ast::MethodName::Func(name) = func_decl.name {
self.resolver.register(name, fn_);
}
for (i, param) in func_decl.input_arguments.iter().enumerate() {
let value = unsafe { LLVMGetParam(fn_, i as u32) };
let name = self.resolver.get_or_add(param.name);
unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) };
self.resolver.register(param.name, value);
if func_decl.name.is_kernel() {
let attr_kind = unsafe {
LLVMGetEnumAttributeKindForName(b"byref".as_ptr().cast(), b"byref".len())
};
let attr = unsafe {
LLVMCreateTypeAttribute(
self.context,
attr_kind,
get_type(self.context, &param.v_type)?,
)
};
unsafe { LLVMAddAttributeAtIndex(fn_, i as u32 + 1, attr) };
}
}
let call_conv = if func_decl.name.is_kernel() {
Self::kernel_call_convention()
} else {
Self::func_call_convention()
};
unsafe { LLVMSetFunctionCallConv(fn_, call_conv) };
if let Some(statements) = method.body {
let variables_bb =
unsafe { LLVMAppendBasicBlockInContext(self.context, fn_, LLVM_UNNAMED.as_ptr()) };
let variables_builder = Builder::new_raw(self.context);
unsafe { LLVMPositionBuilderAtEnd(variables_builder.get(), variables_bb) };
let real_bb =
unsafe { LLVMAppendBasicBlockInContext(self.context, fn_, LLVM_UNNAMED.as_ptr()) };
unsafe { LLVMPositionBuilderAtEnd(self.builder.get(), real_bb) };
let mut method_emitter = MethodEmitContext::new(self, fn_, variables_builder);
for var in func_decl.return_arguments {
method_emitter.emit_variable(var)?;
}
for statement in statements.iter() {
if let Statement::Label(label) = statement {
method_emitter.emit_label_initial(*label);
}
}
for statement in statements {
method_emitter.emit_statement(statement)?;
}
unsafe { LLVMBuildBr(method_emitter.variables_builder.get(), real_bb) };
}
Ok(())
}
fn emit_global(
&mut self,
_linking: ast::LinkingDirective,
var: ast::Variable<SpirvWord>,
) -> Result<(), TranslateError> {
let name = self
.id_defs
.ident_map
.get(&var.name)
.map(|entry| {
entry
.name
.as_ref()
.map(|text| Ok::<_, NulError>(Cow::Owned(CString::new(&**text)?)))
})
.flatten()
.transpose()
.map_err(|_| error_unreachable())?
.unwrap_or(Cow::Borrowed(LLVM_UNNAMED));
let global = unsafe {
LLVMAddGlobalInAddressSpace(
self.module,
get_type(self.context, &var.v_type)?,
name.as_ptr(),
get_state_space(var.state_space)?,
)
};
self.resolver.register(var.name, global);
if let Some(align) = var.align {
unsafe { LLVMSetAlignment(global, align) };
}
if !var.array_init.is_empty() {
self.emit_array_init(&var.v_type, &*var.array_init, global)?;
}
Ok(())
}
// TODO: instead of Vec<u8> we should emit a typed initializer
fn emit_array_init(
&mut self,
type_: &ast::Type,
array_init: &[u8],
global: *mut llvm_zluda::LLVMValue,
) -> Result<(), TranslateError> {
match type_ {
ast::Type::Array(None, scalar, dimensions) => {
if dimensions.len() != 1 {
todo!()
}
if dimensions[0] as usize * scalar.size_of() as usize != array_init.len() {
return Err(error_unreachable());
}
let type_ = get_scalar_type(self.context, *scalar);
let mut elements = array_init
.chunks(scalar.size_of() as usize)
.map(|chunk| self.constant_from_bytes(*scalar, chunk, type_))
.collect::<Result<Vec<_>, _>>()
.map_err(|_| error_unreachable())?;
let initializer =
unsafe { LLVMConstArray2(type_, elements.as_mut_ptr(), elements.len() as u64) };
unsafe { LLVMSetInitializer(global, initializer) };
}
_ => todo!(),
}
Ok(())
}
fn constant_from_bytes(
&self,
scalar: ast::ScalarType,
bytes: &[u8],
llvm_type: LLVMTypeRef,
) -> Result<LLVMValueRef, TryFromSliceError> {
Ok(match scalar {
ptx_parser::ScalarType::Pred
| ptx_parser::ScalarType::S8
| ptx_parser::ScalarType::B8
| ptx_parser::ScalarType::U8 => unsafe {
LLVMConstInt(llvm_type, u8::from_le_bytes(bytes.try_into()?) as u64, 0)
},
ptx_parser::ScalarType::S16
| ptx_parser::ScalarType::B16
| ptx_parser::ScalarType::U16 => unsafe {
LLVMConstInt(llvm_type, u16::from_le_bytes(bytes.try_into()?) as u64, 0)
},
ptx_parser::ScalarType::S32
| ptx_parser::ScalarType::B32
| ptx_parser::ScalarType::U32 => unsafe {
LLVMConstInt(llvm_type, u32::from_le_bytes(bytes.try_into()?) as u64, 0)
},
ptx_parser::ScalarType::F16 => todo!(),
ptx_parser::ScalarType::BF16 => todo!(),
ptx_parser::ScalarType::U64 => todo!(),
ptx_parser::ScalarType::S64 => todo!(),
ptx_parser::ScalarType::S16x2 => todo!(),
ptx_parser::ScalarType::F32 => todo!(),
ptx_parser::ScalarType::B64 => todo!(),
ptx_parser::ScalarType::F64 => todo!(),
ptx_parser::ScalarType::B128 => todo!(),
ptx_parser::ScalarType::U16x2 => todo!(),
ptx_parser::ScalarType::F16x2 => todo!(),
ptx_parser::ScalarType::BF16x2 => todo!(),
})
}
fn emit_fn_attribute(&self, llvm_object: LLVMValueRef, key: &str, value: &str) {
let attribute = unsafe {
LLVMCreateStringAttribute(
self.context,
key.as_ptr() as _,
key.len() as u32,
value.as_ptr() as _,
value.len() as u32,
)
};
unsafe { LLVMAddAttributeAtIndex(llvm_object, LLVMAttributeFunctionIndex, attribute) };
}
}
fn get_input_argument_type(
context: LLVMContextRef,
v_type: &ast::Type,
state_space: ast::StateSpace,
) -> Result<LLVMTypeRef, TranslateError> {
match state_space {
ast::StateSpace::ParamEntry => {
Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(state_space)?) })
}
ast::StateSpace::Reg => get_type(context, v_type),
_ => return Err(error_unreachable()),
}
}
struct MethodEmitContext<'a> {
context: LLVMContextRef,
module: LLVMModuleRef,
method: LLVMValueRef,
builder: LLVMBuilderRef,
variables_builder: Builder,
resolver: &'a mut ResolveIdent,
}
impl<'a> MethodEmitContext<'a> {
fn new(
parent: &'a mut ModuleEmitContext,
method: LLVMValueRef,
variables_builder: Builder,
) -> MethodEmitContext<'a> {
MethodEmitContext {
context: parent.context,
module: parent.module,
builder: parent.builder.get(),
variables_builder,
resolver: &mut parent.resolver,
method,
}
}
fn emit_statement(
&mut self,
statement: Statement<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<(), TranslateError> {
Ok(match statement {
Statement::Variable(var) => self.emit_variable(var)?,
Statement::Label(label) => self.emit_label_delayed(label)?,
Statement::Instruction(inst) => self.emit_instruction(inst)?,
Statement::Conditional(cond) => self.emit_conditional(cond)?,
Statement::Conversion(conversion) => self.emit_conversion(conversion)?,
Statement::Constant(constant) => self.emit_constant(constant)?,
Statement::RetValue(_, values) => self.emit_ret_value(values)?,
Statement::PtrAccess(ptr_access) => self.emit_ptr_access(ptr_access)?,
Statement::RepackVector(repack) => self.emit_vector_repack(repack)?,
Statement::FunctionPointer(_) => todo!(),
Statement::VectorRead(vector_read) => self.emit_vector_read(vector_read)?,
Statement::VectorWrite(vector_write) => self.emit_vector_write(vector_write)?,
})
}
fn emit_variable(&mut self, var: ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
let alloca = unsafe {
LLVMZludaBuildAlloca(
self.variables_builder.get(),
get_type(self.context, &var.v_type)?,
get_state_space(var.state_space)?,
self.resolver.get_or_add_raw(var.name),
)
};
self.resolver.register(var.name, alloca);
if let Some(align) = var.align {
unsafe { LLVMSetAlignment(alloca, align) };
}
if !var.array_init.is_empty() {
todo!()
}
Ok(())
}
fn emit_label_initial(&mut self, label: SpirvWord) {
let block = unsafe {
LLVMAppendBasicBlockInContext(
self.context,
self.method,
self.resolver.get_or_add_raw(label),
)
};
self.resolver
.register(label, unsafe { LLVMBasicBlockAsValue(block) });
}
fn emit_label_delayed(&mut self, label: SpirvWord) -> Result<(), TranslateError> {
let block = self.resolver.value(label)?;
let block = unsafe { LLVMValueAsBasicBlock(block) };
let last_block = unsafe { LLVMGetInsertBlock(self.builder) };
if unsafe { LLVMGetBasicBlockTerminator(last_block) } == ptr::null_mut() {
unsafe { LLVMBuildBr(self.builder, block) };
}
unsafe { LLVMPositionBuilderAtEnd(self.builder, block) };
Ok(())
}
fn emit_instruction(
&mut self,
inst: ast::Instruction<SpirvWord>,
) -> Result<(), TranslateError> {
match inst {
ast::Instruction::Mov { data, arguments } => self.emit_mov(data, arguments),
ast::Instruction::Ld { data, arguments } => self.emit_ld(data, arguments),
ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments),
ast::Instruction::St { data, arguments } => self.emit_st(data, arguments),
ast::Instruction::Mul { data, arguments } => self.emit_mul(data, arguments),
ast::Instruction::Setp { data, arguments } => self.emit_setp(data, arguments),
ast::Instruction::SetpBool { .. } => todo!(),
ast::Instruction::Not { data, arguments } => self.emit_not(data, arguments),
ast::Instruction::Or { data, arguments } => self.emit_or(data, arguments),
ast::Instruction::And { arguments, .. } => self.emit_and(arguments),
ast::Instruction::Bra { arguments } => self.emit_bra(arguments),
ast::Instruction::Call { data, arguments } => self.emit_call(data, arguments),
ast::Instruction::Cvt { data, arguments } => self.emit_cvt(data, arguments),
ast::Instruction::Shr { data, arguments } => self.emit_shr(data, arguments),
ast::Instruction::Shl { data, arguments } => self.emit_shl(data, arguments),
ast::Instruction::Ret { data } => Ok(self.emit_ret(data)),
ast::Instruction::Cvta { data, arguments } => self.emit_cvta(data, arguments),
ast::Instruction::Abs { data, arguments } => self.emit_abs(data, arguments),
ast::Instruction::Mad { data, arguments } => self.emit_mad(data, arguments),
ast::Instruction::Fma { data, arguments } => self.emit_fma(data, arguments),
ast::Instruction::Sub { data, arguments } => self.emit_sub(data, arguments),
ast::Instruction::Min { data, arguments } => self.emit_min(data, arguments),
ast::Instruction::Max { data, arguments } => self.emit_max(data, arguments),
ast::Instruction::Rcp { data, arguments } => self.emit_rcp(data, arguments),
ast::Instruction::Sqrt { data, arguments } => self.emit_sqrt(data, arguments),
ast::Instruction::Rsqrt { data, arguments } => self.emit_rsqrt(data, arguments),
ast::Instruction::Selp { data, arguments } => self.emit_selp(data, arguments),
ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments),
ast::Instruction::AtomCas { data, arguments } => self.emit_atom_cas(data, arguments),
ast::Instruction::Div { data, arguments } => self.emit_div(data, arguments),
ast::Instruction::Neg { data, arguments } => self.emit_neg(data, arguments),
ast::Instruction::Sin { data, arguments } => self.emit_sin(data, arguments),
ast::Instruction::Cos { data, arguments } => self.emit_cos(data, arguments),
ast::Instruction::Lg2 { data, arguments } => self.emit_lg2(data, arguments),
ast::Instruction::Ex2 { data, arguments } => self.emit_ex2(data, arguments),
ast::Instruction::Clz { data, arguments } => self.emit_clz(data, arguments),
ast::Instruction::Brev { data, arguments } => self.emit_brev(data, arguments),
ast::Instruction::Popc { data, arguments } => self.emit_popc(data, arguments),
ast::Instruction::Xor { data, arguments } => self.emit_xor(data, arguments),
ast::Instruction::Rem { data, arguments } => self.emit_rem(data, arguments),
ast::Instruction::PrmtSlow { .. } => todo!(),
ast::Instruction::Prmt { data, arguments } => self.emit_prmt(data, arguments),
ast::Instruction::Membar { data } => self.emit_membar(data),
ast::Instruction::Trap {} => todo!(),
// replaced by a function call
ast::Instruction::Bfe { .. }
| ast::Instruction::Bar { .. }
| ast::Instruction::Bfi { .. }
| ast::Instruction::Activemask { .. } => return Err(error_unreachable()),
}
}
fn emit_ld(
&mut self,
data: ast::LdDetails,
arguments: ast::LdArgs<SpirvWord>,
) -> Result<(), TranslateError> {
if data.qualifier != ast::LdStQualifier::Weak {
todo!()
}
let builder = self.builder;
let type_ = get_type(self.context, &data.typ)?;
let ptr = self.resolver.value(arguments.src)?;
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildLoad2(builder, type_, ptr, dst)
});
Ok(())
}
fn emit_conversion(&mut self, conversion: ImplicitConversion) -> Result<(), TranslateError> {
let builder = self.builder;
match conversion.kind {
ConversionKind::Default => self.emit_conversion_default(
self.resolver.value(conversion.src)?,
conversion.dst,
&conversion.from_type,
conversion.from_space,
&conversion.to_type,
conversion.to_space,
),
ConversionKind::SignExtend => {
let src = self.resolver.value(conversion.src)?;
let type_ = get_type(self.context, &conversion.to_type)?;
self.resolver.with_result(conversion.dst, |dst| unsafe {
LLVMBuildSExt(builder, src, type_, dst)
});
Ok(())
}
ConversionKind::BitToPtr => {
let src = self.resolver.value(conversion.src)?;
let type_ = get_pointer_type(self.context, conversion.to_space)?;
self.resolver.with_result(conversion.dst, |dst| unsafe {
LLVMBuildIntToPtr(builder, src, type_, dst)
});
Ok(())
}
ConversionKind::PtrToPtr => {
let src = self.resolver.value(conversion.src)?;
let dst_type = get_pointer_type(self.context, conversion.to_space)?;
self.resolver.with_result(conversion.dst, |dst| unsafe {
LLVMBuildAddrSpaceCast(builder, src, dst_type, dst)
});
Ok(())
}
ConversionKind::AddressOf => {
let src = self.resolver.value(conversion.src)?;
let dst_type = get_type(self.context, &conversion.to_type)?;
self.resolver.with_result(conversion.dst, |dst| unsafe {
LLVMBuildPtrToInt(self.builder, src, dst_type, dst)
});
Ok(())
}
}
}
fn emit_conversion_default(
&mut self,
src: LLVMValueRef,
dst: SpirvWord,
from_type: &ast::Type,
from_space: ast::StateSpace,
to_type: &ast::Type,
to_space: ast::StateSpace,
) -> Result<(), TranslateError> {
match (from_type, to_type) {
(ast::Type::Scalar(from_type), ast::Type::Scalar(to_type_scalar)) => {
let from_layout = from_type.layout();
let to_layout = to_type.layout();
if from_layout.size() == to_layout.size() {
let dst_type = get_type(self.context, &to_type)?;
if from_type.kind() != ast::ScalarKind::Float
&& to_type_scalar.kind() != ast::ScalarKind::Float
{
// It is noop, but another instruction expects result of this conversion
self.resolver.register(dst, src);
} else {
self.resolver.with_result(dst, |dst| unsafe {
LLVMBuildBitCast(self.builder, src, dst_type, dst)
});
}
Ok(())
} else {
// This block is safe because it's illegal to implictly convert between floating point values
let same_width_bit_type = unsafe {
LLVMIntTypeInContext(self.context, (from_layout.size() * 8) as u32)
};
let same_width_bit_value = unsafe {
LLVMBuildBitCast(
self.builder,
src,
same_width_bit_type,
LLVM_UNNAMED.as_ptr(),
)
};
let wide_bit_type = match to_type_scalar.layout().size() {
1 => ast::ScalarType::B8,
2 => ast::ScalarType::B16,
4 => ast::ScalarType::B32,
8 => ast::ScalarType::B64,
_ => return Err(error_unreachable()),
};
let wide_bit_type_llvm = unsafe {
LLVMIntTypeInContext(self.context, (to_layout.size() * 8) as u32)
};
if to_type_scalar.kind() == ast::ScalarKind::Unsigned
|| to_type_scalar.kind() == ast::ScalarKind::Bit
{
let llvm_fn = if to_type_scalar.size_of() >= from_type.size_of() {
LLVMBuildZExtOrBitCast
} else {
LLVMBuildTrunc
};
self.resolver.with_result(dst, |dst| unsafe {
llvm_fn(self.builder, same_width_bit_value, wide_bit_type_llvm, dst)
});
Ok(())
} else {
let conversion_fn = if from_type.kind() == ast::ScalarKind::Signed
&& to_type_scalar.kind() == ast::ScalarKind::Signed
{
if to_type_scalar.size_of() >= from_type.size_of() {
LLVMBuildSExtOrBitCast
} else {
LLVMBuildTrunc
}
} else {
if to_type_scalar.size_of() >= from_type.size_of() {
LLVMBuildZExtOrBitCast
} else {
LLVMBuildTrunc
}
};
let wide_bit_value = unsafe {
conversion_fn(
self.builder,
same_width_bit_value,
wide_bit_type_llvm,
LLVM_UNNAMED.as_ptr(),
)
};
self.emit_conversion_default(
wide_bit_value,
dst,
&wide_bit_type.into(),
from_space,
to_type,
to_space,
)
}
}
}
(ast::Type::Vector(..), ast::Type::Scalar(..))
| (ast::Type::Scalar(..), ast::Type::Array(..))
| (ast::Type::Array(..), ast::Type::Scalar(..)) => {
let dst_type = get_type(self.context, to_type)?;
self.resolver.with_result(dst, |dst| unsafe {
LLVMBuildBitCast(self.builder, src, dst_type, dst)
});
Ok(())
}
_ => todo!(),
}
}
fn emit_constant(&mut self, constant: ConstantDefinition) -> Result<(), TranslateError> {
let type_ = get_scalar_type(self.context, constant.typ);
let value = match constant.value {
ast::ImmediateValue::U64(x) => unsafe { LLVMConstInt(type_, x, 0) },
ast::ImmediateValue::S64(x) => unsafe { LLVMConstInt(type_, x as u64, 0) },
ast::ImmediateValue::F32(x) => unsafe { LLVMConstReal(type_, x as f64) },
ast::ImmediateValue::F64(x) => unsafe { LLVMConstReal(type_, x) },
};
self.resolver.register(constant.dst, value);
Ok(())
}
fn emit_add(
&mut self,
data: ast::ArithDetails,
arguments: ast::AddArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let builder = self.builder;
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
let fn_ = match data {
ast::ArithDetails::Integer(..) => LLVMBuildAdd,
ast::ArithDetails::Float(..) => LLVMBuildFAdd,
};
self.resolver.with_result(arguments.dst, |dst| unsafe {
fn_(builder, src1, src2, dst)
});
Ok(())
}
fn emit_st(
&self,
data: ast::StData,
arguments: ast::StArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let ptr = self.resolver.value(arguments.src1)?;
let value = self.resolver.value(arguments.src2)?;
if data.qualifier != ast::LdStQualifier::Weak {
todo!()
}
unsafe { LLVMBuildStore(self.builder, value, ptr) };
Ok(())
}
fn emit_ret(&self, _data: ast::RetData) {
unsafe { LLVMBuildRetVoid(self.builder) };
}
fn emit_call(
&mut self,
data: ast::CallDetails,
arguments: ast::CallArgs<SpirvWord>,
) -> Result<(), TranslateError> {
if cfg!(debug_assertions) {
for (_, space) in data.return_arguments.iter() {
if *space != ast::StateSpace::Reg {
panic!()
}
}
for (_, space) in data.input_arguments.iter() {
if *space != ast::StateSpace::Reg {
panic!()
}
}
}
let name = match &*arguments.return_arguments {
[] => LLVM_UNNAMED.as_ptr(),
[dst] => self.resolver.get_or_add_raw(*dst),
_ => todo!(),
};
let type_ = get_function_type(
self.context,
data.return_arguments.iter().map(|(type_, ..)| type_),
data.input_arguments
.iter()
.map(|(type_, space)| get_input_argument_type(self.context, &type_, *space)),
)?;
let mut input_arguments = arguments
.input_arguments
.iter()
.map(|arg| self.resolver.value(*arg))
.collect::<Result<Vec<_>, _>>()?;
let llvm_fn = unsafe {
LLVMBuildCall2(
self.builder,
type_,
self.resolver.value(arguments.func)?,
input_arguments.as_mut_ptr(),
input_arguments.len() as u32,
name,
)
};
match &*arguments.return_arguments {
[] => {}
[name] => {
self.resolver.register(*name, llvm_fn);
}
_ => todo!(),
}
Ok(())
}
fn emit_mov(
&mut self,
_data: ast::MovDetails,
arguments: ast::MovArgs<SpirvWord>,
) -> Result<(), TranslateError> {
self.resolver
.register(arguments.dst, self.resolver.value(arguments.src)?);
Ok(())
}
fn emit_ptr_access(&mut self, ptr_access: PtrAccess<SpirvWord>) -> Result<(), TranslateError> {
let ptr_src = self.resolver.value(ptr_access.ptr_src)?;
let mut offset_src = self.resolver.value(ptr_access.offset_src)?;
let pointee_type = get_scalar_type(self.context, ast::ScalarType::B8);
self.resolver.with_result(ptr_access.dst, |dst| unsafe {
LLVMBuildInBoundsGEP2(self.builder, pointee_type, ptr_src, &mut offset_src, 1, dst)
});
Ok(())
}
fn emit_and(&mut self, arguments: ast::AndArgs<SpirvWord>) -> Result<(), TranslateError> {
let builder = self.builder;
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildAnd(builder, src1, src2, dst)
});
Ok(())
}
fn emit_atom(
&mut self,
data: ast::AtomDetails,
arguments: ast::AtomArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let builder = self.builder;
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
let op = match data.op {
ast::AtomicOp::And => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAnd,
ast::AtomicOp::Or => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpOr,
ast::AtomicOp::Xor => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXor,
ast::AtomicOp::Exchange => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXchg,
ast::AtomicOp::Add => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAdd,
ast::AtomicOp::IncrementWrap => {
LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUIncWrap
}
ast::AtomicOp::DecrementWrap => {
LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUDecWrap
}
ast::AtomicOp::SignedMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMin,
ast::AtomicOp::UnsignedMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMin,
ast::AtomicOp::SignedMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMax,
ast::AtomicOp::UnsignedMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMax,
ast::AtomicOp::FloatAdd => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFAdd,
ast::AtomicOp::FloatMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMin,
ast::AtomicOp::FloatMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMax,
};
self.resolver.register(arguments.dst, unsafe {
LLVMZludaBuildAtomicRMW(
builder,
op,
src1,
src2,
get_scope(data.scope)?,
get_ordering(data.semantics),
)
});
Ok(())
}
fn emit_atom_cas(
&mut self,
data: ast::AtomCasDetails,
arguments: ast::AtomCasArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
let src3 = self.resolver.value(arguments.src3)?;
let success_ordering = get_ordering(data.semantics);
let failure_ordering = get_ordering_failure(data.semantics);
let temp = unsafe {
LLVMZludaBuildAtomicCmpXchg(
self.builder,
src1,
src2,
src3,
get_scope(data.scope)?,
success_ordering,
failure_ordering,
)
};
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildExtractValue(self.builder, temp, 0, dst)
});
Ok(())
}
fn emit_bra(&self, arguments: ast::BraArgs<SpirvWord>) -> Result<(), TranslateError> {
let src = self.resolver.value(arguments.src)?;
let src = unsafe { LLVMValueAsBasicBlock(src) };
unsafe { LLVMBuildBr(self.builder, src) };
Ok(())
}
fn emit_brev(
&mut self,
data: ast::ScalarType,
arguments: ast::BrevArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let llvm_fn = match data.size_of() {
4 => c"llvm.bitreverse.i32",
8 => c"llvm.bitreverse.i64",
_ => return Err(error_unreachable()),
};
let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, llvm_fn.as_ptr()) };
let type_ = get_scalar_type(self.context, data);
let fn_type = get_function_type(
self.context,
iter::once(&data.into()),
iter::once(Ok(type_)),
)?;
if fn_ == ptr::null_mut() {
fn_ = unsafe { LLVMAddFunction(self.module, llvm_fn.as_ptr(), fn_type) };
}
let mut src = self.resolver.value(arguments.src)?;
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildCall2(self.builder, fn_type, fn_, &mut src, 1, dst)
});
Ok(())
}
fn emit_ret_value(
&mut self,
values: Vec<(SpirvWord, ptx_parser::Type)>,
) -> Result<(), TranslateError> {
match &*values {
[] => unsafe { LLVMBuildRetVoid(self.builder) },
[(value, type_)] => {
let value = self.resolver.value(*value)?;
let type_ = get_type(self.context, type_)?;
let value =
unsafe { LLVMBuildLoad2(self.builder, type_, value, LLVM_UNNAMED.as_ptr()) };
unsafe { LLVMBuildRet(self.builder, value) }
}
_ => todo!(),
};
Ok(())
}
fn emit_clz(
&mut self,
data: ptx_parser::ScalarType,
arguments: ptx_parser::ClzArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let llvm_fn = match data.size_of() {
4 => c"llvm.ctlz.i32",
8 => c"llvm.ctlz.i64",
_ => return Err(error_unreachable()),
};
let type_ = get_scalar_type(self.context, data.into());
let pred = get_scalar_type(self.context, ast::ScalarType::Pred);
let fn_type = get_function_type(
self.context,
iter::once(&ast::ScalarType::U32.into()),
[Ok(type_), Ok(pred)].into_iter(),
)?;
let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, llvm_fn.as_ptr()) };
if fn_ == ptr::null_mut() {
fn_ = unsafe { LLVMAddFunction(self.module, llvm_fn.as_ptr(), fn_type) };
}
let src = self.resolver.value(arguments.src)?;
let false_ = unsafe { LLVMConstInt(pred, 0, 0) };
let mut args = [src, false_];
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildCall2(
self.builder,
fn_type,
fn_,
args.as_mut_ptr(),
args.len() as u32,
dst,
)
});
Ok(())
}
fn emit_mul(
&mut self,
data: ast::MulDetails,
arguments: ast::MulArgs<SpirvWord>,
) -> Result<(), TranslateError> {
self.emit_mul_impl(data, Some(arguments.dst), arguments.src1, arguments.src2)?;
Ok(())
}
fn emit_mul_impl(
&mut self,
data: ast::MulDetails,
dst: Option<SpirvWord>,
src1: SpirvWord,
src2: SpirvWord,
) -> Result<LLVMValueRef, TranslateError> {
let mul_fn = match data {
ast::MulDetails::Integer { control, type_ } => match control {
ast::MulIntControl::Low => LLVMBuildMul,
ast::MulIntControl::High => return self.emit_mul_high(type_, dst, src1, src2),
ast::MulIntControl::Wide => {
return Ok(self.emit_mul_wide_impl(type_, dst, src1, src2)?.1)
}
},
ast::MulDetails::Float(..) => LLVMBuildFMul,
};
let src1 = self.resolver.value(src1)?;
let src2 = self.resolver.value(src2)?;
Ok(self
.resolver
.with_result_option(dst, |dst| unsafe { mul_fn(self.builder, src1, src2, dst) }))
}
fn emit_mul_high(
&mut self,
type_: ptx_parser::ScalarType,
dst: Option<SpirvWord>,
src1: SpirvWord,
src2: SpirvWord,
) -> Result<LLVMValueRef, TranslateError> {
let (wide_type, wide_value) = self.emit_mul_wide_impl(type_, None, src1, src2)?;
let shift_constant =
unsafe { LLVMConstInt(wide_type, (type_.layout().size() * 8) as u64, 0) };
let shifted = unsafe {
LLVMBuildLShr(
self.builder,
wide_value,
shift_constant,
LLVM_UNNAMED.as_ptr(),
)
};
let narrow_type = get_scalar_type(self.context, type_);
Ok(self.resolver.with_result_option(dst, |dst| unsafe {
LLVMBuildTrunc(self.builder, shifted, narrow_type, dst)
}))
}
fn emit_mul_wide_impl(
&mut self,
type_: ptx_parser::ScalarType,
dst: Option<SpirvWord>,
src1: SpirvWord,
src2: SpirvWord,
) -> Result<(LLVMTypeRef, LLVMValueRef), TranslateError> {
let src1 = self.resolver.value(src1)?;
let src2 = self.resolver.value(src2)?;
let wide_type =
unsafe { LLVMIntTypeInContext(self.context, (type_.layout().size() * 8 * 2) as u32) };
let llvm_cast = match type_.kind() {
ptx_parser::ScalarKind::Signed => LLVMBuildSExt,
ptx_parser::ScalarKind::Unsigned => LLVMBuildZExt,
_ => return Err(error_unreachable()),
};
let src1 = unsafe { llvm_cast(self.builder, src1, wide_type, LLVM_UNNAMED.as_ptr()) };
let src2 = unsafe { llvm_cast(self.builder, src2, wide_type, LLVM_UNNAMED.as_ptr()) };
Ok((
wide_type,
self.resolver.with_result_option(dst, |dst| unsafe {
LLVMBuildMul(self.builder, src1, src2, dst)
}),
))
}
fn emit_cos(
&mut self,
_data: ast::FlushToZero,
arguments: ast::CosArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let llvm_f32 = get_scalar_type(self.context, ast::ScalarType::F32);
let cos = self.emit_intrinsic(
c"llvm.cos.f32",
Some(arguments.dst),
&ast::ScalarType::F32.into(),
vec![(self.resolver.value(arguments.src)?, llvm_f32)],
)?;
unsafe { LLVMZludaSetFastMathFlags(cos, LLVMZludaFastMathApproxFunc) }
Ok(())
}
fn emit_or(
&mut self,
_data: ptx_parser::ScalarType,
arguments: ptx_parser::OrArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildOr(self.builder, src1, src2, dst)
});
Ok(())
}
fn emit_xor(
&mut self,
_data: ptx_parser::ScalarType,
arguments: ptx_parser::XorArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildXor(self.builder, src1, src2, dst)
});
Ok(())
}
fn emit_vector_read(&mut self, vec_acccess: VectorRead) -> Result<(), TranslateError> {
let src = self.resolver.value(vec_acccess.vector_src)?;
let index = unsafe {
LLVMConstInt(
get_scalar_type(self.context, ast::ScalarType::B8),
vec_acccess.member as _,
0,
)
};
self.resolver
.with_result(vec_acccess.scalar_dst, |dst| unsafe {
LLVMBuildExtractElement(self.builder, src, index, dst)
});
Ok(())
}
fn emit_vector_write(&mut self, vector_write: VectorWrite) -> Result<(), TranslateError> {
let vector_src = self.resolver.value(vector_write.vector_src)?;
let scalar_src = self.resolver.value(vector_write.scalar_src)?;
let index = unsafe {
LLVMConstInt(
get_scalar_type(self.context, ast::ScalarType::B8),
vector_write.member as _,
0,
)
};
self.resolver
.with_result(vector_write.vector_dst, |dst| unsafe {
LLVMBuildInsertElement(self.builder, vector_src, scalar_src, index, dst)
});
Ok(())
}
fn emit_vector_repack(&mut self, repack: RepackVectorDetails) -> Result<(), TranslateError> {
let i8_type = get_scalar_type(self.context, ast::ScalarType::B8);
if repack.is_extract {
let src = self.resolver.value(repack.packed)?;
for (index, dst) in repack.unpacked.iter().enumerate() {
let index: *mut LLVMValue = unsafe { LLVMConstInt(i8_type, index as _, 0) };
self.resolver.with_result(*dst, |dst| unsafe {
LLVMBuildExtractElement(self.builder, src, index, dst)
});
}
} else {
let vector_type = get_type(
self.context,
&ast::Type::Vector(repack.unpacked.len() as u8, repack.typ),
)?;
let mut temp_vec = unsafe { LLVMGetUndef(vector_type) };
for (index, src_id) in repack.unpacked.iter().enumerate() {
let dst = if index == repack.unpacked.len() - 1 {
Some(repack.packed)
} else {
None
};
let scalar_src = self.resolver.value(*src_id)?;
let index = unsafe { LLVMConstInt(i8_type, index as _, 0) };
temp_vec = self.resolver.with_result_option(dst, |dst| unsafe {
LLVMBuildInsertElement(self.builder, temp_vec, scalar_src, index, dst)
});
}
}
Ok(())
}
fn emit_div(
&mut self,
data: ptx_parser::DivDetails,
arguments: ptx_parser::DivArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let integer_div = match data {
ptx_parser::DivDetails::Unsigned(_) => LLVMBuildUDiv,
ptx_parser::DivDetails::Signed(_) => LLVMBuildSDiv,
ptx_parser::DivDetails::Float(float_div) => {
return self.emit_div_float(float_div, arguments)
}
};
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
self.resolver.with_result(arguments.dst, |dst| unsafe {
integer_div(self.builder, src1, src2, dst)
});
Ok(())
}
fn emit_div_float(
&mut self,
float_div: ptx_parser::DivFloatDetails,
arguments: ptx_parser::DivArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let builder = self.builder;
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
let _rnd = match float_div.kind {
ptx_parser::DivFloatKind::Approx => ast::RoundingMode::NearestEven,
ptx_parser::DivFloatKind::ApproxFull => ast::RoundingMode::NearestEven,
ptx_parser::DivFloatKind::Rounding(rounding_mode) => rounding_mode,
};
let approx = match float_div.kind {
ptx_parser::DivFloatKind::Approx => {
LLVMZludaFastMathAllowReciprocal | LLVMZludaFastMathApproxFunc
}
ptx_parser::DivFloatKind::ApproxFull => LLVMZludaFastMathNone,
ptx_parser::DivFloatKind::Rounding(_) => LLVMZludaFastMathNone,
};
let fdiv = self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildFDiv(builder, src1, src2, dst)
});
unsafe { LLVMZludaSetFastMathFlags(fdiv, approx) };
if let ptx_parser::DivFloatKind::ApproxFull = float_div.kind {
// https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-div:
// div.full.f32 implements a relatively fast, full-range approximation that scales
// operands to achieve better accuracy, but is not fully IEEE 754 compliant and does not
// support rounding modifiers. The maximum ulp error is 2 across the full range of
// inputs.
// https://llvm.org/docs/LangRef.html#fpmath-metadata
let fpmath_value =
unsafe { LLVMConstReal(get_scalar_type(self.context, ast::ScalarType::F32), 2.0) };
let fpmath_value = unsafe { LLVMValueAsMetadata(fpmath_value) };
let mut md_node_content = [fpmath_value];
let md_node = unsafe {
LLVMMDNodeInContext2(
self.context,
md_node_content.as_mut_ptr(),
md_node_content.len(),
)
};
let md_node = unsafe { LLVMMetadataAsValue(self.context, md_node) };
let kind = unsafe {
LLVMGetMDKindIDInContext(
self.context,
"fpmath".as_ptr().cast(),
"fpmath".len() as u32,
)
};
unsafe { LLVMSetMetadata(fdiv, kind, md_node) };
}
Ok(())
}
fn emit_cvta(
&mut self,
data: ptx_parser::CvtaDetails,
arguments: ptx_parser::CvtaArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let (from_space, to_space) = match data.direction {
ptx_parser::CvtaDirection::GenericToExplicit => {
(ast::StateSpace::Generic, data.state_space)
}
ptx_parser::CvtaDirection::ExplicitToGeneric => {
(data.state_space, ast::StateSpace::Generic)
}
};
let from_type = get_pointer_type(self.context, from_space)?;
let dest_type = get_pointer_type(self.context, to_space)?;
let src = self.resolver.value(arguments.src)?;
let temp_ptr =
unsafe { LLVMBuildIntToPtr(self.builder, src, from_type, LLVM_UNNAMED.as_ptr()) };
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildAddrSpaceCast(self.builder, temp_ptr, dest_type, dst)
});
Ok(())
}
fn emit_sub(
&mut self,
data: ptx_parser::ArithDetails,
arguments: ptx_parser::SubArgs<SpirvWord>,
) -> Result<(), TranslateError> {
match data {
ptx_parser::ArithDetails::Integer(arith_integer) => {
self.emit_sub_integer(arith_integer, arguments)
}
ptx_parser::ArithDetails::Float(arith_float) => {
self.emit_sub_float(arith_float, arguments)
}
}
}
fn emit_sub_integer(
&mut self,
arith_integer: ptx_parser::ArithInteger,
arguments: ptx_parser::SubArgs<SpirvWord>,
) -> Result<(), TranslateError> {
if arith_integer.saturate {
todo!()
}
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildSub(self.builder, src1, src2, dst)
});
Ok(())
}
fn emit_sub_float(
&mut self,
arith_float: ptx_parser::ArithFloat,
arguments: ptx_parser::SubArgs<SpirvWord>,
) -> Result<(), TranslateError> {
if arith_float.saturate {
todo!()
}
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildFSub(self.builder, src1, src2, dst)
});
Ok(())
}
fn emit_sin(
&mut self,
_data: ptx_parser::FlushToZero,
arguments: ptx_parser::SinArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let llvm_f32 = get_scalar_type(self.context, ast::ScalarType::F32);
let sin = self.emit_intrinsic(
c"llvm.sin.f32",
Some(arguments.dst),
&ast::ScalarType::F32.into(),
vec![(self.resolver.value(arguments.src)?, llvm_f32)],
)?;
unsafe { LLVMZludaSetFastMathFlags(sin, LLVMZludaFastMathApproxFunc) }
Ok(())
}
fn emit_intrinsic(
&mut self,
name: &CStr,
dst: Option<SpirvWord>,
return_type: &ast::Type,
arguments: Vec<(LLVMValueRef, LLVMTypeRef)>,
) -> Result<LLVMValueRef, TranslateError> {
let fn_type = get_function_type(
self.context,
iter::once(return_type),
arguments.iter().map(|(_, type_)| Ok(*type_)),
)?;
let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) };
if fn_ == ptr::null_mut() {
fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
}
let mut arguments = arguments.iter().map(|(arg, _)| *arg).collect::<Vec<_>>();
Ok(self.resolver.with_result_option(dst, |dst| unsafe {
LLVMBuildCall2(
self.builder,
fn_type,
fn_,
arguments.as_mut_ptr(),
arguments.len() as u32,
dst,
)
}))
}
fn emit_neg(
&mut self,
data: ptx_parser::TypeFtz,
arguments: ptx_parser::NegArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let src = self.resolver.value(arguments.src)?;
let llvm_fn = if data.type_.kind() == ptx_parser::ScalarKind::Float {
LLVMBuildFNeg
} else {
LLVMBuildNeg
};
self.resolver.with_result(arguments.dst, |dst| unsafe {
llvm_fn(self.builder, src, dst)
});
Ok(())
}
fn emit_not(
&mut self,
_data: ptx_parser::ScalarType,
arguments: ptx_parser::NotArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let src = self.resolver.value(arguments.src)?;
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildNot(self.builder, src, dst)
});
Ok(())
}
fn emit_setp(
&mut self,
data: ptx_parser::SetpData,
arguments: ptx_parser::SetpArgs<SpirvWord>,
) -> Result<(), TranslateError> {
if arguments.dst2.is_some() {
todo!()
}
match data.cmp_op {
ptx_parser::SetpCompareOp::Integer(setp_compare_int) => {
self.emit_setp_int(setp_compare_int, arguments)
}
ptx_parser::SetpCompareOp::Float(setp_compare_float) => {
self.emit_setp_float(setp_compare_float, arguments)
}
}
}
fn emit_setp_int(
&mut self,
setp: ptx_parser::SetpCompareInt,
arguments: ptx_parser::SetpArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let op = match setp {
ptx_parser::SetpCompareInt::Eq => LLVMIntPredicate::LLVMIntEQ,
ptx_parser::SetpCompareInt::NotEq => LLVMIntPredicate::LLVMIntNE,
ptx_parser::SetpCompareInt::UnsignedLess => LLVMIntPredicate::LLVMIntULT,
ptx_parser::SetpCompareInt::UnsignedLessOrEq => LLVMIntPredicate::LLVMIntULE,
ptx_parser::SetpCompareInt::UnsignedGreater => LLVMIntPredicate::LLVMIntUGT,
ptx_parser::SetpCompareInt::UnsignedGreaterOrEq => LLVMIntPredicate::LLVMIntUGE,
ptx_parser::SetpCompareInt::SignedLess => LLVMIntPredicate::LLVMIntSLT,
ptx_parser::SetpCompareInt::SignedLessOrEq => LLVMIntPredicate::LLVMIntSLE,
ptx_parser::SetpCompareInt::SignedGreater => LLVMIntPredicate::LLVMIntSGT,
ptx_parser::SetpCompareInt::SignedGreaterOrEq => LLVMIntPredicate::LLVMIntSGE,
};
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
self.resolver.with_result(arguments.dst1, |dst1| unsafe {
LLVMBuildICmp(self.builder, op, src1, src2, dst1)
});
Ok(())
}
fn emit_setp_float(
&mut self,
setp: ptx_parser::SetpCompareFloat,
arguments: ptx_parser::SetpArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let op = match setp {
ptx_parser::SetpCompareFloat::Eq => LLVMRealPredicate::LLVMRealOEQ,
ptx_parser::SetpCompareFloat::NotEq => LLVMRealPredicate::LLVMRealONE,
ptx_parser::SetpCompareFloat::Less => LLVMRealPredicate::LLVMRealOLT,
ptx_parser::SetpCompareFloat::LessOrEq => LLVMRealPredicate::LLVMRealOLE,
ptx_parser::SetpCompareFloat::Greater => LLVMRealPredicate::LLVMRealOGT,
ptx_parser::SetpCompareFloat::GreaterOrEq => LLVMRealPredicate::LLVMRealOGE,
ptx_parser::SetpCompareFloat::NanEq => LLVMRealPredicate::LLVMRealUEQ,
ptx_parser::SetpCompareFloat::NanNotEq => LLVMRealPredicate::LLVMRealUNE,
ptx_parser::SetpCompareFloat::NanLess => LLVMRealPredicate::LLVMRealULT,
ptx_parser::SetpCompareFloat::NanLessOrEq => LLVMRealPredicate::LLVMRealULE,
ptx_parser::SetpCompareFloat::NanGreater => LLVMRealPredicate::LLVMRealUGT,
ptx_parser::SetpCompareFloat::NanGreaterOrEq => LLVMRealPredicate::LLVMRealUGE,
ptx_parser::SetpCompareFloat::IsNotNan => LLVMRealPredicate::LLVMRealORD,
ptx_parser::SetpCompareFloat::IsAnyNan => LLVMRealPredicate::LLVMRealUNO,
};
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
self.resolver.with_result(arguments.dst1, |dst1| unsafe {
LLVMBuildFCmp(self.builder, op, src1, src2, dst1)
});
Ok(())
}
fn emit_conditional(&mut self, cond: BrachCondition) -> Result<(), TranslateError> {
let predicate = self.resolver.value(cond.predicate)?;
let if_true = self.resolver.value(cond.if_true)?;
let if_false = self.resolver.value(cond.if_false)?;
unsafe {
LLVMBuildCondBr(
self.builder,
predicate,
LLVMValueAsBasicBlock(if_true),
LLVMValueAsBasicBlock(if_false),
)
};
Ok(())
}
fn emit_cvt(
&mut self,
data: ptx_parser::CvtDetails,
arguments: ptx_parser::CvtArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let dst_type = get_scalar_type(self.context, data.to);
let llvm_fn = match data.mode {
ptx_parser::CvtMode::ZeroExtend => LLVMBuildZExt,
ptx_parser::CvtMode::SignExtend => LLVMBuildSExt,
ptx_parser::CvtMode::Truncate => LLVMBuildTrunc,
ptx_parser::CvtMode::Bitcast => LLVMBuildBitCast,
ptx_parser::CvtMode::SaturateUnsignedToSigned => {
return self.emit_cvt_unsigned_to_signed_sat(data.from, data.to, arguments)
}
ptx_parser::CvtMode::SaturateSignedToUnsigned => {
return self.emit_cvt_signed_to_unsigned_sat(data.from, data.to, arguments)
}
ptx_parser::CvtMode::FPExtend { .. } => LLVMBuildFPExt,
ptx_parser::CvtMode::FPTruncate { .. } => LLVMBuildFPTrunc,
ptx_parser::CvtMode::FPRound {
integer_rounding, ..
} => {
return self.emit_cvt_float_to_int(
data.from,
data.to,
integer_rounding.unwrap_or(ast::RoundingMode::NearestEven),
arguments,
Some(LLVMBuildFPToSI),
)
}
ptx_parser::CvtMode::SignedFromFP { rounding, .. } => {
return self.emit_cvt_float_to_int(
data.from,
data.to,
rounding,
arguments,
Some(LLVMBuildFPToSI),
)
}
ptx_parser::CvtMode::UnsignedFromFP { rounding, .. } => {
return self.emit_cvt_float_to_int(
data.from,
data.to,
rounding,
arguments,
Some(LLVMBuildFPToUI),
)
}
ptx_parser::CvtMode::FPFromSigned(_) => {
return self.emit_cvt_int_to_float(data.to, arguments, LLVMBuildSIToFP)
}
ptx_parser::CvtMode::FPFromUnsigned(_) => {
return self.emit_cvt_int_to_float(data.to, arguments, LLVMBuildUIToFP)
}
};
let src = self.resolver.value(arguments.src)?;
self.resolver.with_result(arguments.dst, |dst| unsafe {
llvm_fn(self.builder, src, dst_type, dst)
});
Ok(())
}
fn emit_cvt_unsigned_to_signed_sat(
&mut self,
from: ptx_parser::ScalarType,
to: ptx_parser::ScalarType,
arguments: ptx_parser::CvtArgs<SpirvWord>,
) -> Result<(), TranslateError> {
// This looks dodgy, but it's fine. MAX bit pattern is always 0b11..1,
// so if it's downcast to a smaller type, it will be the maximum value
// of the smaller type
let max_value = match to {
ptx_parser::ScalarType::S8 => i8::MAX as u64,
ptx_parser::ScalarType::S16 => i16::MAX as u64,
ptx_parser::ScalarType::S32 => i32::MAX as u64,
ptx_parser::ScalarType::S64 => i64::MAX as u64,
_ => return Err(error_unreachable()),
};
let from_llvm = get_scalar_type(self.context, from);
let max = unsafe { LLVMConstInt(from_llvm, max_value, 0) };
let clamped = self.emit_intrinsic(
c"llvm.umin",
None,
&from.into(),
vec![
(self.resolver.value(arguments.src)?, from_llvm),
(max, from_llvm),
],
)?;
let resize_fn = if to.layout().size() >= from.layout().size() {
LLVMBuildSExtOrBitCast
} else {
LLVMBuildTrunc
};
let to_llvm = get_scalar_type(self.context, to);
self.resolver.with_result(arguments.dst, |dst| unsafe {
resize_fn(self.builder, clamped, to_llvm, dst)
});
Ok(())
}
fn emit_cvt_signed_to_unsigned_sat(
&mut self,
from: ptx_parser::ScalarType,
to: ptx_parser::ScalarType,
arguments: ptx_parser::CvtArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let from_llvm = get_scalar_type(self.context, from);
let zero = unsafe { LLVMConstInt(from_llvm, 0, 0) };
let zero_clamp_intrinsic = format!("llvm.smax.{}\0", LLVMTypeDisplay(from));
let zero_clamped = self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(zero_clamp_intrinsic.as_bytes()) },
None,
&from.into(),
vec![
(self.resolver.value(arguments.src)?, from_llvm),
(zero, from_llvm),
],
)?;
// zero_clamped is now unsigned
let max_value = match to {
ptx_parser::ScalarType::U8 => u8::MAX as u64,
ptx_parser::ScalarType::U16 => u16::MAX as u64,
ptx_parser::ScalarType::U32 => u32::MAX as u64,
ptx_parser::ScalarType::U64 => u64::MAX as u64,
_ => return Err(error_unreachable()),
};
let max = unsafe { LLVMConstInt(from_llvm, max_value, 0) };
let max_clamp_intrinsic = format!("llvm.umin.{}\0", LLVMTypeDisplay(from));
let fully_clamped = self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(max_clamp_intrinsic.as_bytes()) },
None,
&from.into(),
vec![(zero_clamped, from_llvm), (max, from_llvm)],
)?;
let resize_fn = if to.layout().size() >= from.layout().size() {
LLVMBuildZExtOrBitCast
} else {
LLVMBuildTrunc
};
let to_llvm = get_scalar_type(self.context, to);
self.resolver.with_result(arguments.dst, |dst| unsafe {
resize_fn(self.builder, fully_clamped, to_llvm, dst)
});
Ok(())
}
fn emit_cvt_float_to_int(
&mut self,
from: ast::ScalarType,
to: ast::ScalarType,
rounding: ast::RoundingMode,
arguments: ptx_parser::CvtArgs<SpirvWord>,
llvm_cast: Option<
unsafe extern "C" fn(
arg1: LLVMBuilderRef,
Val: LLVMValueRef,
DestTy: LLVMTypeRef,
Name: *const i8,
) -> LLVMValueRef,
>,
) -> Result<(), TranslateError> {
let prefix = match rounding {
ptx_parser::RoundingMode::NearestEven => "llvm.roundeven",
ptx_parser::RoundingMode::Zero => "llvm.trunc",
ptx_parser::RoundingMode::NegativeInf => "llvm.floor",
ptx_parser::RoundingMode::PositiveInf => "llvm.ceil",
};
let intrinsic = format!("{}.{}\0", prefix, LLVMTypeDisplay(from));
let rounded_float = self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
None,
&from.into(),
vec![(
self.resolver.value(arguments.src)?,
get_scalar_type(self.context, from),
)],
)?;
if let Some(llvm_cast) = llvm_cast {
let to = get_scalar_type(self.context, to);
let poisoned_dst =
unsafe { llvm_cast(self.builder, rounded_float, to, LLVM_UNNAMED.as_ptr()) };
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildFreeze(self.builder, poisoned_dst, dst)
});
} else {
self.resolver.register(arguments.dst, rounded_float);
}
// Using explicit saturation gives us worse codegen: it explicitly checks for out of bound
// values and NaNs. Using non-saturated fptosi/fptoui emits v_cvt_<TO>_<FROM> which
// saturates by default and we don't care about NaNs anyway
/*
let cast_intrinsic = format!(
"{}.{}.{}\0",
llvm_cast,
LLVMTypeDisplay(to),
LLVMTypeDisplay(from)
);
self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(cast_intrinsic.as_bytes()) },
Some(arguments.dst),
&to.into(),
vec![(rounded_float, get_scalar_type(self.context, from))],
)?;
*/
Ok(())
}
fn emit_cvt_int_to_float(
&mut self,
to: ptx_parser::ScalarType,
arguments: ptx_parser::CvtArgs<SpirvWord>,
llvm_func: unsafe extern "C" fn(
arg1: LLVMBuilderRef,
Val: LLVMValueRef,
DestTy: LLVMTypeRef,
Name: *const i8,
) -> LLVMValueRef,
) -> Result<(), TranslateError> {
let type_ = get_scalar_type(self.context, to);
let src = self.resolver.value(arguments.src)?;
self.resolver.with_result(arguments.dst, |dst| unsafe {
llvm_func(self.builder, src, type_, dst)
});
Ok(())
}
fn emit_rsqrt(
&mut self,
data: ptx_parser::TypeFtz,
arguments: ptx_parser::RsqrtArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let type_ = get_scalar_type(self.context, data.type_);
let intrinsic = match data.type_ {
ast::ScalarType::F32 => c"llvm.amdgcn.rsq.f32",
ast::ScalarType::F64 => c"llvm.amdgcn.rsq.f64",
_ => return Err(error_unreachable()),
};
self.emit_intrinsic(
intrinsic,
Some(arguments.dst),
&data.type_.into(),
vec![(self.resolver.value(arguments.src)?, type_)],
)?;
Ok(())
}
fn emit_sqrt(
&mut self,
data: ptx_parser::RcpData,
arguments: ptx_parser::SqrtArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let type_ = get_scalar_type(self.context, data.type_);
let intrinsic = match (data.type_, data.kind) {
(ast::ScalarType::F32, ast::RcpKind::Approx) => c"llvm.amdgcn.sqrt.f32",
(ast::ScalarType::F32, ast::RcpKind::Compliant(..)) => c"llvm.sqrt.f32",
(ast::ScalarType::F64, ast::RcpKind::Compliant(..)) => c"llvm.sqrt.f64",
_ => return Err(error_unreachable()),
};
self.emit_intrinsic(
intrinsic,
Some(arguments.dst),
&data.type_.into(),
vec![(self.resolver.value(arguments.src)?, type_)],
)?;
Ok(())
}
fn emit_rcp(
&mut self,
data: ptx_parser::RcpData,
arguments: ptx_parser::RcpArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let type_ = get_scalar_type(self.context, data.type_);
let intrinsic = match (data.type_, data.kind) {
(ast::ScalarType::F32, ast::RcpKind::Approx) => c"llvm.amdgcn.rcp.f32",
(_, ast::RcpKind::Compliant(rnd)) => {
return self.emit_rcp_compliant(data, arguments, rnd)
}
_ => return Err(error_unreachable()),
};
self.emit_intrinsic(
intrinsic,
Some(arguments.dst),
&data.type_.into(),
vec![(self.resolver.value(arguments.src)?, type_)],
)?;
Ok(())
}
fn emit_rcp_compliant(
&mut self,
data: ptx_parser::RcpData,
arguments: ptx_parser::RcpArgs<SpirvWord>,
_rnd: ast::RoundingMode,
) -> Result<(), TranslateError> {
let type_ = get_scalar_type(self.context, data.type_);
let one = unsafe { LLVMConstReal(type_, 1.0) };
let src = self.resolver.value(arguments.src)?;
let rcp = self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildFDiv(self.builder, one, src, dst)
});
unsafe { LLVMZludaSetFastMathFlags(rcp, LLVMZludaFastMathAllowReciprocal) };
Ok(())
}
fn emit_shr(
&mut self,
data: ptx_parser::ShrData,
arguments: ptx_parser::ShrArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let shift_fn = match data.kind {
ptx_parser::RightShiftKind::Arithmetic => LLVMBuildAShr,
ptx_parser::RightShiftKind::Logical => LLVMBuildLShr,
};
self.emit_shift(
data.type_,
arguments.dst,
arguments.src1,
arguments.src2,
shift_fn,
)
}
fn emit_shl(
&mut self,
type_: ptx_parser::ScalarType,
arguments: ptx_parser::ShlArgs<SpirvWord>,
) -> Result<(), TranslateError> {
self.emit_shift(
type_,
arguments.dst,
arguments.src1,
arguments.src2,
LLVMBuildShl,
)
}
fn emit_shift(
&mut self,
type_: ast::ScalarType,
dst: SpirvWord,
src1: SpirvWord,
src2: SpirvWord,
llvm_fn: unsafe extern "C" fn(
LLVMBuilderRef,
LLVMValueRef,
LLVMValueRef,
*const i8,
) -> LLVMValueRef,
) -> Result<(), TranslateError> {
let src1 = self.resolver.value(src1)?;
let shift_size = self.resolver.value(src2)?;
let integer_bits = type_.layout().size() * 8;
let integer_bits_constant = unsafe {
LLVMConstInt(
get_scalar_type(self.context, ast::ScalarType::U32),
integer_bits as u64,
0,
)
};
let should_clamp = unsafe {
LLVMBuildICmp(
self.builder,
LLVMIntPredicate::LLVMIntUGE,
shift_size,
integer_bits_constant,
LLVM_UNNAMED.as_ptr(),
)
};
let llvm_type = get_scalar_type(self.context, type_);
let zero = unsafe { LLVMConstNull(llvm_type) };
let normalized_shift_size = if type_.layout().size() >= 4 {
unsafe {
LLVMBuildZExtOrBitCast(self.builder, shift_size, llvm_type, LLVM_UNNAMED.as_ptr())
}
} else {
unsafe { LLVMBuildTrunc(self.builder, shift_size, llvm_type, LLVM_UNNAMED.as_ptr()) }
};
let shifted = unsafe {
llvm_fn(
self.builder,
src1,
normalized_shift_size,
LLVM_UNNAMED.as_ptr(),
)
};
self.resolver.with_result(dst, |dst| unsafe {
LLVMBuildSelect(self.builder, should_clamp, zero, shifted, dst)
});
Ok(())
}
fn emit_ex2(
&mut self,
data: ptx_parser::TypeFtz,
arguments: ptx_parser::Ex2Args<SpirvWord>,
) -> Result<(), TranslateError> {
let intrinsic = match data.type_ {
ast::ScalarType::F16 => c"llvm.amdgcn.exp2.f16",
ast::ScalarType::F32 => c"llvm.amdgcn.exp2.f32",
_ => return Err(error_unreachable()),
};
self.emit_intrinsic(
intrinsic,
Some(arguments.dst),
&data.type_.into(),
vec![(
self.resolver.value(arguments.src)?,
get_scalar_type(self.context, data.type_),
)],
)?;
Ok(())
}
fn emit_lg2(
&mut self,
_data: ptx_parser::FlushToZero,
arguments: ptx_parser::Lg2Args<SpirvWord>,
) -> Result<(), TranslateError> {
self.emit_intrinsic(
c"llvm.amdgcn.log.f32",
Some(arguments.dst),
&ast::ScalarType::F32.into(),
vec![(
self.resolver.value(arguments.src)?,
get_scalar_type(self.context, ast::ScalarType::F32.into()),
)],
)?;
Ok(())
}
fn emit_selp(
&mut self,
_data: ptx_parser::ScalarType,
arguments: ptx_parser::SelpArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
let src3 = self.resolver.value(arguments.src3)?;
self.resolver.with_result(arguments.dst, |dst_name| unsafe {
LLVMBuildSelect(self.builder, src3, src1, src2, dst_name)
});
Ok(())
}
fn emit_rem(
&mut self,
data: ptx_parser::ScalarType,
arguments: ptx_parser::RemArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let llvm_fn = match data.kind() {
ptx_parser::ScalarKind::Unsigned => LLVMBuildURem,
ptx_parser::ScalarKind::Signed => LLVMBuildSRem,
_ => return Err(error_unreachable()),
};
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
self.resolver.with_result(arguments.dst, |dst_name| unsafe {
llvm_fn(self.builder, src1, src2, dst_name)
});
Ok(())
}
fn emit_popc(
&mut self,
type_: ptx_parser::ScalarType,
arguments: ptx_parser::PopcArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let intrinsic = match type_ {
ast::ScalarType::B32 => c"llvm.ctpop.i32",
ast::ScalarType::B64 => c"llvm.ctpop.i64",
_ => return Err(error_unreachable()),
};
let llvm_type = get_scalar_type(self.context, type_);
self.emit_intrinsic(
intrinsic,
Some(arguments.dst),
&type_.into(),
vec![(self.resolver.value(arguments.src)?, llvm_type)],
)?;
Ok(())
}
fn emit_min(
&mut self,
data: ptx_parser::MinMaxDetails,
arguments: ptx_parser::MinArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let llvm_prefix = match data {
ptx_parser::MinMaxDetails::Signed(..) => "llvm.smin",
ptx_parser::MinMaxDetails::Unsigned(..) => "llvm.umin",
ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => {
return Err(error_todo())
}
ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.minnum",
};
let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_()));
let llvm_type = get_scalar_type(self.context, data.type_());
self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
Some(arguments.dst),
&data.type_().into(),
vec![
(self.resolver.value(arguments.src1)?, llvm_type),
(self.resolver.value(arguments.src2)?, llvm_type),
],
)?;
Ok(())
}
fn emit_max(
&mut self,
data: ptx_parser::MinMaxDetails,
arguments: ptx_parser::MaxArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let llvm_prefix = match data {
ptx_parser::MinMaxDetails::Signed(..) => "llvm.smax",
ptx_parser::MinMaxDetails::Unsigned(..) => "llvm.umax",
ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => {
return Err(error_todo())
}
ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.maxnum",
};
let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_()));
let llvm_type = get_scalar_type(self.context, data.type_());
self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
Some(arguments.dst),
&data.type_().into(),
vec![
(self.resolver.value(arguments.src1)?, llvm_type),
(self.resolver.value(arguments.src2)?, llvm_type),
],
)?;
Ok(())
}
fn emit_fma(
&mut self,
data: ptx_parser::ArithFloat,
arguments: ptx_parser::FmaArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let intrinsic = format!("llvm.fma.{}\0", LLVMTypeDisplay(data.type_));
self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
Some(arguments.dst),
&data.type_.into(),
vec![
(
self.resolver.value(arguments.src1)?,
get_scalar_type(self.context, data.type_),
),
(
self.resolver.value(arguments.src2)?,
get_scalar_type(self.context, data.type_),
),
(
self.resolver.value(arguments.src3)?,
get_scalar_type(self.context, data.type_),
),
],
)?;
Ok(())
}
fn emit_mad(
&mut self,
data: ptx_parser::MadDetails,
arguments: ptx_parser::MadArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let mul_control = match data {
ptx_parser::MadDetails::Float(mad_float) => {
return self.emit_fma(
mad_float,
ast::FmaArgs {
dst: arguments.dst,
src1: arguments.src1,
src2: arguments.src2,
src3: arguments.src3,
},
)
}
ptx_parser::MadDetails::Integer { saturate: true, .. } => return Err(error_todo()),
ptx_parser::MadDetails::Integer { type_, control, .. } => {
ast::MulDetails::Integer { control, type_ }
}
};
let temp = self.emit_mul_impl(mul_control, None, arguments.src1, arguments.src2)?;
let src3 = self.resolver.value(arguments.src3)?;
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildAdd(self.builder, temp, src3, dst)
});
Ok(())
}
fn emit_membar(&self, data: ptx_parser::MemScope) -> Result<(), TranslateError> {
unsafe {
LLVMZludaBuildFence(
self.builder,
LLVMAtomicOrdering::LLVMAtomicOrderingSequentiallyConsistent,
get_scope_membar(data)?,
LLVM_UNNAMED.as_ptr(),
)
};
Ok(())
}
fn emit_prmt(
&mut self,
control: u16,
arguments: ptx_parser::PrmtArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let components = [
(control >> 0) & 0b1111,
(control >> 4) & 0b1111,
(control >> 8) & 0b1111,
(control >> 12) & 0b1111,
];
if components.iter().any(|&c| c > 7) {
return Err(TranslateError::Todo);
}
let u32_type = get_scalar_type(self.context, ast::ScalarType::U32);
let v4u8_type = get_type(self.context, &ast::Type::Vector(4, ast::ScalarType::U8))?;
let mut components = [
unsafe { LLVMConstInt(u32_type, components[0] as _, 0) },
unsafe { LLVMConstInt(u32_type, components[1] as _, 0) },
unsafe { LLVMConstInt(u32_type, components[2] as _, 0) },
unsafe { LLVMConstInt(u32_type, components[3] as _, 0) },
];
let components_indices =
unsafe { LLVMConstVector(components.as_mut_ptr(), components.len() as u32) };
let src1 = self.resolver.value(arguments.src1)?;
let src1_vector =
unsafe { LLVMBuildBitCast(self.builder, src1, v4u8_type, LLVM_UNNAMED.as_ptr()) };
let src2 = self.resolver.value(arguments.src2)?;
let src2_vector =
unsafe { LLVMBuildBitCast(self.builder, src2, v4u8_type, LLVM_UNNAMED.as_ptr()) };
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildShuffleVector(
self.builder,
src1_vector,
src2_vector,
components_indices,
dst,
)
});
Ok(())
}
fn emit_abs(
&mut self,
data: ast::TypeFtz,
arguments: ptx_parser::AbsArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let llvm_type = get_scalar_type(self.context, data.type_);
let src = self.resolver.value(arguments.src)?;
let (prefix, intrinsic_arguments) = if data.type_.kind() == ast::ScalarKind::Float {
("llvm.fabs", vec![(src, llvm_type)])
} else {
let pred = get_scalar_type(self.context, ast::ScalarType::Pred);
let zero = unsafe { LLVMConstInt(pred, 0, 0) };
("llvm.abs", vec![(src, llvm_type), (zero, pred)])
};
let llvm_intrinsic = format!("{}.{}\0", prefix, LLVMTypeDisplay(data.type_));
self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(llvm_intrinsic.as_bytes()) },
Some(arguments.dst),
&data.type_.into(),
intrinsic_arguments,
)?;
Ok(())
}
/*
// Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
// Should be available in LLVM 19
fn with_rounding<T>(&mut self, rnd: ast::RoundingMode, fn_: impl FnOnce(&mut Self) -> T) -> T {
let mut u32_type = get_scalar_type(self.context, ast::ScalarType::U32);
let void_type = unsafe { LLVMVoidTypeInContext(self.context) };
let get_rounding = c"llvm.get.rounding";
let get_rounding_fn_type = unsafe { LLVMFunctionType(u32_type, ptr::null_mut(), 0, 0) };
let mut get_rounding_fn =
unsafe { LLVMGetNamedFunction(self.module, get_rounding.as_ptr()) };
if get_rounding_fn == ptr::null_mut() {
get_rounding_fn = unsafe {
LLVMAddFunction(self.module, get_rounding.as_ptr(), get_rounding_fn_type)
};
}
let set_rounding = c"llvm.set.rounding";
let set_rounding_fn_type = unsafe { LLVMFunctionType(void_type, &mut u32_type, 1, 0) };
let mut set_rounding_fn =
unsafe { LLVMGetNamedFunction(self.module, set_rounding.as_ptr()) };
if set_rounding_fn == ptr::null_mut() {
set_rounding_fn = unsafe {
LLVMAddFunction(self.module, set_rounding.as_ptr(), set_rounding_fn_type)
};
}
let mut preserved_rounding_mode = unsafe {
LLVMBuildCall2(
self.builder,
get_rounding_fn_type,
get_rounding_fn,
ptr::null_mut(),
0,
LLVM_UNNAMED.as_ptr(),
)
};
let mut requested_rounding = unsafe {
LLVMConstInt(
get_scalar_type(self.context, ast::ScalarType::B32),
rounding_to_llvm(rnd) as u64,
0,
)
};
unsafe {
LLVMBuildCall2(
self.builder,
set_rounding_fn_type,
set_rounding_fn,
&mut requested_rounding,
1,
LLVM_UNNAMED.as_ptr(),
)
};
let result = fn_(self);
unsafe {
LLVMBuildCall2(
self.builder,
set_rounding_fn_type,
set_rounding_fn,
&mut preserved_rounding_mode,
1,
LLVM_UNNAMED.as_ptr(),
)
};
result
}
*/
}
fn get_pointer_type<'ctx>(
context: LLVMContextRef,
to_space: ast::StateSpace,
) -> Result<LLVMTypeRef, TranslateError> {
Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(to_space)?) })
}
// https://llvm.org/docs/AMDGPUUsage.html#memory-scopes
fn get_scope(scope: ast::MemScope) -> Result<*const i8, TranslateError> {
Ok(match scope {
ast::MemScope::Cta => c"workgroup-one-as",
ast::MemScope::Gpu => c"agent-one-as",
ast::MemScope::Sys => c"one-as",
ast::MemScope::Cluster => todo!(),
}
.as_ptr())
}
fn get_scope_membar(scope: ast::MemScope) -> Result<*const i8, TranslateError> {
Ok(match scope {
ast::MemScope::Cta => c"workgroup",
ast::MemScope::Gpu => c"agent",
ast::MemScope::Sys => c"",
ast::MemScope::Cluster => todo!(),
}
.as_ptr())
}
fn get_ordering(semantics: ast::AtomSemantics) -> LLVMAtomicOrdering {
match semantics {
ast::AtomSemantics::Relaxed => LLVMAtomicOrdering::LLVMAtomicOrderingMonotonic,
ast::AtomSemantics::Acquire => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire,
ast::AtomSemantics::Release => LLVMAtomicOrdering::LLVMAtomicOrderingRelease,
ast::AtomSemantics::AcqRel => LLVMAtomicOrdering::LLVMAtomicOrderingAcquireRelease,
}
}
fn get_ordering_failure(semantics: ast::AtomSemantics) -> LLVMAtomicOrdering {
match semantics {
ast::AtomSemantics::Relaxed => LLVMAtomicOrdering::LLVMAtomicOrderingMonotonic,
ast::AtomSemantics::Acquire => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire,
ast::AtomSemantics::Release => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire,
ast::AtomSemantics::AcqRel => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire,
}
}
fn get_type(context: LLVMContextRef, type_: &ast::Type) -> Result<LLVMTypeRef, TranslateError> {
Ok(match type_ {
ast::Type::Scalar(scalar) => get_scalar_type(context, *scalar),
ast::Type::Vector(size, scalar) => {
let base_type = get_scalar_type(context, *scalar);
unsafe { LLVMVectorType(base_type, *size as u32) }
}
ast::Type::Array(vec, scalar, dimensions) => {
let mut underlying_type = get_scalar_type(context, *scalar);
if let Some(size) = vec {
underlying_type = unsafe { LLVMVectorType(underlying_type, size.get() as u32) };
}
if dimensions.is_empty() {
return Ok(unsafe { LLVMArrayType2(underlying_type, 0) });
}
dimensions
.iter()
.rfold(underlying_type, |result, dimension| unsafe {
LLVMArrayType2(result, *dimension as u64)
})
}
ast::Type::Pointer(_, space) => get_pointer_type(context, *space)?,
})
}
fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeRef {
match type_ {
ast::ScalarType::Pred => unsafe { LLVMInt1TypeInContext(context) },
ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => unsafe {
LLVMInt8TypeInContext(context)
},
ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => unsafe {
LLVMInt16TypeInContext(context)
},
ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => unsafe {
LLVMInt32TypeInContext(context)
},
ast::ScalarType::U64 | ast::ScalarType::S64 | ast::ScalarType::B64 => unsafe {
LLVMInt64TypeInContext(context)
},
ast::ScalarType::B128 => unsafe { LLVMInt128TypeInContext(context) },
ast::ScalarType::F16 => unsafe { LLVMHalfTypeInContext(context) },
ast::ScalarType::F32 => unsafe { LLVMFloatTypeInContext(context) },
ast::ScalarType::F64 => unsafe { LLVMDoubleTypeInContext(context) },
ast::ScalarType::BF16 => unsafe { LLVMBFloatTypeInContext(context) },
ast::ScalarType::U16x2 => todo!(),
ast::ScalarType::S16x2 => todo!(),
ast::ScalarType::F16x2 => todo!(),
ast::ScalarType::BF16x2 => todo!(),
}
}
fn get_function_type<'a>(
context: LLVMContextRef,
mut return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
input_args: impl ExactSizeIterator<Item = Result<LLVMTypeRef, TranslateError>>,
) -> Result<LLVMTypeRef, TranslateError> {
let mut input_args = input_args.collect::<Result<Vec<_>, _>>()?;
let return_type = match return_args.len() {
0 => unsafe { LLVMVoidTypeInContext(context) },
1 => get_type(context, return_args.next().unwrap())?,
_ => todo!(),
};
Ok(unsafe {
LLVMFunctionType(
return_type,
input_args.as_mut_ptr(),
input_args.len() as u32,
0,
)
})
}
fn get_state_space(space: ast::StateSpace) -> Result<u32, TranslateError> {
match space {
ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE),
ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE),
ast::StateSpace::Param => Err(TranslateError::Todo),
ast::StateSpace::ParamEntry => Ok(CONSTANT_ADDRESS_SPACE),
ast::StateSpace::ParamFunc => Err(TranslateError::Todo),
ast::StateSpace::Local => Ok(PRIVATE_ADDRESS_SPACE),
ast::StateSpace::Global => Ok(GLOBAL_ADDRESS_SPACE),
ast::StateSpace::Const => Ok(CONSTANT_ADDRESS_SPACE),
ast::StateSpace::Shared => Ok(SHARED_ADDRESS_SPACE),
ast::StateSpace::SharedCta => Err(TranslateError::Todo),
ast::StateSpace::SharedCluster => Err(TranslateError::Todo),
}
}
struct ResolveIdent {
words: HashMap<SpirvWord, String>,
values: HashMap<SpirvWord, LLVMValueRef>,
}
impl ResolveIdent {
fn new<'input>(_id_defs: &GlobalStringIdentResolver2<'input>) -> Self {
ResolveIdent {
words: HashMap::new(),
values: HashMap::new(),
}
}
fn get_or_ad_impl<'a, T>(&'a mut self, word: SpirvWord, fn_: impl FnOnce(&'a str) -> T) -> T {
let str = match self.words.entry(word) {
hash_map::Entry::Occupied(entry) => entry.into_mut(),
hash_map::Entry::Vacant(entry) => {
let mut text = word.0.to_string();
text.push('\0');
entry.insert(text)
}
};
fn_(&str[..str.len() - 1])
}
fn get_or_add(&mut self, word: SpirvWord) -> &str {
self.get_or_ad_impl(word, |x| x)
}
fn get_or_add_raw(&mut self, word: SpirvWord) -> *const i8 {
self.get_or_add(word).as_ptr().cast()
}
fn register(&mut self, word: SpirvWord, v: LLVMValueRef) {
self.values.insert(word, v);
}
fn value(&self, word: SpirvWord) -> Result<LLVMValueRef, TranslateError> {
self.values
.get(&word)
.copied()
.ok_or_else(|| error_unreachable())
}
fn with_result(
&mut self,
word: SpirvWord,
fn_: impl FnOnce(*const i8) -> LLVMValueRef,
) -> LLVMValueRef {
let t = self.get_or_ad_impl(word, |dst| fn_(dst.as_ptr().cast()));
self.register(word, t);
t
}
fn with_result_option(
&mut self,
word: Option<SpirvWord>,
fn_: impl FnOnce(*const i8) -> LLVMValueRef,
) -> LLVMValueRef {
match word {
Some(word) => self.with_result(word, fn_),
None => fn_(LLVM_UNNAMED.as_ptr()),
}
}
}
struct LLVMTypeDisplay(ast::ScalarType);
impl std::fmt::Display for LLVMTypeDisplay {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.0 {
ast::ScalarType::Pred => write!(f, "i1"),
ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => write!(f, "i8"),
ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => write!(f, "i16"),
ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => write!(f, "i32"),
ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => write!(f, "i64"),
ptx_parser::ScalarType::B128 => write!(f, "i128"),
ast::ScalarType::F16 => write!(f, "f16"),
ptx_parser::ScalarType::BF16 => write!(f, "bfloat"),
ast::ScalarType::F32 => write!(f, "f32"),
ast::ScalarType::F64 => write!(f, "f64"),
ptx_parser::ScalarType::S16x2 | ptx_parser::ScalarType::U16x2 => write!(f, "v2i16"),
ast::ScalarType::F16x2 => write!(f, "v2f16"),
ptx_parser::ScalarType::BF16x2 => write!(f, "v2bfloat"),
}
}
}
/*
fn rounding_to_llvm(this: ast::RoundingMode) -> u32 {
match this {
ptx_parser::RoundingMode::Zero => 0,
ptx_parser::RoundingMode::NearestEven => 1,
ptx_parser::RoundingMode::PositiveInf => 2,
ptx_parser::RoundingMode::NegativeInf => 3,
}
}
*/