Implement atomics

This commit is contained in:
Andrzej Janik 2024-09-26 18:54:15 +02:00
parent c4e1315194
commit 820eaf8ada
4 changed files with 298 additions and 26 deletions

View file

@ -1,6 +1,112 @@
#include <llvm-c/Core.h>
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Instructions.h"
using namespace llvm;
typedef enum
{
LLVMZludaAtomicRMWBinOpXchg, /**< Set the new value and return the one old */
LLVMZludaAtomicRMWBinOpAdd, /**< Add a value and return the old one */
LLVMZludaAtomicRMWBinOpSub, /**< Subtract a value and return the old one */
LLVMZludaAtomicRMWBinOpAnd, /**< And a value and return the old one */
LLVMZludaAtomicRMWBinOpNand, /**< Not-And a value and return the old one */
LLVMZludaAtomicRMWBinOpOr, /**< OR a value and return the old one */
LLVMZludaAtomicRMWBinOpXor, /**< Xor a value and return the old one */
LLVMZludaAtomicRMWBinOpMax, /**< Sets the value if it's greater than the
original using a signed comparison and return
the old one */
LLVMZludaAtomicRMWBinOpMin, /**< Sets the value if it's Smaller than the
original using a signed comparison and return
the old one */
LLVMZludaAtomicRMWBinOpUMax, /**< Sets the value if it's greater than the
original using an unsigned comparison and return
the old one */
LLVMZludaAtomicRMWBinOpUMin, /**< Sets the value if it's greater than the
original using an unsigned comparison and return
the old one */
LLVMZludaAtomicRMWBinOpFAdd, /**< Add a floating point value and return the
old one */
LLVMZludaAtomicRMWBinOpFSub, /**< Subtract a floating point value and return the
old one */
LLVMZludaAtomicRMWBinOpFMax, /**< Sets the value if it's greater than the
original using an floating point comparison and
return the old one */
LLVMZludaAtomicRMWBinOpFMin, /**< Sets the value if it's smaller than the
original using an floating point comparison and
return the old one */
LLVMZludaAtomicRMWBinOpUIncWrap, /**< Increments the value, wrapping back to zero
when incremented above input value */
LLVMZludaAtomicRMWBinOpUDecWrap, /**< Decrements the value, wrapping back to
the input value when decremented below zero */
} LLVMZludaAtomicRMWBinOp;
static llvm::AtomicRMWInst::BinOp mapFromLLVMRMWBinOp(LLVMZludaAtomicRMWBinOp BinOp)
{
switch (BinOp)
{
case LLVMZludaAtomicRMWBinOpXchg:
return llvm::AtomicRMWInst::Xchg;
case LLVMZludaAtomicRMWBinOpAdd:
return llvm::AtomicRMWInst::Add;
case LLVMZludaAtomicRMWBinOpSub:
return llvm::AtomicRMWInst::Sub;
case LLVMZludaAtomicRMWBinOpAnd:
return llvm::AtomicRMWInst::And;
case LLVMZludaAtomicRMWBinOpNand:
return llvm::AtomicRMWInst::Nand;
case LLVMZludaAtomicRMWBinOpOr:
return llvm::AtomicRMWInst::Or;
case LLVMZludaAtomicRMWBinOpXor:
return llvm::AtomicRMWInst::Xor;
case LLVMZludaAtomicRMWBinOpMax:
return llvm::AtomicRMWInst::Max;
case LLVMZludaAtomicRMWBinOpMin:
return llvm::AtomicRMWInst::Min;
case LLVMZludaAtomicRMWBinOpUMax:
return llvm::AtomicRMWInst::UMax;
case LLVMZludaAtomicRMWBinOpUMin:
return llvm::AtomicRMWInst::UMin;
case LLVMZludaAtomicRMWBinOpFAdd:
return llvm::AtomicRMWInst::FAdd;
case LLVMZludaAtomicRMWBinOpFSub:
return llvm::AtomicRMWInst::FSub;
case LLVMZludaAtomicRMWBinOpFMax:
return llvm::AtomicRMWInst::FMax;
case LLVMZludaAtomicRMWBinOpFMin:
return llvm::AtomicRMWInst::FMin;
case LLVMZludaAtomicRMWBinOpUIncWrap:
return llvm::AtomicRMWInst::UIncWrap;
case LLVMZludaAtomicRMWBinOpUDecWrap:
return llvm::AtomicRMWInst::UDecWrap;
}
llvm_unreachable("Invalid LLVMZludaAtomicRMWBinOp value!");
}
static AtomicOrdering mapFromLLVMOrdering(LLVMAtomicOrdering Ordering)
{
switch (Ordering)
{
case LLVMAtomicOrderingNotAtomic:
return AtomicOrdering::NotAtomic;
case LLVMAtomicOrderingUnordered:
return AtomicOrdering::Unordered;
case LLVMAtomicOrderingMonotonic:
return AtomicOrdering::Monotonic;
case LLVMAtomicOrderingAcquire:
return AtomicOrdering::Acquire;
case LLVMAtomicOrderingRelease:
return AtomicOrdering::Release;
case LLVMAtomicOrderingAcquireRelease:
return AtomicOrdering::AcquireRelease;
case LLVMAtomicOrderingSequentiallyConsistent:
return AtomicOrdering::SequentiallyConsistent;
}
llvm_unreachable("Invalid LLVMAtomicOrdering value!");
}
LLVM_C_EXTERN_C_BEGIN
@ -10,4 +116,18 @@ LLVMValueRef LLVMZludaBuildAlloca(LLVMBuilderRef B, LLVMTypeRef Ty, unsigned Add
return llvm::wrap(llvm::unwrap(B)->CreateAlloca(llvm::unwrap(Ty), AddrSpace, nullptr, Name));
}
LLVMValueRef LLVMZludaBuildAtomicRMW(LLVMBuilderRef B, LLVMZludaAtomicRMWBinOp op,
LLVMValueRef PTR, LLVMValueRef Val,
char *scope,
LLVMAtomicOrdering ordering)
{
auto builder = llvm::unwrap(B);
LLVMContext &context = builder->getContext();
llvm::AtomicRMWInst::BinOp intop = mapFromLLVMRMWBinOp(op);
return llvm::wrap(builder->CreateAtomicRMW(
intop, llvm::unwrap(PTR), llvm::unwrap(Val), llvm::MaybeAlign(),
mapFromLLVMOrdering(ordering),
context.getOrInsertSyncScopeID(scope)));
}
LLVM_C_EXTERN_C_END

