diff --git a/level_zero/Cargo.toml b/level_zero/Cargo.toml index 97537b3..851159d 100644 --- a/level_zero/Cargo.toml +++ b/level_zero/Cargo.toml @@ -7,4 +7,8 @@ edition = "2018" [lib] [dependencies] -level_zero-sys = { path = "../level_zero-sys" } \ No newline at end of file +level_zero-sys = { path = "../level_zero-sys" } + +[dependencies.ocl-core] +version = "0.11" +features = ["opencl_version_1_2", "opencl_version_2_0", "opencl_version_2_1"] \ No newline at end of file diff --git a/level_zero/src/ze.rs b/level_zero/src/ze.rs index 5ced5d0..f8a2c3b 100644 --- a/level_zero/src/ze.rs +++ b/level_zero/src/ze.rs @@ -238,7 +238,76 @@ impl Drop for CommandQueue { pub struct Module(sys::ze_module_handle_t); impl Module { - pub fn new_spirv( + // HACK ALERT + // We use OpenCL for now to do SPIR-V linking, because Level0 + // does not allow linking. Don't let presence of zeModuleDynamicLink fool + // you, it's not currently possible to create non-compiled modules. + // zeModuleCreate always compiles (builds and links). + pub fn build_link_spirv<'a>( + ctx: &mut Context, + d: &Device, + binaries: &[&'a [u8]], + ) -> (Result, Option) { + let ocl_program = match Self::build_link_spirv_impl(binaries) { + Err(_) => return (Err(sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN), None), + Ok(prog) => prog, + }; + match ocl_core::get_program_info(&ocl_program, ocl_core::ProgramInfo::Binaries) { + Ok(ocl_core::ProgramInfoResult::Binaries(binaries)) => { + let (module, build_log) = Self::build_native(ctx, d, &binaries[0]); + (module, Some(build_log)) + } + _ => return (Err(sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN), None), + } + } + + fn build_link_spirv_impl<'a>(binaries: &[&'a [u8]]) -> ocl_core::Result { + let platforms = ocl_core::get_platform_ids()?; + let (platform, device) = platforms + .iter() + .find_map(|plat| { + let devices = + ocl_core::get_device_ids(plat, Some(ocl_core::DeviceType::GPU), None).ok()?; + for dev in devices { + let vendor = + ocl_core::get_device_info(dev, ocl_core::DeviceInfo::VendorId).ok()?; + if let ocl_core::DeviceInfoResult::VendorId(0x8086) = vendor { + let dev_type = + ocl_core::get_device_info(dev, ocl_core::DeviceInfo::Type).ok()?; + if let ocl_core::DeviceInfoResult::Type(ocl_core::DeviceType::GPU) = + dev_type + { + return Some((plat.clone(), dev)); + } + } + } + None + }) + .ok_or("")?; + let ctx_props = ocl_core::ContextProperties::new().platform(platform); + let ocl_ctx = ocl_core::create_context_from_type::( + Some(&ctx_props), + ocl_core::DeviceType::GPU, + None, + None, + )?; + let mut programs = Vec::with_capacity(binaries.len()); + for binary in binaries { + programs.push(ocl_core::create_program_with_il(&ocl_ctx, binary, None)?); + } + let options = CString::default(); + ocl_core::link_program::( + &ocl_ctx, + Some(&[device]), + &options, + &programs.iter().collect::>(), + None, + None, + None, + ) + } + + pub fn build_spirv( ctx: &mut Context, d: &Device, bin: &[u8], @@ -247,7 +316,7 @@ impl Module { Module::new(ctx, true, d, bin, opts) } - pub fn new_native(ctx: &mut Context, d: &Device, bin: &[u8]) -> (Result, BuildLog) { + pub fn build_native(ctx: &mut Context, d: &Device, bin: &[u8]) -> (Result, BuildLog) { Module::new(ctx, false, d, bin, None) } diff --git a/notcuda/src/impl/module.rs b/notcuda/src/impl/module.rs index eea862b..35436c3 100644 --- a/notcuda/src/impl/module.rs +++ b/notcuda/src/impl/module.rs @@ -53,7 +53,7 @@ impl ModuleData { Ok(_) if errors.len() > 0 => return Err(ModuleCompileError::Parse(errors, None)), Ok(ast) => ast, }; - let (spirv, all_arg_lens) = ptx::to_spirv(ast)?; + let (_, spirv, all_arg_lens) = ptx::to_spirv(ast)?; let byte_il = unsafe { slice::from_raw_parts::( spirv.as_ptr() as *const _, @@ -61,7 +61,7 @@ impl ModuleData { ) }; let module = super::device::with_current_exclusive(|dev| { - l0::Module::new_spirv(&mut dev.l0_context, &dev.base, byte_il, None) + l0::Module::build_spirv(&mut dev.l0_context, &dev.base, byte_il, None) }); match module { Ok((Ok(module), _)) => Ok(Mutex::new(Self { diff --git a/ptx/lib/notcuda_ptx_impl.cl b/ptx/lib/notcuda_ptx_impl.cl new file mode 100644 index 0000000..a0d487b --- /dev/null +++ b/ptx/lib/notcuda_ptx_impl.cl @@ -0,0 +1,121 @@ +// Every time this file changes it must te rebuilt: +// ocloc -file notcuda_ptx_impl.cl -64 -options "-cl-std=CL2.0" -out_dir . -device kbl -output_no_suffix -spv_only +// Additionally you should strip names: +// spirv-opt --strip-debug notcuda_ptx_impl.spv -o notcuda_ptx_impl.spv + +#define FUNC(NAME) __notcuda_ptx_impl__ ## NAME + +#define atomic_inc(NAME, SUCCESS, FAILURE, SCOPE, SPACE) \ + uint FUNC(NAME)(SPACE uint* ptr, uint threshold) { \ + uint expected = *ptr; \ + uint desired; \ + do { \ + desired = (expected >= threshold) ? 0 : expected + 1; \ + } while (!atomic_compare_exchange_strong_explicit((volatile SPACE atomic_uint*)ptr, &expected, desired, SUCCESS, FAILURE, SCOPE)); \ + return expected; \ + } + +#define atomic_dec(NAME, SUCCESS, FAILURE, SCOPE, SPACE) \ + uint FUNC(NAME)(SPACE uint* ptr, uint threshold) { \ + uint expected = *ptr; \ + uint desired; \ + do { \ + desired = (expected == 0 || expected > threshold) ? threshold : expected - 1; \ + } while (!atomic_compare_exchange_strong_explicit((volatile SPACE atomic_uint*)ptr, &expected, desired, SUCCESS, FAILURE, SCOPE)); \ + return expected; \ + } + +// We are doing all this mess instead of accepting memory_order and memory_scope parameters +// because ocloc emits broken (failing spirv-dis) SPIR-V when memory_order or memory_scope is a parameter + +// atom.inc +atomic_inc(atom_relaxed_cta_generic_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group, ); +atomic_inc(atom_acquire_cta_generic_inc, memory_order_acquire, memory_order_acquire, memory_scope_work_group, ); +atomic_inc(atom_release_cta_generic_inc, memory_order_release, memory_order_acquire, memory_scope_work_group, ); +atomic_inc(atom_acq_rel_cta_generic_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_work_group, ); + +atomic_inc(atom_relaxed_gpu_generic_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_device, ); +atomic_inc(atom_acquire_gpu_generic_inc, memory_order_acquire, memory_order_acquire, memory_scope_device, ); +atomic_inc(atom_release_gpu_generic_inc, memory_order_release, memory_order_acquire, memory_scope_device, ); +atomic_inc(atom_acq_rel_gpu_generic_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_device, ); + +atomic_inc(atom_relaxed_sys_generic_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_device, ); +atomic_inc(atom_acquire_sys_generic_inc, memory_order_acquire, memory_order_acquire, memory_scope_device, ); +atomic_inc(atom_release_sys_generic_inc, memory_order_release, memory_order_acquire, memory_scope_device, ); +atomic_inc(atom_acq_rel_sys_generic_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_device, ); + +atomic_inc(atom_relaxed_cta_global_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group, __global); +atomic_inc(atom_acquire_cta_global_inc, memory_order_acquire, memory_order_acquire, memory_scope_work_group, __global); +atomic_inc(atom_release_cta_global_inc, memory_order_release, memory_order_acquire, memory_scope_work_group, __global); +atomic_inc(atom_acq_rel_cta_global_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_work_group, __global); + +atomic_inc(atom_relaxed_gpu_global_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __global); +atomic_inc(atom_acquire_gpu_global_inc, memory_order_acquire, memory_order_acquire, memory_scope_device, __global); +atomic_inc(atom_release_gpu_global_inc, memory_order_release, memory_order_acquire, memory_scope_device, __global); +atomic_inc(atom_acq_rel_gpu_global_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __global); + +atomic_inc(atom_relaxed_sys_global_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __global); +atomic_inc(atom_acquire_sys_global_inc, memory_order_acquire, memory_order_acquire, memory_scope_device, __global); +atomic_inc(atom_release_sys_global_inc, memory_order_release, memory_order_acquire, memory_scope_device, __global); +atomic_inc(atom_acq_rel_sys_global_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __global); + +atomic_inc(atom_relaxed_cta_shared_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group, __local); +atomic_inc(atom_acquire_cta_shared_inc, memory_order_acquire, memory_order_acquire, memory_scope_work_group, __local); +atomic_inc(atom_release_cta_shared_inc, memory_order_release, memory_order_acquire, memory_scope_work_group, __local); +atomic_inc(atom_acq_rel_cta_shared_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_work_group, __local); + +atomic_inc(atom_relaxed_gpu_shared_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __local); +atomic_inc(atom_acquire_gpu_shared_inc, memory_order_acquire, memory_order_acquire, memory_scope_device, __local); +atomic_inc(atom_release_gpu_shared_inc, memory_order_release, memory_order_acquire, memory_scope_device, __local); +atomic_inc(atom_acq_rel_gpu_shared_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __local); + +atomic_inc(atom_relaxed_sys_shared_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __local); +atomic_inc(atom_acquire_sys_shared_inc, memory_order_acquire, memory_order_acquire, memory_scope_device, __local); +atomic_inc(atom_release_sys_shared_inc, memory_order_release, memory_order_acquire, memory_scope_device, __local); +atomic_inc(atom_acq_rel_sys_shared_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __local); + +// atom.dec +atomic_dec(atom_relaxed_cta_generic_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group, ); +atomic_dec(atom_acquire_cta_generic_dec, memory_order_acquire, memory_order_acquire, memory_scope_work_group, ); +atomic_dec(atom_release_cta_generic_dec, memory_order_release, memory_order_acquire, memory_scope_work_group, ); +atomic_dec(atom_acq_rel_cta_generic_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_work_group, ); + +atomic_dec(atom_relaxed_gpu_generic_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_device, ); +atomic_dec(atom_acquire_gpu_generic_dec, memory_order_acquire, memory_order_acquire, memory_scope_device, ); +atomic_dec(atom_release_gpu_generic_dec, memory_order_release, memory_order_acquire, memory_scope_device, ); +atomic_dec(atom_acq_rel_gpu_generic_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_device, ); + +atomic_dec(atom_relaxed_sys_generic_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_device, ); +atomic_dec(atom_acquire_sys_generic_dec, memory_order_acquire, memory_order_acquire, memory_scope_device, ); +atomic_dec(atom_release_sys_generic_dec, memory_order_release, memory_order_acquire, memory_scope_device, ); +atomic_dec(atom_acq_rel_sys_generic_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_device, ); + +atomic_dec(atom_relaxed_cta_global_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group, __global); +atomic_dec(atom_acquire_cta_global_dec, memory_order_acquire, memory_order_acquire, memory_scope_work_group, __global); +atomic_dec(atom_release_cta_global_dec, memory_order_release, memory_order_acquire, memory_scope_work_group, __global); +atomic_dec(atom_acq_rel_cta_global_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_work_group, __global); + +atomic_dec(atom_relaxed_gpu_global_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __global); +atomic_dec(atom_acquire_gpu_global_dec, memory_order_acquire, memory_order_acquire, memory_scope_device, __global); +atomic_dec(atom_release_gpu_global_dec, memory_order_release, memory_order_acquire, memory_scope_device, __global); +atomic_dec(atom_acq_rel_gpu_global_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __global); + +atomic_dec(atom_relaxed_sys_global_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __global); +atomic_dec(atom_acquire_sys_global_dec, memory_order_acquire, memory_order_acquire, memory_scope_device, __global); +atomic_dec(atom_release_sys_global_dec, memory_order_release, memory_order_acquire, memory_scope_device, __global); +atomic_dec(atom_acq_rel_sys_global_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __global); + +atomic_dec(atom_relaxed_cta_shared_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group, __local); +atomic_dec(atom_acquire_cta_shared_dec, memory_order_acquire, memory_order_acquire, memory_scope_work_group, __local); +atomic_dec(atom_release_cta_shared_dec, memory_order_release, memory_order_acquire, memory_scope_work_group, __local); +atomic_dec(atom_acq_rel_cta_shared_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_work_group, __local); + +atomic_dec(atom_relaxed_gpu_shared_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __local); +atomic_dec(atom_acquire_gpu_shared_dec, memory_order_acquire, memory_order_acquire, memory_scope_device, __local); +atomic_dec(atom_release_gpu_shared_dec, memory_order_release, memory_order_acquire, memory_scope_device, __local); +atomic_dec(atom_acq_rel_gpu_shared_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __local); + +atomic_dec(atom_relaxed_sys_shared_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __local); +atomic_dec(atom_acquire_sys_shared_dec, memory_order_acquire, memory_order_acquire, memory_scope_device, __local); +atomic_dec(atom_release_sys_shared_dec, memory_order_release, memory_order_acquire, memory_scope_device, __local); +atomic_dec(atom_acq_rel_sys_shared_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __local); diff --git a/ptx/lib/notcuda_ptx_impl.spv b/ptx/lib/notcuda_ptx_impl.spv new file mode 100644 index 0000000..36f37bb Binary files /dev/null and b/ptx/lib/notcuda_ptx_impl.spv differ diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 1266ea4..ad8e87d 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -109,11 +109,12 @@ macro_rules! sub_type { }; } -// Pointer is used when doing SLM converison to SPIRV sub_type! { VariableRegType { Scalar(ScalarType), Vector(SizedScalarType, u8), + // Pointer variant is used when passing around SLM pointer between + // function calls for dynamic SLM Pointer(SizedScalarType, PointerStateSpace) } } @@ -215,6 +216,11 @@ sub_enum!(SelpType { F64, }); +#[derive(Copy, Clone, Eq, PartialEq)] +pub enum BarDetails { + SyncAligned, +} + pub trait UnwrapWithVec { fn unwrap_with(self, errs: &mut Vec) -> To; } @@ -301,6 +307,7 @@ impl From for Type { sub_enum!( PointerStateSpace : LdStateSpace { + Generic, Global, Const, Shared, @@ -372,6 +379,8 @@ sub_enum!(IntType { S64 }); +sub_enum!(BitType { B8, B16, B32, B64 }); + sub_enum!(UIntType { U8, U16, U32, U64 }); sub_enum!(SIntType { S8, S16, S32, S64 }); @@ -527,6 +536,9 @@ pub enum Instruction { Rcp(RcpDetails, Arg2

), And(OrAndType, Arg3

), Selp(SelpType, Arg4

), + Bar(BarDetails, Arg1Bar

), + Atom(AtomDetails, Arg3

), + AtomCas(AtomCasDetails, Arg4

), } #[derive(Copy, Clone)] @@ -577,6 +589,10 @@ pub struct Arg1 { pub src: P::Id, // it is a jump destination, but in terms of operands it is a source operand } +pub struct Arg1Bar { + pub src: P::Operand, +} + pub struct Arg2 { pub dst: P::Id, pub src: P::Operand, @@ -712,12 +728,12 @@ impl From for PointerType { pub enum LdStQualifier { Weak, Volatile, - Relaxed(LdScope), - Acquire(LdScope), + Relaxed(MemScope), + Acquire(MemScope), } #[derive(Copy, Clone, PartialEq, Eq)] -pub enum LdScope { +pub enum MemScope { Cta, Gpu, Sys, @@ -1051,6 +1067,74 @@ pub struct MinMaxFloat { pub typ: FloatType, } +#[derive(Copy, Clone)] +pub struct AtomDetails { + pub semantics: AtomSemantics, + pub scope: MemScope, + pub space: AtomSpace, + pub inner: AtomInnerDetails, +} + +#[derive(Copy, Clone)] +pub enum AtomSemantics { + Relaxed, + Acquire, + Release, + AcquireRelease, +} + +#[derive(Copy, Clone)] +pub enum AtomSpace { + Generic, + Global, + Shared, +} + +#[derive(Copy, Clone)] +pub enum AtomInnerDetails { + Bit { op: AtomBitOp, typ: BitType }, + Unsigned { op: AtomUIntOp, typ: UIntType }, + Signed { op: AtomSIntOp, typ: SIntType }, + Float { op: AtomFloatOp, typ: FloatType }, +} + +#[derive(Copy, Clone, Eq, PartialEq)] +pub enum AtomBitOp { + And, + Or, + Xor, + Exchange, +} + +#[derive(Copy, Clone, Eq, PartialEq)] +pub enum AtomUIntOp { + Add, + Inc, + Dec, + Min, + Max, +} + +#[derive(Copy, Clone, Eq, PartialEq)] +pub enum AtomSIntOp { + Add, + Min, + Max, +} + +#[derive(Copy, Clone, Eq, PartialEq)] +pub enum AtomFloatOp { + Add, +} + +#[derive(Copy, Clone)] +pub struct AtomCasDetails { + pub semantics: AtomSemantics, + pub scope: MemScope, + pub space: AtomSpace, + pub typ: BitType +} + pub enum NumsOrArrays<'a> { Nums(Vec<(&'a str, u32)>), Arrays(Vec>), diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index dfe5a5f..806a3fc 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -35,9 +35,12 @@ match { "<", ">", "|", "=", + ".acq_rel", ".acquire", + ".add", ".address_size", ".align", + ".aligned", ".and", ".approx", ".b16", @@ -45,14 +48,17 @@ match { ".b64", ".b8", ".ca", + ".cas", ".cg", ".const", ".cs", ".cta", ".cv", + ".dec", ".entry", ".eq", ".equ", + ".exch", ".extern", ".f16", ".f16x2", @@ -69,6 +75,7 @@ match { ".gtu", ".hi", ".hs", + ".inc", ".le", ".leu", ".lo", @@ -78,6 +85,8 @@ match { ".lt", ".ltu", ".lu", + ".max", + ".min", ".nan", ".NaN", ".ne", @@ -88,6 +97,7 @@ match { ".pred", ".reg", ".relaxed", + ".release", ".rm", ".rmi", ".rn", @@ -103,6 +113,7 @@ match { ".sat", ".section", ".shared", + ".sync", ".sys", ".target", ".to", @@ -126,6 +137,9 @@ match { "abs", "add", "and", + "atom", + "bar", + "barrier", "bra", "call", "cvt", @@ -162,6 +176,9 @@ ExtendedID : &'input str = { "abs", "add", "and", + "atom", + "bar", + "barrier", "bra", "call", "cvt", @@ -372,6 +389,7 @@ StateSpaceSpecifier: ast::StateSpace = { ".param" => ast::StateSpace::Param, // used to prepare function call }; +#[inline] ScalarType: ast::ScalarType = { ".f16" => ast::ScalarType::F16, ".f16x2" => ast::ScalarType::F16x2, @@ -438,6 +456,7 @@ Variable: ast::Variable = { let v_type = ast::VariableType::Param(v_type); ast::Variable {align, v_type, name, array_init} }, + SharedVariable, }; RegVariable: (Option, ast::VariableRegType, &'input str) = { @@ -478,6 +497,32 @@ LocalVariable: ast::Variable = { } } +SharedVariable: ast::Variable = { + ".shared" > => { + let (align, t, name) = var; + let v_type = ast::VariableGlobalType::Scalar(t); + ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() } + }, + ".shared" > => { + let (align, v_len, t, name) = var; + let v_type = ast::VariableGlobalType::Vector(t, v_len); + ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() } + }, + ".shared" > =>? { + let (align, t, name, arr_or_ptr) = var; + let (v_type, array_init) = match arr_or_ptr { + ast::ArrayOrPointer::Array { dimensions, init } => { + (ast::VariableGlobalType::Array(t, dimensions), init) + } + ast::ArrayOrPointer::Pointer => { + return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray }); + } + }; + Ok(ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init }) + } +} + + ModuleVariable: ast::Variable = { LinkingDirectives ".global" => { let (align, v_type, name, array_init) = def; @@ -619,7 +664,10 @@ Instruction: ast::Instruction> = { InstMin, InstMax, InstRcp, - InstSelp + InstSelp, + InstBar, + InstAtom, + InstAtomCas }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld @@ -655,14 +703,14 @@ LdStType: ast::LdStType = { LdStQualifier: ast::LdStQualifier = { ".weak" => ast::LdStQualifier::Weak, ".volatile" => ast::LdStQualifier::Volatile, - ".relaxed" => ast::LdStQualifier::Relaxed(s), - ".acquire" => ast::LdStQualifier::Acquire(s), + ".relaxed" => ast::LdStQualifier::Relaxed(s), + ".acquire" => ast::LdStQualifier::Acquire(s), }; -LdScope: ast::LdScope = { - ".cta" => ast::LdScope::Cta, - ".gpu" => ast::LdScope::Gpu, - ".sys" => ast::LdScope::Sys +MemScope: ast::MemScope = { + ".cta" => ast::MemScope::Cta, + ".gpu" => ast::MemScope::Gpu, + ".sys" => ast::MemScope::Sys }; LdStateSpace: ast::LdStateSpace = { @@ -798,6 +846,13 @@ SIntType: ast::SIntType = { ".s64" => ast::SIntType::S64, }; +FloatType: ast::FloatType = { + ".f16" => ast::FloatType::F16, + ".f16x2" => ast::FloatType::F16x2, + ".f32" => ast::FloatType::F32, + ".f64" => ast::FloatType::F64, +}; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-add // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-add @@ -1296,6 +1351,140 @@ SelpType: ast::SelpType = { ".f64" => ast::SelpType::F64, }; +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar +InstBar: ast::Instruction> = { + "barrier" ".sync" ".aligned" => ast::Instruction::Bar(ast::BarDetails::SyncAligned, a), + "bar" ".sync" => ast::Instruction::Bar(ast::BarDetails::SyncAligned, a) +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom +// The documentation does not mention all spported operations: +// * Operation .add requires .u32 or .s32 or .u64 or .f64 or f16 or f16x2 or .f32 +// * Operation .inc requires .u32 type for instuction +// * Operation .dec requires .u32 type for instuction +// Otherwise as documented +InstAtom: ast::Instruction> = { + "atom" => { + let details = ast::AtomDetails { + semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), + scope: scope.unwrap_or(ast::MemScope::Gpu), + space: space.unwrap_or(ast::AtomSpace::Generic), + inner: ast::AtomInnerDetails::Bit { op, typ } + }; + ast::Instruction::Atom(details,a) + }, + "atom" ".inc" ".u32" => { + let details = ast::AtomDetails { + semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), + scope: scope.unwrap_or(ast::MemScope::Gpu), + space: space.unwrap_or(ast::AtomSpace::Generic), + inner: ast::AtomInnerDetails::Unsigned { + op: ast::AtomUIntOp::Inc, + typ: ast::UIntType::U32 + } + }; + ast::Instruction::Atom(details,a) + }, + "atom" ".dec" ".u32" => { + let details = ast::AtomDetails { + semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), + scope: scope.unwrap_or(ast::MemScope::Gpu), + space: space.unwrap_or(ast::AtomSpace::Generic), + inner: ast::AtomInnerDetails::Unsigned { + op: ast::AtomUIntOp::Dec, + typ: ast::UIntType::U32 + } + }; + ast::Instruction::Atom(details,a) + }, + "atom" ".add" => { + let op = ast::AtomFloatOp::Add; + let details = ast::AtomDetails { + semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), + scope: scope.unwrap_or(ast::MemScope::Gpu), + space: space.unwrap_or(ast::AtomSpace::Generic), + inner: ast::AtomInnerDetails::Float { op, typ } + }; + ast::Instruction::Atom(details,a) + }, + "atom" => { + let details = ast::AtomDetails { + semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), + scope: scope.unwrap_or(ast::MemScope::Gpu), + space: space.unwrap_or(ast::AtomSpace::Generic), + inner: ast::AtomInnerDetails::Unsigned { op, typ } + }; + ast::Instruction::Atom(details,a) + }, + "atom" => { + let details = ast::AtomDetails { + semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), + scope: scope.unwrap_or(ast::MemScope::Gpu), + space: space.unwrap_or(ast::AtomSpace::Generic), + inner: ast::AtomInnerDetails::Signed { op, typ } + }; + ast::Instruction::Atom(details,a) + } +} + +InstAtomCas: ast::Instruction> = { + "atom" ".cas" => { + let details = ast::AtomCasDetails { + semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), + scope: scope.unwrap_or(ast::MemScope::Gpu), + space: space.unwrap_or(ast::AtomSpace::Generic), + typ, + }; + ast::Instruction::AtomCas(details,a) + }, +} + +AtomSemantics: ast::AtomSemantics = { + ".relaxed" => ast::AtomSemantics::Relaxed, + ".acquire" => ast::AtomSemantics::Acquire, + ".release" => ast::AtomSemantics::Release, + ".acq_rel" => ast::AtomSemantics::AcquireRelease +} + +AtomSpace: ast::AtomSpace = { + ".global" => ast::AtomSpace::Global, + ".shared" => ast::AtomSpace::Shared +} + +AtomBitOp: ast::AtomBitOp = { + ".and" => ast::AtomBitOp::And, + ".or" => ast::AtomBitOp::Or, + ".xor" => ast::AtomBitOp::Xor, + ".exch" => ast::AtomBitOp::Exchange, +} + +AtomUIntOp: ast::AtomUIntOp = { + ".add" => ast::AtomUIntOp::Add, + ".min" => ast::AtomUIntOp::Min, + ".max" => ast::AtomUIntOp::Max, +} + +AtomSIntOp: ast::AtomSIntOp = { + ".add" => ast::AtomSIntOp::Add, + ".min" => ast::AtomSIntOp::Min, + ".max" => ast::AtomSIntOp::Max, +} + +AtomBitType: ast::BitType = { + ".b32" => ast::BitType::B32, + ".b64" => ast::BitType::B64, +} + +AtomUIntType: ast::UIntType = { + ".u32" => ast::UIntType::U32, + ".u64" => ast::UIntType::U64, +} + +AtomSIntType: ast::SIntType = { + ".s32" => ast::SIntType::S32, + ".s64" => ast::SIntType::S64, +} + ArithDetails: ast::ArithDetails = { => ast::ArithDetails::Unsigned(t), => ast::ArithDetails::Signed(ast::ArithSInt { @@ -1414,6 +1603,10 @@ Arg1: ast::Arg1> = { => ast::Arg1{<>} }; +Arg1Bar: ast::Arg1Bar> = { + => ast::Arg1Bar{<>} +}; + Arg2: ast::Arg2> = { "," => ast::Arg2{<>} }; @@ -1448,10 +1641,18 @@ Arg3: ast::Arg3> = { "," "," => ast::Arg3{<>} }; +Arg3Atom: ast::Arg3> = { + "," "[" "]" "," => ast::Arg3{<>} +}; + Arg4: ast::Arg4> = { "," "," "," => ast::Arg4{<>} }; +Arg4Atom: ast::Arg4> = { + "," "[" "]" "," "," => ast::Arg4{<>} +}; + Arg4Setp: ast::Arg4Setp> = { "," "," => ast::Arg4Setp{<>} }; diff --git a/ptx/src/test/spirv_build/bar_sync.ptx b/ptx/src/test/spirv_build/bar_sync.ptx new file mode 100644 index 0000000..54c6663 --- /dev/null +++ b/ptx/src/test/spirv_build/bar_sync.ptx @@ -0,0 +1,10 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry bar_sync() +{ + .reg .u32 temp_32; + bar.sync temp_32; + ret; +} diff --git a/ptx/src/test/spirv_run/and.spvtxt b/ptx/src/test/spirv_run/and.spvtxt index 9b72477..8358c28 100644 --- a/ptx/src/test/spirv_run/and.spvtxt +++ b/ptx/src/test/spirv_run/and.spvtxt @@ -11,8 +11,8 @@ OpCapability Int16 OpCapability Int64 OpCapability Float16 OpCapability Float64 -OpCapability FunctionFloatControlINTEL -OpExtension "SPV_INTEL_float_controls2" +; OpCapability FunctionFloatControlINTEL +; OpExtension "SPV_INTEL_float_controls2" %33 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "and" diff --git a/ptx/src/test/spirv_run/atom_add.ptx b/ptx/src/test/spirv_run/atom_add.ptx new file mode 100644 index 0000000..5d1f667 --- /dev/null +++ b/ptx/src/test/spirv_run/atom_add.ptx @@ -0,0 +1,28 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry atom_add( + .param .u64 input, + .param .u64 output +) +{ + .shared .align 4 .b8 shared_mem[1024]; + + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u32 temp1; + .reg .u32 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u32 temp1, [in_addr]; + ld.u32 temp2, [in_addr+4]; + st.shared.u32 [shared_mem], temp1; + atom.shared.add.u32 temp1, [shared_mem], temp2; + ld.shared.u32 temp2, [shared_mem]; + st.u32 [out_addr], temp1; + st.u32 [out_addr+4], temp2; + ret; +} diff --git a/ptx/src/test/spirv_run/atom_add.spvtxt b/ptx/src/test/spirv_run/atom_add.spvtxt new file mode 100644 index 0000000..2c83fe9 --- /dev/null +++ b/ptx/src/test/spirv_run/atom_add.spvtxt @@ -0,0 +1,84 @@ +; SPIR-V +; Version: 1.3 +; Generator: rspirv +; Bound: 55 +OpCapability GenericPointer +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpCapability Int8 +OpCapability Int16 +OpCapability Int64 +OpCapability Float16 +OpCapability Float64 +; OpCapability FunctionFloatControlINTEL +; OpExtension "SPV_INTEL_float_controls2" +%40 = OpExtInstImport "OpenCL.std" +OpMemoryModel Physical64 OpenCL +OpEntryPoint Kernel %1 "atom_add" %4 +OpDecorate %4 Alignment 4 +%41 = OpTypeVoid +%42 = OpTypeInt 32 0 +%43 = OpTypeInt 8 0 +%44 = OpConstant %42 1024 +%45 = OpTypeArray %43 %44 +%46 = OpTypePointer Workgroup %45 +%4 = OpVariable %46 Workgroup +%47 = OpTypeInt 64 0 +%48 = OpTypeFunction %41 %47 %47 +%49 = OpTypePointer Function %47 +%50 = OpTypePointer Function %42 +%51 = OpTypePointer Generic %42 +%27 = OpConstant %47 4 +%52 = OpTypePointer Workgroup %42 +%53 = OpConstant %42 1 +%54 = OpConstant %42 0 +%29 = OpConstant %47 4 +%1 = OpFunction %41 None %48 +%9 = OpFunctionParameter %47 +%10 = OpFunctionParameter %47 +%38 = OpLabel +%2 = OpVariable %49 Function +%3 = OpVariable %49 Function +%5 = OpVariable %49 Function +%6 = OpVariable %49 Function +%7 = OpVariable %50 Function +%8 = OpVariable %50 Function +OpStore %2 %9 +OpStore %3 %10 +%12 = OpLoad %47 %2 +%11 = OpCopyObject %47 %12 +OpStore %5 %11 +%14 = OpLoad %47 %3 +%13 = OpCopyObject %47 %14 +OpStore %6 %13 +%16 = OpLoad %47 %5 +%31 = OpConvertUToPtr %51 %16 +%15 = OpLoad %42 %31 +OpStore %7 %15 +%18 = OpLoad %47 %5 +%28 = OpIAdd %47 %18 %27 +%32 = OpConvertUToPtr %51 %28 +%17 = OpLoad %42 %32 +OpStore %8 %17 +%19 = OpLoad %42 %7 +%33 = OpBitcast %52 %4 +OpStore %33 %19 +%21 = OpLoad %42 %8 +%34 = OpBitcast %52 %4 +%20 = OpAtomicIAdd %42 %34 %53 %54 %21 +OpStore %7 %20 +%35 = OpBitcast %52 %4 +%22 = OpLoad %42 %35 +OpStore %8 %22 +%23 = OpLoad %47 %6 +%24 = OpLoad %42 %7 +%36 = OpConvertUToPtr %51 %23 +OpStore %36 %24 +%25 = OpLoad %47 %6 +%26 = OpLoad %42 %8 +%30 = OpIAdd %47 %25 %29 +%37 = OpConvertUToPtr %51 %30 +OpStore %37 %26 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/ptx/src/test/spirv_run/atom_cas.ptx b/ptx/src/test/spirv_run/atom_cas.ptx new file mode 100644 index 0000000..440a1cb --- /dev/null +++ b/ptx/src/test/spirv_run/atom_cas.ptx @@ -0,0 +1,24 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry atom_cas( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u32 temp1; + .reg .u32 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u32 temp1, [in_addr]; + atom.cas.b32 temp1, [in_addr+4], temp1, 100; + ld.u32 temp2, [in_addr+4]; + st.u32 [out_addr], temp1; + st.u32 [out_addr+4], temp2; + ret; +} diff --git a/ptx/src/test/spirv_run/atom_cas.spvtxt b/ptx/src/test/spirv_run/atom_cas.spvtxt new file mode 100644 index 0000000..c5fb922 --- /dev/null +++ b/ptx/src/test/spirv_run/atom_cas.spvtxt @@ -0,0 +1,77 @@ +; SPIR-V +; Version: 1.3 +; Generator: rspirv +; Bound: 51 +OpCapability GenericPointer +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpCapability Int8 +OpCapability Int16 +OpCapability Int64 +OpCapability Float16 +OpCapability Float64 +; OpCapability FunctionFloatControlINTEL +; OpExtension "SPV_INTEL_float_controls2" +%41 = OpExtInstImport "OpenCL.std" +OpMemoryModel Physical64 OpenCL +OpEntryPoint Kernel %1 "atom_cas" +%42 = OpTypeVoid +%43 = OpTypeInt 64 0 +%44 = OpTypeFunction %42 %43 %43 +%45 = OpTypePointer Function %43 +%46 = OpTypeInt 32 0 +%47 = OpTypePointer Function %46 +%48 = OpTypePointer Generic %46 +%25 = OpConstant %43 4 +%27 = OpConstant %46 100 +%49 = OpConstant %46 1 +%50 = OpConstant %46 0 +%28 = OpConstant %43 4 +%30 = OpConstant %43 4 +%1 = OpFunction %42 None %44 +%8 = OpFunctionParameter %43 +%9 = OpFunctionParameter %43 +%39 = OpLabel +%2 = OpVariable %45 Function +%3 = OpVariable %45 Function +%4 = OpVariable %45 Function +%5 = OpVariable %45 Function +%6 = OpVariable %47 Function +%7 = OpVariable %47 Function +OpStore %2 %8 +OpStore %3 %9 +%11 = OpLoad %43 %2 +%10 = OpCopyObject %43 %11 +OpStore %4 %10 +%13 = OpLoad %43 %3 +%12 = OpCopyObject %43 %13 +OpStore %5 %12 +%15 = OpLoad %43 %4 +%32 = OpConvertUToPtr %48 %15 +%14 = OpLoad %46 %32 +OpStore %6 %14 +%17 = OpLoad %43 %4 +%18 = OpLoad %46 %6 +%26 = OpIAdd %43 %17 %25 +%34 = OpConvertUToPtr %48 %26 +%35 = OpCopyObject %46 %18 +%33 = OpAtomicCompareExchange %46 %34 %49 %50 %50 %27 %35 +%16 = OpCopyObject %46 %33 +OpStore %6 %16 +%20 = OpLoad %43 %4 +%29 = OpIAdd %43 %20 %28 +%36 = OpConvertUToPtr %48 %29 +%19 = OpLoad %46 %36 +OpStore %7 %19 +%21 = OpLoad %43 %5 +%22 = OpLoad %46 %6 +%37 = OpConvertUToPtr %48 %21 +OpStore %37 %22 +%23 = OpLoad %43 %5 +%24 = OpLoad %46 %7 +%31 = OpIAdd %43 %23 %30 +%38 = OpConvertUToPtr %48 %31 +OpStore %38 %24 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/ptx/src/test/spirv_run/atom_inc.ptx b/ptx/src/test/spirv_run/atom_inc.ptx new file mode 100644 index 0000000..ed3df08 --- /dev/null +++ b/ptx/src/test/spirv_run/atom_inc.ptx @@ -0,0 +1,26 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry atom_inc( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u32 temp1; + .reg .u32 temp2; + .reg .u32 temp3; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + atom.inc.u32 temp1, [in_addr], 101; + atom.global.inc.u32 temp2, [in_addr], 101; + ld.u32 temp3, [in_addr]; + st.u32 [out_addr], temp1; + st.u32 [out_addr+4], temp2; + st.u32 [out_addr+8], temp3; + ret; +} diff --git a/ptx/src/test/spirv_run/atom_inc.spvtxt b/ptx/src/test/spirv_run/atom_inc.spvtxt new file mode 100644 index 0000000..6948cd9 --- /dev/null +++ b/ptx/src/test/spirv_run/atom_inc.spvtxt @@ -0,0 +1,89 @@ +; SPIR-V +; Version: 1.3 +; Generator: rspirv +; Bound: 60 +OpCapability GenericPointer +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpCapability Int8 +OpCapability Int16 +OpCapability Int64 +OpCapability Float16 +OpCapability Float64 +; OpCapability FunctionFloatControlINTEL +; OpExtension "SPV_INTEL_float_controls2" +%49 = OpExtInstImport "OpenCL.std" +OpMemoryModel Physical64 OpenCL +OpEntryPoint Kernel %1 "atom_inc" +OpDecorate %40 LinkageAttributes "__notcuda_ptx_impl__atom_relaxed_gpu_generic_inc" Import +OpDecorate %44 LinkageAttributes "__notcuda_ptx_impl__atom_relaxed_gpu_global_inc" Import +%50 = OpTypeVoid +%51 = OpTypeInt 32 0 +%52 = OpTypePointer Generic %51 +%53 = OpTypeFunction %51 %52 %51 +%54 = OpTypePointer CrossWorkgroup %51 +%55 = OpTypeFunction %51 %54 %51 +%56 = OpTypeInt 64 0 +%57 = OpTypeFunction %50 %56 %56 +%58 = OpTypePointer Function %56 +%59 = OpTypePointer Function %51 +%27 = OpConstant %51 101 +%28 = OpConstant %51 101 +%29 = OpConstant %56 4 +%31 = OpConstant %56 8 +%40 = OpFunction %51 None %53 +%42 = OpFunctionParameter %52 +%43 = OpFunctionParameter %51 +OpFunctionEnd +%44 = OpFunction %51 None %55 +%46 = OpFunctionParameter %54 +%47 = OpFunctionParameter %51 +OpFunctionEnd +%1 = OpFunction %50 None %57 +%9 = OpFunctionParameter %56 +%10 = OpFunctionParameter %56 +%39 = OpLabel +%2 = OpVariable %58 Function +%3 = OpVariable %58 Function +%4 = OpVariable %58 Function +%5 = OpVariable %58 Function +%6 = OpVariable %59 Function +%7 = OpVariable %59 Function +%8 = OpVariable %59 Function +OpStore %2 %9 +OpStore %3 %10 +%12 = OpLoad %56 %2 +%11 = OpCopyObject %56 %12 +OpStore %4 %11 +%14 = OpLoad %56 %3 +%13 = OpCopyObject %56 %14 +OpStore %5 %13 +%16 = OpLoad %56 %4 +%33 = OpConvertUToPtr %52 %16 +%15 = OpFunctionCall %51 %40 %33 %27 +OpStore %6 %15 +%18 = OpLoad %56 %4 +%34 = OpConvertUToPtr %54 %18 +%17 = OpFunctionCall %51 %44 %34 %28 +OpStore %7 %17 +%20 = OpLoad %56 %4 +%35 = OpConvertUToPtr %52 %20 +%19 = OpLoad %51 %35 +OpStore %8 %19 +%21 = OpLoad %56 %5 +%22 = OpLoad %51 %6 +%36 = OpConvertUToPtr %52 %21 +OpStore %36 %22 +%23 = OpLoad %56 %5 +%24 = OpLoad %51 %7 +%30 = OpIAdd %56 %23 %29 +%37 = OpConvertUToPtr %52 %30 +OpStore %37 %24 +%25 = OpLoad %56 %5 +%26 = OpLoad %51 %8 +%32 = OpIAdd %56 %25 %31 +%38 = OpConvertUToPtr %52 %32 +OpStore %38 %26 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/ptx/src/test/spirv_run/constant_f32.spvtxt b/ptx/src/test/spirv_run/constant_f32.spvtxt index 905bec4..27c5f4e 100644 --- a/ptx/src/test/spirv_run/constant_f32.spvtxt +++ b/ptx/src/test/spirv_run/constant_f32.spvtxt @@ -11,12 +11,12 @@ OpCapability Int16 OpCapability Int64 OpCapability Float16 OpCapability Float64 -OpCapability FunctionFloatControlINTEL -OpExtension "SPV_INTEL_float_controls2" +; OpCapability FunctionFloatControlINTEL +; OpExtension "SPV_INTEL_float_controls2" %24 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "constant_f32" -OpDecorate %1 FunctionDenormModeINTEL 32 Preserve +; OpDecorate %1 FunctionDenormModeINTEL 32 Preserve %25 = OpTypeVoid %26 = OpTypeInt 64 0 %27 = OpTypeFunction %25 %26 %26 diff --git a/ptx/src/test/spirv_run/constant_negative.spvtxt b/ptx/src/test/spirv_run/constant_negative.spvtxt index 39e5d19..ec2ff72 100644 --- a/ptx/src/test/spirv_run/constant_negative.spvtxt +++ b/ptx/src/test/spirv_run/constant_negative.spvtxt @@ -11,8 +11,8 @@ OpCapability Int16 OpCapability Int64 OpCapability Float16 OpCapability Float64 -OpCapability FunctionFloatControlINTEL -OpExtension "SPV_INTEL_float_controls2" +; OpCapability FunctionFloatControlINTEL +; OpExtension "SPV_INTEL_float_controls2" %24 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "constant_negative" diff --git a/ptx/src/test/spirv_run/fma.spvtxt b/ptx/src/test/spirv_run/fma.spvtxt index 734bf0f..4a90d09 100644 --- a/ptx/src/test/spirv_run/fma.spvtxt +++ b/ptx/src/test/spirv_run/fma.spvtxt @@ -11,12 +11,12 @@ OpCapability Int16 OpCapability Int64 OpCapability Float16 OpCapability Float64 -OpCapability FunctionFloatControlINTEL -OpExtension "SPV_INTEL_float_controls2" +; OpCapability FunctionFloatControlINTEL +; OpExtension "SPV_INTEL_float_controls2" %37 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "fma" -OpDecorate %1 FunctionDenormModeINTEL 32 Preserve +; OpDecorate %1 FunctionDenormModeINTEL 32 Preserve %38 = OpTypeVoid %39 = OpTypeInt 64 0 %40 = OpTypeFunction %38 %39 %39 diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 98b9630..40a9d64 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -86,12 +86,20 @@ test_ptx!(rcp, [2f32], [0.5f32]); // 0x3f000000 is 0.5 // TODO: mul_ftz fails because IGC does not yet handle SPV_INTEL_float_controls2 // test_ptx!(mul_ftz, [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], [0u32]); -test_ptx!(mul_non_ftz, [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], [0b1_00000000_01000000000000000000000u32]); +test_ptx!( + mul_non_ftz, + [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], + [0b1_00000000_01000000000000000000000u32] +); test_ptx!(constant_f32, [10f32], [5f32]); test_ptx!(constant_negative, [-101i32], [101i32]); test_ptx!(and, [6u32, 3u32], [2u32]); test_ptx!(selp, [100u16, 200u16], [200u16]); -test_ptx!(fma, [2f32, 3f32, 5f32], [11f32]); +test_ptx!(fma, [2f32, 3f32, 5f32], [11f32]); +test_ptx!(shared_variable, [513u64], [513u64]); +test_ptx!(atom_cas, [91u32, 91u32], [91u32, 100u32]); +test_ptx!(atom_inc, [100u32], [100u32, 101u32, 0u32]); +test_ptx!(atom_add, [2u32, 4u32], [2u32, 6u32]); struct DisplayError { err: T, @@ -124,7 +132,7 @@ fn test_ptx_assert<'a, T: From + ze::SafeRepr + Debug + Copy + PartialEq>( let name = CString::new(name)?; let result = run_spirv(name.as_c_str(), notcuda_module, input, output) .map_err(|err| DisplayError { err })?; - assert_eq!(output, result.as_slice()); + assert_eq!(result.as_slice(), output); Ok(()) } @@ -145,8 +153,8 @@ fn run_spirv + ze::SafeRepr + Copy + Debug>( let use_shared_mem = module .kernel_info .get(name.to_str().unwrap()) - .unwrap() - .uses_shared_mem; + .map(|info| info.uses_shared_mem) + .unwrap_or(false); let mut result = vec![0u8.into(); output.len()]; { let mut drivers = ze::Driver::get()?; @@ -155,11 +163,20 @@ fn run_spirv + ze::SafeRepr + Copy + Debug>( let mut devices = drv.devices()?; let dev = devices.drain(0..1).next().unwrap(); let queue = ze::CommandQueue::new(&mut ctx, &dev)?; - let (module, log) = ze::Module::new_spirv(&mut ctx, &dev, byte_il, None); + let (module, maybe_log) = match module.should_link_ptx_impl { + Some(ptx_impl) => ze::Module::build_link_spirv(&mut ctx, &dev, &[ptx_impl, byte_il]), + None => { + let (module, log) = ze::Module::build_spirv(&mut ctx, &dev, byte_il, None); + (module, Some(log)) + } + }; let module = match module { Ok(m) => m, Err(err) => { - let raw_err_string = log.get_cstring()?; + let raw_err_string = maybe_log + .map(|log| log.get_cstring()) + .transpose()? + .unwrap_or(CString::default()); let err_string = raw_err_string.to_string_lossy(); panic!("{:?}\n{}", err, err_string); } @@ -215,7 +232,11 @@ fn test_spvtxt_assert<'a>( ptr::null_mut(), ) }; - assert!(result == spv_result_t::SPV_SUCCESS); + if result != spv_result_t::SPV_SUCCESS { + panic!("{:?}\n{}", result, unsafe { + str::from_utf8_unchecked(spirv_txt) + }); + } let mut parsed_spirv = Vec::::new(); let result = unsafe { spirv_tools::spvBinaryParse( diff --git a/ptx/src/test/spirv_run/mul_ftz.spvtxt b/ptx/src/test/spirv_run/mul_ftz.spvtxt index da6a12a..56cec5a 100644 --- a/ptx/src/test/spirv_run/mul_ftz.spvtxt +++ b/ptx/src/test/spirv_run/mul_ftz.spvtxt @@ -11,8 +11,8 @@ OpCapability Int16 OpCapability Int64 OpCapability Float16 OpCapability Float64 -OpCapability FunctionFloatControlINTEL -OpExtension "SPV_INTEL_float_controls2" +; OpCapability FunctionFloatControlINTEL +; OpExtension "SPV_INTEL_float_controls2" %30 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "mul_ftz" diff --git a/ptx/src/test/spirv_run/selp.spvtxt b/ptx/src/test/spirv_run/selp.spvtxt index dffd9af..6f73bc2 100644 --- a/ptx/src/test/spirv_run/selp.spvtxt +++ b/ptx/src/test/spirv_run/selp.spvtxt @@ -11,8 +11,8 @@ OpCapability Int16 OpCapability Int64 OpCapability Float16 OpCapability Float64 -OpCapability FunctionFloatControlINTEL -OpExtension "SPV_INTEL_float_controls2" +; OpCapability FunctionFloatControlINTEL +; OpExtension "SPV_INTEL_float_controls2" %31 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "selp" diff --git a/ptx/src/test/spirv_run/shared_variable.ptx b/ptx/src/test/spirv_run/shared_variable.ptx new file mode 100644 index 0000000..4f7eff3 --- /dev/null +++ b/ptx/src/test/spirv_run/shared_variable.ptx @@ -0,0 +1,26 @@ +.version 6.5 +.target sm_30 +.address_size 64 + + +.visible .entry shared_variable( + .param .u64 input, + .param .u64 output +) +{ + .shared .align 4 .b8 shared_mem1[128]; + + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 temp1; + .reg .u64 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.global.u64 temp1, [in_addr]; + st.shared.u64 [shared_mem1], temp1; + ld.shared.u64 temp2, [shared_mem1]; + st.global.u64 [out_addr], temp2; + ret; +} \ No newline at end of file diff --git a/ptx/src/test/spirv_run/shared_variable.spvtxt b/ptx/src/test/spirv_run/shared_variable.spvtxt new file mode 100644 index 0000000..1af2bd1 --- /dev/null +++ b/ptx/src/test/spirv_run/shared_variable.spvtxt @@ -0,0 +1,65 @@ +; SPIR-V +; Version: 1.3 +; Generator: rspirv +; Bound: 39 +OpCapability GenericPointer +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpCapability Int8 +OpCapability Int16 +OpCapability Int64 +OpCapability Float16 +OpCapability Float64 +; OpCapability FunctionFloatControlINTEL +; OpExtension "SPV_INTEL_float_controls2" +%27 = OpExtInstImport "OpenCL.std" +OpMemoryModel Physical64 OpenCL +OpEntryPoint Kernel %1 "shared_variable" %4 +OpDecorate %4 Alignment 4 +%28 = OpTypeVoid +%29 = OpTypeInt 32 0 +%30 = OpTypeInt 8 0 +%31 = OpConstant %29 128 +%32 = OpTypeArray %30 %31 +%33 = OpTypePointer Workgroup %32 +%4 = OpVariable %33 Workgroup +%34 = OpTypeInt 64 0 +%35 = OpTypeFunction %28 %34 %34 +%36 = OpTypePointer Function %34 +%37 = OpTypePointer CrossWorkgroup %34 +%38 = OpTypePointer Workgroup %34 +%1 = OpFunction %28 None %35 +%9 = OpFunctionParameter %34 +%10 = OpFunctionParameter %34 +%25 = OpLabel +%2 = OpVariable %36 Function +%3 = OpVariable %36 Function +%5 = OpVariable %36 Function +%6 = OpVariable %36 Function +%7 = OpVariable %36 Function +%8 = OpVariable %36 Function +OpStore %2 %9 +OpStore %3 %10 +%12 = OpLoad %34 %2 +%11 = OpCopyObject %34 %12 +OpStore %5 %11 +%14 = OpLoad %34 %3 +%13 = OpCopyObject %34 %14 +OpStore %6 %13 +%16 = OpLoad %34 %5 +%21 = OpConvertUToPtr %37 %16 +%15 = OpLoad %34 %21 +OpStore %7 %15 +%17 = OpLoad %34 %7 +%22 = OpBitcast %38 %4 +OpStore %22 %17 +%23 = OpBitcast %38 %4 +%18 = OpLoad %34 %23 +OpStore %8 %18 +%19 = OpLoad %34 %6 +%20 = OpLoad %34 %8 +%24 = OpConvertUToPtr %37 %19 +OpStore %24 %20 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index a7025b1..6b07c0f 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,14 +1,13 @@ use crate::ast; use half::f16; use rspirv::{binary::Disassemble, dr}; +use std::collections::{hash_map, HashMap, HashSet}; use std::{borrow::Cow, hash::Hash, iter, mem}; -use std::{ - collections::{hash_map, HashMap, HashSet}, - convert::TryFrom, -}; use rspirv::binary::Assemble; +static NOTCUDA_PTX_IMPL: &'static [u8] = include_bytes!("../lib/notcuda_ptx_impl.spv"); + quick_error! { #[derive(Debug)] pub enum TranslateError { @@ -69,6 +68,7 @@ impl Into for ast::PointerStateSpace { ast::PointerStateSpace::Global => spirv::StorageClass::CrossWorkgroup, ast::PointerStateSpace::Shared => spirv::StorageClass::Workgroup, ast::PointerStateSpace::Param => spirv::StorageClass::Function, + ast::PointerStateSpace::Generic => spirv::StorageClass::Generic, } } } @@ -419,6 +419,7 @@ impl TypeWordMap { pub struct Module { pub spirv: dr::Module, pub kernel_info: HashMap, + pub should_link_ptx_impl: Option<&'static [u8]>, } pub struct KernelInfo { @@ -428,15 +429,22 @@ pub struct KernelInfo { pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result { let mut id_defs = GlobalStringIdResolver::new(1); + let mut ptx_impl_imports = HashMap::new(); let directives = ast .directives .into_iter() - .map(|f| translate_directive(&mut id_defs, f)) + .map(|directive| translate_directive(&mut id_defs, &mut ptx_impl_imports, directive)) .collect::, _>>()?; + let must_link_ptx_impl = ptx_impl_imports.len() > 0; + let directives = ptx_impl_imports + .into_iter() + .map(|(_, v)| v) + .chain(directives.into_iter()) + .collect::>(); let mut builder = dr::Builder::new(); builder.reserve_ids(id_defs.current_id()); - let mut directives = - convert_dynamic_shared_memory_usage(&mut id_defs, directives, &mut || builder.id()); + let call_map = get_call_map(&directives); + let mut directives = convert_dynamic_shared_memory_usage(directives, &mut || builder.id()); normalize_variable_decls(&mut directives); let denorm_information = compute_denorm_information(&directives); // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module @@ -448,32 +456,142 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + id_defs: &GlobalStringIdResolver<'input>, + opencl_id: spirv::Word, + denorm_information: &HashMap, HashMap>, + call_map: &HashMap<&'input str, HashSet>, + directives: Vec, + kernel_info: &mut HashMap, +) -> Result<(), TranslateError> { + let empty_body = Vec::new(); + for d in directives.iter() { match d { Directive::Variable(var) => { - emit_variable(&mut builder, &mut map, &var)?; + emit_variable(builder, map, &var)?; } Directive::Method(f) => { - let f_body = match f.body { + let f_body = match &f.body { Some(f) => f, - None => continue, + None => { + if f.import_as.is_some() { + &empty_body + } else { + continue; + } + } }; - emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.globals)?; + for var in f.globals.iter() { + emit_variable(builder, map, var)?; + } emit_function_header( - &mut builder, - &mut map, + builder, + map, &id_defs, - f.func_decl, + &f.globals, + &f.func_decl, &denorm_information, - &mut kernel_info, + call_map, + &directives, + kernel_info, )?; - emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?; + emit_function_body_ops(builder, map, opencl_id, &f_body)?; builder.end_function()?; + if let (ast::MethodDecl::Func(_, fn_id, _), Some(name)) = + (&f.func_decl, &f.import_as) + { + builder.decorate( + *fn_id, + spirv::Decoration::LinkageAttributes, + &[ + dr::Operand::LiteralString(name.clone()), + dr::Operand::LinkageType(spirv::LinkageType::Import), + ], + ); + } } } } - let spirv = builder.module(); - Ok(Module { spirv, kernel_info }) + Ok(()) +} + +fn get_call_map<'input>( + module: &[Directive<'input>], +) -> HashMap<&'input str, HashSet> { + let mut directly_called_by = HashMap::new(); + for directive in module { + match directive { + Directive::Method(Function { + func_decl, + body: Some(statements), + .. + }) => { + let call_key = CallgraphKey::new(&func_decl); + for statement in statements { + match statement { + Statement::Call(call) => { + multi_hash_map_append(&mut directly_called_by, call_key, call.func); + } + _ => {} + } + } + } + _ => {} + } + } + let mut result = HashMap::new(); + for (method_key, children) in directly_called_by.iter() { + match method_key { + CallgraphKey::Kernel(name) => { + let mut visited = HashSet::new(); + for child in children { + add_call_map_single(&directly_called_by, &mut visited, *child); + } + result.insert(*name, visited); + } + CallgraphKey::Func(_) => {} + } + } + result +} + +fn add_call_map_single<'input>( + directly_called_by: &MultiHashMap, spirv::Word>, + visited: &mut HashSet, + current: spirv::Word, +) { + if !visited.insert(current) { + return; + } + if let Some(children) = directly_called_by.get(&CallgraphKey::Func(current)) { + for child in children { + add_call_map_single(directly_called_by, visited, *child); + } + } } type MultiHashMap = HashMap>; @@ -495,7 +613,6 @@ fn multi_hash_map_append(m: &mut MultiHashMap, // This pass looks for all uses of .extern .shared and converts them to // an additional method argument fn convert_dynamic_shared_memory_usage<'input>( - id_defs: &mut GlobalStringIdResolver<'input>, module: Vec>, new_id: &mut impl FnMut() -> spirv::Word, ) -> Vec> { @@ -524,6 +641,7 @@ fn convert_dynamic_shared_memory_usage<'input>( func_decl, globals, body: Some(statements), + import_as, }) => { let call_key = CallgraphKey::new(&func_decl); let statements = statements @@ -545,6 +663,7 @@ fn convert_dynamic_shared_memory_usage<'input>( func_decl, globals, body: Some(statements), + import_as, }) } directive => directive, @@ -561,6 +680,7 @@ fn convert_dynamic_shared_memory_usage<'input>( mut func_decl, globals, body: Some(statements), + import_as, }) => { let call_key = CallgraphKey::new(&func_decl); if !methods_using_extern_shared.contains(&call_key) { @@ -568,6 +688,7 @@ fn convert_dynamic_shared_memory_usage<'input>( func_decl, globals, body: Some(statements), + import_as, }); } let shared_id_param = new_id(); @@ -625,6 +746,7 @@ fn convert_dynamic_shared_memory_usage<'input>( func_decl, globals, body: Some(new_statements), + import_as, }) } directive => directive, @@ -744,15 +866,6 @@ fn denorm_count_map_update_impl( } } -fn denorm_count_map_merge( - dst: &mut DenormCountMap, - src: &DenormCountMap, -) { - for (k, count) in src { - denorm_count_map_update_impl(dst, *k, *count); - } -} - // HACK ALERT! // This function is a "good enough" heuristic of whetever to mark f16/f32 operations // in the kernel as flushing denorms to zero or preserving them @@ -763,7 +876,7 @@ fn compute_denorm_information<'input>( module: &[Directive<'input>], ) -> HashMap, HashMap> { let mut denorm_methods = HashMap::new(); - for directive in module.iter() { + for directive in module { match directive { Directive::Variable(_) | Directive::Method(Function { body: None, .. }) => {} Directive::Method(Function { @@ -861,9 +974,12 @@ fn emit_builtins( fn emit_function_header<'a>( builder: &mut dr::Builder, map: &mut TypeWordMap, - global: &GlobalStringIdResolver<'a>, - func_directive: ast::MethodDecl, + defined_globals: &GlobalStringIdResolver<'a>, + synthetic_globals: &[ast::Variable], + func_directive: &ast::MethodDecl, denorm_information: &HashMap, HashMap>, + call_map: &HashMap<&'a str, HashSet>, + direcitves: &[Directive], kernel_info: &mut HashMap, ) -> Result<(), TranslateError> { if let ast::MethodDecl::Kernel { @@ -884,22 +1000,49 @@ fn emit_function_header<'a>( let (ret_type, func_type) = get_function_type(builder, map, &func_directive); let fn_id = match func_directive { ast::MethodDecl::Kernel { name, .. } => { - let fn_id = global.get_id(name)?; - let mut global_variables = global + let fn_id = defined_globals.get_id(name)?; + let mut global_variables = defined_globals .variables_type_check .iter() .filter_map(|(k, t)| t.as_ref().map(|_| *k)) .collect::>(); - let mut interface = global + let mut interface = defined_globals .special_registers .iter() .map(|(_, id)| *id) .collect::>(); + for ast::Variable { name, .. } in synthetic_globals { + interface.push(*name); + } + let empty_hash_set = HashSet::new(); + let child_fns = call_map.get(name).unwrap_or(&empty_hash_set); + for directive in direcitves { + match directive { + Directive::Method(Function { + func_decl: ast::MethodDecl::Func(_, name, _), + globals, + .. + }) => { + if child_fns.contains(name) { + for var in globals { + interface.push(var.name); + } + } + } + _ => {} + } + } + global_variables.append(&mut interface); - builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, global_variables); + builder.entry_point( + spirv::ExecutionModel::Kernel, + fn_id, + *name, + global_variables, + ); fn_id } - ast::MethodDecl::Func(_, name, _) => name, + ast::MethodDecl::Func(_, name, _) => *name, }; builder.begin_function( ret_type, @@ -934,9 +1077,10 @@ fn emit_function_header<'a>( pub fn to_spirv<'a>( ast: ast::Module<'a>, -) -> Result<(Vec, HashMap>), TranslateError> { +) -> Result<(Option<&'static [u8]>, Vec, HashMap>), TranslateError> { let module = to_spirv_module(ast)?; Ok(( + module.should_link_ptx_impl, module.spirv.assemble(), module .kernel_info @@ -977,11 +1121,14 @@ fn emit_memory_model(builder: &mut dr::Builder) { fn translate_directive<'input>( id_defs: &mut GlobalStringIdResolver<'input>, + ptx_impl_imports: &mut HashMap, d: ast::Directive<'input, ast::ParsedArgParams<'input>>, ) -> Result, TranslateError> { Ok(match d { ast::Directive::Variable(v) => Directive::Variable(translate_variable(id_defs, v)?), - ast::Directive::Method(f) => Directive::Method(translate_function(id_defs, f)?), + ast::Directive::Method(f) => { + Directive::Method(translate_function(id_defs, ptx_impl_imports, f)?) + } }) } @@ -1000,10 +1147,11 @@ fn translate_variable<'a>( fn translate_function<'a>( id_defs: &mut GlobalStringIdResolver<'a>, + ptx_impl_imports: &mut HashMap, f: ast::ParsedFunction<'a>, ) -> Result, TranslateError> { let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive); - to_ssa(str_resolver, fn_resolver, fn_decl, f.body) + to_ssa(ptx_impl_imports, str_resolver, fn_resolver, fn_decl, f.body) } fn expand_kernel_params<'a, 'b>( @@ -1043,6 +1191,7 @@ fn expand_fn_params<'a, 'b>( } fn to_ssa<'input, 'b>( + ptx_impl_imports: &mut HashMap, mut id_defs: FnStringIdResolver<'input, 'b>, fn_defs: GlobalFnDeclResolver<'input, 'b>, f_args: ast::MethodDecl<'input, spirv::Word>, @@ -1055,6 +1204,7 @@ fn to_ssa<'input, 'b>( func_decl: f_args, body: None, globals: Vec::new(), + import_as: None, }) } }; @@ -1071,19 +1221,90 @@ fn to_ssa<'input, 'b>( insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?; let mut numeric_id_defs = numeric_id_defs.unmut(); let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs); - let (f_body, globals) = extract_globals(labeled_statements); + let (f_body, globals) = + extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs); Ok(Function { func_decl: f_args, globals: globals, body: Some(f_body), + import_as: None, }) } -fn extract_globals( +fn extract_globals<'input, 'b>( sorted_statements: Vec, -) -> (Vec, Vec) { - // This fn will be used for SLM - (sorted_statements, Vec::new()) + ptx_impl_imports: &mut HashMap, + id_def: &mut NumericIdResolver, +) -> ( + Vec, + Vec>, +) { + let mut local = Vec::with_capacity(sorted_statements.len()); + let mut global = Vec::new(); + for statement in sorted_statements { + match statement { + Statement::Variable( + var + @ + ast::Variable { + v_type: ast::VariableType::Shared(_), + .. + }, + ) + | Statement::Variable( + var + @ + ast::Variable { + v_type: ast::VariableType::Global(_), + .. + }, + ) => global.push(var), + Statement::Instruction(ast::Instruction::Atom( + d + @ + ast::AtomDetails { + inner: + ast::AtomInnerDetails::Unsigned { + op: ast::AtomUIntOp::Inc, + .. + }, + .. + }, + a, + )) => { + local.push(to_ptx_impl_atomic_call( + id_def, + ptx_impl_imports, + d, + a, + "inc", + )); + } + Statement::Instruction(ast::Instruction::Atom( + d + @ + ast::AtomDetails { + inner: + ast::AtomInnerDetails::Unsigned { + op: ast::AtomUIntOp::Dec, + .. + }, + .. + }, + a, + )) => { + local.push(to_ptx_impl_atomic_call( + id_def, + ptx_impl_imports, + d, + a, + "dec", + )); + } + s => local.push(s), + } + } + (local, global) } fn normalize_variable_decls(directives: &mut Vec) { @@ -1269,6 +1490,15 @@ fn convert_to_typed_statements( ast::Instruction::Selp(d, a) => { result.push(Statement::Instruction(ast::Instruction::Selp(d, a.cast()))) } + ast::Instruction::Bar(d, a) => { + result.push(Statement::Instruction(ast::Instruction::Bar(d, a.cast()))) + } + ast::Instruction::Atom(d, a) => { + result.push(Statement::Instruction(ast::Instruction::Atom(d, a.cast()))) + } + ast::Instruction::AtomCas(d, a) => result.push(Statement::Instruction( + ast::Instruction::AtomCas(d, a.cast()), + )), }, Statement::Label(i) => result.push(Statement::Label(i)), Statement::Variable(v) => result.push(Statement::Variable(v)), @@ -1286,6 +1516,99 @@ fn convert_to_typed_statements( Ok(result) } +fn to_ptx_impl_atomic_call( + id_defs: &mut NumericIdResolver, + ptx_impl_imports: &mut HashMap, + details: ast::AtomDetails, + arg: ast::Arg3, + op: &'static str, +) -> ExpandedStatement { + let semantics = ptx_semantics_name(details.semantics); + let scope = ptx_scope_name(details.scope); + let space = ptx_space_name(details.space); + let fn_name = format!( + "__notcuda_ptx_impl__atom_{}_{}_{}_{}", + semantics, scope, space, op + ); + // TODO: extract to a function + let ptr_space = match details.space { + ast::AtomSpace::Generic => ast::PointerStateSpace::Generic, + ast::AtomSpace::Global => ast::PointerStateSpace::Global, + ast::AtomSpace::Shared => ast::PointerStateSpace::Shared, + }; + let fn_id = match ptx_impl_imports.entry(fn_name) { + hash_map::Entry::Vacant(entry) => { + let fn_id = id_defs.new_id(None); + let func_decl = ast::MethodDecl::Func::( + vec![ast::FnArgument { + align: None, + v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar( + ast::ScalarType::U32, + )), + name: id_defs.new_id(None), + array_init: Vec::new(), + }], + fn_id, + vec![ + ast::FnArgument { + align: None, + v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Pointer( + ast::SizedScalarType::U32, + ptr_space, + )), + name: id_defs.new_id(None), + array_init: Vec::new(), + }, + ast::FnArgument { + align: None, + v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar( + ast::ScalarType::U32, + )), + name: id_defs.new_id(None), + array_init: Vec::new(), + }, + ], + ); + let func = Function { + func_decl, + globals: Vec::new(), + body: None, + import_as: Some(entry.key().clone()), + }; + entry.insert(Directive::Method(func)); + fn_id + } + hash_map::Entry::Occupied(entry) => match entry.get() { + Directive::Method(Function { + func_decl: ast::MethodDecl::Func(_, name, _), + .. + }) => *name, + _ => unreachable!(), + }, + }; + Statement::Call(ResolvedCall { + uniform: false, + func: fn_id, + ret_params: vec![( + arg.dst, + ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)), + )], + param_list: vec![ + ( + arg.src1, + ast::FnArgumentType::Reg(ast::VariableRegType::Pointer( + ast::SizedScalarType::U32, + ptr_space, + )), + ), + ( + arg.src2, + ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)), + ), + ], + }) +} + fn to_resolved_fn_args( params: Vec, params_decl: &[ast::FnArgumentType], @@ -1529,6 +1852,9 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>( | (t, ArgumentSemantics::DefaultRelaxed) | (t, ArgumentSemantics::PhysicalPointer) => t, }; + if let ast::Type::Array(_, _) = id_type { + return Ok(desc.op); + } let generated_id = id_def.new_id(id_type.clone()); if !desc.is_dst { result.push(Statement::LoadVar( @@ -1916,6 +2242,12 @@ fn insert_implicit_conversions( if let ast::Instruction::St(d, _) = &inst { state_space = Some(d.state_space.to_ld_ss()); } + if let ast::Instruction::Atom(d, _) = &inst { + state_space = Some(d.space.to_ld_ss()); + } + if let ast::Instruction::AtomCas(d, _) = &inst { + state_space = Some(d.space.to_ld_ss()); + } if let ast::Instruction::Mov(_, ast::Arg2Mov::Normal(_)) = &inst { default_conversion_fn = should_bitcast_packed; } @@ -2387,6 +2719,52 @@ fn emit_function_body_ops( let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t)); builder.select(result_type, Some(a.dst), a.src3, a.src2, a.src2)?; } + // TODO: implement named barriers + ast::Instruction::Bar(d, _) => { + let workgroup_scope = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(spirv::Scope::Workgroup as u32), + )?; + let barrier_semantics = match d { + ast::BarDetails::SyncAligned => map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr( + spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY + | spirv::MemorySemantics::WORKGROUP_MEMORY + | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, + ), + )?, + }; + builder.control_barrier(workgroup_scope, workgroup_scope, barrier_semantics)?; + } + ast::Instruction::Atom(details, arg) => { + emit_atom(builder, map, details, arg)?; + } + ast::Instruction::AtomCas(details, arg) => { + let result_type = map.get_or_add_scalar(builder, details.typ.into()); + let memory_const = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(details.scope.to_spirv() as u32), + )?; + let semantics_const = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(details.semantics.to_spirv().bits()), + )?; + builder.atomic_compare_exchange( + result_type, + Some(arg.dst), + arg.src1, + memory_const, + semantics_const, + semantics_const, + arg.src3, + arg.src2, + )?; + } }, Statement::LoadVar(arg, typ) => { let type_id = map.get_or_add(builder, SpirvType::from(typ.clone())); @@ -2417,6 +2795,99 @@ fn emit_function_body_ops( Ok(()) } +fn emit_atom( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + details: &ast::AtomDetails, + arg: &ast::Arg3, +) -> Result<(), TranslateError> { + let (spirv_op, typ) = match details.inner { + ast::AtomInnerDetails::Bit { op, typ } => { + let spirv_op = match op { + ast::AtomBitOp::And => dr::Builder::atomic_and, + ast::AtomBitOp::Or => dr::Builder::atomic_or, + ast::AtomBitOp::Xor => dr::Builder::atomic_xor, + ast::AtomBitOp::Exchange => dr::Builder::atomic_exchange, + }; + (spirv_op, ast::ScalarType::from(typ)) + } + ast::AtomInnerDetails::Unsigned { op, typ } => { + let spirv_op = match op { + ast::AtomUIntOp::Add => dr::Builder::atomic_i_add, + ast::AtomUIntOp::Inc | ast::AtomUIntOp::Dec => { + return Err(TranslateError::Unreachable); + } + ast::AtomUIntOp::Min => dr::Builder::atomic_u_min, + ast::AtomUIntOp::Max => dr::Builder::atomic_u_max, + }; + (spirv_op, typ.into()) + } + ast::AtomInnerDetails::Signed { op, typ } => { + let spirv_op = match op { + ast::AtomSIntOp::Add => dr::Builder::atomic_i_add, + ast::AtomSIntOp::Min => dr::Builder::atomic_s_min, + ast::AtomSIntOp::Max => dr::Builder::atomic_s_max, + }; + (spirv_op, typ.into()) + } + // TODO: Hardware is capable of this, implement it through builtin + ast::AtomInnerDetails::Float { .. } => todo!(), + }; + let result_type = map.get_or_add_scalar(builder, typ); + let memory_const = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(details.scope.to_spirv() as u32), + )?; + let semantics_const = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(details.semantics.to_spirv().bits()), + )?; + spirv_op( + builder, + result_type, + Some(arg.dst), + arg.src1, + memory_const, + semantics_const, + arg.src2, + )?; + Ok(()) +} + +#[derive(Clone)] +struct PtxImplImport { + out_arg: ast::Type, + fn_id: u32, + in_args: Vec, +} + +fn ptx_semantics_name(sema: ast::AtomSemantics) -> &'static str { + match sema { + ast::AtomSemantics::Relaxed => "relaxed", + ast::AtomSemantics::Acquire => "acquire", + ast::AtomSemantics::Release => "release", + ast::AtomSemantics::AcquireRelease => "acq_rel", + } +} + +fn ptx_scope_name(scope: ast::MemScope) -> &'static str { + match scope { + ast::MemScope::Cta => "cta", + ast::MemScope::Gpu => "gpu", + ast::MemScope::Sys => "sys", + } +} + +fn ptx_space_name(space: ast::AtomSpace) -> &'static str { + match space { + ast::AtomSpace::Generic => "generic", + ast::AtomSpace::Global => "global", + ast::AtomSpace::Shared => "shared", + } +} + fn emit_mul_float( builder: &mut dr::Builder, map: &mut TypeWordMap, @@ -2652,7 +3123,7 @@ fn emit_cvt( map: &mut TypeWordMap, dets: &ast::CvtDetails, arg: &ast::Arg2, -) -> Result<(), dr::Error> { +) -> Result<(), TranslateError> { match dets { ast::CvtDetails::FloatFromFloat(desc) => { if desc.dst == desc.src { @@ -3011,7 +3482,7 @@ fn emit_implicit_conversion( builder: &mut dr::Builder, map: &mut TypeWordMap, cv: &ImplicitConversion, -) -> Result<(), dr::Error> { +) -> Result<(), TranslateError> { let from_parts = cv.from.to_parts(); let to_parts = cv.to.to_parts(); match (from_parts.kind, to_parts.kind, cv.kind) { @@ -3019,7 +3490,7 @@ fn emit_implicit_conversion( let dst_type = map.get_or_add_scalar(builder, ast::ScalarType::B64); builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?; } - (_, _, ConversionKind::BitToPtr(space)) => { + (_, _, ConversionKind::BitToPtr(_)) => { let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone())); builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?; } @@ -3782,8 +4253,9 @@ enum Directive<'input> { struct Function<'input> { pub func_decl: ast::MethodDecl<'input, spirv::Word>, - pub globals: Vec, + pub globals: Vec>, pub body: Option>, + import_as: Option, } pub trait ArgumentMapVisitor { @@ -4091,6 +4563,13 @@ impl ast::Instruction { a.map_non_shift(visitor, &ast::Type::Scalar(t.into()), false)?, ), ast::Instruction::Selp(t, a) => ast::Instruction::Selp(t, a.map_selp(visitor, t)?), + ast::Instruction::Bar(d, a) => ast::Instruction::Bar(d, a.map(visitor)?), + ast::Instruction::Atom(d, a) => { + ast::Instruction::Atom(d, a.map_atom(visitor, d.inner.get_type(), d.space)?) + } + ast::Instruction::AtomCas(d, a) => { + ast::Instruction::AtomCas(d, a.map_atom(visitor, d.typ, d.space)?) + } }) } } @@ -4337,6 +4816,9 @@ impl ast::Instruction { | ast::Instruction::Rcp(_, _) | ast::Instruction::And(_, _) | ast::Instruction::Selp(_, _) + | ast::Instruction::Bar(_, _) + | ast::Instruction::Atom(_, _) + | ast::Instruction::AtomCas(_, _) | ast::Instruction::Mad(_, _) => None, } } @@ -4358,6 +4840,9 @@ impl ast::Instruction { ast::Instruction::And(_, _) => None, ast::Instruction::Cvta(_, _) => None, ast::Instruction::Selp(_, _) => None, + ast::Instruction::Bar(_, _) => None, + ast::Instruction::Atom(_, _) => None, + ast::Instruction::AtomCas(_, _) => None, ast::Instruction::Sub(ast::ArithDetails::Signed(_), _) => None, ast::Instruction::Sub(ast::ArithDetails::Unsigned(_), _) => None, ast::Instruction::Add(ast::ArithDetails::Signed(_), _) => None, @@ -4612,6 +5097,27 @@ impl ast::Arg1 { } } +impl ast::Arg1Bar { + fn cast>(self) -> ast::Arg1Bar { + ast::Arg1Bar { src: self.src } + } + + fn map>( + self, + visitor: &mut V, + ) -> Result, TranslateError> { + let new_src = visitor.operand( + ArgumentDescriptor { + op: self.src, + is_dst: false, + sema: ArgumentSemantics::Default, + }, + &ast::Type::Scalar(ast::ScalarType::U32), + )?; + Ok(ast::Arg1Bar { src: new_src }) + } +} + impl ast::Arg2 { fn cast>(self) -> ast::Arg2 { ast::Arg2 { @@ -5022,6 +5528,43 @@ impl ast::Arg3 { )?; Ok(ast::Arg3 { dst, src1, src2 }) } + + fn map_atom>( + self, + visitor: &mut V, + t: ast::ScalarType, + state_space: ast::AtomSpace, + ) -> Result, TranslateError> { + let scalar_type = ast::ScalarType::from(t); + let dst = visitor.id( + ArgumentDescriptor { + op: self.dst, + is_dst: true, + sema: ArgumentSemantics::Default, + }, + Some(&ast::Type::Scalar(scalar_type)), + )?; + let src1 = visitor.operand( + ArgumentDescriptor { + op: self.src1, + is_dst: false, + sema: ArgumentSemantics::PhysicalPointer, + }, + &ast::Type::Pointer( + ast::PointerType::Scalar(scalar_type), + state_space.to_ld_ss(), + ), + )?; + let src2 = visitor.operand( + ArgumentDescriptor { + op: self.src2, + is_dst: false, + sema: ArgumentSemantics::Default, + }, + &ast::Type::Scalar(scalar_type), + )?; + Ok(ast::Arg3 { dst, src1, src2 }) + } } impl ast::Arg4 { @@ -5129,6 +5672,56 @@ impl ast::Arg4 { src3, }) } + + fn map_atom>( + self, + visitor: &mut V, + t: ast::BitType, + state_space: ast::AtomSpace, + ) -> Result, TranslateError> { + let scalar_type = ast::ScalarType::from(t); + let dst = visitor.id( + ArgumentDescriptor { + op: self.dst, + is_dst: true, + sema: ArgumentSemantics::Default, + }, + Some(&ast::Type::Scalar(scalar_type)), + )?; + let src1 = visitor.operand( + ArgumentDescriptor { + op: self.src1, + is_dst: false, + sema: ArgumentSemantics::PhysicalPointer, + }, + &ast::Type::Pointer( + ast::PointerType::Scalar(scalar_type), + state_space.to_ld_ss(), + ), + )?; + let src2 = visitor.operand( + ArgumentDescriptor { + op: self.src2, + is_dst: false, + sema: ArgumentSemantics::Default, + }, + &ast::Type::Scalar(scalar_type), + )?; + let src3 = visitor.operand( + ArgumentDescriptor { + op: self.src3, + is_dst: false, + sema: ArgumentSemantics::Default, + }, + &ast::Type::Scalar(scalar_type), + )?; + Ok(ast::Arg4 { + dst, + src1, + src2, + src3, + }) + } } impl ast::Arg4Setp { @@ -5434,6 +6027,17 @@ impl ast::MinMaxDetails { } } +impl ast::AtomInnerDetails { + fn get_type(&self) -> ast::ScalarType { + match self { + ast::AtomInnerDetails::Bit { typ, .. } => (*typ).into(), + ast::AtomInnerDetails::Unsigned { typ, .. } => (*typ).into(), + ast::AtomInnerDetails::Signed { typ, .. } => (*typ).into(), + ast::AtomInnerDetails::Float { typ, .. } => (*typ).into(), + } + } +} + impl ast::SIntType { fn from_size(width: u8) -> Self { match width { @@ -5509,6 +6113,37 @@ impl ast::MulDetails { } } +impl ast::AtomSpace { + fn to_ld_ss(self) -> ast::LdStateSpace { + match self { + ast::AtomSpace::Generic => ast::LdStateSpace::Generic, + ast::AtomSpace::Global => ast::LdStateSpace::Global, + ast::AtomSpace::Shared => ast::LdStateSpace::Shared, + } + } +} + +impl ast::MemScope { + fn to_spirv(self) -> spirv::Scope { + match self { + ast::MemScope::Cta => spirv::Scope::Workgroup, + ast::MemScope::Gpu => spirv::Scope::Device, + ast::MemScope::Sys => spirv::Scope::CrossDevice, + } + } +} + +impl ast::AtomSemantics { + fn to_spirv(self) -> spirv::MemorySemantics { + match self { + ast::AtomSemantics::Relaxed => spirv::MemorySemantics::RELAXED, + ast::AtomSemantics::Acquire => spirv::MemorySemantics::ACQUIRE, + ast::AtomSemantics::Release => spirv::MemorySemantics::RELEASE, + ast::AtomSemantics::AcquireRelease => spirv::MemorySemantics::ACQUIRE_RELEASE, + } + } +} + fn bitcast_logical_pointer( operand: &ast::Type, instr: &ast::Type, @@ -5528,7 +6163,27 @@ fn bitcast_physical_pointer( ) -> Result, TranslateError> { match operand_type { // array decays to a pointer - ast::Type::Array(_, _) => todo!(), + ast::Type::Array(op_scalar_t, _) => { + if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type { + if ss == Some(*instr_space) { + if ast::Type::Scalar(*op_scalar_t) == ast::Type::from(instr_scalar_t.clone()) { + Ok(None) + } else { + Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false })) + } + } else { + if ss == Some(ast::LdStateSpace::Generic) + || *instr_space == ast::LdStateSpace::Generic + { + Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false })) + } else { + Err(TranslateError::MismatchedType) + } + } + } else { + Err(TranslateError::MismatchedType) + } + } ast::Type::Scalar(ast::ScalarType::B64) | ast::Type::Scalar(ast::ScalarType::U64) | ast::Type::Scalar(ast::ScalarType::S64) => {