diff --git a/llvm_zluda/src/lib.cpp b/llvm_zluda/src/lib.cpp index 3da88fb..e9b3a52 100644 --- a/llvm_zluda/src/lib.cpp +++ b/llvm_zluda/src/lib.cpp @@ -1,6 +1,112 @@ #include #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Type.h" +#include "llvm/IR/Instructions.h" + +using namespace llvm; + +typedef enum +{ + LLVMZludaAtomicRMWBinOpXchg, /**< Set the new value and return the one old */ + LLVMZludaAtomicRMWBinOpAdd, /**< Add a value and return the old one */ + LLVMZludaAtomicRMWBinOpSub, /**< Subtract a value and return the old one */ + LLVMZludaAtomicRMWBinOpAnd, /**< And a value and return the old one */ + LLVMZludaAtomicRMWBinOpNand, /**< Not-And a value and return the old one */ + LLVMZludaAtomicRMWBinOpOr, /**< OR a value and return the old one */ + LLVMZludaAtomicRMWBinOpXor, /**< Xor a value and return the old one */ + LLVMZludaAtomicRMWBinOpMax, /**< Sets the value if it's greater than the + original using a signed comparison and return + the old one */ + LLVMZludaAtomicRMWBinOpMin, /**< Sets the value if it's Smaller than the + original using a signed comparison and return + the old one */ + LLVMZludaAtomicRMWBinOpUMax, /**< Sets the value if it's greater than the + original using an unsigned comparison and return + the old one */ + LLVMZludaAtomicRMWBinOpUMin, /**< Sets the value if it's greater than the + original using an unsigned comparison and return + the old one */ + LLVMZludaAtomicRMWBinOpFAdd, /**< Add a floating point value and return the + old one */ + LLVMZludaAtomicRMWBinOpFSub, /**< Subtract a floating point value and return the + old one */ + LLVMZludaAtomicRMWBinOpFMax, /**< Sets the value if it's greater than the + original using an floating point comparison and + return the old one */ + LLVMZludaAtomicRMWBinOpFMin, /**< Sets the value if it's smaller than the + original using an floating point comparison and + return the old one */ + LLVMZludaAtomicRMWBinOpUIncWrap, /**< Increments the value, wrapping back to zero + when incremented above input value */ + LLVMZludaAtomicRMWBinOpUDecWrap, /**< Decrements the value, wrapping back to + the input value when decremented below zero */ +} LLVMZludaAtomicRMWBinOp; + +static llvm::AtomicRMWInst::BinOp mapFromLLVMRMWBinOp(LLVMZludaAtomicRMWBinOp BinOp) +{ + switch (BinOp) + { + case LLVMZludaAtomicRMWBinOpXchg: + return llvm::AtomicRMWInst::Xchg; + case LLVMZludaAtomicRMWBinOpAdd: + return llvm::AtomicRMWInst::Add; + case LLVMZludaAtomicRMWBinOpSub: + return llvm::AtomicRMWInst::Sub; + case LLVMZludaAtomicRMWBinOpAnd: + return llvm::AtomicRMWInst::And; + case LLVMZludaAtomicRMWBinOpNand: + return llvm::AtomicRMWInst::Nand; + case LLVMZludaAtomicRMWBinOpOr: + return llvm::AtomicRMWInst::Or; + case LLVMZludaAtomicRMWBinOpXor: + return llvm::AtomicRMWInst::Xor; + case LLVMZludaAtomicRMWBinOpMax: + return llvm::AtomicRMWInst::Max; + case LLVMZludaAtomicRMWBinOpMin: + return llvm::AtomicRMWInst::Min; + case LLVMZludaAtomicRMWBinOpUMax: + return llvm::AtomicRMWInst::UMax; + case LLVMZludaAtomicRMWBinOpUMin: + return llvm::AtomicRMWInst::UMin; + case LLVMZludaAtomicRMWBinOpFAdd: + return llvm::AtomicRMWInst::FAdd; + case LLVMZludaAtomicRMWBinOpFSub: + return llvm::AtomicRMWInst::FSub; + case LLVMZludaAtomicRMWBinOpFMax: + return llvm::AtomicRMWInst::FMax; + case LLVMZludaAtomicRMWBinOpFMin: + return llvm::AtomicRMWInst::FMin; + case LLVMZludaAtomicRMWBinOpUIncWrap: + return llvm::AtomicRMWInst::UIncWrap; + case LLVMZludaAtomicRMWBinOpUDecWrap: + return llvm::AtomicRMWInst::UDecWrap; + } + + llvm_unreachable("Invalid LLVMZludaAtomicRMWBinOp value!"); +} + +static AtomicOrdering mapFromLLVMOrdering(LLVMAtomicOrdering Ordering) +{ + switch (Ordering) + { + case LLVMAtomicOrderingNotAtomic: + return AtomicOrdering::NotAtomic; + case LLVMAtomicOrderingUnordered: + return AtomicOrdering::Unordered; + case LLVMAtomicOrderingMonotonic: + return AtomicOrdering::Monotonic; + case LLVMAtomicOrderingAcquire: + return AtomicOrdering::Acquire; + case LLVMAtomicOrderingRelease: + return AtomicOrdering::Release; + case LLVMAtomicOrderingAcquireRelease: + return AtomicOrdering::AcquireRelease; + case LLVMAtomicOrderingSequentiallyConsistent: + return AtomicOrdering::SequentiallyConsistent; + } + + llvm_unreachable("Invalid LLVMAtomicOrdering value!"); +} LLVM_C_EXTERN_C_BEGIN @@ -10,4 +116,18 @@ LLVMValueRef LLVMZludaBuildAlloca(LLVMBuilderRef B, LLVMTypeRef Ty, unsigned Add return llvm::wrap(llvm::unwrap(B)->CreateAlloca(llvm::unwrap(Ty), AddrSpace, nullptr, Name)); } +LLVMValueRef LLVMZludaBuildAtomicRMW(LLVMBuilderRef B, LLVMZludaAtomicRMWBinOp op, + LLVMValueRef PTR, LLVMValueRef Val, + char *scope, + LLVMAtomicOrdering ordering) +{ + auto builder = llvm::unwrap(B); + LLVMContext &context = builder->getContext(); + llvm::AtomicRMWInst::BinOp intop = mapFromLLVMRMWBinOp(op); + return llvm::wrap(builder->CreateAtomicRMW( + intop, llvm::unwrap(PTR), llvm::unwrap(Val), llvm::MaybeAlign(), + mapFromLLVMOrdering(ordering), + context.getOrInsertSyncScopeID(scope))); +} + LLVM_C_EXTERN_C_END \ No newline at end of file diff --git a/llvm_zluda/src/lib.rs b/llvm_zluda/src/lib.rs index 18072a8..b995cdb 100644 --- a/llvm_zluda/src/lib.rs +++ b/llvm_zluda/src/lib.rs @@ -1,5 +1,28 @@ use llvm_sys::prelude::*; pub use llvm_sys::*; + +#[repr(C)] +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum LLVMZludaAtomicRMWBinOp { + LLVMZludaAtomicRMWBinOpXchg = 0, + LLVMZludaAtomicRMWBinOpAdd = 1, + LLVMZludaAtomicRMWBinOpSub = 2, + LLVMZludaAtomicRMWBinOpAnd = 3, + LLVMZludaAtomicRMWBinOpNand = 4, + LLVMZludaAtomicRMWBinOpOr = 5, + LLVMZludaAtomicRMWBinOpXor = 6, + LLVMZludaAtomicRMWBinOpMax = 7, + LLVMZludaAtomicRMWBinOpMin = 8, + LLVMZludaAtomicRMWBinOpUMax = 9, + LLVMZludaAtomicRMWBinOpUMin = 10, + LLVMZludaAtomicRMWBinOpFAdd = 11, + LLVMZludaAtomicRMWBinOpFSub = 12, + LLVMZludaAtomicRMWBinOpFMax = 13, + LLVMZludaAtomicRMWBinOpFMin = 14, + LLVMZludaAtomicRMWBinOpUIncWrap = 15, + LLVMZludaAtomicRMWBinOpUDecWrap = 16, +} + extern "C" { pub fn LLVMZludaBuildAlloca( B: LLVMBuilderRef, @@ -7,4 +30,13 @@ extern "C" { AddrSpace: u32, Name: *const i8, ) -> LLVMValueRef; + + pub fn LLVMZludaBuildAtomicRMW( + B: LLVMBuilderRef, + op: LLVMZludaAtomicRMWBinOp, + PTR: LLVMValueRef, + Val: LLVMValueRef, + scope: *const i8, + ordering: LLVMAtomicOrdering, + ) -> LLVMValueRef; } diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 7f74d1a..bc5f745 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -19,15 +19,15 @@ // unsafe { LLVMBuildAdd(builder, lhs, rhs, dst) }; use std::convert::{TryFrom, TryInto}; -use std::ffi::CStr; +use std::ffi::{CStr, NulError}; use std::ops::Deref; use std::ptr; use super::*; use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule}; use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer; -use llvm_zluda::core::*; -use llvm_zluda::prelude::*; +use llvm_zluda::{core::*, LLVMAtomicOrdering, LLVMAtomicRMWBinOp, LLVMZludaAtomicRMWBinOp}; +use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW}; use llvm_zluda::{LLVMCallConv, LLVMZludaBuildAlloca}; const LLVM_UNNAMED: &CStr = c""; @@ -172,7 +172,7 @@ pub(super) fn run<'input>( let mut emit_ctx = ModuleEmitContext::new(&context, &module, &id_defs); for directive in directives { match directive { - Directive2::Variable(..) => todo!(), + Directive2::Variable(linking, variable) => emit_ctx.emit_global(linking, variable)?, Directive2::Method(method) => emit_ctx.emit_method(method)?, } } @@ -281,6 +281,43 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { } Ok(()) } + + fn emit_global( + &mut self, + linking: ast::LinkingDirective, + var: ptx_parser::Variable, + ) -> Result<(), TranslateError> { + let name = self + .id_defs + .ident_map + .get(&var.name) + .map(|entry| { + entry + .name + .as_ref() + .map(|text| Ok::<_, NulError>(Cow::Owned(CString::new(&**text)?))) + }) + .flatten() + .transpose() + .map_err(|_| error_unreachable())? + .unwrap_or(Cow::Borrowed(LLVM_UNNAMED)); + let global = unsafe { + LLVMAddGlobalInAddressSpace( + self.module, + get_type(self.context, &var.v_type)?, + name.as_ptr(), + get_state_space(var.state_space)?, + ) + }; + self.resolver.register(var.name, global); + if let Some(align) = var.align { + unsafe { LLVMSetAlignment(global, align) }; + } + if !var.array_init.is_empty() { + todo!() + } + Ok(()) + } } fn get_input_argument_type( @@ -419,7 +456,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { ast::Instruction::Rsqrt { data, arguments } => todo!(), ast::Instruction::Selp { data, arguments } => todo!(), ast::Instruction::Bar { data, arguments } => todo!(), - ast::Instruction::Atom { data, arguments } => todo!(), + ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments), ast::Instruction::AtomCas { data, arguments } => todo!(), ast::Instruction::Div { data, arguments } => todo!(), ast::Instruction::Neg { data, arguments } => todo!(), @@ -499,7 +536,14 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { }); Ok(()) } - ConversionKind::PtrToPtr => todo!(), + ConversionKind::PtrToPtr => { + let src = self.resolver.value(conversion.src)?; + let dst_type = get_pointer_type(self.context, conversion.to_space)?; + self.resolver.with_result(conversion.dst, |dst| unsafe { + LLVMBuildAddrSpaceCast(builder, src, dst_type, dst) + }); + Ok(()) + } ConversionKind::AddressOf => todo!(), } } @@ -635,6 +679,51 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { }); Ok(()) } + + fn emit_atom( + &mut self, + data: ptx_parser::AtomDetails, + arguments: ptx_parser::AtomArgs, + ) -> Result<(), TranslateError> { + let builder = self.builder; + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + let op = match data.op { + ptx_parser::AtomicOp::And => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAnd, + ptx_parser::AtomicOp::Or => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpOr, + ptx_parser::AtomicOp::Xor => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXor, + ptx_parser::AtomicOp::Exchange => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXchg, + ptx_parser::AtomicOp::Add => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAdd, + ptx_parser::AtomicOp::IncrementWrap => { + LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUIncWrap + } + ptx_parser::AtomicOp::DecrementWrap => { + LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUDecWrap + } + ptx_parser::AtomicOp::SignedMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMin, + ptx_parser::AtomicOp::UnsignedMin => { + LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMin + } + ptx_parser::AtomicOp::SignedMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMax, + ptx_parser::AtomicOp::UnsignedMax => { + LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMax + } + ptx_parser::AtomicOp::FloatAdd => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFAdd, + ptx_parser::AtomicOp::FloatMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMin, + ptx_parser::AtomicOp::FloatMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMax, + }; + self.resolver.register(arguments.dst, unsafe { + LLVMZludaBuildAtomicRMW( + builder, + op, + src1, + src2, + get_scope(data.scope)?, + get_ordering(data.semantics), + ) + }); + Ok(()) + } } fn get_pointer_type<'ctx>( @@ -644,6 +733,26 @@ fn get_pointer_type<'ctx>( Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(to_space)?) }) } +// https://llvm.org/docs/AMDGPUUsage.html#memory-scopes +fn get_scope(scope: ast::MemScope) -> Result<*const i8, TranslateError> { + Ok(match scope { + ast::MemScope::Cta => c"workgroup-one-as", + ast::MemScope::Gpu => c"agent-one-as", + ast::MemScope::Sys => c"one-as", + ast::MemScope::Cluster => todo!(), + } + .as_ptr()) +} + +fn get_ordering(semantics: ast::AtomSemantics) -> LLVMAtomicOrdering { + match semantics { + ast::AtomSemantics::Relaxed => LLVMAtomicOrdering::LLVMAtomicOrderingMonotonic, + ast::AtomSemantics::Acquire => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire, + ast::AtomSemantics::Release => LLVMAtomicOrdering::LLVMAtomicOrderingRelease, + ast::AtomSemantics::AcqRel => LLVMAtomicOrdering::LLVMAtomicOrderingAcquireRelease, + } +} + fn get_type(context: LLVMContextRef, type_: &ast::Type) -> Result { Ok(match type_ { ast::Type::Scalar(scalar) => get_scalar_type(context, *scalar), diff --git a/ptx/src/pass/insert_explicit_load_store.rs b/ptx/src/pass/insert_explicit_load_store.rs index ec6498c..42988ea 100644 --- a/ptx/src/pass/insert_explicit_load_store.rs +++ b/ptx/src/pass/insert_explicit_load_store.rs @@ -52,7 +52,7 @@ fn run_method<'a, 'input>( let new_name = visitor .resolver .register_unnamed(Some((arg.v_type.clone(), new_space))); - visitor.input_argument(old_name, new_name, old_space); + visitor.input_argument(old_name, new_name, old_space)?; arg.name = new_name; arg.state_space = new_space; } @@ -154,7 +154,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { old_name: SpirvWord, new_name: SpirvWord, old_space: ast::StateSpace, - ) -> Result<(), TranslateError> { + ) -> Result { Ok(match old_space { ast::StateSpace::Reg => { self.variables.insert( @@ -164,6 +164,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { type_: type_.clone(), }, ); + true } ast::StateSpace::Param => { self.variables.insert( @@ -174,19 +175,18 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { name: new_name, }, ); + true } // Good as-is - ast::StateSpace::Local => {} - // Will be pulled into global scope later - ast::StateSpace::Generic + ast::StateSpace::Local + | ast::StateSpace::Generic | ast::StateSpace::SharedCluster | ast::StateSpace::Global | ast::StateSpace::Const | ast::StateSpace::SharedCta - | ast::StateSpace::Shared => {} - ast::StateSpace::ParamEntry | ast::StateSpace::ParamFunc => { - return Err(error_unreachable()) - } + | ast::StateSpace::Shared + | ast::StateSpace::ParamEntry + | ast::StateSpace::ParamFunc => return Err(error_unreachable()), }) } @@ -239,17 +239,28 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { } fn visit_variable(&mut self, var: &mut ast::Variable) -> Result<(), TranslateError> { - if var.state_space != ast::StateSpace::Local { - let old_name = var.name; - let old_space = var.state_space; - let new_space = ast::StateSpace::Local; - let new_name = self - .resolver - .register_unnamed(Some((var.v_type.clone(), new_space))); - self.variable(&var.v_type, old_name, new_name, old_space)?; - var.name = new_name; - var.state_space = new_space; - } + let old_space = match var.state_space { + space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space, + // Do nothing + ptx_parser::StateSpace::Local => return Ok(()), + // Handled by another pass + ptx_parser::StateSpace::Generic + | ptx_parser::StateSpace::SharedCluster + | ptx_parser::StateSpace::ParamEntry + | ptx_parser::StateSpace::Global + | ptx_parser::StateSpace::SharedCta + | ptx_parser::StateSpace::Const + | ptx_parser::StateSpace::Shared + | ptx_parser::StateSpace::ParamFunc => return Ok(()), + }; + let old_name = var.name; + let new_space = ast::StateSpace::Local; + let new_name = self + .resolver + .register_unnamed(Some((var.v_type.clone(), new_space))); + self.variable(&var.v_type, old_name, new_name, old_space)?; + var.name = new_name; + var.state_space = new_space; Ok(()) } }