View file

@ -1,5 +1,28 @@
use llvm_sys::prelude::*;
pub use llvm_sys::*;
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum LLVMZludaAtomicRMWBinOp {
LLVMZludaAtomicRMWBinOpXchg = 0,
LLVMZludaAtomicRMWBinOpAdd = 1,
LLVMZludaAtomicRMWBinOpSub = 2,
LLVMZludaAtomicRMWBinOpAnd = 3,
LLVMZludaAtomicRMWBinOpNand = 4,
LLVMZludaAtomicRMWBinOpOr = 5,
LLVMZludaAtomicRMWBinOpXor = 6,
LLVMZludaAtomicRMWBinOpMax = 7,
LLVMZludaAtomicRMWBinOpMin = 8,
LLVMZludaAtomicRMWBinOpUMax = 9,
LLVMZludaAtomicRMWBinOpUMin = 10,
LLVMZludaAtomicRMWBinOpFAdd = 11,
LLVMZludaAtomicRMWBinOpFSub = 12,
LLVMZludaAtomicRMWBinOpFMax = 13,
LLVMZludaAtomicRMWBinOpFMin = 14,
LLVMZludaAtomicRMWBinOpUIncWrap = 15,
LLVMZludaAtomicRMWBinOpUDecWrap = 16,
}
extern "C" {
pub fn LLVMZludaBuildAlloca(
B: LLVMBuilderRef,
@ -7,4 +30,13 @@ extern "C" {
AddrSpace: u32,
Name: *const i8,
) -> LLVMValueRef;
pub fn LLVMZludaBuildAtomicRMW(
B: LLVMBuilderRef,
op: LLVMZludaAtomicRMWBinOp,
PTR: LLVMValueRef,
Val: LLVMValueRef,
scope: *const i8,
ordering: LLVMAtomicOrdering,
) -> LLVMValueRef;
}

View file

@ -19,15 +19,15 @@
// unsafe { LLVMBuildAdd(builder, lhs, rhs, dst) };
use std::convert::{TryFrom, TryInto};
use std::ffi::CStr;
use std::ffi::{CStr, NulError};
use std::ops::Deref;
use std::ptr;
use super::*;
use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule};
use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer;
use llvm_zluda::core::*;
use llvm_zluda::prelude::*;
use llvm_zluda::{core::*, LLVMAtomicOrdering, LLVMAtomicRMWBinOp, LLVMZludaAtomicRMWBinOp};
use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW};
use llvm_zluda::{LLVMCallConv, LLVMZludaBuildAlloca};
const LLVM_UNNAMED: &CStr = c"";
@ -172,7 +172,7 @@ pub(super) fn run<'input>(
let mut emit_ctx = ModuleEmitContext::new(&context, &module, &id_defs);
for directive in directives {
match directive {
Directive2::Variable(..) => todo!(),
Directive2::Variable(linking, variable) => emit_ctx.emit_global(linking, variable)?,
Directive2::Method(method) => emit_ctx.emit_method(method)?,
}
}
@ -281,6 +281,43 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
}
Ok(())
}
fn emit_global(
&mut self,
linking: ast::LinkingDirective,
var: ptx_parser::Variable<SpirvWord>,
) -> Result<(), TranslateError> {
let name = self
.id_defs
.ident_map
.get(&var.name)
.map(|entry| {
entry
.name
.as_ref()
.map(|text| Ok::<_, NulError>(Cow::Owned(CString::new(&**text)?)))
})
.flatten()
.transpose()
.map_err(|_| error_unreachable())?
.unwrap_or(Cow::Borrowed(LLVM_UNNAMED));
let global = unsafe {
LLVMAddGlobalInAddressSpace(
self.module,
get_type(self.context, &var.v_type)?,
name.as_ptr(),
get_state_space(var.state_space)?,
)
};
self.resolver.register(var.name, global);
if let Some(align) = var.align {
unsafe { LLVMSetAlignment(global, align) };
}
if !var.array_init.is_empty() {
todo!()
}
Ok(())
}
}
fn get_input_argument_type(
@ -419,7 +456,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
ast::Instruction::Rsqrt { data, arguments } => todo!(),
ast::Instruction::Selp { data, arguments } => todo!(),
ast::Instruction::Bar { data, arguments } => todo!(),
ast::Instruction::Atom { data, arguments } => todo!(),
ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments),
ast::Instruction::AtomCas { data, arguments } => todo!(),
ast::Instruction::Div { data, arguments } => todo!(),
ast::Instruction::Neg { data, arguments } => todo!(),
@ -499,7 +536,14 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
});
Ok(())
}
ConversionKind::PtrToPtr => todo!(),
ConversionKind::PtrToPtr => {
let src = self.resolver.value(conversion.src)?;
let dst_type = get_pointer_type(self.context, conversion.to_space)?;
self.resolver.with_result(conversion.dst, |dst| unsafe {
LLVMBuildAddrSpaceCast(builder, src, dst_type, dst)
});
Ok(())
}
ConversionKind::AddressOf => todo!(),
}
}
@ -635,6 +679,51 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
});
Ok(())
}
fn emit_atom(
&mut self,
data: ptx_parser::AtomDetails,
arguments: ptx_parser::AtomArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let builder = self.builder;
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
let op = match data.op {
ptx_parser::AtomicOp::And => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAnd,
ptx_parser::AtomicOp::Or => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpOr,
ptx_parser::AtomicOp::Xor => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXor,
ptx_parser::AtomicOp::Exchange => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXchg,
ptx_parser::AtomicOp::Add => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAdd,
ptx_parser::AtomicOp::IncrementWrap => {
LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUIncWrap
}
ptx_parser::AtomicOp::DecrementWrap => {
LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUDecWrap
}
ptx_parser::AtomicOp::SignedMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMin,
ptx_parser::AtomicOp::UnsignedMin => {
LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMin
}
ptx_parser::AtomicOp::SignedMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMax,
ptx_parser::AtomicOp::UnsignedMax => {
LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMax
}
ptx_parser::AtomicOp::FloatAdd => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFAdd,
ptx_parser::AtomicOp::FloatMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMin,
ptx_parser::AtomicOp::FloatMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMax,
};
self.resolver.register(arguments.dst, unsafe {
LLVMZludaBuildAtomicRMW(
builder,
op,
src1,
src2,
get_scope(data.scope)?,
get_ordering(data.semantics),
)
});
Ok(())
}
}
fn get_pointer_type<'ctx>(
@ -644,6 +733,26 @@ fn get_pointer_type<'ctx>(
Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(to_space)?) })
}
// https://llvm.org/docs/AMDGPUUsage.html#memory-scopes
fn get_scope(scope: ast::MemScope) -> Result<*const i8, TranslateError> {
Ok(match scope {
ast::MemScope::Cta => c"workgroup-one-as",
ast::MemScope::Gpu => c"agent-one-as",
ast::MemScope::Sys => c"one-as",
ast::MemScope::Cluster => todo!(),
}
.as_ptr())
}
fn get_ordering(semantics: ast::AtomSemantics) -> LLVMAtomicOrdering {
match semantics {
ast::AtomSemantics::Relaxed => LLVMAtomicOrdering::LLVMAtomicOrderingMonotonic,
ast::AtomSemantics::Acquire => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire,
ast::AtomSemantics::Release => LLVMAtomicOrdering::LLVMAtomicOrderingRelease,
ast::AtomSemantics::AcqRel => LLVMAtomicOrdering::LLVMAtomicOrderingAcquireRelease,
}
}
fn get_type(context: LLVMContextRef, type_: &ast::Type) -> Result<LLVMTypeRef, TranslateError> {
Ok(match type_ {
ast::Type::Scalar(scalar) => get_scalar_type(context, *scalar),

View file

@ -52,7 +52,7 @@ fn run_method<'a, 'input>(
let new_name = visitor
.resolver
.register_unnamed(Some((arg.v_type.clone(), new_space)));
visitor.input_argument(old_name, new_name, old_space);
visitor.input_argument(old_name, new_name, old_space)?;
arg.name = new_name;
arg.state_space = new_space;
}
@ -154,7 +154,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
old_name: SpirvWord,
new_name: SpirvWord,
old_space: ast::StateSpace,
) -> Result<(), TranslateError> {
) -> Result<bool, TranslateError> {
Ok(match old_space {
ast::StateSpace::Reg => {
self.variables.insert(
@ -164,6 +164,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
type_: type_.clone(),
},
);
true
}
ast::StateSpace::Param => {
self.variables.insert(
@ -174,19 +175,18 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
name: new_name,
},
);
true
}
// Good as-is
ast::StateSpace::Local => {}
// Will be pulled into global scope later
ast::StateSpace::Generic
ast::StateSpace::Local
| ast::StateSpace::Generic
| ast::StateSpace::SharedCluster
| ast::StateSpace::Global
| ast::StateSpace::Const
| ast::StateSpace::SharedCta
| ast::StateSpace::Shared => {}
ast::StateSpace::ParamEntry | ast::StateSpace::ParamFunc => {
return Err(error_unreachable())
}
| ast::StateSpace::Shared
| ast::StateSpace::ParamEntry
| ast::StateSpace::ParamFunc => return Err(error_unreachable()),
})
}
@ -239,17 +239,28 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
}
fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
if var.state_space != ast::StateSpace::Local {
let old_name = var.name;
let old_space = var.state_space;
let new_space = ast::StateSpace::Local;
let new_name = self
.resolver
.register_unnamed(Some((var.v_type.clone(), new_space)));
self.variable(&var.v_type, old_name, new_name, old_space)?;
var.name = new_name;
var.state_space = new_space;
}
let old_space = match var.state_space {
space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space,
// Do nothing
ptx_parser::StateSpace::Local => return Ok(()),
// Handled by another pass
ptx_parser::StateSpace::Generic
| ptx_parser::StateSpace::SharedCluster
| ptx_parser::StateSpace::ParamEntry
| ptx_parser::StateSpace::Global
| ptx_parser::StateSpace::SharedCta
| ptx_parser::StateSpace::Const
| ptx_parser::StateSpace::Shared
| ptx_parser::StateSpace::ParamFunc => return Ok(()),
};
let old_name = var.name;
let new_space = ast::StateSpace::Local;
let new_name = self
.resolver
.register_unnamed(Some((var.v_type.clone(), new_space)));
self.variable(&var.v_type, old_name, new_name, old_space)?;
var.name = new_name;
var.state_space = new_space;
Ok(())
}
}