mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-19 16:04:44 +00:00
Remove inkwell
This commit is contained in:
parent
fb68c67adb
commit
631417b405
7 changed files with 366 additions and 436 deletions
|
@ -1,2 +0,0 @@
|
|||
[patch.crates-io]
|
||||
inkwell = { git = "https://github.com/vosen/inkwell.git", rev = "46027c2afb7e98976438cdcc41a2949dedb60b2e" }
|
|
@ -15,8 +15,3 @@ features = [ "disable-alltargets-init", "no-llvm-linking" ]
|
|||
[build-dependencies]
|
||||
cmake = "0.1"
|
||||
cc = "1.0.69"
|
||||
|
||||
[dependencies.inkwell]
|
||||
version = "0.5"
|
||||
default-features = false # default features contain all LLVM targets (x86, mips, riscv, etc.)
|
||||
features = [ "llvm17-0-no-llvm-linking", "no-libffi-linking" ]
|
||||
|
|
|
@ -1,15 +1,10 @@
|
|||
pub mod inkwell {
|
||||
pub use inkwell::*;
|
||||
}
|
||||
pub mod llvm {
|
||||
use llvm_sys::prelude::*;
|
||||
pub use llvm_sys::*;
|
||||
extern "C" {
|
||||
pub fn LLVMZludaBuildAlloca(
|
||||
B: LLVMBuilderRef,
|
||||
Ty: LLVMTypeRef,
|
||||
AddrSpace: u32,
|
||||
Name: *const i8,
|
||||
) -> LLVMValueRef;
|
||||
}
|
||||
use llvm_sys::prelude::*;
|
||||
pub use llvm_sys::*;
|
||||
extern "C" {
|
||||
pub fn LLVMZludaBuildAlloca(
|
||||
B: LLVMBuilderRef,
|
||||
Ty: LLVMTypeRef,
|
||||
AddrSpace: u32,
|
||||
Name: *const i8,
|
||||
) -> LLVMValueRef;
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
name = "ptx"
|
||||
version = "0.0.0"
|
||||
authors = ["Andrzej Janik <vosen@vosen.pl>"]
|
||||
edition = "2018"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
|
||||
|
|
|
@ -1,74 +1,207 @@
|
|||
// We use Raw LLVM-C bindings here because using inkwell is just not worth it.
|
||||
// Specifically the issue is with builder functions. We maintain the mapping
|
||||
// between ZLUDA identifiers and LLVM values. When using inkwell, LLVM values
|
||||
// are kept as instances `AnyValueEnum`. Now look at the signature of
|
||||
// `Builder::build_int_add(...)`:
|
||||
// pub fn build_int_add<T: IntMathValue<'ctx>>(&self, lhs: T, rhs: T, name: &str, ) -> Result<T, BuilderError>
|
||||
// At this point both lhs and rhs are `AnyValueEnum`. To call
|
||||
// `build_int_add(...)` we would have to do something like this:
|
||||
// if let (Ok(lhs), Ok(rhs)) = (lhs.as_int(), rhs.as_int()) {
|
||||
// builder.build_int_add(lhs, rhs, dst)?;
|
||||
// } else if let (Ok(lhs), Ok(rhs)) = (lhs.as_pointer(), rhs.as_pointer()) {
|
||||
// builder.build_int_add(lhs, rhs, dst)?;
|
||||
// } else if let (Ok(lhs), Ok(rhs)) = (lhs.as_vector(), rhs.as_vector()) {
|
||||
// builder.build_int_add(lhs, rhs, dst)?;
|
||||
// } else {
|
||||
// return Err(error_unrachable());
|
||||
// }
|
||||
// while with plain LLVM-C it's just:
|
||||
// unsafe { LLVMBuildAdd(builder, lhs, rhs, dst) };
|
||||
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
use std::ffi::CStr;
|
||||
use std::ops::Deref;
|
||||
use std::ptr;
|
||||
|
||||
use super::*;
|
||||
use llvm_zluda::inkwell::builder::{Builder, BuilderError};
|
||||
use llvm_zluda::inkwell::context::{AsContextRef, Context};
|
||||
use llvm_zluda::inkwell::memory_buffer::MemoryBuffer;
|
||||
use llvm_zluda::inkwell::types::{
|
||||
ArrayType, AsTypeRef, BasicMetadataTypeEnum, BasicType, BasicTypeEnum, FloatType, FunctionType,
|
||||
IntType, PointerType, VectorType, VoidType,
|
||||
};
|
||||
use llvm_zluda::inkwell::values::{
|
||||
AnyValue, AnyValueEnum, ArrayValue, BasicValueEnum, FloatMathValue, FloatValue, FunctionValue,
|
||||
InstructionValue, IntMathValue, IntValue, PhiValue, PointerValue, StructValue, VectorValue,
|
||||
};
|
||||
use llvm_zluda::inkwell::{self, module, AddressSpace};
|
||||
use llvm_zluda::llvm::core::{
|
||||
LLVMArrayType2, LLVMBFloatType, LLVMBFloatTypeInContext, LLVMVectorType,
|
||||
};
|
||||
use llvm_zluda::llvm::prelude::*;
|
||||
use llvm_zluda::llvm::{LLVMCallConv, LLVMZludaBuildAlloca};
|
||||
use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule};
|
||||
use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer;
|
||||
use llvm_zluda::core::*;
|
||||
use llvm_zluda::prelude::*;
|
||||
use llvm_zluda::{LLVMCallConv, LLVMZludaBuildAlloca};
|
||||
|
||||
const LLVM_UNNAMED: &str = "\0";
|
||||
const LLVM_UNNAMED: &CStr = c"";
|
||||
// https://llvm.org/docs/AMDGPUUsage.html#address-spaces
|
||||
const GENERIC_ADDRESS_SPACE: u16 = 0;
|
||||
const GLOBAL_ADDRESS_SPACE: u16 = 1;
|
||||
const SHARED_ADDRESS_SPACE: u16 = 3;
|
||||
const CONSTANT_ADDRESS_SPACE: u16 = 4;
|
||||
const PRIVATE_ADDRESS_SPACE: u16 = 5;
|
||||
const GENERIC_ADDRESS_SPACE: u32 = 0;
|
||||
const GLOBAL_ADDRESS_SPACE: u32 = 1;
|
||||
const SHARED_ADDRESS_SPACE: u32 = 3;
|
||||
const CONSTANT_ADDRESS_SPACE: u32 = 4;
|
||||
const PRIVATE_ADDRESS_SPACE: u32 = 5;
|
||||
|
||||
struct Context(LLVMContextRef);
|
||||
|
||||
impl Context {
|
||||
fn new() -> Self {
|
||||
Self(unsafe { LLVMContextCreate() })
|
||||
}
|
||||
|
||||
fn get(&self) -> LLVMContextRef {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Context {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
LLVMContextDispose(self.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Module(LLVMModuleRef);
|
||||
|
||||
impl Module {
|
||||
fn new(ctx: &Context, name: &CStr) -> Self {
|
||||
Self(unsafe { LLVMModuleCreateWithNameInContext(name.as_ptr(), ctx.get()) })
|
||||
}
|
||||
|
||||
fn get(&self) -> LLVMModuleRef {
|
||||
self.0
|
||||
}
|
||||
|
||||
fn verify(&self) -> Result<(), Message> {
|
||||
let mut err = ptr::null_mut();
|
||||
let error = unsafe {
|
||||
LLVMVerifyModule(
|
||||
self.get(),
|
||||
LLVMVerifierFailureAction::LLVMReturnStatusAction,
|
||||
&mut err,
|
||||
)
|
||||
};
|
||||
if error == 1 && err != ptr::null_mut() {
|
||||
Err(Message(unsafe { CStr::from_ptr(err) }))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn write_bitcode_to_memory(&self) -> MemoryBuffer {
|
||||
let memory_buffer = unsafe { LLVMWriteBitcodeToMemoryBuffer(self.get()) };
|
||||
MemoryBuffer(memory_buffer)
|
||||
}
|
||||
|
||||
fn write_to_stderr(&self) {
|
||||
unsafe { LLVMDumpModule(self.get()) };
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Module {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
LLVMDisposeModule(self.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Builder(LLVMBuilderRef);
|
||||
|
||||
impl Builder {
|
||||
fn new(ctx: &Context) -> Self {
|
||||
Self::new_raw(ctx.get())
|
||||
}
|
||||
|
||||
fn new_raw(ctx: LLVMContextRef) -> Self {
|
||||
Self(unsafe { LLVMCreateBuilderInContext(ctx) })
|
||||
}
|
||||
|
||||
fn get(&self) -> LLVMBuilderRef {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Builder {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
LLVMDisposeBuilder(self.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Message(&'static CStr);
|
||||
|
||||
impl Drop for Message {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
LLVMDisposeMessage(self.0.as_ptr().cast_mut());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Message {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
std::fmt::Debug::fmt(&self.0, f)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MemoryBuffer(LLVMMemoryBufferRef);
|
||||
|
||||
impl Drop for MemoryBuffer {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
LLVMDisposeMemoryBuffer(self.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for MemoryBuffer {
|
||||
type Target = [u8];
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
let data = unsafe { LLVMGetBufferStart(self.0) };
|
||||
let len = unsafe { LLVMGetBufferSize(self.0) };
|
||||
unsafe { std::slice::from_raw_parts(data.cast(), len) }
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn run<'input>(
|
||||
id_defs: &GlobalStringIdResolver<'input>,
|
||||
call_map: MethodsCallMap<'input>,
|
||||
directives: Vec<Directive<'input>>,
|
||||
) -> Result<MemoryBuffer, TranslateError> {
|
||||
let context = inkwell::context::Context::create();
|
||||
let module = context.create_module(LLVM_UNNAMED);
|
||||
let builder = context.create_builder();
|
||||
let mut emit_ctx = ModuleEmitContext::new(&context, module, builder, id_defs);
|
||||
let context = Context::new();
|
||||
let module = Module::new(&context, LLVM_UNNAMED);
|
||||
let mut emit_ctx = ModuleEmitContext::new(&context, &module, id_defs);
|
||||
for directive in directives {
|
||||
match directive {
|
||||
Directive::Variable(..) => todo!(),
|
||||
Directive::Method(method) => emit_ctx.emit_method(method)?,
|
||||
}
|
||||
}
|
||||
if let Err(err) = emit_ctx.module.verify() {
|
||||
emit_ctx.module.print_to_stderr();
|
||||
panic!("{}", err);
|
||||
module.write_to_stderr();
|
||||
if let Err(err) = module.verify() {
|
||||
panic!("{:?}", err);
|
||||
}
|
||||
Ok(emit_ctx.module.write_bitcode_to_memory())
|
||||
Ok(module.write_bitcode_to_memory())
|
||||
}
|
||||
|
||||
struct ModuleEmitContext<'ctx, 'input> {
|
||||
context: &'ctx Context,
|
||||
module: module::Module<'ctx>,
|
||||
builder: Builder<'ctx>,
|
||||
id_defs: &'ctx GlobalStringIdResolver<'input>,
|
||||
resolver: ResolveIdent<'ctx>,
|
||||
struct ModuleEmitContext<'a, 'input> {
|
||||
context: LLVMContextRef,
|
||||
module: LLVMModuleRef,
|
||||
builder: Builder,
|
||||
id_defs: &'a GlobalStringIdResolver<'input>,
|
||||
resolver: ResolveIdent,
|
||||
}
|
||||
|
||||
impl<'ctx, 'input> ModuleEmitContext<'ctx, 'input> {
|
||||
impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
||||
fn new(
|
||||
context: &'ctx Context,
|
||||
module: module::Module<'ctx>,
|
||||
builder: Builder<'ctx>,
|
||||
id_defs: &'ctx GlobalStringIdResolver<'input>,
|
||||
context: &Context,
|
||||
module: &Module,
|
||||
id_defs: &'a GlobalStringIdResolver<'input>,
|
||||
) -> Self {
|
||||
ModuleEmitContext {
|
||||
context: &context,
|
||||
module,
|
||||
builder,
|
||||
context: context.get(),
|
||||
module: module.get(),
|
||||
builder: Builder::new(context),
|
||||
id_defs,
|
||||
resolver: ResolveIdent::new(&id_defs),
|
||||
}
|
||||
|
@ -84,85 +217,86 @@ impl<'ctx, 'input> ModuleEmitContext<'ctx, 'input> {
|
|||
|
||||
fn emit_method(&mut self, method: Function<'input>) -> Result<(), TranslateError> {
|
||||
let func_decl = method.func_decl.borrow();
|
||||
let fn_ = self.module.add_function(
|
||||
method
|
||||
.import_as
|
||||
.as_deref()
|
||||
.unwrap_or_else(|| match func_decl.name {
|
||||
ast::MethodName::Kernel(name) => name,
|
||||
ast::MethodName::Func(id) => self.id_defs.reverse_variables[&id],
|
||||
}),
|
||||
self.function_type(
|
||||
func_decl.return_arguments.iter().map(|v| &v.v_type),
|
||||
func_decl.input_arguments.iter().map(|v| &v.v_type),
|
||||
),
|
||||
None,
|
||||
let name = method
|
||||
.import_as
|
||||
.as_deref()
|
||||
.unwrap_or_else(|| match func_decl.name {
|
||||
ast::MethodName::Kernel(name) => name,
|
||||
ast::MethodName::Func(id) => self.id_defs.reverse_variables[&id],
|
||||
});
|
||||
let name = CString::new(name).map_err(|_| error_unreachable())?;
|
||||
let fn_type = self.function_type(
|
||||
func_decl.return_arguments.iter().map(|v| &v.v_type),
|
||||
func_decl.input_arguments.iter().map(|v| &v.v_type),
|
||||
);
|
||||
let fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
|
||||
for (i, param) in func_decl.input_arguments.iter().enumerate() {
|
||||
let value = fn_
|
||||
.get_nth_param(i as u32)
|
||||
.ok_or_else(|| error_unreachable())?;
|
||||
value.set_name(self.resolver.get_or_add(param.name));
|
||||
let value = unsafe { LLVMGetParam(fn_, i as u32) };
|
||||
let name = self.resolver.get_or_add(param.name);
|
||||
unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) };
|
||||
self.resolver.register(param.name, value);
|
||||
}
|
||||
fn_.set_call_conventions(if func_decl.name.is_kernel() {
|
||||
let call_conv = if func_decl.name.is_kernel() {
|
||||
Self::kernel_call_convention()
|
||||
} else {
|
||||
Self::func_call_convention()
|
||||
});
|
||||
};
|
||||
unsafe { LLVMSetFunctionCallConv(fn_, call_conv) };
|
||||
if let Some(statements) = method.body {
|
||||
let variables_bb = self.context.append_basic_block(fn_, LLVM_UNNAMED);
|
||||
let variables_builder = self.context.create_builder();
|
||||
variables_builder.position_at_end(variables_bb);
|
||||
let real_bb = self.context.append_basic_block(fn_, LLVM_UNNAMED);
|
||||
self.builder.position_at_end(real_bb);
|
||||
let variables_bb =
|
||||
unsafe { LLVMAppendBasicBlockInContext(self.context, fn_, LLVM_UNNAMED.as_ptr()) };
|
||||
let variables_builder = Builder::new_raw(self.context);
|
||||
unsafe { LLVMPositionBuilderAtEnd(variables_builder.get(), variables_bb) };
|
||||
let real_bb =
|
||||
unsafe { LLVMAppendBasicBlockInContext(self.context, fn_, LLVM_UNNAMED.as_ptr()) };
|
||||
unsafe { LLVMPositionBuilderAtEnd(self.builder.get(), real_bb) };
|
||||
let mut method_emitter = MethodEmitContext::new(self, fn_, variables_builder);
|
||||
for statement in statements {
|
||||
method_emitter.emit_statement(statement)?;
|
||||
}
|
||||
method_emitter.variables_builder.build_unconditional_branch(real_bb);
|
||||
unsafe { LLVMBuildBr(method_emitter.variables_builder.get(), real_bb) };
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn function_type<'a>(
|
||||
fn function_type(
|
||||
&self,
|
||||
return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
|
||||
input_args: impl ExactSizeIterator<Item = &'a ast::Type>,
|
||||
) -> FunctionType<'ctx> {
|
||||
) -> LLVMTypeRef {
|
||||
if return_args.len() == 0 {
|
||||
let input_args = input_args
|
||||
let mut input_args = input_args
|
||||
.map(|type_| match type_ {
|
||||
ast::Type::Scalar(scalar) => match scalar {
|
||||
ast::ScalarType::Pred => {
|
||||
BasicMetadataTypeEnum::from(self.context.bool_type())
|
||||
unsafe { LLVMInt1TypeInContext(self.context) }
|
||||
}
|
||||
ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => {
|
||||
BasicMetadataTypeEnum::from(self.context.i8_type())
|
||||
unsafe { LLVMInt8TypeInContext(self.context) }
|
||||
}
|
||||
ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => {
|
||||
BasicMetadataTypeEnum::from(self.context.i16_type())
|
||||
unsafe { LLVMInt16TypeInContext(self.context) }
|
||||
}
|
||||
ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => {
|
||||
BasicMetadataTypeEnum::from(self.context.i32_type())
|
||||
unsafe { LLVMInt32TypeInContext(self.context) }
|
||||
}
|
||||
ast::ScalarType::U64 | ast::ScalarType::S64 | ast::ScalarType::B64 => {
|
||||
BasicMetadataTypeEnum::from(self.context.i64_type())
|
||||
unsafe { LLVMInt64TypeInContext(self.context) }
|
||||
}
|
||||
ast::ScalarType::B128 => {
|
||||
BasicMetadataTypeEnum::from(self.context.i128_type())
|
||||
unsafe { LLVMInt128TypeInContext(self.context) }
|
||||
}
|
||||
ast::ScalarType::F16 => {
|
||||
BasicMetadataTypeEnum::from(self.context.f16_type())
|
||||
unsafe { LLVMHalfTypeInContext(self.context) }
|
||||
}
|
||||
ast::ScalarType::F32 => {
|
||||
BasicMetadataTypeEnum::from(self.context.f32_type())
|
||||
unsafe { LLVMFloatTypeInContext(self.context) }
|
||||
}
|
||||
ast::ScalarType::F64 => {
|
||||
BasicMetadataTypeEnum::from(self.context.f64_type())
|
||||
unsafe { LLVMDoubleTypeInContext(self.context) }
|
||||
}
|
||||
ast::ScalarType::BF16 => {
|
||||
BasicMetadataTypeEnum::from(unsafe { FloatType::new(LLVMBFloatType()) })
|
||||
unsafe { LLVMBFloatTypeInContext(self.context) }
|
||||
}
|
||||
ast::ScalarType::U16x2 => todo!(),
|
||||
ast::ScalarType::S16x2 => todo!(),
|
||||
|
@ -174,41 +308,39 @@ impl<'ctx, 'input> ModuleEmitContext<'ctx, 'input> {
|
|||
ast::Type::Pointer(_, _) => todo!(),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
return self.context.void_type().fn_type(&*input_args, false);
|
||||
return unsafe {
|
||||
LLVMFunctionType(
|
||||
LLVMVoidTypeInContext(self.context),
|
||||
input_args.as_mut_ptr(),
|
||||
input_args.len() as u32,
|
||||
0,
|
||||
)
|
||||
};
|
||||
}
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn get_type(&self, type_: &ast::Type) -> FunctionType<'ctx> {
|
||||
match type_ {
|
||||
ast::Type::Scalar(_) => todo!(),
|
||||
ast::Type::Vector(_, _) => todo!(),
|
||||
ast::Type::Array(_, _, _) => todo!(),
|
||||
ast::Type::Pointer(_, _) => todo!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct MethodEmitContext<'a, 'ctx, 'input> {
|
||||
context: &'ctx Context,
|
||||
module: &'a module::Module<'ctx>,
|
||||
method: FunctionValue<'ctx>,
|
||||
builder: &'a Builder<'ctx>,
|
||||
struct MethodEmitContext<'a, 'input> {
|
||||
context: LLVMContextRef,
|
||||
module: LLVMModuleRef,
|
||||
method: LLVMValueRef,
|
||||
builder: LLVMBuilderRef,
|
||||
id_defs: &'a GlobalStringIdResolver<'input>,
|
||||
variables_builder: Builder<'ctx>,
|
||||
resolver: &'a mut ResolveIdent<'ctx>,
|
||||
variables_builder: Builder,
|
||||
resolver: &'a mut ResolveIdent,
|
||||
}
|
||||
|
||||
impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> {
|
||||
fn new(
|
||||
parent: &'a mut ModuleEmitContext<'ctx, 'input>,
|
||||
method: FunctionValue<'ctx>,
|
||||
variables_builder: Builder<'ctx>,
|
||||
) -> MethodEmitContext<'a, 'ctx, 'input> {
|
||||
impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
||||
fn new<'x>(
|
||||
parent: &'a mut ModuleEmitContext<'x, 'input>,
|
||||
method: LLVMValueRef,
|
||||
variables_builder: Builder,
|
||||
) -> MethodEmitContext<'a, 'input> {
|
||||
MethodEmitContext {
|
||||
context: &parent.context,
|
||||
module: &parent.module,
|
||||
builder: &parent.builder,
|
||||
context: parent.context,
|
||||
module: parent.module,
|
||||
builder: parent.builder.get(),
|
||||
id_defs: parent.id_defs,
|
||||
variables_builder,
|
||||
resolver: &mut parent.resolver,
|
||||
|
@ -238,19 +370,16 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> {
|
|||
|
||||
fn emit_variable(&mut self, var: ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
|
||||
let alloca = unsafe {
|
||||
PointerValue::new(LLVMZludaBuildAlloca(
|
||||
self.variables_builder.as_mut_ptr(),
|
||||
get_type::<BasicTypeEnum>(&self.context, &var.v_type)?.as_type_ref(),
|
||||
get_state_space(var.state_space)? as u32,
|
||||
LLVMZludaBuildAlloca(
|
||||
self.variables_builder.get(),
|
||||
get_type(self.context, &var.v_type)?,
|
||||
get_state_space(var.state_space)?,
|
||||
self.resolver.get_or_add_raw(var.name),
|
||||
))
|
||||
)
|
||||
};
|
||||
self.resolver.register(var.name, alloca);
|
||||
if let Some(align) = var.align {
|
||||
let alloca = alloca.as_instruction().ok_or_else(|| error_unreachable())?;
|
||||
alloca
|
||||
.set_alignment(align)
|
||||
.map_err(|_| error_unreachable())?;
|
||||
unsafe { LLVMSetAlignment(alloca, align) };
|
||||
}
|
||||
if !var.array_init.is_empty() {
|
||||
todo!()
|
||||
|
@ -259,27 +388,24 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> {
|
|||
}
|
||||
|
||||
fn emit_label(&mut self, label: SpirvWord) {
|
||||
let block = self
|
||||
.context
|
||||
.append_basic_block(self.method, self.resolver.get_or_add(label));
|
||||
if self
|
||||
.builder
|
||||
.get_insert_block()
|
||||
.unwrap()
|
||||
.get_terminator()
|
||||
.is_none()
|
||||
{
|
||||
self.builder.build_unconditional_branch(block);
|
||||
let block = unsafe {
|
||||
LLVMAppendBasicBlockInContext(
|
||||
self.context,
|
||||
self.method,
|
||||
self.resolver.get_or_add_raw(label),
|
||||
)
|
||||
};
|
||||
let last_block = unsafe { LLVMGetInsertBlock(self.builder) };
|
||||
if unsafe { LLVMGetBasicBlockTerminator(last_block) } == ptr::null_mut() {
|
||||
unsafe { LLVMBuildBr(self.builder, block) };
|
||||
}
|
||||
self.builder.position_at_end(block);
|
||||
unsafe { LLVMPositionBuilderAtEnd(self.builder, block) };
|
||||
}
|
||||
|
||||
fn emit_store_var(&mut self, store: StoreVarDetails) -> Result<(), TranslateError> {
|
||||
let src1 = self.resolver.value(store.arg.src1)?;
|
||||
let src2 = self.resolver.value(store.arg.src2)?;
|
||||
self.builder
|
||||
.build_store(src1.as_pointer()?, src2.as_basic()?)
|
||||
.map_err(|_| error_unreachable())?;
|
||||
let ptr = self.resolver.value(store.arg.src1)?;
|
||||
let value = self.resolver.value(store.arg.src2)?;
|
||||
unsafe { LLVMBuildStore(self.builder, value, ptr) };
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -303,7 +429,7 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> {
|
|||
ast::Instruction::Cvt { data, arguments } => todo!(),
|
||||
ast::Instruction::Shr { data, arguments } => todo!(),
|
||||
ast::Instruction::Shl { data, arguments } => todo!(),
|
||||
ast::Instruction::Ret { data } => self.emit_ret(data),
|
||||
ast::Instruction::Ret { data } => Ok(self.emit_ret(data)),
|
||||
ast::Instruction::Cvta { data, arguments } => todo!(),
|
||||
ast::Instruction::Abs { data, arguments } => todo!(),
|
||||
ast::Instruction::Mad { data, arguments } => todo!(),
|
||||
|
@ -351,10 +477,12 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> {
|
|||
todo!()
|
||||
}
|
||||
let builder = self.builder;
|
||||
let type_ = get_type::<BasicTypeEnum>(&self.context, &data.typ)?;
|
||||
let ptr = self.resolver.value(arguments.src)?.as_pointer()?;
|
||||
self.resolver
|
||||
.with_result(arguments.dst, |dst| builder.build_load(type_, ptr, dst))
|
||||
let type_ = get_type(self.context, &data.typ)?;
|
||||
let ptr = self.resolver.value(arguments.src)?;
|
||||
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||
LLVMBuildLoad2(builder, type_, ptr, dst)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_load_variable(&mut self, var: LoadVarDetails) -> Result<(), TranslateError> {
|
||||
|
@ -362,10 +490,12 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> {
|
|||
todo!()
|
||||
}
|
||||
let builder = self.builder;
|
||||
let type_ = get_type::<BasicTypeEnum>(&self.context, &var.typ)?;
|
||||
let ptr = self.resolver.value(var.arg.src)?.as_pointer()?;
|
||||
self.resolver
|
||||
.with_result(var.arg.dst, |dst| builder.build_load(type_, ptr, dst))
|
||||
let type_ = get_type(self.context, &var.typ)?;
|
||||
let ptr = self.resolver.value(var.arg.src)?;
|
||||
self.resolver.with_result(var.arg.dst, |dst| unsafe {
|
||||
LLVMBuildLoad2(builder, type_, ptr, dst)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_conversion(&mut self, conversion: ImplicitConversion) -> Result<(), TranslateError> {
|
||||
|
@ -374,11 +504,12 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> {
|
|||
ConversionKind::Default => todo!(),
|
||||
ConversionKind::SignExtend => todo!(),
|
||||
ConversionKind::BitToPtr => {
|
||||
let src = self.resolver.value(conversion.src)?.as_int()?;
|
||||
let src = self.resolver.value(conversion.src)?;
|
||||
let type_ = get_pointer_type(self.context, conversion.to_space)?;
|
||||
self.resolver.with_result(conversion.dst, |dst| {
|
||||
builder.build_int_to_ptr(src, type_, dst)
|
||||
})
|
||||
self.resolver.with_result(conversion.dst, |dst| unsafe {
|
||||
LLVMBuildIntToPtr(builder, src, type_, dst)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
ConversionKind::PtrToPtr => todo!(),
|
||||
ConversionKind::AddressOf => todo!(),
|
||||
|
@ -386,21 +517,12 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> {
|
|||
}
|
||||
|
||||
fn emit_constant(&mut self, constant: ConstantDefinition) -> Result<(), TranslateError> {
|
||||
let type_ = get_scalar_type::<BasicTypeEnum>(&self.context, constant.typ);
|
||||
let value: AnyValueEnum = match (type_, constant.value) {
|
||||
(BasicTypeEnum::IntType(type_), ast::ImmediateValue::U64(x)) => {
|
||||
type_.const_int(x, false).into()
|
||||
}
|
||||
(BasicTypeEnum::IntType(type_), ast::ImmediateValue::S64(x)) => {
|
||||
type_.const_int(x as u64, false).into()
|
||||
}
|
||||
(BasicTypeEnum::FloatType(type_), ast::ImmediateValue::F32(x)) => {
|
||||
type_.const_float(x as f64).into()
|
||||
}
|
||||
(BasicTypeEnum::FloatType(type_), ast::ImmediateValue::F64(x)) => {
|
||||
type_.const_float(x).into()
|
||||
}
|
||||
_ => return Err(error_unreachable()),
|
||||
let type_ = get_scalar_type(self.context, constant.typ);
|
||||
let value = match constant.value {
|
||||
ast::ImmediateValue::U64(x) => unsafe { LLVMConstInt(type_, x, 0) },
|
||||
ast::ImmediateValue::S64(x) => unsafe { LLVMConstInt(type_, x as u64, 0) },
|
||||
ast::ImmediateValue::F32(x) => unsafe { LLVMConstReal(type_, x as f64) },
|
||||
ast::ImmediateValue::F64(x) => unsafe { LLVMConstReal(type_, x) },
|
||||
};
|
||||
self.resolver.register(constant.dst, value);
|
||||
Ok(())
|
||||
|
@ -412,14 +534,16 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> {
|
|||
arguments: ast::AddArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let builder = self.builder;
|
||||
let src1 = self.resolver.value(arguments.src1)?.as_int()?;
|
||||
let src2 = self.resolver.value(arguments.src2)?.as_int()?;
|
||||
let src1 = self.resolver.value(arguments.src1)?;
|
||||
let src2 = self.resolver.value(arguments.src2)?;
|
||||
let fn_ = match data {
|
||||
ast::ArithDetails::Integer(integer) => Builder::build_int_add,
|
||||
ast::ArithDetails::Float(float) => todo!(),
|
||||
ast::ArithDetails::Integer(integer) => LLVMBuildAdd,
|
||||
ast::ArithDetails::Float(float) => LLVMBuildFAdd,
|
||||
};
|
||||
self.resolver
|
||||
.with_result(arguments.dst, |dst| fn_(builder, src1, src2, dst))
|
||||
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||
fn_(builder, src1, src2, dst)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_st(
|
||||
|
@ -427,129 +551,80 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> {
|
|||
data: ptx_parser::StData,
|
||||
arguments: ptx_parser::StArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let builder = self.builder;
|
||||
let src1 = self.resolver.value(arguments.src1)?.as_pointer()?;
|
||||
let src2 = self.resolver.value(arguments.src2)?.as_basic()?;
|
||||
let ptr = self.resolver.value(arguments.src1)?;
|
||||
let value = self.resolver.value(arguments.src2)?;
|
||||
if data.qualifier != ast::LdStQualifier::Weak {
|
||||
todo!()
|
||||
}
|
||||
self.builder
|
||||
.build_store(src1, src2)
|
||||
.map_err(|_| error_unreachable())?;
|
||||
unsafe { LLVMBuildStore(self.builder, value, ptr) };
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_ret(&self, _data: ptx_parser::RetData) -> Result<(), TranslateError> {
|
||||
self.builder
|
||||
.build_return(None)
|
||||
.map_err(|_| error_unreachable())?;
|
||||
Ok(())
|
||||
fn emit_ret(&self, _data: ptx_parser::RetData) {
|
||||
unsafe { LLVMBuildRetVoid(self.builder) };
|
||||
}
|
||||
}
|
||||
|
||||
fn get_pointer_type<'ctx>(
|
||||
context: &'ctx Context,
|
||||
context: LLVMContextRef,
|
||||
to_space: ast::StateSpace,
|
||||
) -> Result<PointerType<'ctx>, TranslateError> {
|
||||
Ok(context.ptr_type(AddressSpace::from(get_state_space(to_space)?)))
|
||||
) -> Result<LLVMTypeRef, TranslateError> {
|
||||
Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(to_space)?) })
|
||||
}
|
||||
|
||||
fn get_type<
|
||||
'ctx,
|
||||
T: From<IntType<'ctx>>
|
||||
+ From<FloatType<'ctx>>
|
||||
+ From<VectorType<'ctx>>
|
||||
+ From<PointerType<'ctx>>
|
||||
+ From<ArrayType<'ctx>>,
|
||||
>(
|
||||
context: &'ctx Context,
|
||||
type_: &ast::Type,
|
||||
) -> Result<T, TranslateError> {
|
||||
fn get_type(context: LLVMContextRef, type_: &ast::Type) -> Result<LLVMTypeRef, TranslateError> {
|
||||
Ok(match type_ {
|
||||
ast::Type::Scalar(scalar) => get_scalar_type(context, *scalar),
|
||||
ast::Type::Vector(size, scalar) => {
|
||||
let base_type = get_scalar_type::<BasicTypeEnum>(context, *scalar);
|
||||
let base_type = match base_type {
|
||||
BasicTypeEnum::FloatType(t) => t.as_type_ref(),
|
||||
BasicTypeEnum::IntType(t) => t.as_type_ref(),
|
||||
_ => return Err(error_unreachable()),
|
||||
};
|
||||
T::from(unsafe { VectorType::new(LLVMVectorType(base_type, *size as u32)) })
|
||||
let base_type = get_scalar_type(context, *scalar);
|
||||
unsafe { LLVMVectorType(base_type, *size as u32) }
|
||||
}
|
||||
ast::Type::Array(vec, scalar, dimensions) => {
|
||||
let mut underlying_type = get_scalar_type::<BasicTypeEnum>(context, *scalar);
|
||||
let mut underlying_type = get_scalar_type(context, *scalar);
|
||||
if let Some(size) = vec {
|
||||
underlying_type = BasicTypeEnum::VectorType(unsafe {
|
||||
VectorType::new(LLVMVectorType(
|
||||
match underlying_type {
|
||||
BasicTypeEnum::FloatType(t) => t.as_type_ref(),
|
||||
BasicTypeEnum::IntType(t) => t.as_type_ref(),
|
||||
_ => return Err(error_unreachable()),
|
||||
},
|
||||
size.get() as u32,
|
||||
))
|
||||
});
|
||||
underlying_type = unsafe { LLVMVectorType(underlying_type, size.get() as u32) };
|
||||
}
|
||||
if dimensions.is_empty() {
|
||||
return Ok(T::from(underlying_type.array_type(0)));
|
||||
return Ok(unsafe { LLVMArrayType2(underlying_type, 0) });
|
||||
}
|
||||
let llvm_type = dimensions
|
||||
dimensions
|
||||
.iter()
|
||||
.rfold(underlying_type.as_type_ref(), |result, dimension| unsafe {
|
||||
.rfold(underlying_type, |result, dimension| unsafe {
|
||||
LLVMArrayType2(result, *dimension as u64)
|
||||
});
|
||||
T::from(unsafe { ArrayType::new(llvm_type) })
|
||||
}
|
||||
ast::Type::Pointer(_, space) => {
|
||||
T::from(context.ptr_type(AddressSpace::from(get_state_space(*space)?)))
|
||||
})
|
||||
}
|
||||
ast::Type::Pointer(_, space) => get_pointer_type(context, *space)?,
|
||||
})
|
||||
}
|
||||
|
||||
fn get_scalar_type<
|
||||
'ctx,
|
||||
T: From<IntType<'ctx>> + From<FloatType<'ctx>> + From<VectorType<'ctx>>,
|
||||
>(
|
||||
context: &'ctx Context,
|
||||
type_: ast::ScalarType,
|
||||
) -> T {
|
||||
fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeRef {
|
||||
match type_ {
|
||||
ast::ScalarType::Pred => T::from(context.bool_type()),
|
||||
ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => {
|
||||
T::from(context.i8_type())
|
||||
}
|
||||
ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => {
|
||||
T::from(context.i16_type())
|
||||
}
|
||||
ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => {
|
||||
T::from(context.i32_type())
|
||||
}
|
||||
ast::ScalarType::U64 | ast::ScalarType::S64 | ast::ScalarType::B64 => {
|
||||
T::from(context.i64_type())
|
||||
}
|
||||
ast::ScalarType::B128 => T::from(context.i128_type()),
|
||||
ast::ScalarType::F16 => T::from(context.f16_type()),
|
||||
ast::ScalarType::F32 => T::from(context.f32_type()),
|
||||
ast::ScalarType::F64 => T::from(context.f64_type()),
|
||||
ast::ScalarType::BF16 => {
|
||||
T::from(unsafe { FloatType::new(LLVMBFloatTypeInContext(context.as_ctx_ref())) })
|
||||
}
|
||||
ast::ScalarType::U16x2 | ast::ScalarType::S16x2 => {
|
||||
T::from(unsafe { VectorType::new(LLVMVectorType(context.i16_type().as_type_ref(), 2)) })
|
||||
}
|
||||
ast::ScalarType::F16x2 => {
|
||||
T::from(unsafe { VectorType::new(LLVMVectorType(context.f16_type().as_type_ref(), 2)) })
|
||||
}
|
||||
ast::ScalarType::BF16x2 => T::from(unsafe {
|
||||
VectorType::new(LLVMVectorType(
|
||||
LLVMBFloatTypeInContext(context.as_ctx_ref()),
|
||||
2,
|
||||
))
|
||||
}),
|
||||
ast::ScalarType::Pred => unsafe { LLVMInt1TypeInContext(context) },
|
||||
ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => unsafe {
|
||||
LLVMInt8TypeInContext(context)
|
||||
},
|
||||
ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => unsafe {
|
||||
LLVMInt16TypeInContext(context)
|
||||
},
|
||||
ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => unsafe {
|
||||
LLVMInt32TypeInContext(context)
|
||||
},
|
||||
ast::ScalarType::U64 | ast::ScalarType::S64 | ast::ScalarType::B64 => unsafe {
|
||||
LLVMInt64TypeInContext(context)
|
||||
},
|
||||
ast::ScalarType::B128 => unsafe { LLVMInt128TypeInContext(context) },
|
||||
ast::ScalarType::F16 => unsafe { LLVMHalfTypeInContext(context) },
|
||||
ast::ScalarType::F32 => unsafe { LLVMFloatTypeInContext(context) },
|
||||
ast::ScalarType::F64 => unsafe { LLVMDoubleTypeInContext(context) },
|
||||
ast::ScalarType::BF16 => unsafe { LLVMBFloatTypeInContext(context) },
|
||||
ast::ScalarType::U16x2 => todo!(),
|
||||
ast::ScalarType::S16x2 => todo!(),
|
||||
ast::ScalarType::F16x2 => todo!(),
|
||||
ast::ScalarType::BF16x2 => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_state_space(space: ast::StateSpace) -> Result<u16, TranslateError> {
|
||||
fn get_state_space(space: ast::StateSpace) -> Result<u32, TranslateError> {
|
||||
match space {
|
||||
ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE),
|
||||
ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE),
|
||||
|
@ -566,12 +641,12 @@ fn get_state_space(space: ast::StateSpace) -> Result<u16, TranslateError> {
|
|||
}
|
||||
}
|
||||
|
||||
struct ResolveIdent<'ctx> {
|
||||
struct ResolveIdent {
|
||||
words: HashMap<SpirvWord, String>,
|
||||
values: HashMap<SpirvWord, AnyValueEnum<'ctx>>,
|
||||
values: HashMap<SpirvWord, LLVMValueRef>,
|
||||
}
|
||||
|
||||
impl<'ctx> ResolveIdent<'ctx> {
|
||||
impl ResolveIdent {
|
||||
fn new<'input>(_id_defs: &GlobalStringIdResolver<'input>) -> Self {
|
||||
ResolveIdent {
|
||||
words: HashMap::new(),
|
||||
|
@ -580,14 +655,15 @@ impl<'ctx> ResolveIdent<'ctx> {
|
|||
}
|
||||
|
||||
fn get_or_ad_impl<'a, T>(&'a mut self, word: SpirvWord, fn_: impl FnOnce(&'a str) -> T) -> T {
|
||||
match self.words.entry(word) {
|
||||
hash_map::Entry::Occupied(entry) => fn_(entry.into_mut()),
|
||||
let str = match self.words.entry(word) {
|
||||
hash_map::Entry::Occupied(entry) => entry.into_mut(),
|
||||
hash_map::Entry::Vacant(entry) => {
|
||||
let mut text = word.0.to_string();
|
||||
text.push('\0');
|
||||
fn_(entry.insert(text))
|
||||
entry.insert(text)
|
||||
}
|
||||
}
|
||||
};
|
||||
fn_(&str[..str.len() - 1])
|
||||
}
|
||||
|
||||
fn get_or_add(&mut self, word: SpirvWord) -> &str {
|
||||
|
@ -598,153 +674,19 @@ impl<'ctx> ResolveIdent<'ctx> {
|
|||
self.get_or_add(word).as_ptr().cast()
|
||||
}
|
||||
|
||||
fn register(&mut self, word: SpirvWord, t: impl AnyValue<'ctx>) {
|
||||
self.values.insert(word, t.as_any_value_enum());
|
||||
fn register(&mut self, word: SpirvWord, v: LLVMValueRef) {
|
||||
self.values.insert(word, v);
|
||||
}
|
||||
|
||||
fn value(&self, word: SpirvWord) -> Result<AnyValueEnum<'ctx>, TranslateError> {
|
||||
fn value(&self, word: SpirvWord) -> Result<LLVMValueRef, TranslateError> {
|
||||
self.values
|
||||
.get(&word)
|
||||
.copied()
|
||||
.ok_or_else(|| error_unreachable())
|
||||
}
|
||||
|
||||
fn with_result<T: AnyValue<'ctx>>(
|
||||
&mut self,
|
||||
word: SpirvWord,
|
||||
fn_: impl FnOnce(&str) -> Result<T, BuilderError>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let t = self
|
||||
.get_or_ad_impl(word, fn_)
|
||||
.map_err(|_| error_unreachable())?;
|
||||
fn with_result(&mut self, word: SpirvWord, fn_: impl FnOnce(*const i8) -> LLVMValueRef) {
|
||||
let t = self.get_or_ad_impl(word, |dst| fn_(dst.as_ptr().cast()));
|
||||
self.register(word, t);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_int_math(
|
||||
&mut self,
|
||||
builder: &Builder<'ctx>,
|
||||
dst: SpirvWord,
|
||||
src1: SpirvWord,
|
||||
src2: SpirvWord,
|
||||
fn_: impl IntMathOp<'ctx>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let src1 = self.value(src1)?;
|
||||
let src2 = self.value(src2)?;
|
||||
self.with_result(dst, |dst| {
|
||||
Ok(match (src1, src2) {
|
||||
(AnyValueEnum::IntValue(src1), AnyValueEnum::IntValue(src2)) => {
|
||||
AnyValueEnum::from(fn_.call(builder, src1, src2, dst)?)
|
||||
}
|
||||
(AnyValueEnum::PointerValue(src1), AnyValueEnum::PointerValue(src2)) => {
|
||||
AnyValueEnum::from(fn_.call(builder, src1, src2, dst)?)
|
||||
}
|
||||
(AnyValueEnum::VectorValue(src1), AnyValueEnum::VectorValue(src2)) => {
|
||||
AnyValueEnum::from(fn_.call(builder, src1, src2, dst)?)
|
||||
}
|
||||
_ => return todo!(),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
trait IntMathOp<'ctx> {
|
||||
fn call<T: IntMathValue<'ctx>>(
|
||||
self,
|
||||
builder: &Builder<'ctx>,
|
||||
src1: T,
|
||||
src2: T,
|
||||
dst: &str,
|
||||
) -> Result<T, BuilderError>;
|
||||
}
|
||||
|
||||
trait AnyValueEnumExt<'ctx> {
|
||||
fn as_array(self) -> Result<ArrayValue<'ctx>, TranslateError>;
|
||||
fn as_int(self) -> Result<IntValue<'ctx>, TranslateError>;
|
||||
fn as_float(self) -> Result<FloatValue<'ctx>, TranslateError>;
|
||||
fn as_phi(self) -> Result<PhiValue<'ctx>, TranslateError>;
|
||||
fn as_function(self) -> Result<FunctionValue<'ctx>, TranslateError>;
|
||||
fn as_pointer(self) -> Result<PointerValue<'ctx>, TranslateError>;
|
||||
fn as_struct(self) -> Result<StructValue<'ctx>, TranslateError>;
|
||||
fn as_vector(self) -> Result<VectorValue<'ctx>, TranslateError>;
|
||||
fn as_instruction(self) -> Result<InstructionValue<'ctx>, TranslateError>;
|
||||
fn as_basic(self) -> Result<BasicValueEnum<'ctx>, TranslateError>;
|
||||
}
|
||||
|
||||
impl<'ctx> AnyValueEnumExt<'ctx> for AnyValueEnum<'ctx> {
|
||||
fn as_array(self) -> Result<ArrayValue<'ctx>, TranslateError> {
|
||||
if let AnyValueEnum::ArrayValue(x) = self {
|
||||
Ok(x)
|
||||
} else {
|
||||
Err(error_unreachable())
|
||||
}
|
||||
}
|
||||
|
||||
fn as_int(self) -> Result<IntValue<'ctx>, TranslateError> {
|
||||
if let AnyValueEnum::IntValue(x) = self {
|
||||
Ok(x)
|
||||
} else {
|
||||
Err(error_unreachable())
|
||||
}
|
||||
}
|
||||
|
||||
fn as_float(self) -> Result<FloatValue<'ctx>, TranslateError> {
|
||||
if let AnyValueEnum::FloatValue(x) = self {
|
||||
Ok(x)
|
||||
} else {
|
||||
Err(error_unreachable())
|
||||
}
|
||||
}
|
||||
|
||||
fn as_phi(self) -> Result<PhiValue<'ctx>, TranslateError> {
|
||||
if let AnyValueEnum::PhiValue(x) = self {
|
||||
Ok(x)
|
||||
} else {
|
||||
Err(error_unreachable())
|
||||
}
|
||||
}
|
||||
|
||||
fn as_function(self) -> Result<FunctionValue<'ctx>, TranslateError> {
|
||||
if let AnyValueEnum::FunctionValue(x) = self {
|
||||
Ok(x)
|
||||
} else {
|
||||
Err(error_unreachable())
|
||||
}
|
||||
}
|
||||
|
||||
fn as_pointer(self) -> Result<PointerValue<'ctx>, TranslateError> {
|
||||
if let AnyValueEnum::PointerValue(x) = self {
|
||||
Ok(x)
|
||||
} else {
|
||||
Err(error_unreachable())
|
||||
}
|
||||
}
|
||||
|
||||
fn as_struct(self) -> Result<StructValue<'ctx>, TranslateError> {
|
||||
if let AnyValueEnum::StructValue(x) = self {
|
||||
Ok(x)
|
||||
} else {
|
||||
Err(error_unreachable())
|
||||
}
|
||||
}
|
||||
|
||||
fn as_vector(self) -> Result<VectorValue<'ctx>, TranslateError> {
|
||||
if let AnyValueEnum::VectorValue(x) = self {
|
||||
Ok(x)
|
||||
} else {
|
||||
Err(error_unreachable())
|
||||
}
|
||||
}
|
||||
|
||||
fn as_instruction(self) -> Result<InstructionValue<'ctx>, TranslateError> {
|
||||
if let AnyValueEnum::InstructionValue(x) = self {
|
||||
Ok(x)
|
||||
} else {
|
||||
Err(error_unreachable())
|
||||
}
|
||||
}
|
||||
|
||||
fn as_basic(self) -> Result<BasicValueEnum<'ctx>, TranslateError> {
|
||||
BasicValueEnum::try_from(self).map_err(|_| error_unreachable())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
use llvm_zluda::inkwell::memory_buffer::MemoryBuffer;
|
||||
use ptx_parser as ast;
|
||||
use rspirv::{binary::Assemble, dr};
|
||||
use std::hash::Hash;
|
||||
|
@ -17,7 +16,7 @@ use std::{
|
|||
mod convert_dynamic_shared_memory_usage;
|
||||
mod convert_to_stateful_memory_access;
|
||||
mod convert_to_typed;
|
||||
mod emit_llvm;
|
||||
pub(crate) mod emit_llvm;
|
||||
mod emit_spirv;
|
||||
mod expand_arguments;
|
||||
mod extract_globals;
|
||||
|
@ -182,7 +181,7 @@ fn to_ssa<'input, 'b>(
|
|||
}
|
||||
|
||||
pub struct Module {
|
||||
pub llvm_ir: MemoryBuffer,
|
||||
pub llvm_ir: emit_llvm::MemoryBuffer,
|
||||
pub kernel_info: HashMap<String, KernelInfo>,
|
||||
}
|
||||
|
||||
|
@ -598,6 +597,7 @@ fn error_unreachable() -> TranslateError {
|
|||
TranslateError::Unreachable
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
fn error_unknown_symbol() -> TranslateError {
|
||||
panic!()
|
||||
}
|
||||
|
@ -607,6 +607,7 @@ fn error_unknown_symbol() -> TranslateError {
|
|||
TranslateError::UnknownSymbol
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
fn error_mismatched_type() -> TranslateError {
|
||||
panic!()
|
||||
}
|
||||
|
|
|
@ -2,7 +2,6 @@ use crate::pass;
|
|||
use crate::ptx;
|
||||
use crate::translate;
|
||||
use hip_runtime_sys::hipError_t;
|
||||
use llvm_zluda::inkwell::memory_buffer::MemoryBuffer;
|
||||
use rspirv::{
|
||||
binary::{Assemble, Disassemble},
|
||||
dr::{Block, Function, Instruction, Loader, Operand},
|
||||
|
@ -379,21 +378,21 @@ fn run_hip<Input: From<u8> + Copy + Debug, Output: From<u8> + Copy + Debug + Def
|
|||
Ok(result)
|
||||
}
|
||||
|
||||
unsafe fn compile_amd(buffer: &MemoryBuffer) -> Vec<u8> {
|
||||
unsafe fn compile_amd(buffer: &pass::emit_llvm::MemoryBuffer) -> Vec<u8> {
|
||||
use amd_comgr_sys::*;
|
||||
let mut data_set = mem::zeroed();
|
||||
amd_comgr_create_data_set(&mut data_set).unwrap();
|
||||
let mut data = mem::zeroed();
|
||||
amd_comgr_create_data(amd_comgr_data_kind_t::AMD_COMGR_DATA_KIND_BC, &mut data).unwrap();
|
||||
let buffer = buffer.as_slice();
|
||||
let buffer = &**buffer;
|
||||
amd_comgr_set_data(data, buffer.len(), buffer.as_ptr().cast()).unwrap();
|
||||
amd_comgr_set_data_name(data, "zluda.bc\0".as_ptr().cast()).unwrap();
|
||||
amd_comgr_set_data_name(data, c"zluda.bc".as_ptr()).unwrap();
|
||||
amd_comgr_data_set_add(data_set, data).unwrap();
|
||||
let mut reloc_data = mem::zeroed();
|
||||
amd_comgr_create_data_set(&mut reloc_data).unwrap();
|
||||
let mut action_info = mem::zeroed();
|
||||
amd_comgr_create_action_info(&mut action_info).unwrap();
|
||||
amd_comgr_action_info_set_isa_name(action_info, "amdgcn-amd-amdhsa--gfx1030\0".as_ptr().cast())
|
||||
amd_comgr_action_info_set_isa_name(action_info, c"amdgcn-amd-amdhsa--gfx1030".as_ptr())
|
||||
.unwrap();
|
||||
amd_comgr_do_action(
|
||||
amd_comgr_action_kind_t::AMD_COMGR_ACTION_CODEGEN_BC_TO_RELOCATABLE,
|
||||
|
|
Loading…
Add table
Reference in a new issue