From 631417b405a3bc21325c31d7a61e3d22eee1b87c Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Thu, 12 Sep 2024 04:37:31 +0200 Subject: [PATCH] Remove inkwell --- .cargo/config.toml | 2 - llvm_zluda/Cargo.toml | 5 - llvm_zluda/src/lib.rs | 23 +- ptx/Cargo.toml | 2 +- ptx/src/pass/emit_llvm.rs | 754 ++++++++++++++++------------------ ptx/src/pass/mod.rs | 7 +- ptx/src/test/spirv_run/mod.rs | 9 +- 7 files changed, 366 insertions(+), 436 deletions(-) diff --git a/.cargo/config.toml b/.cargo/config.toml index f03ab8c..e69de29 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,2 +0,0 @@ -[patch.crates-io] -inkwell = { git = "https://github.com/vosen/inkwell.git", rev = "46027c2afb7e98976438cdcc41a2949dedb60b2e" } diff --git a/llvm_zluda/Cargo.toml b/llvm_zluda/Cargo.toml index 0e7f8a0..b285fc7 100644 --- a/llvm_zluda/Cargo.toml +++ b/llvm_zluda/Cargo.toml @@ -15,8 +15,3 @@ features = [ "disable-alltargets-init", "no-llvm-linking" ] [build-dependencies] cmake = "0.1" cc = "1.0.69" - -[dependencies.inkwell] -version = "0.5" -default-features = false # default features contain all LLVM targets (x86, mips, riscv, etc.) -features = [ "llvm17-0-no-llvm-linking", "no-libffi-linking" ] diff --git a/llvm_zluda/src/lib.rs b/llvm_zluda/src/lib.rs index c72e261..18072a8 100644 --- a/llvm_zluda/src/lib.rs +++ b/llvm_zluda/src/lib.rs @@ -1,15 +1,10 @@ -pub mod inkwell { - pub use inkwell::*; -} -pub mod llvm { - use llvm_sys::prelude::*; - pub use llvm_sys::*; - extern "C" { - pub fn LLVMZludaBuildAlloca( - B: LLVMBuilderRef, - Ty: LLVMTypeRef, - AddrSpace: u32, - Name: *const i8, - ) -> LLVMValueRef; - } +use llvm_sys::prelude::*; +pub use llvm_sys::*; +extern "C" { + pub fn LLVMZludaBuildAlloca( + B: LLVMBuilderRef, + Ty: LLVMTypeRef, + AddrSpace: u32, + Name: *const i8, + ) -> LLVMValueRef; } diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml index 12f47ea..d99a9f6 100644 --- a/ptx/Cargo.toml +++ b/ptx/Cargo.toml @@ -2,7 +2,7 @@ name = "ptx" version = "0.0.0" authors = ["Andrzej Janik "] -edition = "2018" +edition = "2021" [lib] diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 5d1b0cd..44debba 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -1,74 +1,207 @@ +// We use Raw LLVM-C bindings here because using inkwell is just not worth it. +// Specifically the issue is with builder functions. We maintain the mapping +// between ZLUDA identifiers and LLVM values. When using inkwell, LLVM values +// are kept as instances `AnyValueEnum`. Now look at the signature of +// `Builder::build_int_add(...)`: +// pub fn build_int_add>(&self, lhs: T, rhs: T, name: &str, ) -> Result +// At this point both lhs and rhs are `AnyValueEnum`. To call +// `build_int_add(...)` we would have to do something like this: +// if let (Ok(lhs), Ok(rhs)) = (lhs.as_int(), rhs.as_int()) { +// builder.build_int_add(lhs, rhs, dst)?; +// } else if let (Ok(lhs), Ok(rhs)) = (lhs.as_pointer(), rhs.as_pointer()) { +// builder.build_int_add(lhs, rhs, dst)?; +// } else if let (Ok(lhs), Ok(rhs)) = (lhs.as_vector(), rhs.as_vector()) { +// builder.build_int_add(lhs, rhs, dst)?; +// } else { +// return Err(error_unrachable()); +// } +// while with plain LLVM-C it's just: +// unsafe { LLVMBuildAdd(builder, lhs, rhs, dst) }; + use std::convert::{TryFrom, TryInto}; +use std::ffi::CStr; +use std::ops::Deref; use std::ptr; use super::*; -use llvm_zluda::inkwell::builder::{Builder, BuilderError}; -use llvm_zluda::inkwell::context::{AsContextRef, Context}; -use llvm_zluda::inkwell::memory_buffer::MemoryBuffer; -use llvm_zluda::inkwell::types::{ - ArrayType, AsTypeRef, BasicMetadataTypeEnum, BasicType, BasicTypeEnum, FloatType, FunctionType, - IntType, PointerType, VectorType, VoidType, -}; -use llvm_zluda::inkwell::values::{ - AnyValue, AnyValueEnum, ArrayValue, BasicValueEnum, FloatMathValue, FloatValue, FunctionValue, - InstructionValue, IntMathValue, IntValue, PhiValue, PointerValue, StructValue, VectorValue, -}; -use llvm_zluda::inkwell::{self, module, AddressSpace}; -use llvm_zluda::llvm::core::{ - LLVMArrayType2, LLVMBFloatType, LLVMBFloatTypeInContext, LLVMVectorType, -}; -use llvm_zluda::llvm::prelude::*; -use llvm_zluda::llvm::{LLVMCallConv, LLVMZludaBuildAlloca}; +use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule}; +use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer; +use llvm_zluda::core::*; +use llvm_zluda::prelude::*; +use llvm_zluda::{LLVMCallConv, LLVMZludaBuildAlloca}; -const LLVM_UNNAMED: &str = "\0"; +const LLVM_UNNAMED: &CStr = c""; // https://llvm.org/docs/AMDGPUUsage.html#address-spaces -const GENERIC_ADDRESS_SPACE: u16 = 0; -const GLOBAL_ADDRESS_SPACE: u16 = 1; -const SHARED_ADDRESS_SPACE: u16 = 3; -const CONSTANT_ADDRESS_SPACE: u16 = 4; -const PRIVATE_ADDRESS_SPACE: u16 = 5; +const GENERIC_ADDRESS_SPACE: u32 = 0; +const GLOBAL_ADDRESS_SPACE: u32 = 1; +const SHARED_ADDRESS_SPACE: u32 = 3; +const CONSTANT_ADDRESS_SPACE: u32 = 4; +const PRIVATE_ADDRESS_SPACE: u32 = 5; + +struct Context(LLVMContextRef); + +impl Context { + fn new() -> Self { + Self(unsafe { LLVMContextCreate() }) + } + + fn get(&self) -> LLVMContextRef { + self.0 + } +} + +impl Drop for Context { + fn drop(&mut self) { + unsafe { + LLVMContextDispose(self.0); + } + } +} + +struct Module(LLVMModuleRef); + +impl Module { + fn new(ctx: &Context, name: &CStr) -> Self { + Self(unsafe { LLVMModuleCreateWithNameInContext(name.as_ptr(), ctx.get()) }) + } + + fn get(&self) -> LLVMModuleRef { + self.0 + } + + fn verify(&self) -> Result<(), Message> { + let mut err = ptr::null_mut(); + let error = unsafe { + LLVMVerifyModule( + self.get(), + LLVMVerifierFailureAction::LLVMReturnStatusAction, + &mut err, + ) + }; + if error == 1 && err != ptr::null_mut() { + Err(Message(unsafe { CStr::from_ptr(err) })) + } else { + Ok(()) + } + } + + fn write_bitcode_to_memory(&self) -> MemoryBuffer { + let memory_buffer = unsafe { LLVMWriteBitcodeToMemoryBuffer(self.get()) }; + MemoryBuffer(memory_buffer) + } + + fn write_to_stderr(&self) { + unsafe { LLVMDumpModule(self.get()) }; + } +} + +impl Drop for Module { + fn drop(&mut self) { + unsafe { + LLVMDisposeModule(self.0); + } + } +} + +struct Builder(LLVMBuilderRef); + +impl Builder { + fn new(ctx: &Context) -> Self { + Self::new_raw(ctx.get()) + } + + fn new_raw(ctx: LLVMContextRef) -> Self { + Self(unsafe { LLVMCreateBuilderInContext(ctx) }) + } + + fn get(&self) -> LLVMBuilderRef { + self.0 + } +} + +impl Drop for Builder { + fn drop(&mut self) { + unsafe { + LLVMDisposeBuilder(self.0); + } + } +} + +struct Message(&'static CStr); + +impl Drop for Message { + fn drop(&mut self) { + unsafe { + LLVMDisposeMessage(self.0.as_ptr().cast_mut()); + } + } +} + +impl std::fmt::Debug for Message { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Debug::fmt(&self.0, f) + } +} + +pub struct MemoryBuffer(LLVMMemoryBufferRef); + +impl Drop for MemoryBuffer { + fn drop(&mut self) { + unsafe { + LLVMDisposeMemoryBuffer(self.0); + } + } +} + +impl Deref for MemoryBuffer { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + let data = unsafe { LLVMGetBufferStart(self.0) }; + let len = unsafe { LLVMGetBufferSize(self.0) }; + unsafe { std::slice::from_raw_parts(data.cast(), len) } + } +} pub(super) fn run<'input>( id_defs: &GlobalStringIdResolver<'input>, call_map: MethodsCallMap<'input>, directives: Vec>, ) -> Result { - let context = inkwell::context::Context::create(); - let module = context.create_module(LLVM_UNNAMED); - let builder = context.create_builder(); - let mut emit_ctx = ModuleEmitContext::new(&context, module, builder, id_defs); + let context = Context::new(); + let module = Module::new(&context, LLVM_UNNAMED); + let mut emit_ctx = ModuleEmitContext::new(&context, &module, id_defs); for directive in directives { match directive { Directive::Variable(..) => todo!(), Directive::Method(method) => emit_ctx.emit_method(method)?, } } - if let Err(err) = emit_ctx.module.verify() { - emit_ctx.module.print_to_stderr(); - panic!("{}", err); + module.write_to_stderr(); + if let Err(err) = module.verify() { + panic!("{:?}", err); } - Ok(emit_ctx.module.write_bitcode_to_memory()) + Ok(module.write_bitcode_to_memory()) } -struct ModuleEmitContext<'ctx, 'input> { - context: &'ctx Context, - module: module::Module<'ctx>, - builder: Builder<'ctx>, - id_defs: &'ctx GlobalStringIdResolver<'input>, - resolver: ResolveIdent<'ctx>, +struct ModuleEmitContext<'a, 'input> { + context: LLVMContextRef, + module: LLVMModuleRef, + builder: Builder, + id_defs: &'a GlobalStringIdResolver<'input>, + resolver: ResolveIdent, } -impl<'ctx, 'input> ModuleEmitContext<'ctx, 'input> { +impl<'a, 'input> ModuleEmitContext<'a, 'input> { fn new( - context: &'ctx Context, - module: module::Module<'ctx>, - builder: Builder<'ctx>, - id_defs: &'ctx GlobalStringIdResolver<'input>, + context: &Context, + module: &Module, + id_defs: &'a GlobalStringIdResolver<'input>, ) -> Self { ModuleEmitContext { - context: &context, - module, - builder, + context: context.get(), + module: module.get(), + builder: Builder::new(context), id_defs, resolver: ResolveIdent::new(&id_defs), } @@ -84,85 +217,86 @@ impl<'ctx, 'input> ModuleEmitContext<'ctx, 'input> { fn emit_method(&mut self, method: Function<'input>) -> Result<(), TranslateError> { let func_decl = method.func_decl.borrow(); - let fn_ = self.module.add_function( - method - .import_as - .as_deref() - .unwrap_or_else(|| match func_decl.name { - ast::MethodName::Kernel(name) => name, - ast::MethodName::Func(id) => self.id_defs.reverse_variables[&id], - }), - self.function_type( - func_decl.return_arguments.iter().map(|v| &v.v_type), - func_decl.input_arguments.iter().map(|v| &v.v_type), - ), - None, + let name = method + .import_as + .as_deref() + .unwrap_or_else(|| match func_decl.name { + ast::MethodName::Kernel(name) => name, + ast::MethodName::Func(id) => self.id_defs.reverse_variables[&id], + }); + let name = CString::new(name).map_err(|_| error_unreachable())?; + let fn_type = self.function_type( + func_decl.return_arguments.iter().map(|v| &v.v_type), + func_decl.input_arguments.iter().map(|v| &v.v_type), ); + let fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) }; for (i, param) in func_decl.input_arguments.iter().enumerate() { - let value = fn_ - .get_nth_param(i as u32) - .ok_or_else(|| error_unreachable())?; - value.set_name(self.resolver.get_or_add(param.name)); + let value = unsafe { LLVMGetParam(fn_, i as u32) }; + let name = self.resolver.get_or_add(param.name); + unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) }; self.resolver.register(param.name, value); } - fn_.set_call_conventions(if func_decl.name.is_kernel() { + let call_conv = if func_decl.name.is_kernel() { Self::kernel_call_convention() } else { Self::func_call_convention() - }); + }; + unsafe { LLVMSetFunctionCallConv(fn_, call_conv) }; if let Some(statements) = method.body { - let variables_bb = self.context.append_basic_block(fn_, LLVM_UNNAMED); - let variables_builder = self.context.create_builder(); - variables_builder.position_at_end(variables_bb); - let real_bb = self.context.append_basic_block(fn_, LLVM_UNNAMED); - self.builder.position_at_end(real_bb); + let variables_bb = + unsafe { LLVMAppendBasicBlockInContext(self.context, fn_, LLVM_UNNAMED.as_ptr()) }; + let variables_builder = Builder::new_raw(self.context); + unsafe { LLVMPositionBuilderAtEnd(variables_builder.get(), variables_bb) }; + let real_bb = + unsafe { LLVMAppendBasicBlockInContext(self.context, fn_, LLVM_UNNAMED.as_ptr()) }; + unsafe { LLVMPositionBuilderAtEnd(self.builder.get(), real_bb) }; let mut method_emitter = MethodEmitContext::new(self, fn_, variables_builder); for statement in statements { method_emitter.emit_statement(statement)?; } - method_emitter.variables_builder.build_unconditional_branch(real_bb); + unsafe { LLVMBuildBr(method_emitter.variables_builder.get(), real_bb) }; } Ok(()) } - fn function_type<'a>( + fn function_type( &self, return_args: impl ExactSizeIterator, input_args: impl ExactSizeIterator, - ) -> FunctionType<'ctx> { + ) -> LLVMTypeRef { if return_args.len() == 0 { - let input_args = input_args + let mut input_args = input_args .map(|type_| match type_ { ast::Type::Scalar(scalar) => match scalar { ast::ScalarType::Pred => { - BasicMetadataTypeEnum::from(self.context.bool_type()) + unsafe { LLVMInt1TypeInContext(self.context) } } ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => { - BasicMetadataTypeEnum::from(self.context.i8_type()) + unsafe { LLVMInt8TypeInContext(self.context) } } ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => { - BasicMetadataTypeEnum::from(self.context.i16_type()) + unsafe { LLVMInt16TypeInContext(self.context) } } ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => { - BasicMetadataTypeEnum::from(self.context.i32_type()) + unsafe { LLVMInt32TypeInContext(self.context) } } ast::ScalarType::U64 | ast::ScalarType::S64 | ast::ScalarType::B64 => { - BasicMetadataTypeEnum::from(self.context.i64_type()) + unsafe { LLVMInt64TypeInContext(self.context) } } ast::ScalarType::B128 => { - BasicMetadataTypeEnum::from(self.context.i128_type()) + unsafe { LLVMInt128TypeInContext(self.context) } } ast::ScalarType::F16 => { - BasicMetadataTypeEnum::from(self.context.f16_type()) + unsafe { LLVMHalfTypeInContext(self.context) } } ast::ScalarType::F32 => { - BasicMetadataTypeEnum::from(self.context.f32_type()) + unsafe { LLVMFloatTypeInContext(self.context) } } ast::ScalarType::F64 => { - BasicMetadataTypeEnum::from(self.context.f64_type()) + unsafe { LLVMDoubleTypeInContext(self.context) } } ast::ScalarType::BF16 => { - BasicMetadataTypeEnum::from(unsafe { FloatType::new(LLVMBFloatType()) }) + unsafe { LLVMBFloatTypeInContext(self.context) } } ast::ScalarType::U16x2 => todo!(), ast::ScalarType::S16x2 => todo!(), @@ -174,41 +308,39 @@ impl<'ctx, 'input> ModuleEmitContext<'ctx, 'input> { ast::Type::Pointer(_, _) => todo!(), }) .collect::>(); - return self.context.void_type().fn_type(&*input_args, false); + return unsafe { + LLVMFunctionType( + LLVMVoidTypeInContext(self.context), + input_args.as_mut_ptr(), + input_args.len() as u32, + 0, + ) + }; } todo!() } - - fn get_type(&self, type_: &ast::Type) -> FunctionType<'ctx> { - match type_ { - ast::Type::Scalar(_) => todo!(), - ast::Type::Vector(_, _) => todo!(), - ast::Type::Array(_, _, _) => todo!(), - ast::Type::Pointer(_, _) => todo!(), - } - } } -struct MethodEmitContext<'a, 'ctx, 'input> { - context: &'ctx Context, - module: &'a module::Module<'ctx>, - method: FunctionValue<'ctx>, - builder: &'a Builder<'ctx>, +struct MethodEmitContext<'a, 'input> { + context: LLVMContextRef, + module: LLVMModuleRef, + method: LLVMValueRef, + builder: LLVMBuilderRef, id_defs: &'a GlobalStringIdResolver<'input>, - variables_builder: Builder<'ctx>, - resolver: &'a mut ResolveIdent<'ctx>, + variables_builder: Builder, + resolver: &'a mut ResolveIdent, } -impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> { - fn new( - parent: &'a mut ModuleEmitContext<'ctx, 'input>, - method: FunctionValue<'ctx>, - variables_builder: Builder<'ctx>, - ) -> MethodEmitContext<'a, 'ctx, 'input> { +impl<'a, 'input> MethodEmitContext<'a, 'input> { + fn new<'x>( + parent: &'a mut ModuleEmitContext<'x, 'input>, + method: LLVMValueRef, + variables_builder: Builder, + ) -> MethodEmitContext<'a, 'input> { MethodEmitContext { - context: &parent.context, - module: &parent.module, - builder: &parent.builder, + context: parent.context, + module: parent.module, + builder: parent.builder.get(), id_defs: parent.id_defs, variables_builder, resolver: &mut parent.resolver, @@ -238,19 +370,16 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> { fn emit_variable(&mut self, var: ast::Variable) -> Result<(), TranslateError> { let alloca = unsafe { - PointerValue::new(LLVMZludaBuildAlloca( - self.variables_builder.as_mut_ptr(), - get_type::(&self.context, &var.v_type)?.as_type_ref(), - get_state_space(var.state_space)? as u32, + LLVMZludaBuildAlloca( + self.variables_builder.get(), + get_type(self.context, &var.v_type)?, + get_state_space(var.state_space)?, self.resolver.get_or_add_raw(var.name), - )) + ) }; self.resolver.register(var.name, alloca); if let Some(align) = var.align { - let alloca = alloca.as_instruction().ok_or_else(|| error_unreachable())?; - alloca - .set_alignment(align) - .map_err(|_| error_unreachable())?; + unsafe { LLVMSetAlignment(alloca, align) }; } if !var.array_init.is_empty() { todo!() @@ -259,27 +388,24 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> { } fn emit_label(&mut self, label: SpirvWord) { - let block = self - .context - .append_basic_block(self.method, self.resolver.get_or_add(label)); - if self - .builder - .get_insert_block() - .unwrap() - .get_terminator() - .is_none() - { - self.builder.build_unconditional_branch(block); + let block = unsafe { + LLVMAppendBasicBlockInContext( + self.context, + self.method, + self.resolver.get_or_add_raw(label), + ) + }; + let last_block = unsafe { LLVMGetInsertBlock(self.builder) }; + if unsafe { LLVMGetBasicBlockTerminator(last_block) } == ptr::null_mut() { + unsafe { LLVMBuildBr(self.builder, block) }; } - self.builder.position_at_end(block); + unsafe { LLVMPositionBuilderAtEnd(self.builder, block) }; } fn emit_store_var(&mut self, store: StoreVarDetails) -> Result<(), TranslateError> { - let src1 = self.resolver.value(store.arg.src1)?; - let src2 = self.resolver.value(store.arg.src2)?; - self.builder - .build_store(src1.as_pointer()?, src2.as_basic()?) - .map_err(|_| error_unreachable())?; + let ptr = self.resolver.value(store.arg.src1)?; + let value = self.resolver.value(store.arg.src2)?; + unsafe { LLVMBuildStore(self.builder, value, ptr) }; Ok(()) } @@ -303,7 +429,7 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> { ast::Instruction::Cvt { data, arguments } => todo!(), ast::Instruction::Shr { data, arguments } => todo!(), ast::Instruction::Shl { data, arguments } => todo!(), - ast::Instruction::Ret { data } => self.emit_ret(data), + ast::Instruction::Ret { data } => Ok(self.emit_ret(data)), ast::Instruction::Cvta { data, arguments } => todo!(), ast::Instruction::Abs { data, arguments } => todo!(), ast::Instruction::Mad { data, arguments } => todo!(), @@ -351,10 +477,12 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> { todo!() } let builder = self.builder; - let type_ = get_type::(&self.context, &data.typ)?; - let ptr = self.resolver.value(arguments.src)?.as_pointer()?; - self.resolver - .with_result(arguments.dst, |dst| builder.build_load(type_, ptr, dst)) + let type_ = get_type(self.context, &data.typ)?; + let ptr = self.resolver.value(arguments.src)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildLoad2(builder, type_, ptr, dst) + }); + Ok(()) } fn emit_load_variable(&mut self, var: LoadVarDetails) -> Result<(), TranslateError> { @@ -362,10 +490,12 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> { todo!() } let builder = self.builder; - let type_ = get_type::(&self.context, &var.typ)?; - let ptr = self.resolver.value(var.arg.src)?.as_pointer()?; - self.resolver - .with_result(var.arg.dst, |dst| builder.build_load(type_, ptr, dst)) + let type_ = get_type(self.context, &var.typ)?; + let ptr = self.resolver.value(var.arg.src)?; + self.resolver.with_result(var.arg.dst, |dst| unsafe { + LLVMBuildLoad2(builder, type_, ptr, dst) + }); + Ok(()) } fn emit_conversion(&mut self, conversion: ImplicitConversion) -> Result<(), TranslateError> { @@ -374,11 +504,12 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> { ConversionKind::Default => todo!(), ConversionKind::SignExtend => todo!(), ConversionKind::BitToPtr => { - let src = self.resolver.value(conversion.src)?.as_int()?; + let src = self.resolver.value(conversion.src)?; let type_ = get_pointer_type(self.context, conversion.to_space)?; - self.resolver.with_result(conversion.dst, |dst| { - builder.build_int_to_ptr(src, type_, dst) - }) + self.resolver.with_result(conversion.dst, |dst| unsafe { + LLVMBuildIntToPtr(builder, src, type_, dst) + }); + Ok(()) } ConversionKind::PtrToPtr => todo!(), ConversionKind::AddressOf => todo!(), @@ -386,21 +517,12 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> { } fn emit_constant(&mut self, constant: ConstantDefinition) -> Result<(), TranslateError> { - let type_ = get_scalar_type::(&self.context, constant.typ); - let value: AnyValueEnum = match (type_, constant.value) { - (BasicTypeEnum::IntType(type_), ast::ImmediateValue::U64(x)) => { - type_.const_int(x, false).into() - } - (BasicTypeEnum::IntType(type_), ast::ImmediateValue::S64(x)) => { - type_.const_int(x as u64, false).into() - } - (BasicTypeEnum::FloatType(type_), ast::ImmediateValue::F32(x)) => { - type_.const_float(x as f64).into() - } - (BasicTypeEnum::FloatType(type_), ast::ImmediateValue::F64(x)) => { - type_.const_float(x).into() - } - _ => return Err(error_unreachable()), + let type_ = get_scalar_type(self.context, constant.typ); + let value = match constant.value { + ast::ImmediateValue::U64(x) => unsafe { LLVMConstInt(type_, x, 0) }, + ast::ImmediateValue::S64(x) => unsafe { LLVMConstInt(type_, x as u64, 0) }, + ast::ImmediateValue::F32(x) => unsafe { LLVMConstReal(type_, x as f64) }, + ast::ImmediateValue::F64(x) => unsafe { LLVMConstReal(type_, x) }, }; self.resolver.register(constant.dst, value); Ok(()) @@ -412,14 +534,16 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> { arguments: ast::AddArgs, ) -> Result<(), TranslateError> { let builder = self.builder; - let src1 = self.resolver.value(arguments.src1)?.as_int()?; - let src2 = self.resolver.value(arguments.src2)?.as_int()?; + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; let fn_ = match data { - ast::ArithDetails::Integer(integer) => Builder::build_int_add, - ast::ArithDetails::Float(float) => todo!(), + ast::ArithDetails::Integer(integer) => LLVMBuildAdd, + ast::ArithDetails::Float(float) => LLVMBuildFAdd, }; - self.resolver - .with_result(arguments.dst, |dst| fn_(builder, src1, src2, dst)) + self.resolver.with_result(arguments.dst, |dst| unsafe { + fn_(builder, src1, src2, dst) + }); + Ok(()) } fn emit_st( @@ -427,129 +551,80 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> { data: ptx_parser::StData, arguments: ptx_parser::StArgs, ) -> Result<(), TranslateError> { - let builder = self.builder; - let src1 = self.resolver.value(arguments.src1)?.as_pointer()?; - let src2 = self.resolver.value(arguments.src2)?.as_basic()?; + let ptr = self.resolver.value(arguments.src1)?; + let value = self.resolver.value(arguments.src2)?; if data.qualifier != ast::LdStQualifier::Weak { todo!() } - self.builder - .build_store(src1, src2) - .map_err(|_| error_unreachable())?; + unsafe { LLVMBuildStore(self.builder, value, ptr) }; Ok(()) } - fn emit_ret(&self, _data: ptx_parser::RetData) -> Result<(), TranslateError> { - self.builder - .build_return(None) - .map_err(|_| error_unreachable())?; - Ok(()) + fn emit_ret(&self, _data: ptx_parser::RetData) { + unsafe { LLVMBuildRetVoid(self.builder) }; } } fn get_pointer_type<'ctx>( - context: &'ctx Context, + context: LLVMContextRef, to_space: ast::StateSpace, -) -> Result, TranslateError> { - Ok(context.ptr_type(AddressSpace::from(get_state_space(to_space)?))) +) -> Result { + Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(to_space)?) }) } -fn get_type< - 'ctx, - T: From> - + From> - + From> - + From> - + From>, ->( - context: &'ctx Context, - type_: &ast::Type, -) -> Result { +fn get_type(context: LLVMContextRef, type_: &ast::Type) -> Result { Ok(match type_ { ast::Type::Scalar(scalar) => get_scalar_type(context, *scalar), ast::Type::Vector(size, scalar) => { - let base_type = get_scalar_type::(context, *scalar); - let base_type = match base_type { - BasicTypeEnum::FloatType(t) => t.as_type_ref(), - BasicTypeEnum::IntType(t) => t.as_type_ref(), - _ => return Err(error_unreachable()), - }; - T::from(unsafe { VectorType::new(LLVMVectorType(base_type, *size as u32)) }) + let base_type = get_scalar_type(context, *scalar); + unsafe { LLVMVectorType(base_type, *size as u32) } } ast::Type::Array(vec, scalar, dimensions) => { - let mut underlying_type = get_scalar_type::(context, *scalar); + let mut underlying_type = get_scalar_type(context, *scalar); if let Some(size) = vec { - underlying_type = BasicTypeEnum::VectorType(unsafe { - VectorType::new(LLVMVectorType( - match underlying_type { - BasicTypeEnum::FloatType(t) => t.as_type_ref(), - BasicTypeEnum::IntType(t) => t.as_type_ref(), - _ => return Err(error_unreachable()), - }, - size.get() as u32, - )) - }); + underlying_type = unsafe { LLVMVectorType(underlying_type, size.get() as u32) }; } if dimensions.is_empty() { - return Ok(T::from(underlying_type.array_type(0))); + return Ok(unsafe { LLVMArrayType2(underlying_type, 0) }); } - let llvm_type = dimensions + dimensions .iter() - .rfold(underlying_type.as_type_ref(), |result, dimension| unsafe { + .rfold(underlying_type, |result, dimension| unsafe { LLVMArrayType2(result, *dimension as u64) - }); - T::from(unsafe { ArrayType::new(llvm_type) }) - } - ast::Type::Pointer(_, space) => { - T::from(context.ptr_type(AddressSpace::from(get_state_space(*space)?))) + }) } + ast::Type::Pointer(_, space) => get_pointer_type(context, *space)?, }) } -fn get_scalar_type< - 'ctx, - T: From> + From> + From>, ->( - context: &'ctx Context, - type_: ast::ScalarType, -) -> T { +fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeRef { match type_ { - ast::ScalarType::Pred => T::from(context.bool_type()), - ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => { - T::from(context.i8_type()) - } - ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => { - T::from(context.i16_type()) - } - ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => { - T::from(context.i32_type()) - } - ast::ScalarType::U64 | ast::ScalarType::S64 | ast::ScalarType::B64 => { - T::from(context.i64_type()) - } - ast::ScalarType::B128 => T::from(context.i128_type()), - ast::ScalarType::F16 => T::from(context.f16_type()), - ast::ScalarType::F32 => T::from(context.f32_type()), - ast::ScalarType::F64 => T::from(context.f64_type()), - ast::ScalarType::BF16 => { - T::from(unsafe { FloatType::new(LLVMBFloatTypeInContext(context.as_ctx_ref())) }) - } - ast::ScalarType::U16x2 | ast::ScalarType::S16x2 => { - T::from(unsafe { VectorType::new(LLVMVectorType(context.i16_type().as_type_ref(), 2)) }) - } - ast::ScalarType::F16x2 => { - T::from(unsafe { VectorType::new(LLVMVectorType(context.f16_type().as_type_ref(), 2)) }) - } - ast::ScalarType::BF16x2 => T::from(unsafe { - VectorType::new(LLVMVectorType( - LLVMBFloatTypeInContext(context.as_ctx_ref()), - 2, - )) - }), + ast::ScalarType::Pred => unsafe { LLVMInt1TypeInContext(context) }, + ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => unsafe { + LLVMInt8TypeInContext(context) + }, + ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => unsafe { + LLVMInt16TypeInContext(context) + }, + ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => unsafe { + LLVMInt32TypeInContext(context) + }, + ast::ScalarType::U64 | ast::ScalarType::S64 | ast::ScalarType::B64 => unsafe { + LLVMInt64TypeInContext(context) + }, + ast::ScalarType::B128 => unsafe { LLVMInt128TypeInContext(context) }, + ast::ScalarType::F16 => unsafe { LLVMHalfTypeInContext(context) }, + ast::ScalarType::F32 => unsafe { LLVMFloatTypeInContext(context) }, + ast::ScalarType::F64 => unsafe { LLVMDoubleTypeInContext(context) }, + ast::ScalarType::BF16 => unsafe { LLVMBFloatTypeInContext(context) }, + ast::ScalarType::U16x2 => todo!(), + ast::ScalarType::S16x2 => todo!(), + ast::ScalarType::F16x2 => todo!(), + ast::ScalarType::BF16x2 => todo!(), } } -fn get_state_space(space: ast::StateSpace) -> Result { +fn get_state_space(space: ast::StateSpace) -> Result { match space { ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE), ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE), @@ -566,12 +641,12 @@ fn get_state_space(space: ast::StateSpace) -> Result { } } -struct ResolveIdent<'ctx> { +struct ResolveIdent { words: HashMap, - values: HashMap>, + values: HashMap, } -impl<'ctx> ResolveIdent<'ctx> { +impl ResolveIdent { fn new<'input>(_id_defs: &GlobalStringIdResolver<'input>) -> Self { ResolveIdent { words: HashMap::new(), @@ -580,14 +655,15 @@ impl<'ctx> ResolveIdent<'ctx> { } fn get_or_ad_impl<'a, T>(&'a mut self, word: SpirvWord, fn_: impl FnOnce(&'a str) -> T) -> T { - match self.words.entry(word) { - hash_map::Entry::Occupied(entry) => fn_(entry.into_mut()), + let str = match self.words.entry(word) { + hash_map::Entry::Occupied(entry) => entry.into_mut(), hash_map::Entry::Vacant(entry) => { let mut text = word.0.to_string(); text.push('\0'); - fn_(entry.insert(text)) + entry.insert(text) } - } + }; + fn_(&str[..str.len() - 1]) } fn get_or_add(&mut self, word: SpirvWord) -> &str { @@ -598,153 +674,19 @@ impl<'ctx> ResolveIdent<'ctx> { self.get_or_add(word).as_ptr().cast() } - fn register(&mut self, word: SpirvWord, t: impl AnyValue<'ctx>) { - self.values.insert(word, t.as_any_value_enum()); + fn register(&mut self, word: SpirvWord, v: LLVMValueRef) { + self.values.insert(word, v); } - fn value(&self, word: SpirvWord) -> Result, TranslateError> { + fn value(&self, word: SpirvWord) -> Result { self.values .get(&word) .copied() .ok_or_else(|| error_unreachable()) } - fn with_result>( - &mut self, - word: SpirvWord, - fn_: impl FnOnce(&str) -> Result, - ) -> Result<(), TranslateError> { - let t = self - .get_or_ad_impl(word, fn_) - .map_err(|_| error_unreachable())?; + fn with_result(&mut self, word: SpirvWord, fn_: impl FnOnce(*const i8) -> LLVMValueRef) { + let t = self.get_or_ad_impl(word, |dst| fn_(dst.as_ptr().cast())); self.register(word, t); - Ok(()) - } - - fn build_int_math( - &mut self, - builder: &Builder<'ctx>, - dst: SpirvWord, - src1: SpirvWord, - src2: SpirvWord, - fn_: impl IntMathOp<'ctx>, - ) -> Result<(), TranslateError> { - let src1 = self.value(src1)?; - let src2 = self.value(src2)?; - self.with_result(dst, |dst| { - Ok(match (src1, src2) { - (AnyValueEnum::IntValue(src1), AnyValueEnum::IntValue(src2)) => { - AnyValueEnum::from(fn_.call(builder, src1, src2, dst)?) - } - (AnyValueEnum::PointerValue(src1), AnyValueEnum::PointerValue(src2)) => { - AnyValueEnum::from(fn_.call(builder, src1, src2, dst)?) - } - (AnyValueEnum::VectorValue(src1), AnyValueEnum::VectorValue(src2)) => { - AnyValueEnum::from(fn_.call(builder, src1, src2, dst)?) - } - _ => return todo!(), - }) - }) - } -} - -trait IntMathOp<'ctx> { - fn call>( - self, - builder: &Builder<'ctx>, - src1: T, - src2: T, - dst: &str, - ) -> Result; -} - -trait AnyValueEnumExt<'ctx> { - fn as_array(self) -> Result, TranslateError>; - fn as_int(self) -> Result, TranslateError>; - fn as_float(self) -> Result, TranslateError>; - fn as_phi(self) -> Result, TranslateError>; - fn as_function(self) -> Result, TranslateError>; - fn as_pointer(self) -> Result, TranslateError>; - fn as_struct(self) -> Result, TranslateError>; - fn as_vector(self) -> Result, TranslateError>; - fn as_instruction(self) -> Result, TranslateError>; - fn as_basic(self) -> Result, TranslateError>; -} - -impl<'ctx> AnyValueEnumExt<'ctx> for AnyValueEnum<'ctx> { - fn as_array(self) -> Result, TranslateError> { - if let AnyValueEnum::ArrayValue(x) = self { - Ok(x) - } else { - Err(error_unreachable()) - } - } - - fn as_int(self) -> Result, TranslateError> { - if let AnyValueEnum::IntValue(x) = self { - Ok(x) - } else { - Err(error_unreachable()) - } - } - - fn as_float(self) -> Result, TranslateError> { - if let AnyValueEnum::FloatValue(x) = self { - Ok(x) - } else { - Err(error_unreachable()) - } - } - - fn as_phi(self) -> Result, TranslateError> { - if let AnyValueEnum::PhiValue(x) = self { - Ok(x) - } else { - Err(error_unreachable()) - } - } - - fn as_function(self) -> Result, TranslateError> { - if let AnyValueEnum::FunctionValue(x) = self { - Ok(x) - } else { - Err(error_unreachable()) - } - } - - fn as_pointer(self) -> Result, TranslateError> { - if let AnyValueEnum::PointerValue(x) = self { - Ok(x) - } else { - Err(error_unreachable()) - } - } - - fn as_struct(self) -> Result, TranslateError> { - if let AnyValueEnum::StructValue(x) = self { - Ok(x) - } else { - Err(error_unreachable()) - } - } - - fn as_vector(self) -> Result, TranslateError> { - if let AnyValueEnum::VectorValue(x) = self { - Ok(x) - } else { - Err(error_unreachable()) - } - } - - fn as_instruction(self) -> Result, TranslateError> { - if let AnyValueEnum::InstructionValue(x) = self { - Ok(x) - } else { - Err(error_unreachable()) - } - } - - fn as_basic(self) -> Result, TranslateError> { - BasicValueEnum::try_from(self).map_err(|_| error_unreachable()) } } diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 6693434..3aa3b0a 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -1,4 +1,3 @@ -use llvm_zluda::inkwell::memory_buffer::MemoryBuffer; use ptx_parser as ast; use rspirv::{binary::Assemble, dr}; use std::hash::Hash; @@ -17,7 +16,7 @@ use std::{ mod convert_dynamic_shared_memory_usage; mod convert_to_stateful_memory_access; mod convert_to_typed; -mod emit_llvm; +pub(crate) mod emit_llvm; mod emit_spirv; mod expand_arguments; mod extract_globals; @@ -182,7 +181,7 @@ fn to_ssa<'input, 'b>( } pub struct Module { - pub llvm_ir: MemoryBuffer, + pub llvm_ir: emit_llvm::MemoryBuffer, pub kernel_info: HashMap, } @@ -598,6 +597,7 @@ fn error_unreachable() -> TranslateError { TranslateError::Unreachable } +#[cfg(debug_assertions)] fn error_unknown_symbol() -> TranslateError { panic!() } @@ -607,6 +607,7 @@ fn error_unknown_symbol() -> TranslateError { TranslateError::UnknownSymbol } +#[cfg(debug_assertions)] fn error_mismatched_type() -> TranslateError { panic!() } diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index e0982ff..2e6c910 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -2,7 +2,6 @@ use crate::pass; use crate::ptx; use crate::translate; use hip_runtime_sys::hipError_t; -use llvm_zluda::inkwell::memory_buffer::MemoryBuffer; use rspirv::{ binary::{Assemble, Disassemble}, dr::{Block, Function, Instruction, Loader, Operand}, @@ -379,21 +378,21 @@ fn run_hip + Copy + Debug, Output: From + Copy + Debug + Def Ok(result) } -unsafe fn compile_amd(buffer: &MemoryBuffer) -> Vec { +unsafe fn compile_amd(buffer: &pass::emit_llvm::MemoryBuffer) -> Vec { use amd_comgr_sys::*; let mut data_set = mem::zeroed(); amd_comgr_create_data_set(&mut data_set).unwrap(); let mut data = mem::zeroed(); amd_comgr_create_data(amd_comgr_data_kind_t::AMD_COMGR_DATA_KIND_BC, &mut data).unwrap(); - let buffer = buffer.as_slice(); + let buffer = &**buffer; amd_comgr_set_data(data, buffer.len(), buffer.as_ptr().cast()).unwrap(); - amd_comgr_set_data_name(data, "zluda.bc\0".as_ptr().cast()).unwrap(); + amd_comgr_set_data_name(data, c"zluda.bc".as_ptr()).unwrap(); amd_comgr_data_set_add(data_set, data).unwrap(); let mut reloc_data = mem::zeroed(); amd_comgr_create_data_set(&mut reloc_data).unwrap(); let mut action_info = mem::zeroed(); amd_comgr_create_action_info(&mut action_info).unwrap(); - amd_comgr_action_info_set_isa_name(action_info, "amdgcn-amd-amdhsa--gfx1030\0".as_ptr().cast()) + amd_comgr_action_info_set_isa_name(action_info, c"amdgcn-amd-amdhsa--gfx1030".as_ptr()) .unwrap(); amd_comgr_do_action( amd_comgr_action_kind_t::AMD_COMGR_ACTION_CODEGEN_BC_TO_RELOCATABLE,