Parse vector movs (mov.type a.x b.y;)

This commit is contained in:
Andrzej Janik 2020-09-12 02:33:20 +02:00
parent 1238796dfd
commit 48dac43540
6 changed files with 178 additions and 133 deletions

View file

@ -8,7 +8,7 @@ pub struct Module {
pub enum ModuleCompileError<'a> {
Parse(
Vec<ptx::ast::PtxError>,
Option<ptx::ParseError<usize, ptx::Token<'a>, &'a str>>,
Option<ptx::ParseError<usize, ptx::Token<'a>, ptx::ast::PtxError>>,
),
Compile(ptx::SpirvError),
}

View file

@ -316,7 +316,8 @@ pub struct PredAt<ID> {
pub enum Instruction<P: ArgParams> {
Ld(LdData, Arg2<P>),
Mov(MovData, Arg2Mov<P>),
Mov(MovType, Arg2<P>),
MovVector(MovVectorType, Arg2Vec<P>),
Mul(MulDetails, Arg3<P>),
Add(AddDetails, Arg3<P>),
Setp(SetpData, Arg4<P>),
@ -348,7 +349,6 @@ pub trait ArgParams {
type ID;
type Operand;
type CallOperand;
type MovOperand;
}
pub struct ParsedArgParams<'a> {
@ -359,7 +359,6 @@ impl<'a> ArgParams for ParsedArgParams<'a> {
type ID = &'a str;
type Operand = Operand<&'a str>;
type CallOperand = CallOperand<&'a str>;
type MovOperand = MovOperand<&'a str>;
}
pub struct Arg1<P: ArgParams> {
@ -376,9 +375,10 @@ pub struct Arg2St<P: ArgParams> {
pub src2: P::Operand,
}
pub struct Arg2Mov<P: ArgParams> {
pub dst: P::ID,
pub src: P::MovOperand,
pub enum Arg2Vec<P: ArgParams> {
Dst((P::ID, u8), P::ID),
Src(P::ID, (P::ID, u8)),
Both((P::ID, u8), (P::ID, u8)),
}
pub struct Arg3<P: ArgParams> {
@ -415,11 +415,6 @@ pub enum CallOperand<ID> {
Imm(i128),
}
pub enum MovOperand<ID> {
Op(Operand<ID>),
Vec(ID, u8),
}
pub enum VectorPrefix {
V2,
V4,
@ -467,10 +462,6 @@ pub enum LdCacheOperator {
Uncached,
}
pub struct MovData {
pub typ: Type,
}
sub_scalar_type!(MovScalarType {
B16,
B32,
@ -486,19 +477,25 @@ sub_scalar_type!(MovScalarType {
Pred,
});
enum MovType {
Scalar(MovScalarType),
Vector(MovScalarType, u8),
Array(MovScalarType, u32),
}
// pred vectors are illegal
sub_scalar_type!(MovVectorType {
B16,
B32,
B64,
U16,
U32,
U64,
S16,
S32,
S64,
F32,
F64,
});
impl From<MovType> for Type {
fn from(t: MovType) -> Self {
match t {
MovType::Scalar(t) => Type::Scalar(t.into()),
MovType::Vector(t, len) => Type::Vector(t.into(), len),
MovType::Array(t, len) => Type::Array(t.into(), len),
}
sub_type! {
MovType {
Scalar(MovScalarType),
Vector(MovVectorType, u8),
}
}

View file

@ -6,13 +6,13 @@ extern crate lalrpop_util;
extern crate quick_error;
extern crate bit_vec;
extern crate half;
#[cfg(test)]
extern crate level_zero as ze;
#[cfg(test)]
extern crate level_zero_sys as l0;
extern crate rspirv;
extern crate spirv_headers as spirv;
extern crate half;
#[cfg(test)]
extern crate spirv_tools_sys as spirv_tools;
@ -27,12 +27,26 @@ pub mod ast;
mod test;
mod translate;
pub use lalrpop_util::ParseError as ParseError;
pub use lalrpop_util::lexer::Token as Token;
pub use crate::ptx::ModuleParser as ModuleParser;
pub use translate::to_spirv as to_spirv;
pub use crate::ptx::ModuleParser;
pub use lalrpop_util::lexer::Token;
pub use lalrpop_util::ParseError;
pub use rspirv::dr::Error as SpirvError;
pub use translate::to_spirv;
pub(crate) fn without_none<T>(x: Vec<Option<T>>) -> Vec<T> {
x.into_iter().filter_map(|x| x).collect()
}
pub(crate) fn vector_index<'input>(
inp: &'input str,
) -> Result<u8, ParseError<usize, lalrpop_util::lexer::Token<'input>, ast::PtxError>> {
match inp {
"x" | "r" => Ok(0),
"y" | "g" => Ok(1),
"z" | "b" => Ok(2),
"w" | "a" => Ok(3),
_ => Err(ParseError::User {
error: ast::PtxError::WrongVectorElement,
}),
}
}

View file

@ -1,9 +1,13 @@
use crate::ast;
use crate::ast::UnwrapWithVec;
use crate::without_none;
use crate::{without_none, vector_index};
grammar<'a>(errors: &mut Vec<ast::PtxError>);
extern {
type Error = ast::PtxError;
}
match {
r"\s+" => { },
r"//[^\n\r]*[\n\r]*" => { },
@ -487,24 +491,49 @@ LdCacheOperator: ast::LdCacheOperator = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
InstMov: ast::Instruction<ast::ParsedArgParams<'input>> = {
"mov" <t:MovType> <a:Arg2Mov> => {
ast::Instruction::Mov(ast::MovData{ typ:t }, a)
"mov" <t:MovType> <a:Arg2> => {
ast::Instruction::Mov(t, a)
},
"mov" <t:MovVectorType> <a:Arg2Vec> => {
ast::Instruction::MovVector(t, a)
}
};
MovType: ast::Type = {
".b16" => ast::Type::Scalar(ast::ScalarType::B16),
".b32" => ast::Type::Scalar(ast::ScalarType::B32),
".b64" => ast::Type::Scalar(ast::ScalarType::B64),
".u16" => ast::Type::Scalar(ast::ScalarType::U16),
".u32" => ast::Type::Scalar(ast::ScalarType::U32),
".u64" => ast::Type::Scalar(ast::ScalarType::U64),
".s16" => ast::Type::Scalar(ast::ScalarType::S16),
".s32" => ast::Type::Scalar(ast::ScalarType::S32),
".s64" => ast::Type::Scalar(ast::ScalarType::S64),
".f32" => ast::Type::Scalar(ast::ScalarType::F32),
".f64" => ast::Type::Scalar(ast::ScalarType::F64),
".pred" => ast::Type::Scalar(ast::ScalarType::Pred)
#[inline]
MovType: ast::MovType = {
<t:MovScalarType> => ast::MovType::Scalar(t),
<pref:VectorPrefix> <t:MovVectorType> => ast::MovType::Vector(t, pref)
}
#[inline]
MovScalarType: ast::MovScalarType = {
".b16" => ast::MovScalarType::B16,
".b32" => ast::MovScalarType::B32,
".b64" => ast::MovScalarType::B64,
".u16" => ast::MovScalarType::U16,
".u32" => ast::MovScalarType::U32,
".u64" => ast::MovScalarType::U64,
".s16" => ast::MovScalarType::S16,
".s32" => ast::MovScalarType::S32,
".s64" => ast::MovScalarType::S64,
".f32" => ast::MovScalarType::F32,
".f64" => ast::MovScalarType::F64,
".pred" => ast::MovScalarType::Pred
};
#[inline]
MovVectorType: ast::MovVectorType = {
".b16" => ast::MovVectorType::B16,
".b32" => ast::MovVectorType::B32,
".b64" => ast::MovVectorType::B64,
".u16" => ast::MovVectorType::U16,
".u32" => ast::MovVectorType::U32,
".u64" => ast::MovVectorType::U64,
".s16" => ast::MovVectorType::S16,
".s32" => ast::MovVectorType::S32,
".s64" => ast::MovVectorType::S64,
".f32" => ast::MovVectorType::F32,
".f64" => ast::MovVectorType::F64,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul
@ -989,29 +1018,6 @@ CallOperand: ast::CallOperand<&'input str> = {
}
};
MovOperand: ast::MovOperand<&'input str> = {
<o:Operand> => ast::MovOperand::Op(o),
<o:VectorOperand> => {
let (pref, suf) = o;
let suf_idx = match suf {
"x" | "r" => 0,
"y" | "g" => 1,
"z" | "b" => 2,
"w" | "a" => 3,
_ => {
errors.push(ast::PtxError::WrongVectorElement);
0
}
};
ast::MovOperand::Vec(pref, suf_idx)
}
};
VectorOperand: (&'input str, &'input str) = {
<pref:ExtendedID> "." <suf:ExtendedID> => (pref, suf),
<pref:ExtendedID> <suf:DotID> => (pref, &suf[1..]),
};
Arg1: ast::Arg1<ast::ParsedArgParams<'input>> = {
<src:ExtendedID> => ast::Arg1{<>}
};
@ -1020,8 +1026,21 @@ Arg2: ast::Arg2<ast::ParsedArgParams<'input>> = {
<dst:ExtendedID> "," <src:Operand> => ast::Arg2{<>}
};
Arg2Mov: ast::Arg2Mov<ast::ParsedArgParams<'input>> = {
<dst:ExtendedID> "," <src:MovOperand> => ast::Arg2Mov{<>}
Arg2Vec: ast::Arg2Vec<ast::ParsedArgParams<'input>> = {
<dst:VectorOperand> "," <src:ExtendedID> => ast::Arg2Vec::Dst(dst, src),
<dst:ExtendedID> "," <src:VectorOperand> => ast::Arg2Vec::Src(dst, src),
<dst:VectorOperand> "," <src:VectorOperand> => ast::Arg2Vec::Both(dst, src),
};
VectorOperand: (&'input str, u8) = {
<pref:ExtendedID> "." <suf:ExtendedID> =>? {
let suf_idx = vector_index(suf)?;
Ok((pref, suf_idx))
},
<pref:ExtendedID> <suf:DotID> =>? {
let suf_idx = vector_index(&suf[1..])?;
Ok((pref, suf_idx))
}
};
Arg3: ast::Arg3<ast::ParsedArgParams<'input>> = {

View file

@ -1,7 +1,7 @@
// Excersise as many features of vector types as possible
.version 6.5
.target sm_53
.target sm_60
.address_size 64
.func (.reg .v2 .u32 output) impl(
@ -17,6 +17,7 @@
add.u32 temp2, temp1, temp2;
mov.u32 temp_v.x, temp2;
mov.u32 temp_v.y, temp2;
mov.u32 temp_v.x, temp_v.y;
mov.v2.u32 output, temp_v;
ret;
}

View file

@ -737,14 +737,11 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
}
}
fn src_mov_operand(
fn src_vec_operand(
&mut self,
desc: ArgumentDescriptor<ast::MovOperand<spirv::Word>>,
) -> spirv::Word {
match &desc.op {
ast::MovOperand::Op(opr) => self.operand(desc.new_op(*opr)),
ast::MovOperand::Vec(opr, _) => self.variable(desc.new_op(*opr)),
}
desc: ArgumentDescriptor<(spirv::Word, u8)>,
) -> (spirv::Word, u8) {
(self.variable(desc.new_op(desc.op.0)), desc.op.1)
}
}
@ -986,8 +983,9 @@ fn emit_function_body_ops(
}
// SPIR-V does not support ret as guaranteed-converged
ast::Instruction::Ret(_) => builder.ret()?,
ast::Instruction::Mov(mov, arg) => {
let result_type = map.get_or_add(builder, SpirvType::from(mov.typ));
ast::Instruction::Mov(mov_type, arg) => {
let result_type =
map.get_or_add(builder, SpirvType::from(ast::Type::from(*mov_type)));
builder.copy_object(result_type, Some(arg.dst), arg.src)?;
}
ast::Instruction::Mul(mul, arg) => match mul {
@ -1032,6 +1030,7 @@ fn emit_function_body_ops(
builder.copy_object(result_type, Some(arg.dst), arg.src)?;
}
ast::Instruction::SetpBool(_, _) => todo!(),
ast::Instruction::MovVector(_, _) => todo!(),
},
Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(builder, SpirvType::from(*typ));
@ -1751,7 +1750,6 @@ impl ast::ArgParams for NormalizedArgParams {
type ID = spirv::Word;
type Operand = ast::Operand<spirv::Word>;
type CallOperand = ast::CallOperand<spirv::Word>;
type MovOperand = ast::MovOperand<spirv::Word>;
}
impl ArgParamsEx for NormalizedArgParams {
@ -1768,7 +1766,6 @@ impl ast::ArgParams for ExpandedArgParams {
type ID = spirv::Word;
type Operand = spirv::Word;
type CallOperand = spirv::Word;
type MovOperand = spirv::Word;
}
impl ArgParamsEx for ExpandedArgParams {
@ -1781,7 +1778,7 @@ trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {
fn variable(&mut self, desc: ArgumentDescriptor<T::ID>) -> U::ID;
fn operand(&mut self, desc: ArgumentDescriptor<T::Operand>) -> U::Operand;
fn src_call_operand(&mut self, desc: ArgumentDescriptor<T::CallOperand>) -> U::CallOperand;
fn src_mov_operand(&mut self, desc: ArgumentDescriptor<T::MovOperand>) -> U::MovOperand;
fn src_vec_operand(&mut self, desc: ArgumentDescriptor<(T::ID, u8)>) -> (U::ID, u8);
}
impl<T> ArgumentMapVisitor<ExpandedArgParams, ExpandedArgParams> for T
@ -1794,12 +1791,14 @@ where
fn operand(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
self(desc)
}
fn src_call_operand(&mut self, mut desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
desc.op = self(desc.new_op(desc.op));
desc.op
fn src_call_operand(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
self(desc.new_op(desc.op))
}
fn src_mov_operand(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
self(desc)
fn src_vec_operand(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, u8)>,
) -> (spirv::Word, u8) {
(self(desc.new_op(desc.op.0)), desc.op.1)
}
}
@ -1832,16 +1831,8 @@ where
}
}
fn src_mov_operand(
&mut self,
desc: ArgumentDescriptor<ast::MovOperand<&str>>,
) -> ast::MovOperand<spirv::Word> {
match desc.op {
ast::MovOperand::Op(op) => ast::MovOperand::Op(self.operand(desc.new_op(op))),
ast::MovOperand::Vec(reg, x2) => {
ast::MovOperand::Vec(self.variable(desc.new_op(reg)), x2)
}
}
fn src_vec_operand(&mut self, desc: ArgumentDescriptor<(&str, u8)>) -> (spirv::Word, u8) {
(self(desc.op.0), desc.op.1)
}
}
@ -1869,6 +1860,7 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
visitor: &mut V,
) -> ast::Instruction<U> {
match self {
ast::Instruction::MovVector(_, _) => todo!(),
ast::Instruction::Abs(_, _) => todo!(),
ast::Instruction::Call(_) => unreachable!(),
ast::Instruction::Ld(d, a) => {
@ -1879,9 +1871,8 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
a.map_ld(visitor, Some(ast::Type::Scalar(inst_type)), src_is_pointer),
)
}
ast::Instruction::Mov(d, a) => {
let inst_type = d.typ;
ast::Instruction::Mov(d, a.map(visitor, Some(inst_type)))
ast::Instruction::Mov(mov_type, a) => {
ast::Instruction::Mov(mov_type, a.map(visitor, Some(mov_type.into())))
}
ast::Instruction::Mul(d, a) => {
let inst_type = d.get_type();
@ -1982,19 +1973,11 @@ where
}
}
fn src_mov_operand(
fn src_vec_operand(
&mut self,
desc: ArgumentDescriptor<ast::MovOperand<spirv::Word>>,
) -> ast::MovOperand<spirv::Word> {
match desc.op {
ast::MovOperand::Op(op) => ast::MovOperand::Op(ArgumentMapVisitor::<
NormalizedArgParams,
NormalizedArgParams,
>::operand(
self, desc.new_op(op)
)),
ast::MovOperand::Vec(x1, x2) => ast::MovOperand::Vec(x1, x2),
}
desc: ArgumentDescriptor<(spirv::Word, u8)>,
) -> (spirv::Word, u8) {
(self(desc.new_op(desc.op.0)), desc.op.1)
}
}
@ -2004,6 +1987,7 @@ impl ast::Instruction<ExpandedArgParams> {
ast::Instruction::Bra(_, a) => Some(a.src),
ast::Instruction::Ld(_, _)
| ast::Instruction::Mov(_, _)
| ast::Instruction::MovVector(_, _)
| ast::Instruction::Mul(_, _)
| ast::Instruction::Add(_, _)
| ast::Instruction::Setp(_, _)
@ -2201,25 +2185,55 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
}
}
impl<T: ArgParamsEx> ast::Arg2Mov<T> {
impl<T: ArgParamsEx> ast::Arg2Vec<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: Option<ast::Type>,
) -> ast::Arg2Mov<U> {
ast::Arg2Mov {
dst: visitor.variable(ArgumentDescriptor {
op: self.dst,
typ: t,
is_dst: true,
is_pointer: false,
}),
src: visitor.src_mov_operand(ArgumentDescriptor {
op: self.src,
typ: t,
is_dst: false,
is_pointer: false,
}),
t: ast::Type,
) -> ast::Arg2Vec<U> {
match self {
ast::Arg2Vec::Dst(dst, src) => ast::Arg2Vec::Dst(
visitor.src_vec_operand(ArgumentDescriptor {
op: dst,
typ: Some(t),
is_dst: true,
is_pointer: false,
}),
visitor.variable(ArgumentDescriptor {
op: src,
typ: Some(t),
is_dst: false,
is_pointer: false,
}),
),
ast::Arg2Vec::Src(dst, src) => ast::Arg2Vec::Src (
visitor.variable(ArgumentDescriptor {
op: dst,
typ: Some(t),
is_dst: true,
is_pointer: false,
}),
visitor.src_vec_operand(ArgumentDescriptor {
op: src,
typ: Some(t),
is_dst: false,
is_pointer: false,
}),
),
ast::Arg2Vec::Both(dst, src) => ast::Arg2Vec::Both (
visitor.src_vec_operand(ArgumentDescriptor {
op: dst,
typ: Some(t),
is_dst: true,
is_pointer: false,
}),
visitor.src_vec_operand(ArgumentDescriptor {
op: src,
typ: Some(t),
is_dst: false,
is_pointer: false,
}),
),
}
}
}