mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-03 06:40:21 +00:00
Implement atomics
This commit is contained in:
parent
c4e1315194
commit
820eaf8ada
4 changed files with 298 additions and 26 deletions
|
@ -1,6 +1,112 @@
|
||||||
#include <llvm-c/Core.h>
|
#include <llvm-c/Core.h>
|
||||||
#include "llvm/IR/IRBuilder.h"
|
#include "llvm/IR/IRBuilder.h"
|
||||||
#include "llvm/IR/Type.h"
|
#include "llvm/IR/Type.h"
|
||||||
|
#include "llvm/IR/Instructions.h"
|
||||||
|
|
||||||
|
using namespace llvm;
|
||||||
|
|
||||||
|
typedef enum
|
||||||
|
{
|
||||||
|
LLVMZludaAtomicRMWBinOpXchg, /**< Set the new value and return the one old */
|
||||||
|
LLVMZludaAtomicRMWBinOpAdd, /**< Add a value and return the old one */
|
||||||
|
LLVMZludaAtomicRMWBinOpSub, /**< Subtract a value and return the old one */
|
||||||
|
LLVMZludaAtomicRMWBinOpAnd, /**< And a value and return the old one */
|
||||||
|
LLVMZludaAtomicRMWBinOpNand, /**< Not-And a value and return the old one */
|
||||||
|
LLVMZludaAtomicRMWBinOpOr, /**< OR a value and return the old one */
|
||||||
|
LLVMZludaAtomicRMWBinOpXor, /**< Xor a value and return the old one */
|
||||||
|
LLVMZludaAtomicRMWBinOpMax, /**< Sets the value if it's greater than the
|
||||||
|
original using a signed comparison and return
|
||||||
|
the old one */
|
||||||
|
LLVMZludaAtomicRMWBinOpMin, /**< Sets the value if it's Smaller than the
|
||||||
|
original using a signed comparison and return
|
||||||
|
the old one */
|
||||||
|
LLVMZludaAtomicRMWBinOpUMax, /**< Sets the value if it's greater than the
|
||||||
|
original using an unsigned comparison and return
|
||||||
|
the old one */
|
||||||
|
LLVMZludaAtomicRMWBinOpUMin, /**< Sets the value if it's greater than the
|
||||||
|
original using an unsigned comparison and return
|
||||||
|
the old one */
|
||||||
|
LLVMZludaAtomicRMWBinOpFAdd, /**< Add a floating point value and return the
|
||||||
|
old one */
|
||||||
|
LLVMZludaAtomicRMWBinOpFSub, /**< Subtract a floating point value and return the
|
||||||
|
old one */
|
||||||
|
LLVMZludaAtomicRMWBinOpFMax, /**< Sets the value if it's greater than the
|
||||||
|
original using an floating point comparison and
|
||||||
|
return the old one */
|
||||||
|
LLVMZludaAtomicRMWBinOpFMin, /**< Sets the value if it's smaller than the
|
||||||
|
original using an floating point comparison and
|
||||||
|
return the old one */
|
||||||
|
LLVMZludaAtomicRMWBinOpUIncWrap, /**< Increments the value, wrapping back to zero
|
||||||
|
when incremented above input value */
|
||||||
|
LLVMZludaAtomicRMWBinOpUDecWrap, /**< Decrements the value, wrapping back to
|
||||||
|
the input value when decremented below zero */
|
||||||
|
} LLVMZludaAtomicRMWBinOp;
|
||||||
|
|
||||||
|
static llvm::AtomicRMWInst::BinOp mapFromLLVMRMWBinOp(LLVMZludaAtomicRMWBinOp BinOp)
|
||||||
|
{
|
||||||
|
switch (BinOp)
|
||||||
|
{
|
||||||
|
case LLVMZludaAtomicRMWBinOpXchg:
|
||||||
|
return llvm::AtomicRMWInst::Xchg;
|
||||||
|
case LLVMZludaAtomicRMWBinOpAdd:
|
||||||
|
return llvm::AtomicRMWInst::Add;
|
||||||
|
case LLVMZludaAtomicRMWBinOpSub:
|
||||||
|
return llvm::AtomicRMWInst::Sub;
|
||||||
|
case LLVMZludaAtomicRMWBinOpAnd:
|
||||||
|
return llvm::AtomicRMWInst::And;
|
||||||
|
case LLVMZludaAtomicRMWBinOpNand:
|
||||||
|
return llvm::AtomicRMWInst::Nand;
|
||||||
|
case LLVMZludaAtomicRMWBinOpOr:
|
||||||
|
return llvm::AtomicRMWInst::Or;
|
||||||
|
case LLVMZludaAtomicRMWBinOpXor:
|
||||||
|
return llvm::AtomicRMWInst::Xor;
|
||||||
|
case LLVMZludaAtomicRMWBinOpMax:
|
||||||
|
return llvm::AtomicRMWInst::Max;
|
||||||
|
case LLVMZludaAtomicRMWBinOpMin:
|
||||||
|
return llvm::AtomicRMWInst::Min;
|
||||||
|
case LLVMZludaAtomicRMWBinOpUMax:
|
||||||
|
return llvm::AtomicRMWInst::UMax;
|
||||||
|
case LLVMZludaAtomicRMWBinOpUMin:
|
||||||
|
return llvm::AtomicRMWInst::UMin;
|
||||||
|
case LLVMZludaAtomicRMWBinOpFAdd:
|
||||||
|
return llvm::AtomicRMWInst::FAdd;
|
||||||
|
case LLVMZludaAtomicRMWBinOpFSub:
|
||||||
|
return llvm::AtomicRMWInst::FSub;
|
||||||
|
case LLVMZludaAtomicRMWBinOpFMax:
|
||||||
|
return llvm::AtomicRMWInst::FMax;
|
||||||
|
case LLVMZludaAtomicRMWBinOpFMin:
|
||||||
|
return llvm::AtomicRMWInst::FMin;
|
||||||
|
case LLVMZludaAtomicRMWBinOpUIncWrap:
|
||||||
|
return llvm::AtomicRMWInst::UIncWrap;
|
||||||
|
case LLVMZludaAtomicRMWBinOpUDecWrap:
|
||||||
|
return llvm::AtomicRMWInst::UDecWrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm_unreachable("Invalid LLVMZludaAtomicRMWBinOp value!");
|
||||||
|
}
|
||||||
|
|
||||||
|
static AtomicOrdering mapFromLLVMOrdering(LLVMAtomicOrdering Ordering)
|
||||||
|
{
|
||||||
|
switch (Ordering)
|
||||||
|
{
|
||||||
|
case LLVMAtomicOrderingNotAtomic:
|
||||||
|
return AtomicOrdering::NotAtomic;
|
||||||
|
case LLVMAtomicOrderingUnordered:
|
||||||
|
return AtomicOrdering::Unordered;
|
||||||
|
case LLVMAtomicOrderingMonotonic:
|
||||||
|
return AtomicOrdering::Monotonic;
|
||||||
|
case LLVMAtomicOrderingAcquire:
|
||||||
|
return AtomicOrdering::Acquire;
|
||||||
|
case LLVMAtomicOrderingRelease:
|
||||||
|
return AtomicOrdering::Release;
|
||||||
|
case LLVMAtomicOrderingAcquireRelease:
|
||||||
|
return AtomicOrdering::AcquireRelease;
|
||||||
|
case LLVMAtomicOrderingSequentiallyConsistent:
|
||||||
|
return AtomicOrdering::SequentiallyConsistent;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm_unreachable("Invalid LLVMAtomicOrdering value!");
|
||||||
|
}
|
||||||
|
|
||||||
LLVM_C_EXTERN_C_BEGIN
|
LLVM_C_EXTERN_C_BEGIN
|
||||||
|
|
||||||
|
@ -10,4 +116,18 @@ LLVMValueRef LLVMZludaBuildAlloca(LLVMBuilderRef B, LLVMTypeRef Ty, unsigned Add
|
||||||
return llvm::wrap(llvm::unwrap(B)->CreateAlloca(llvm::unwrap(Ty), AddrSpace, nullptr, Name));
|
return llvm::wrap(llvm::unwrap(B)->CreateAlloca(llvm::unwrap(Ty), AddrSpace, nullptr, Name));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LLVMValueRef LLVMZludaBuildAtomicRMW(LLVMBuilderRef B, LLVMZludaAtomicRMWBinOp op,
|
||||||
|
LLVMValueRef PTR, LLVMValueRef Val,
|
||||||
|
char *scope,
|
||||||
|
LLVMAtomicOrdering ordering)
|
||||||
|
{
|
||||||
|
auto builder = llvm::unwrap(B);
|
||||||
|
LLVMContext &context = builder->getContext();
|
||||||
|
llvm::AtomicRMWInst::BinOp intop = mapFromLLVMRMWBinOp(op);
|
||||||
|
return llvm::wrap(builder->CreateAtomicRMW(
|
||||||
|
intop, llvm::unwrap(PTR), llvm::unwrap(Val), llvm::MaybeAlign(),
|
||||||
|
mapFromLLVMOrdering(ordering),
|
||||||
|
context.getOrInsertSyncScopeID(scope)));
|
||||||
|
}
|
||||||
|
|
||||||
LLVM_C_EXTERN_C_END
|
LLVM_C_EXTERN_C_END
|
|
@ -1,5 +1,28 @@
|
||||||
use llvm_sys::prelude::*;
|
use llvm_sys::prelude::*;
|
||||||
pub use llvm_sys::*;
|
pub use llvm_sys::*;
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||||
|
pub enum LLVMZludaAtomicRMWBinOp {
|
||||||
|
LLVMZludaAtomicRMWBinOpXchg = 0,
|
||||||
|
LLVMZludaAtomicRMWBinOpAdd = 1,
|
||||||
|
LLVMZludaAtomicRMWBinOpSub = 2,
|
||||||
|
LLVMZludaAtomicRMWBinOpAnd = 3,
|
||||||
|
LLVMZludaAtomicRMWBinOpNand = 4,
|
||||||
|
LLVMZludaAtomicRMWBinOpOr = 5,
|
||||||
|
LLVMZludaAtomicRMWBinOpXor = 6,
|
||||||
|
LLVMZludaAtomicRMWBinOpMax = 7,
|
||||||
|
LLVMZludaAtomicRMWBinOpMin = 8,
|
||||||
|
LLVMZludaAtomicRMWBinOpUMax = 9,
|
||||||
|
LLVMZludaAtomicRMWBinOpUMin = 10,
|
||||||
|
LLVMZludaAtomicRMWBinOpFAdd = 11,
|
||||||
|
LLVMZludaAtomicRMWBinOpFSub = 12,
|
||||||
|
LLVMZludaAtomicRMWBinOpFMax = 13,
|
||||||
|
LLVMZludaAtomicRMWBinOpFMin = 14,
|
||||||
|
LLVMZludaAtomicRMWBinOpUIncWrap = 15,
|
||||||
|
LLVMZludaAtomicRMWBinOpUDecWrap = 16,
|
||||||
|
}
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
pub fn LLVMZludaBuildAlloca(
|
pub fn LLVMZludaBuildAlloca(
|
||||||
B: LLVMBuilderRef,
|
B: LLVMBuilderRef,
|
||||||
|
@ -7,4 +30,13 @@ extern "C" {
|
||||||
AddrSpace: u32,
|
AddrSpace: u32,
|
||||||
Name: *const i8,
|
Name: *const i8,
|
||||||
) -> LLVMValueRef;
|
) -> LLVMValueRef;
|
||||||
|
|
||||||
|
pub fn LLVMZludaBuildAtomicRMW(
|
||||||
|
B: LLVMBuilderRef,
|
||||||
|
op: LLVMZludaAtomicRMWBinOp,
|
||||||
|
PTR: LLVMValueRef,
|
||||||
|
Val: LLVMValueRef,
|
||||||
|
scope: *const i8,
|
||||||
|
ordering: LLVMAtomicOrdering,
|
||||||
|
) -> LLVMValueRef;
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,15 +19,15 @@
|
||||||
// unsafe { LLVMBuildAdd(builder, lhs, rhs, dst) };
|
// unsafe { LLVMBuildAdd(builder, lhs, rhs, dst) };
|
||||||
|
|
||||||
use std::convert::{TryFrom, TryInto};
|
use std::convert::{TryFrom, TryInto};
|
||||||
use std::ffi::CStr;
|
use std::ffi::{CStr, NulError};
|
||||||
use std::ops::Deref;
|
use std::ops::Deref;
|
||||||
use std::ptr;
|
use std::ptr;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule};
|
use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule};
|
||||||
use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer;
|
use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer;
|
||||||
use llvm_zluda::core::*;
|
use llvm_zluda::{core::*, LLVMAtomicOrdering, LLVMAtomicRMWBinOp, LLVMZludaAtomicRMWBinOp};
|
||||||
use llvm_zluda::prelude::*;
|
use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW};
|
||||||
use llvm_zluda::{LLVMCallConv, LLVMZludaBuildAlloca};
|
use llvm_zluda::{LLVMCallConv, LLVMZludaBuildAlloca};
|
||||||
|
|
||||||
const LLVM_UNNAMED: &CStr = c"";
|
const LLVM_UNNAMED: &CStr = c"";
|
||||||
|
@ -172,7 +172,7 @@ pub(super) fn run<'input>(
|
||||||
let mut emit_ctx = ModuleEmitContext::new(&context, &module, &id_defs);
|
let mut emit_ctx = ModuleEmitContext::new(&context, &module, &id_defs);
|
||||||
for directive in directives {
|
for directive in directives {
|
||||||
match directive {
|
match directive {
|
||||||
Directive2::Variable(..) => todo!(),
|
Directive2::Variable(linking, variable) => emit_ctx.emit_global(linking, variable)?,
|
||||||
Directive2::Method(method) => emit_ctx.emit_method(method)?,
|
Directive2::Method(method) => emit_ctx.emit_method(method)?,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -281,6 +281,43 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn emit_global(
|
||||||
|
&mut self,
|
||||||
|
linking: ast::LinkingDirective,
|
||||||
|
var: ptx_parser::Variable<SpirvWord>,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
let name = self
|
||||||
|
.id_defs
|
||||||
|
.ident_map
|
||||||
|
.get(&var.name)
|
||||||
|
.map(|entry| {
|
||||||
|
entry
|
||||||
|
.name
|
||||||
|
.as_ref()
|
||||||
|
.map(|text| Ok::<_, NulError>(Cow::Owned(CString::new(&**text)?)))
|
||||||
|
})
|
||||||
|
.flatten()
|
||||||
|
.transpose()
|
||||||
|
.map_err(|_| error_unreachable())?
|
||||||
|
.unwrap_or(Cow::Borrowed(LLVM_UNNAMED));
|
||||||
|
let global = unsafe {
|
||||||
|
LLVMAddGlobalInAddressSpace(
|
||||||
|
self.module,
|
||||||
|
get_type(self.context, &var.v_type)?,
|
||||||
|
name.as_ptr(),
|
||||||
|
get_state_space(var.state_space)?,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
self.resolver.register(var.name, global);
|
||||||
|
if let Some(align) = var.align {
|
||||||
|
unsafe { LLVMSetAlignment(global, align) };
|
||||||
|
}
|
||||||
|
if !var.array_init.is_empty() {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_input_argument_type(
|
fn get_input_argument_type(
|
||||||
|
@ -419,7 +456,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
||||||
ast::Instruction::Rsqrt { data, arguments } => todo!(),
|
ast::Instruction::Rsqrt { data, arguments } => todo!(),
|
||||||
ast::Instruction::Selp { data, arguments } => todo!(),
|
ast::Instruction::Selp { data, arguments } => todo!(),
|
||||||
ast::Instruction::Bar { data, arguments } => todo!(),
|
ast::Instruction::Bar { data, arguments } => todo!(),
|
||||||
ast::Instruction::Atom { data, arguments } => todo!(),
|
ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments),
|
||||||
ast::Instruction::AtomCas { data, arguments } => todo!(),
|
ast::Instruction::AtomCas { data, arguments } => todo!(),
|
||||||
ast::Instruction::Div { data, arguments } => todo!(),
|
ast::Instruction::Div { data, arguments } => todo!(),
|
||||||
ast::Instruction::Neg { data, arguments } => todo!(),
|
ast::Instruction::Neg { data, arguments } => todo!(),
|
||||||
|
@ -499,7 +536,14 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
||||||
});
|
});
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
ConversionKind::PtrToPtr => todo!(),
|
ConversionKind::PtrToPtr => {
|
||||||
|
let src = self.resolver.value(conversion.src)?;
|
||||||
|
let dst_type = get_pointer_type(self.context, conversion.to_space)?;
|
||||||
|
self.resolver.with_result(conversion.dst, |dst| unsafe {
|
||||||
|
LLVMBuildAddrSpaceCast(builder, src, dst_type, dst)
|
||||||
|
});
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
ConversionKind::AddressOf => todo!(),
|
ConversionKind::AddressOf => todo!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -635,6 +679,51 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
||||||
});
|
});
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn emit_atom(
|
||||||
|
&mut self,
|
||||||
|
data: ptx_parser::AtomDetails,
|
||||||
|
arguments: ptx_parser::AtomArgs<SpirvWord>,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
let builder = self.builder;
|
||||||
|
let src1 = self.resolver.value(arguments.src1)?;
|
||||||
|
let src2 = self.resolver.value(arguments.src2)?;
|
||||||
|
let op = match data.op {
|
||||||
|
ptx_parser::AtomicOp::And => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAnd,
|
||||||
|
ptx_parser::AtomicOp::Or => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpOr,
|
||||||
|
ptx_parser::AtomicOp::Xor => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXor,
|
||||||
|
ptx_parser::AtomicOp::Exchange => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXchg,
|
||||||
|
ptx_parser::AtomicOp::Add => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAdd,
|
||||||
|
ptx_parser::AtomicOp::IncrementWrap => {
|
||||||
|
LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUIncWrap
|
||||||
|
}
|
||||||
|
ptx_parser::AtomicOp::DecrementWrap => {
|
||||||
|
LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUDecWrap
|
||||||
|
}
|
||||||
|
ptx_parser::AtomicOp::SignedMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMin,
|
||||||
|
ptx_parser::AtomicOp::UnsignedMin => {
|
||||||
|
LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMin
|
||||||
|
}
|
||||||
|
ptx_parser::AtomicOp::SignedMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMax,
|
||||||
|
ptx_parser::AtomicOp::UnsignedMax => {
|
||||||
|
LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMax
|
||||||
|
}
|
||||||
|
ptx_parser::AtomicOp::FloatAdd => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFAdd,
|
||||||
|
ptx_parser::AtomicOp::FloatMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMin,
|
||||||
|
ptx_parser::AtomicOp::FloatMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMax,
|
||||||
|
};
|
||||||
|
self.resolver.register(arguments.dst, unsafe {
|
||||||
|
LLVMZludaBuildAtomicRMW(
|
||||||
|
builder,
|
||||||
|
op,
|
||||||
|
src1,
|
||||||
|
src2,
|
||||||
|
get_scope(data.scope)?,
|
||||||
|
get_ordering(data.semantics),
|
||||||
|
)
|
||||||
|
});
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_pointer_type<'ctx>(
|
fn get_pointer_type<'ctx>(
|
||||||
|
@ -644,6 +733,26 @@ fn get_pointer_type<'ctx>(
|
||||||
Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(to_space)?) })
|
Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(to_space)?) })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://llvm.org/docs/AMDGPUUsage.html#memory-scopes
|
||||||
|
fn get_scope(scope: ast::MemScope) -> Result<*const i8, TranslateError> {
|
||||||
|
Ok(match scope {
|
||||||
|
ast::MemScope::Cta => c"workgroup-one-as",
|
||||||
|
ast::MemScope::Gpu => c"agent-one-as",
|
||||||
|
ast::MemScope::Sys => c"one-as",
|
||||||
|
ast::MemScope::Cluster => todo!(),
|
||||||
|
}
|
||||||
|
.as_ptr())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_ordering(semantics: ast::AtomSemantics) -> LLVMAtomicOrdering {
|
||||||
|
match semantics {
|
||||||
|
ast::AtomSemantics::Relaxed => LLVMAtomicOrdering::LLVMAtomicOrderingMonotonic,
|
||||||
|
ast::AtomSemantics::Acquire => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire,
|
||||||
|
ast::AtomSemantics::Release => LLVMAtomicOrdering::LLVMAtomicOrderingRelease,
|
||||||
|
ast::AtomSemantics::AcqRel => LLVMAtomicOrdering::LLVMAtomicOrderingAcquireRelease,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn get_type(context: LLVMContextRef, type_: &ast::Type) -> Result<LLVMTypeRef, TranslateError> {
|
fn get_type(context: LLVMContextRef, type_: &ast::Type) -> Result<LLVMTypeRef, TranslateError> {
|
||||||
Ok(match type_ {
|
Ok(match type_ {
|
||||||
ast::Type::Scalar(scalar) => get_scalar_type(context, *scalar),
|
ast::Type::Scalar(scalar) => get_scalar_type(context, *scalar),
|
||||||
|
|
|
@ -52,7 +52,7 @@ fn run_method<'a, 'input>(
|
||||||
let new_name = visitor
|
let new_name = visitor
|
||||||
.resolver
|
.resolver
|
||||||
.register_unnamed(Some((arg.v_type.clone(), new_space)));
|
.register_unnamed(Some((arg.v_type.clone(), new_space)));
|
||||||
visitor.input_argument(old_name, new_name, old_space);
|
visitor.input_argument(old_name, new_name, old_space)?;
|
||||||
arg.name = new_name;
|
arg.name = new_name;
|
||||||
arg.state_space = new_space;
|
arg.state_space = new_space;
|
||||||
}
|
}
|
||||||
|
@ -154,7 +154,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
|
||||||
old_name: SpirvWord,
|
old_name: SpirvWord,
|
||||||
new_name: SpirvWord,
|
new_name: SpirvWord,
|
||||||
old_space: ast::StateSpace,
|
old_space: ast::StateSpace,
|
||||||
) -> Result<(), TranslateError> {
|
) -> Result<bool, TranslateError> {
|
||||||
Ok(match old_space {
|
Ok(match old_space {
|
||||||
ast::StateSpace::Reg => {
|
ast::StateSpace::Reg => {
|
||||||
self.variables.insert(
|
self.variables.insert(
|
||||||
|
@ -164,6 +164,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
|
||||||
type_: type_.clone(),
|
type_: type_.clone(),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
true
|
||||||
}
|
}
|
||||||
ast::StateSpace::Param => {
|
ast::StateSpace::Param => {
|
||||||
self.variables.insert(
|
self.variables.insert(
|
||||||
|
@ -174,19 +175,18 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
|
||||||
name: new_name,
|
name: new_name,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
true
|
||||||
}
|
}
|
||||||
// Good as-is
|
// Good as-is
|
||||||
ast::StateSpace::Local => {}
|
ast::StateSpace::Local
|
||||||
// Will be pulled into global scope later
|
| ast::StateSpace::Generic
|
||||||
ast::StateSpace::Generic
|
|
||||||
| ast::StateSpace::SharedCluster
|
| ast::StateSpace::SharedCluster
|
||||||
| ast::StateSpace::Global
|
| ast::StateSpace::Global
|
||||||
| ast::StateSpace::Const
|
| ast::StateSpace::Const
|
||||||
| ast::StateSpace::SharedCta
|
| ast::StateSpace::SharedCta
|
||||||
| ast::StateSpace::Shared => {}
|
| ast::StateSpace::Shared
|
||||||
ast::StateSpace::ParamEntry | ast::StateSpace::ParamFunc => {
|
| ast::StateSpace::ParamEntry
|
||||||
return Err(error_unreachable())
|
| ast::StateSpace::ParamFunc => return Err(error_unreachable()),
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -239,17 +239,28 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
|
fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
|
||||||
if var.state_space != ast::StateSpace::Local {
|
let old_space = match var.state_space {
|
||||||
let old_name = var.name;
|
space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space,
|
||||||
let old_space = var.state_space;
|
// Do nothing
|
||||||
let new_space = ast::StateSpace::Local;
|
ptx_parser::StateSpace::Local => return Ok(()),
|
||||||
let new_name = self
|
// Handled by another pass
|
||||||
.resolver
|
ptx_parser::StateSpace::Generic
|
||||||
.register_unnamed(Some((var.v_type.clone(), new_space)));
|
| ptx_parser::StateSpace::SharedCluster
|
||||||
self.variable(&var.v_type, old_name, new_name, old_space)?;
|
| ptx_parser::StateSpace::ParamEntry
|
||||||
var.name = new_name;
|
| ptx_parser::StateSpace::Global
|
||||||
var.state_space = new_space;
|
| ptx_parser::StateSpace::SharedCta
|
||||||
}
|
| ptx_parser::StateSpace::Const
|
||||||
|
| ptx_parser::StateSpace::Shared
|
||||||
|
| ptx_parser::StateSpace::ParamFunc => return Ok(()),
|
||||||
|
};
|
||||||
|
let old_name = var.name;
|
||||||
|
let new_space = ast::StateSpace::Local;
|
||||||
|
let new_name = self
|
||||||
|
.resolver
|
||||||
|
.register_unnamed(Some((var.v_type.clone(), new_space)));
|
||||||
|
self.variable(&var.v_type, old_name, new_name, old_space)?;
|
||||||
|
var.name = new_name;
|
||||||
|
var.state_space = new_space;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue