Translate instruction ld

This commit is contained in:
Andrzej Janik 2020-05-07 00:37:10 +02:00
parent 3b433456a1
commit fa075abc22
3 changed files with 152 additions and 59 deletions

View file

@ -187,7 +187,53 @@ pub enum MovOperand<ID> {
Vec(String, String),
}
pub struct LdData {}
pub enum VectorPrefix {
V2,
V4
}
pub struct LdData {
pub qualifier: LdQualifier,
pub state_space: LdStateSpace,
pub caching: LdCacheOperator,
pub vector: Option<VectorPrefix>,
pub typ: ScalarType
}
#[derive(PartialEq, Eq)]
pub enum LdQualifier {
Weak,
Volatile,
Relaxed(LdScope),
Acquire(LdScope),
}
#[derive(PartialEq, Eq)]
pub enum LdScope {
Cta,
Gpu,
Sys
}
#[derive(PartialEq, Eq)]
pub enum LdStateSpace {
Generic,
Const,
Global,
Local,
Param,
Shared,
}
#[derive(PartialEq, Eq)]
pub enum LdCacheOperator {
Cached,
L2Only,
Streaming,
LastUse,
Uncached
}
pub struct MovData {}
@ -201,7 +247,9 @@ pub struct SetpBoolData {}
pub struct NotData {}
pub struct BraData {}
pub struct BraData {
pub uniform: bool
}
pub struct CvtData {}

View file

@ -106,6 +106,16 @@ Type: ast::Type = {
};
ScalarType: ast::ScalarType = {
".f16" => ast::ScalarType::F16,
MemoryType
};
ExtendedScalarType: ast::ExtendedScalarType = {
".f16x2" => ast::ExtendedScalarType::F16x2,
".pred" => ast::ExtendedScalarType::Pred,
};
MemoryType: ast::ScalarType = {
".b8" => ast::ScalarType::B8,
".b16" => ast::ScalarType::B16,
".b32" => ast::ScalarType::B32,
@ -118,23 +128,10 @@ ScalarType: ast::ScalarType = {
".s16" => ast::ScalarType::S16,
".s32" => ast::ScalarType::S32,
".s64" => ast::ScalarType::S64,
".f16" => ast::ScalarType::F16,
".f32" => ast::ScalarType::F32,
".f64" => ast::ScalarType::F64,
};
ExtendedScalarType: ast::ExtendedScalarType = {
".f16x2" => ast::ExtendedScalarType::F16x2,
".pred" => ast::ExtendedScalarType::Pred,
};
BaseType = {
".b8", ".b16", ".b32", ".b64",
".u8", ".u16", ".u32", ".u64",
".s8", ".s16", ".s32", ".s64",
".f32", ".f64"
};
Statement: Option<ast::Statement<&'input str>> = {
<l:Label> => Some(ast::Statement::Label(l)),
DebugDirective => None,
@ -191,36 +188,47 @@ Instruction: ast::Instruction<&'input str> = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
InstLd: ast::Instruction<&'input str> = {
"ld" LdQualifier? LdStateSpace? LdCacheOperator? Vector? BaseType <dst:ID> "," "[" <src:Operand> "]" => {
ast::Instruction::Ld(ast::LdData{}, ast::Arg2{dst:dst, src:src})
"ld" <q:LdQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <v:VectorPrefix?> <t:MemoryType> <dst:ID> "," "[" <src:Operand> "]" => {
ast::Instruction::Ld(
ast::LdData {
qualifier: q.unwrap_or(ast::LdQualifier::Weak),
state_space: ss.unwrap_or(ast::LdStateSpace::Generic),
caching: cop.unwrap_or(ast::LdCacheOperator::Cached),
vector: v,
typ: t
},
ast::Arg2 { dst:dst, src:src }
)
}
};
LdQualifier: () = {
".weak",
".volatile",
".relaxed" LdScope,
".acquire" LdScope,
LdQualifier: ast::LdQualifier = {
".weak" => ast::LdQualifier::Weak,
".volatile" => ast::LdQualifier::Volatile,
".relaxed" <s:LdScope> => ast::LdQualifier::Relaxed(s),
".acquire" <s:LdScope> => ast::LdQualifier::Acquire(s),
};
LdScope = {
".cta", ".gpu", ".sys"
LdScope: ast::LdScope = {
".cta" => ast::LdScope::Cta,
".gpu" => ast::LdScope::Gpu,
".sys" => ast::LdScope::Sys
};
LdStateSpace = {
".const",
".global",
".local",
".param",
".shared",
LdStateSpace: ast::LdStateSpace = {
".const" => ast::LdStateSpace::Const,
".global" => ast::LdStateSpace::Global,
".local" => ast::LdStateSpace::Local,
".param" => ast::LdStateSpace::Param,
".shared" => ast::LdStateSpace::Shared,
};
LdCacheOperator = {
".ca",
".cg",
".cs",
".lu",
".cv",
LdCacheOperator: ast::LdCacheOperator = {
".ca" => ast::LdCacheOperator::Cached,
".cg" => ast::LdCacheOperator::L2Only,
".cs" => ast::LdCacheOperator::Streaming,
".lu" => ast::LdCacheOperator::LastUse,
".cv" => ast::LdCacheOperator::Uncached,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
@ -332,7 +340,7 @@ PredAt: ast::PredAt<&'input str> = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-bra
InstBra: ast::Instruction<&'input str> = {
"bra" ".uni"? <a:Arg1> => ast::Instruction::Bra(ast::BraData{}, a)
"bra" <u:".uni"?> <a:Arg1> => ast::Instruction::Bra(ast::BraData{ uniform: u.is_some() }, a)
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt
@ -372,7 +380,7 @@ ShlType = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st
InstSt: ast::Instruction<&'input str> = {
"st" LdQualifier? StStateSpace? StCacheOperator? Vector? BaseType "[" <dst:ID> "]" "," <src:Operand> => {
"st" LdQualifier? StStateSpace? StCacheOperator? VectorPrefix? MemoryType "[" <dst:ID> "]" "," <src:Operand> => {
ast::Instruction::St(ast::StData{}, ast::Arg2{dst:dst, src:src})
}
};
@ -454,9 +462,9 @@ OptionalDst: &'input str = {
"|" <dst2:ID> => dst2
}
Vector = {
".v2",
".v4"
VectorPrefix: ast::VectorPrefix = {
".v2" => ast::VectorPrefix::V2,
".v4" => ast::VectorPrefix::V4
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#debugging-directives-file

View file

@ -8,6 +8,7 @@ use std::fmt;
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
enum SpirvType {
Base(ast::ScalarType),
Pointer(ast::ScalarType, spirv::StorageClass),
}
struct TypeWordMap {
@ -33,29 +34,41 @@ impl TypeWordMap {
self.fn_void
}
fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word {
*self.complex.entry(t).or_insert_with(|| match t {
SpirvType::Base(ast::ScalarType::B8) | SpirvType::Base(ast::ScalarType::U8) => {
fn get_or_add_scalar(&mut self, b: &mut dr::Builder, t: ast::ScalarType) -> spirv::Word {
*self.complex.entry(SpirvType::Base(t)).or_insert_with(|| match t {
ast::ScalarType::B8 | ast::ScalarType::U8 => {
b.type_int(8, 0)
}
SpirvType::Base(ast::ScalarType::B16) | SpirvType::Base(ast::ScalarType::U16) => {
ast::ScalarType::B16 | ast::ScalarType::U16 => {
b.type_int(16, 0)
}
SpirvType::Base(ast::ScalarType::B32) | SpirvType::Base(ast::ScalarType::U32) => {
ast::ScalarType::B32 | ast::ScalarType::U32 => {
b.type_int(32, 0)
}
SpirvType::Base(ast::ScalarType::B64) | SpirvType::Base(ast::ScalarType::U64) => {
ast::ScalarType::B64 | ast::ScalarType::U64 => {
b.type_int(64, 0)
}
SpirvType::Base(ast::ScalarType::S8) => b.type_int(8, 1),
SpirvType::Base(ast::ScalarType::S16) => b.type_int(16, 1),
SpirvType::Base(ast::ScalarType::S32) => b.type_int(32, 1),
SpirvType::Base(ast::ScalarType::S64) => b.type_int(64, 1),
SpirvType::Base(ast::ScalarType::F16) => b.type_float(16),
SpirvType::Base(ast::ScalarType::F32) => b.type_float(32),
SpirvType::Base(ast::ScalarType::F64) => b.type_float(64),
ast::ScalarType::S8 => b.type_int(8, 1),
ast::ScalarType::S16 => b.type_int(16, 1),
ast::ScalarType::S32 => b.type_int(32, 1),
ast::ScalarType::S64 => b.type_int(64, 1),
ast::ScalarType::F16 => b.type_float(16),
ast::ScalarType::F32 => b.type_float(32),
ast::ScalarType::F64 => b.type_float(64),
})
}
fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word {
match t {
SpirvType::Base(scalar) => self.get_or_add_scalar(b, scalar),
SpirvType::Pointer(scalar, storage) => {
let base = self.get_or_add_scalar(b, scalar);
*self.complex.entry(t).or_insert_with(|| {
b.type_pointer(None, storage, base)
})
}
}
}
}
pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, rspirv::dr::Error> {
@ -123,7 +136,7 @@ fn emit_function<'a>(
);
let id_offset = builder.reserve_ids(unique_ids);
emit_function_args(builder, id_offset, map, &f.args);
emit_function_body_ops(builder, id_offset, &normalized_ids, &bbs)?;
emit_function_body_ops(builder, id_offset, map, &normalized_ids, &bbs)?;
builder.end_function()?;
builder.ret()?;
builder.end_function()?;
@ -178,6 +191,7 @@ fn collect_label_ids<'a>(
fn emit_function_body_ops(
builder: &mut dr::Builder,
id_offset: spirv::Word,
map: &mut TypeWordMap,
func: &[Statement],
cfg: &[BasicBlock],
) -> Result<(), dr::Error> {
@ -193,12 +207,35 @@ fn emit_function_body_ops(
};
builder.begin_block(header_id)?;
for s in body {
/*
match s {
Statement::Instruction(pred, inst) => (),
// If block startd with a label it has already been emitted,
// all other labels in the block are unused
Statement::Label(_) => (),
Statement::Conditional(bra) => {
builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?;
}
Statement::Instruction(inst) => match inst {
// Sadly, SPIR-V does not support marking jumps as guaranteed-converged
ast::Instruction::Bra(_, arg) => {
builder.branch(arg.src)?;
}
ast::Instruction::Ld(data, arg) => {
if data.qualifier != ast::LdQualifier::Weak || data.vector.is_some() {
todo!()
}
let storage_class = match data.state_space {
ast::LdStateSpace::Generic => spirv::StorageClass::Generic,
ast::LdStateSpace::Param => spirv::StorageClass::CrossWorkgroup,
_ => todo!(),
};
let result_type = map.get_or_add(builder, SpirvType::Base(data.typ));
let pointer_type =
map.get_or_add(builder, SpirvType::Pointer(data.typ, storage_class));
builder.load(result_type, None, pointer_type, None, [])?;
}
_ => todo!(),
},
}
*/
}
}
Ok(())
@ -1273,7 +1310,7 @@ mod tests {
let func = vec![
Statement::Label(12),
Statement::Instruction(ast::Instruction::Bra(
ast::BraData {},
ast::BraData { uniform: false },
ast::Arg1 { src: 12 },
)),
];