diff --git a/llvm_zluda/src/lib.cpp b/llvm_zluda/src/lib.cpp index e9b3a52..cff2e31 100644 --- a/llvm_zluda/src/lib.cpp +++ b/llvm_zluda/src/lib.cpp @@ -130,4 +130,19 @@ LLVMValueRef LLVMZludaBuildAtomicRMW(LLVMBuilderRef B, LLVMZludaAtomicRMWBinOp o context.getOrInsertSyncScopeID(scope))); } +LLVMValueRef LLVMZludaBuildAtomicCmpXchg(LLVMBuilderRef B, LLVMValueRef Ptr, + LLVMValueRef Cmp, LLVMValueRef New, + char *scope, + LLVMAtomicOrdering SuccessOrdering, + LLVMAtomicOrdering FailureOrdering) +{ + auto builder = llvm::unwrap(B); + LLVMContext &context = builder->getContext(); + return wrap(builder->CreateAtomicCmpXchg( + unwrap(Ptr), unwrap(Cmp), unwrap(New), MaybeAlign(), + mapFromLLVMOrdering(SuccessOrdering), + mapFromLLVMOrdering(FailureOrdering), + context.getOrInsertSyncScopeID(scope))); +} + LLVM_C_EXTERN_C_END \ No newline at end of file diff --git a/llvm_zluda/src/lib.rs b/llvm_zluda/src/lib.rs index b995cdb..3d941fa 100644 --- a/llvm_zluda/src/lib.rs +++ b/llvm_zluda/src/lib.rs @@ -39,4 +39,14 @@ extern "C" { scope: *const i8, ordering: LLVMAtomicOrdering, ) -> LLVMValueRef; + + pub fn LLVMZludaBuildAtomicCmpXchg( + B: LLVMBuilderRef, + Ptr: LLVMValueRef, + Cmp: LLVMValueRef, + New: LLVMValueRef, + scope: *const i8, + SuccessOrdering: LLVMAtomicOrdering, + FailureOrdering: LLVMAtomicOrdering, + ) -> LLVMValueRef; } diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index bc5f745..36a9623 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -26,7 +26,10 @@ use std::ptr; use super::*; use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule}; use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer; -use llvm_zluda::{core::*, LLVMAtomicOrdering, LLVMAtomicRMWBinOp, LLVMZludaAtomicRMWBinOp}; +use llvm_zluda::{ + core::*, LLVMAtomicOrdering, LLVMAtomicRMWBinOp, LLVMZludaAtomicRMWBinOp, + LLVMZludaBuildAtomicCmpXchg, +}; use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW}; use llvm_zluda::{LLVMCallConv, LLVMZludaBuildAlloca}; @@ -457,7 +460,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { ast::Instruction::Selp { data, arguments } => todo!(), ast::Instruction::Bar { data, arguments } => todo!(), ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments), - ast::Instruction::AtomCas { data, arguments } => todo!(), + ast::Instruction::AtomCas { data, arguments } => self.emit_atom_cas(data, arguments), ast::Instruction::Div { data, arguments } => todo!(), ast::Instruction::Neg { data, arguments } => todo!(), ast::Instruction::Sin { data, arguments } => todo!(), @@ -724,6 +727,33 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { }); Ok(()) } + + fn emit_atom_cas( + &mut self, + data: ptx_parser::AtomCasDetails, + arguments: ptx_parser::AtomCasArgs, + ) -> Result<(), TranslateError> { + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + let src3 = self.resolver.value(arguments.src3)?; + let success_ordering = get_ordering(data.semantics); + let failure_ordering = get_ordering_failure(data.semantics); + let temp = unsafe { + LLVMZludaBuildAtomicCmpXchg( + self.builder, + src1, + src2, + src3, + get_scope(data.scope)?, + success_ordering, + failure_ordering, + ) + }; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildExtractValue(self.builder, temp, 0, dst) + }); + Ok(()) + } } fn get_pointer_type<'ctx>( @@ -753,6 +783,15 @@ fn get_ordering(semantics: ast::AtomSemantics) -> LLVMAtomicOrdering { } } +fn get_ordering_failure(semantics: ast::AtomSemantics) -> LLVMAtomicOrdering { + match semantics { + ast::AtomSemantics::Relaxed => LLVMAtomicOrdering::LLVMAtomicOrderingMonotonic, + ast::AtomSemantics::Acquire => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire, + ast::AtomSemantics::Release => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire, + ast::AtomSemantics::AcqRel => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire, + } +} + fn get_type(context: LLVMContextRef, type_: &ast::Type) -> Result { Ok(match type_ { ast::Type::Scalar(scalar) => get_scalar_type(context, *scalar), diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index ff7e2ad..ead747a 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -86,8 +86,7 @@ pub fn to_llvm_module2<'input>(ast: ast::Module<'input>) -> Result, SpirvWord>> = - expand_operands::run(&mut flat_resolver, directives)?; + let directives = expand_operands::run(&mut flat_resolver, directives)?; let directives = deparamize_functions::run(&mut flat_resolver, directives)?; let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?; let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;