diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 82580aa..190c21a 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -187,7 +187,53 @@ pub enum MovOperand { Vec(String, String), } -pub struct LdData {} +pub enum VectorPrefix { + V2, + V4 +} + +pub struct LdData { + pub qualifier: LdQualifier, + pub state_space: LdStateSpace, + pub caching: LdCacheOperator, + pub vector: Option, + pub typ: ScalarType +} + +#[derive(PartialEq, Eq)] +pub enum LdQualifier { + Weak, + Volatile, + Relaxed(LdScope), + Acquire(LdScope), +} + +#[derive(PartialEq, Eq)] +pub enum LdScope { + Cta, + Gpu, + Sys +} + +#[derive(PartialEq, Eq)] +pub enum LdStateSpace { + Generic, + Const, + Global, + Local, + Param, + Shared, +} + + +#[derive(PartialEq, Eq)] +pub enum LdCacheOperator { + Cached, + L2Only, + Streaming, + LastUse, + Uncached +} pub struct MovData {} @@ -201,7 +247,9 @@ pub struct SetpBoolData {} pub struct NotData {} -pub struct BraData {} +pub struct BraData { + pub uniform: bool +} pub struct CvtData {} diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 83a0fe2..ded2386 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -106,6 +106,16 @@ Type: ast::Type = { }; ScalarType: ast::ScalarType = { + ".f16" => ast::ScalarType::F16, + MemoryType +}; + +ExtendedScalarType: ast::ExtendedScalarType = { + ".f16x2" => ast::ExtendedScalarType::F16x2, + ".pred" => ast::ExtendedScalarType::Pred, +}; + +MemoryType: ast::ScalarType = { ".b8" => ast::ScalarType::B8, ".b16" => ast::ScalarType::B16, ".b32" => ast::ScalarType::B32, @@ -118,23 +128,10 @@ ScalarType: ast::ScalarType = { ".s16" => ast::ScalarType::S16, ".s32" => ast::ScalarType::S32, ".s64" => ast::ScalarType::S64, - ".f16" => ast::ScalarType::F16, ".f32" => ast::ScalarType::F32, ".f64" => ast::ScalarType::F64, }; -ExtendedScalarType: ast::ExtendedScalarType = { - ".f16x2" => ast::ExtendedScalarType::F16x2, - ".pred" => ast::ExtendedScalarType::Pred, -}; - -BaseType = { - ".b8", ".b16", ".b32", ".b64", - ".u8", ".u16", ".u32", ".u64", - ".s8", ".s16", ".s32", ".s64", - ".f32", ".f64" -}; - Statement: Option> = { => Some(ast::Statement::Label(l)), DebugDirective => None, @@ -191,36 +188,47 @@ Instruction: ast::Instruction<&'input str> = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld InstLd: ast::Instruction<&'input str> = { - "ld" LdQualifier? LdStateSpace? LdCacheOperator? Vector? BaseType "," "[" "]" => { - ast::Instruction::Ld(ast::LdData{}, ast::Arg2{dst:dst, src:src}) + "ld" "," "[" "]" => { + ast::Instruction::Ld( + ast::LdData { + qualifier: q.unwrap_or(ast::LdQualifier::Weak), + state_space: ss.unwrap_or(ast::LdStateSpace::Generic), + caching: cop.unwrap_or(ast::LdCacheOperator::Cached), + vector: v, + typ: t + }, + ast::Arg2 { dst:dst, src:src } + ) } }; -LdQualifier: () = { - ".weak", - ".volatile", - ".relaxed" LdScope, - ".acquire" LdScope, +LdQualifier: ast::LdQualifier = { + ".weak" => ast::LdQualifier::Weak, + ".volatile" => ast::LdQualifier::Volatile, + ".relaxed" => ast::LdQualifier::Relaxed(s), + ".acquire" => ast::LdQualifier::Acquire(s), }; -LdScope = { - ".cta", ".gpu", ".sys" +LdScope: ast::LdScope = { + ".cta" => ast::LdScope::Cta, + ".gpu" => ast::LdScope::Gpu, + ".sys" => ast::LdScope::Sys }; -LdStateSpace = { - ".const", - ".global", - ".local", - ".param", - ".shared", +LdStateSpace: ast::LdStateSpace = { + ".const" => ast::LdStateSpace::Const, + ".global" => ast::LdStateSpace::Global, + ".local" => ast::LdStateSpace::Local, + ".param" => ast::LdStateSpace::Param, + ".shared" => ast::LdStateSpace::Shared, }; -LdCacheOperator = { - ".ca", - ".cg", - ".cs", - ".lu", - ".cv", +LdCacheOperator: ast::LdCacheOperator = { + ".ca" => ast::LdCacheOperator::Cached, + ".cg" => ast::LdCacheOperator::L2Only, + ".cs" => ast::LdCacheOperator::Streaming, + ".lu" => ast::LdCacheOperator::LastUse, + ".cv" => ast::LdCacheOperator::Uncached, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov @@ -332,7 +340,7 @@ PredAt: ast::PredAt<&'input str> = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-bra InstBra: ast::Instruction<&'input str> = { - "bra" ".uni"? => ast::Instruction::Bra(ast::BraData{}, a) + "bra" => ast::Instruction::Bra(ast::BraData{ uniform: u.is_some() }, a) }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt @@ -372,7 +380,7 @@ ShlType = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st InstSt: ast::Instruction<&'input str> = { - "st" LdQualifier? StStateSpace? StCacheOperator? Vector? BaseType "[" "]" "," => { + "st" LdQualifier? StStateSpace? StCacheOperator? VectorPrefix? MemoryType "[" "]" "," => { ast::Instruction::St(ast::StData{}, ast::Arg2{dst:dst, src:src}) } }; @@ -454,9 +462,9 @@ OptionalDst: &'input str = { "|" => dst2 } -Vector = { - ".v2", - ".v4" +VectorPrefix: ast::VectorPrefix = { + ".v2" => ast::VectorPrefix::V2, + ".v4" => ast::VectorPrefix::V4 }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#debugging-directives-file diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 52de35d..f5c5107 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -8,6 +8,7 @@ use std::fmt; #[derive(PartialEq, Eq, Hash, Clone, Copy)] enum SpirvType { Base(ast::ScalarType), + Pointer(ast::ScalarType, spirv::StorageClass), } struct TypeWordMap { @@ -33,29 +34,41 @@ impl TypeWordMap { self.fn_void } - fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word { - *self.complex.entry(t).or_insert_with(|| match t { - SpirvType::Base(ast::ScalarType::B8) | SpirvType::Base(ast::ScalarType::U8) => { + fn get_or_add_scalar(&mut self, b: &mut dr::Builder, t: ast::ScalarType) -> spirv::Word { + *self.complex.entry(SpirvType::Base(t)).or_insert_with(|| match t { + ast::ScalarType::B8 | ast::ScalarType::U8 => { b.type_int(8, 0) } - SpirvType::Base(ast::ScalarType::B16) | SpirvType::Base(ast::ScalarType::U16) => { + ast::ScalarType::B16 | ast::ScalarType::U16 => { b.type_int(16, 0) } - SpirvType::Base(ast::ScalarType::B32) | SpirvType::Base(ast::ScalarType::U32) => { + ast::ScalarType::B32 | ast::ScalarType::U32 => { b.type_int(32, 0) } - SpirvType::Base(ast::ScalarType::B64) | SpirvType::Base(ast::ScalarType::U64) => { + ast::ScalarType::B64 | ast::ScalarType::U64 => { b.type_int(64, 0) } - SpirvType::Base(ast::ScalarType::S8) => b.type_int(8, 1), - SpirvType::Base(ast::ScalarType::S16) => b.type_int(16, 1), - SpirvType::Base(ast::ScalarType::S32) => b.type_int(32, 1), - SpirvType::Base(ast::ScalarType::S64) => b.type_int(64, 1), - SpirvType::Base(ast::ScalarType::F16) => b.type_float(16), - SpirvType::Base(ast::ScalarType::F32) => b.type_float(32), - SpirvType::Base(ast::ScalarType::F64) => b.type_float(64), + ast::ScalarType::S8 => b.type_int(8, 1), + ast::ScalarType::S16 => b.type_int(16, 1), + ast::ScalarType::S32 => b.type_int(32, 1), + ast::ScalarType::S64 => b.type_int(64, 1), + ast::ScalarType::F16 => b.type_float(16), + ast::ScalarType::F32 => b.type_float(32), + ast::ScalarType::F64 => b.type_float(64), }) } + + fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word { + match t { + SpirvType::Base(scalar) => self.get_or_add_scalar(b, scalar), + SpirvType::Pointer(scalar, storage) => { + let base = self.get_or_add_scalar(b, scalar); + *self.complex.entry(t).or_insert_with(|| { + b.type_pointer(None, storage, base) + }) + } + } + } } pub fn to_spirv(ast: ast::Module) -> Result, rspirv::dr::Error> { @@ -123,7 +136,7 @@ fn emit_function<'a>( ); let id_offset = builder.reserve_ids(unique_ids); emit_function_args(builder, id_offset, map, &f.args); - emit_function_body_ops(builder, id_offset, &normalized_ids, &bbs)?; + emit_function_body_ops(builder, id_offset, map, &normalized_ids, &bbs)?; builder.end_function()?; builder.ret()?; builder.end_function()?; @@ -178,6 +191,7 @@ fn collect_label_ids<'a>( fn emit_function_body_ops( builder: &mut dr::Builder, id_offset: spirv::Word, + map: &mut TypeWordMap, func: &[Statement], cfg: &[BasicBlock], ) -> Result<(), dr::Error> { @@ -193,12 +207,35 @@ fn emit_function_body_ops( }; builder.begin_block(header_id)?; for s in body { - /* match s { - Statement::Instruction(pred, inst) => (), + // If block startd with a label it has already been emitted, + // all other labels in the block are unused Statement::Label(_) => (), + Statement::Conditional(bra) => { + builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?; + } + Statement::Instruction(inst) => match inst { + // Sadly, SPIR-V does not support marking jumps as guaranteed-converged + ast::Instruction::Bra(_, arg) => { + builder.branch(arg.src)?; + } + ast::Instruction::Ld(data, arg) => { + if data.qualifier != ast::LdQualifier::Weak || data.vector.is_some() { + todo!() + } + let storage_class = match data.state_space { + ast::LdStateSpace::Generic => spirv::StorageClass::Generic, + ast::LdStateSpace::Param => spirv::StorageClass::CrossWorkgroup, + _ => todo!(), + }; + let result_type = map.get_or_add(builder, SpirvType::Base(data.typ)); + let pointer_type = + map.get_or_add(builder, SpirvType::Pointer(data.typ, storage_class)); + builder.load(result_type, None, pointer_type, None, [])?; + } + _ => todo!(), + }, } - */ } } Ok(()) @@ -1273,7 +1310,7 @@ mod tests { let func = vec![ Statement::Label(12), Statement::Instruction(ast::Instruction::Bra( - ast::BraData {}, + ast::BraData { uniform: false }, ast::Arg1 { src: 12 }, )), ];