This commit is contained in:
Andrzej Janik 2024-08-23 02:19:36 +02:00
parent 71e025845c
commit 1ec1ca0c30
11 changed files with 903 additions and 148 deletions

View file

@ -7,7 +7,7 @@ edition = "2018"
[lib]
[dependencies]
lalrpop-util = "0.19"
ptx_parser = { path = "../ptx_parser" }
regex = "1"
rspirv = "0.7"
spirv_headers = "1.5"
@ -17,8 +17,12 @@ bit-vec = "0.6"
half ="1.6"
bitflags = "1.2"
[dependencies.lalrpop-util]
version = "0.19.12"
features = ["lexer"]
[build-dependencies.lalrpop]
version = "0.19"
version = "0.19.12"
features = ["lexer"]
[dev-dependencies]

View file

@ -34,15 +34,9 @@ pub enum PtxError {
#[error("")]
NonExternPointer,
#[error("{start}:{end}")]
UnrecognizedStatement {
start: usize,
end: usize,
},
UnrecognizedStatement { start: usize, end: usize },
#[error("{start}:{end}")]
UnrecognizedDirective {
start: usize,
end: usize,
},
UnrecognizedDirective { start: usize, end: usize },
}
// For some weird reson this is illegal:
@ -578,11 +572,15 @@ impl CvtDetails {
if saturate {
if src.kind() == ScalarKind::Signed {
if dst.kind() == ScalarKind::Signed && dst.size_of() >= src.size_of() {
err.push(ParseError::from(PtxError::SyntaxError));
err.push(ParseError::User {
error: PtxError::SyntaxError,
});
}
} else {
if dst == src || dst.size_of() >= src.size_of() {
err.push(ParseError::from(PtxError::SyntaxError));
err.push(ParseError::User {
error: PtxError::SyntaxError,
});
}
}
}
@ -598,7 +596,9 @@ impl CvtDetails {
err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>,
) -> Self {
if flush_to_zero && dst != ScalarType::F32 {
err.push(ParseError::from(PtxError::NonF32Ftz));
err.push(ParseError::from(lalrpop_util::ParseError::User {
error: PtxError::NonF32Ftz,
}));
}
CvtDetails::FloatFromInt(CvtDesc {
dst,
@ -618,7 +618,9 @@ impl CvtDetails {
err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>,
) -> Self {
if flush_to_zero && src != ScalarType::F32 {
err.push(ParseError::from(PtxError::NonF32Ftz));
err.push(ParseError::from(lalrpop_util::ParseError::User {
error: PtxError::NonF32Ftz,
}));
}
CvtDetails::IntFromFloat(CvtDesc {
dst,

View file

@ -24,9 +24,11 @@ lalrpop_mod!(
);
pub mod ast;
mod pass;
#[cfg(test)]
mod test;
mod translate;
mod translate2;
use std::fmt;

531
ptx/src/pass/mod.rs Normal file
View file

@ -0,0 +1,531 @@
use ptx_parser as ast;
use std::{
borrow::Cow,
cell::RefCell,
collections::{hash_map, HashMap},
rc::Rc,
};
mod normalize;
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
enum PtxSpecialRegister {
Tid,
Ntid,
Ctaid,
Nctaid,
Clock,
LanemaskLt,
}
impl PtxSpecialRegister {
fn try_parse(s: &str) -> Option<Self> {
match s {
"%tid" => Some(Self::Tid),
"%ntid" => Some(Self::Ntid),
"%ctaid" => Some(Self::Ctaid),
"%nctaid" => Some(Self::Nctaid),
"%clock" => Some(Self::Clock),
"%lanemask_lt" => Some(Self::LanemaskLt),
_ => None,
}
}
fn get_type(self) -> ast::Type {
match self {
PtxSpecialRegister::Tid
| PtxSpecialRegister::Ntid
| PtxSpecialRegister::Ctaid
| PtxSpecialRegister::Nctaid => ast::Type::Vector(self.get_function_return_type(), 4),
_ => ast::Type::Scalar(self.get_function_return_type()),
}
}
fn get_function_return_type(self) -> ast::ScalarType {
match self {
PtxSpecialRegister::Tid => ast::ScalarType::U32,
PtxSpecialRegister::Ntid => ast::ScalarType::U32,
PtxSpecialRegister::Ctaid => ast::ScalarType::U32,
PtxSpecialRegister::Nctaid => ast::ScalarType::U32,
PtxSpecialRegister::Clock => ast::ScalarType::U32,
PtxSpecialRegister::LanemaskLt => ast::ScalarType::U32,
}
}
fn get_function_input_type(self) -> Option<ast::ScalarType> {
match self {
PtxSpecialRegister::Tid
| PtxSpecialRegister::Ntid
| PtxSpecialRegister::Ctaid
| PtxSpecialRegister::Nctaid => Some(ast::ScalarType::U8),
PtxSpecialRegister::Clock | PtxSpecialRegister::LanemaskLt => None,
}
}
fn get_unprefixed_function_name(self) -> &'static str {
match self {
PtxSpecialRegister::Tid => "sreg_tid",
PtxSpecialRegister::Ntid => "sreg_ntid",
PtxSpecialRegister::Ctaid => "sreg_ctaid",
PtxSpecialRegister::Nctaid => "sreg_nctaid",
PtxSpecialRegister::Clock => "sreg_clock",
PtxSpecialRegister::LanemaskLt => "sreg_lanemask_lt",
}
}
}
struct SpecialRegistersMap {
reg_to_id: HashMap<PtxSpecialRegister, SpirvWord>,
id_to_reg: HashMap<SpirvWord, PtxSpecialRegister>,
}
impl SpecialRegistersMap {
fn new() -> Self {
SpecialRegistersMap {
reg_to_id: HashMap::new(),
id_to_reg: HashMap::new(),
}
}
fn get(&self, id: SpirvWord) -> Option<PtxSpecialRegister> {
self.id_to_reg.get(&id).copied()
}
fn get_or_add(&mut self, current_id: &mut SpirvWord, reg: PtxSpecialRegister) -> SpirvWord {
match self.reg_to_id.entry(reg) {
hash_map::Entry::Occupied(e) => *e.get(),
hash_map::Entry::Vacant(e) => {
let numeric_id = SpirvWord(current_id.0);
current_id.0 += 1;
e.insert(numeric_id);
self.id_to_reg.insert(numeric_id, reg);
numeric_id
}
}
}
}
struct FnStringIdResolver<'input, 'b> {
current_id: &'b mut SpirvWord,
global_variables: &'b HashMap<Cow<'input, str>, SpirvWord>,
global_type_check: &'b HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
special_registers: &'b mut SpecialRegistersMap,
variables: Vec<HashMap<Cow<'input, str>, SpirvWord>>,
type_check: HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
}
impl<'a, 'b> FnStringIdResolver<'a, 'b> {
fn finish(self) -> NumericIdResolver<'b> {
NumericIdResolver {
current_id: self.current_id,
global_type_check: self.global_type_check,
type_check: self.type_check,
special_registers: self.special_registers,
}
}
fn start_block(&mut self) {
self.variables.push(HashMap::new())
}
fn end_block(&mut self) {
self.variables.pop();
}
fn get_id(&mut self, id: &str) -> Result<SpirvWord, TranslateError> {
for scope in self.variables.iter().rev() {
match scope.get(id) {
Some(id) => return Ok(*id),
None => continue,
}
}
match self.global_variables.get(id) {
Some(id) => Ok(*id),
None => {
let sreg = PtxSpecialRegister::try_parse(id).ok_or_else(error_unknown_symbol)?;
Ok(self.special_registers.get_or_add(self.current_id, sreg))
}
}
}
fn add_def(
&mut self,
id: &'a str,
typ: Option<(ast::Type, ast::StateSpace)>,
is_variable: bool,
) -> SpirvWord {
let numeric_id = *self.current_id;
self.variables
.last_mut()
.unwrap()
.insert(Cow::Borrowed(id), numeric_id);
self.type_check.insert(
numeric_id.0,
typ.map(|(typ, space)| (typ, space, is_variable)),
);
self.current_id.0 += 1;
numeric_id
}
#[must_use]
fn add_defs(
&mut self,
base_id: &'a str,
count: u32,
typ: ast::Type,
state_space: ast::StateSpace,
is_variable: bool,
) -> impl Iterator<Item = SpirvWord> {
let numeric_id = *self.current_id;
for i in 0..count {
self.variables.last_mut().unwrap().insert(
Cow::Owned(format!("{}{}", base_id, i)),
SpirvWord(numeric_id.0 + i),
);
self.type_check.insert(
numeric_id.0 + i,
Some((typ.clone(), state_space, is_variable)),
);
}
self.current_id.0 += count;
(0..count)
.into_iter()
.map(move |i| SpirvWord(i + numeric_id.0))
}
}
struct NumericIdResolver<'b> {
current_id: &'b mut SpirvWord,
global_type_check: &'b HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
type_check: HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
special_registers: &'b mut SpecialRegistersMap,
}
impl<'b> NumericIdResolver<'b> {
fn finish(self) -> MutableNumericIdResolver<'b> {
MutableNumericIdResolver { base: self }
}
fn get_typed(
&self,
id: SpirvWord,
) -> Result<(ast::Type, ast::StateSpace, bool), TranslateError> {
match self.type_check.get(&id.0) {
Some(Some(x)) => Ok(x.clone()),
Some(None) => Err(TranslateError::UntypedSymbol),
None => match self.special_registers.get(id) {
Some(x) => Ok((x.get_type(), ast::StateSpace::Sreg, true)),
None => match self.global_type_check.get(&id.0) {
Some(Some(result)) => Ok(result.clone()),
Some(None) | None => Err(TranslateError::UntypedSymbol),
},
},
}
}
// This is for identifiers which will be emitted later as OpVariable
// They are candidates for insertion of LoadVar/StoreVar
fn register_variable(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> SpirvWord {
let new_id = *self.current_id;
self.type_check
.insert(new_id.0, Some((typ, state_space, true)));
self.current_id.0 += 1;
new_id
}
fn register_intermediate(&mut self, typ: Option<(ast::Type, ast::StateSpace)>) -> SpirvWord {
let new_id = *self.current_id;
self.type_check
.insert(new_id.0, typ.map(|(t, space)| (t, space, false)));
self.current_id.0 += 1;
new_id
}
}
struct MutableNumericIdResolver<'b> {
base: NumericIdResolver<'b>,
}
impl<'b> MutableNumericIdResolver<'b> {
fn unmut(self) -> NumericIdResolver<'b> {
self.base
}
fn get_typed(&self, id: SpirvWord) -> Result<(ast::Type, ast::StateSpace), TranslateError> {
self.base.get_typed(id).map(|(t, space, _)| (t, space))
}
fn register_intermediate(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> SpirvWord {
self.base.register_intermediate(Some((typ, state_space)))
}
}
quick_error! {
#[derive(Debug)]
pub enum TranslateError {
UnknownSymbol {}
UntypedSymbol {}
MismatchedType {}
Spirv(err: rspirv::dr::Error) {
from()
display("{}", err)
cause(err)
}
Unreachable {}
Todo {}
}
}
#[cfg(debug_assertions)]
fn error_unreachable() -> TranslateError {
unreachable!()
}
#[cfg(not(debug_assertions))]
fn error_unreachable() -> TranslateError {
TranslateError::Unreachable
}
fn error_unknown_symbol() -> TranslateError {
TranslateError::UnknownSymbol
}
pub struct GlobalFnDeclResolver<'input, 'a> {
fns: &'a HashMap<SpirvWord, FnSigMapper<'input>>,
}
impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> {
fn get_fn_sig_resolver(&self, id: SpirvWord) -> Result<&FnSigMapper<'input>, TranslateError> {
self.fns.get(&id).ok_or_else(error_unknown_symbol)
}
}
struct FnSigMapper<'input> {
// true - stays as return argument
// false - is moved to input argument
return_param_args: Vec<bool>,
func_decl: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
}
impl<'input> FnSigMapper<'input> {
fn remap_to_spirv_repr(mut method: ast::MethodDeclaration<'input, SpirvWord>) -> Self {
let return_param_args = method
.return_arguments
.iter()
.map(|a| a.state_space != ast::StateSpace::Param)
.collect::<Vec<_>>();
let mut new_return_arguments = Vec::new();
for arg in method.return_arguments.into_iter() {
if arg.state_space == ast::StateSpace::Param {
method.input_arguments.push(arg);
} else {
new_return_arguments.push(arg);
}
}
method.return_arguments = new_return_arguments;
FnSigMapper {
return_param_args,
func_decl: Rc::new(RefCell::new(method)),
}
}
/*
fn resolve_in_spirv_repr(
&self,
call_inst: ast::CallInst<NormalizedArgParams>,
) -> Result<ResolvedCall<NormalizedArgParams>, TranslateError> {
let func_decl = (*self.func_decl).borrow();
let mut return_arguments = Vec::new();
let mut input_arguments = call_inst
.param_list
.into_iter()
.zip(func_decl.input_arguments.iter())
.map(|(id, var)| (id, var.v_type.clone(), var.state_space))
.collect::<Vec<_>>();
let mut func_decl_return_iter = func_decl.return_arguments.iter();
let mut func_decl_input_iter = func_decl.input_arguments[input_arguments.len()..].iter();
for (idx, id) in call_inst.ret_params.iter().enumerate() {
let stays_as_return = match self.return_param_args.get(idx) {
Some(x) => *x,
None => return Err(TranslateError::MismatchedType),
};
if stays_as_return {
if let Some(var) = func_decl_return_iter.next() {
return_arguments.push((*id, var.v_type.clone(), var.state_space));
} else {
return Err(TranslateError::MismatchedType);
}
} else {
if let Some(var) = func_decl_input_iter.next() {
input_arguments.push((
ast::Operand::Reg(*id),
var.v_type.clone(),
var.state_space,
));
} else {
return Err(TranslateError::MismatchedType);
}
}
}
if return_arguments.len() != func_decl.return_arguments.len()
|| input_arguments.len() != func_decl.input_arguments.len()
{
return Err(TranslateError::MismatchedType);
}
Ok(ResolvedCall {
return_arguments,
input_arguments,
uniform: call_inst.uniform,
name: call_inst.func,
})
}
*/
}
enum Statement<I, P: ast::Operand> {
Label(SpirvWord),
Variable(ast::Variable<P::Ident>),
Instruction(I),
// SPIR-V compatible replacement for PTX predicates
Conditional(BrachCondition),
LoadVar(LoadVarDetails),
StoreVar(StoreVarDetails),
Conversion(ImplicitConversion),
Constant(ConstantDefinition),
RetValue(ast::RetData, SpirvWord),
PtrAccess(PtrAccess<P>),
RepackVector(RepackVectorDetails),
FunctionPointer(FunctionPointerDetails),
}
struct BrachCondition {
predicate: SpirvWord,
if_true: SpirvWord,
if_false: SpirvWord,
}
struct LoadVarDetails {
arg: ast::LdArgs<SpirvWord>,
typ: ast::Type,
state_space: ast::StateSpace,
// (index, vector_width)
// HACK ALERT
// For some reason IGC explodes when you try to load from builtin vectors
// using OpInBoundsAccessChain, the one true way to do it is to
// OpLoad+OpCompositeExtract
member_index: Option<(u8, Option<u8>)>,
}
struct StoreVarDetails {
arg: ast::StArgs<SpirvWord>,
typ: ast::Type,
member_index: Option<u8>,
}
#[derive(Clone)]
struct ImplicitConversion {
src: SpirvWord,
dst: SpirvWord,
from_type: ast::Type,
to_type: ast::Type,
from_space: ast::StateSpace,
to_space: ast::StateSpace,
kind: ConversionKind,
}
#[derive(PartialEq, Clone)]
enum ConversionKind {
Default,
// zero-extend/chop/bitcast depending on types
SignExtend,
BitToPtr,
PtrToPtr,
AddressOf,
}
struct ConstantDefinition {
pub dst: SpirvWord,
pub typ: ast::ScalarType,
pub value: ast::ImmediateValue,
}
pub struct PtrAccess<T> {
underlying_type: ast::Type,
state_space: ast::StateSpace,
dst: SpirvWord,
ptr_src: SpirvWord,
offset_src: T,
}
struct RepackVectorDetails {
is_extract: bool,
typ: ast::ScalarType,
packed: SpirvWord,
unpacked: Vec<SpirvWord>,
non_default_implicit_conversion: Option<
fn(
(ast::StateSpace, &ast::Type),
(ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError>,
>,
}
struct FunctionPointerDetails {
dst: SpirvWord,
src: SpirvWord,
}
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
struct SpirvWord(spirv::Word);
impl From<spirv::Word> for SpirvWord {
fn from(value: spirv::Word) -> Self {
Self(value)
}
}
impl From<SpirvWord> for spirv::Word {
fn from(value: SpirvWord) -> Self {
value.0
}
}
impl ast::Operand for SpirvWord {
type Ident = Self;
}
fn pred_map_variable<U, T, F: FnMut(T) -> Result<U, TranslateError>>(
this: ast::PredAt<T>,
f: &mut F,
) -> Result<ast::PredAt<U>, TranslateError> {
let new_label = f(this.label)?;
Ok(ast::PredAt {
not: this.not,
label: new_label,
})
}
impl<T: ast::Operand, U: ast::Operand, X: FnMut(&str) -> Result<SpirvWord, Err>, Err> ast::VisitorMap<T, U, Err> for X {
fn visit(
&mut self,
args: T,
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
is_dst: bool,
) -> U {
todo!()
}
fn visit_ident(
&mut self,
args: <T as ptx_parser::Operand>::Ident,
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
is_dst: bool,
) -> <U as ptx_parser::Operand>::Ident {
todo!()
}
}
fn op_map_variable<'a, F: FnMut(&str) -> Result<SpirvWord, TranslateError>>(
this: ast::Instruction<ast::ParsedOperand<&'a str>>,
f: &mut F,
) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
ast::visit_map(this , f)
}

83
ptx/src/pass/normalize.rs Normal file
View file

@ -0,0 +1,83 @@
use super::*;
use ptx_parser as ast;
type NormalizedStatement = Statement<
(
Option<ast::PredAt<SpirvWord>>,
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
),
ast::ParsedOperand<SpirvWord>,
>;
fn run<'input, 'b>(
id_defs: &mut FnStringIdResolver<'input, 'b>,
fn_defs: &GlobalFnDeclResolver<'input, 'b>,
func: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
) -> Result<Vec<NormalizedStatement>, TranslateError> {
for s in func.iter() {
match s {
ast::Statement::Label(id) => {
id_defs.add_def(*id, None, false);
}
_ => (),
}
}
let mut result = Vec::new();
for s in func {
expand_map_variables(id_defs, fn_defs, &mut result, s)?;
}
Ok(result)
}
fn expand_map_variables<'a, 'b>(
id_defs: &mut FnStringIdResolver<'a, 'b>,
fn_defs: &GlobalFnDeclResolver<'a, 'b>,
result: &mut Vec<NormalizedStatement>,
s: ast::Statement<ast::ParsedOperand<&'a str>>,
) -> Result<(), TranslateError> {
match s {
ast::Statement::Block(block) => {
id_defs.start_block();
for s in block {
expand_map_variables(id_defs, fn_defs, result, s)?;
}
id_defs.end_block();
}
ast::Statement::Label(name) => result.push(Statement::Label(id_defs.get_id(name)?)),
ast::Statement::Instruction(p, i) => result.push(Statement::Instruction((
p.map(|p| pred_map_variable(p, &mut |id| id_defs.get_id(id)))
.transpose()?,
op_map_variable(i, &mut |id| id_defs.get_id(id))?,
))),
ast::Statement::Variable(var) => {
let var_type = var.var.v_type.clone();
match var.count {
Some(count) => {
for new_id in
id_defs.add_defs(var.var.name, count, var_type, var.var.state_space, true)
{
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
v_type: var.var.v_type.clone(),
state_space: var.var.state_space,
name: new_id,
array_init: var.var.array_init.clone(),
}))
}
}
None => {
let new_id =
id_defs.add_def(var.var.name, Some((var_type, var.var.state_space)), true);
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
v_type: var.var.v_type.clone(),
state_space: var.var.state_space,
name: new_id,
array_init: var.var.array_init,
}));
}
}
}
};
Ok(())
}

60
ptx/src/translate2.rs Normal file
View file

@ -0,0 +1,60 @@
use std::collections::HashMap;
use half::f16;
use ptx_parser as ast;
fn to_ssa<'input, 'b>(
ptx_impl_imports: &'b mut HashMap<String, Directive<'input>>,
mut id_defs: FnStringIdResolver<'input, 'b>,
fn_defs: GlobalFnDeclResolver<'input, 'b>,
func_decl: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
f_body: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>>,
tuning: Vec<ast::TuningDirective>,
linkage: ast::LinkingDirective,
) -> Result<Function<'input>, TranslateError> {
//deparamize_function_decl(&func_decl)?;
let f_body = match f_body {
Some(vec) => vec,
None => {
return Ok(Function {
func_decl: func_decl,
body: None,
globals: Vec::new(),
import_as: None,
tuning,
linkage,
})
}
};
let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body)?;
/*
let mut numeric_id_defs = id_defs.finish();
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
let typed_statements =
convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
let typed_statements =
fix_special_registers2(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?;
let (func_decl, typed_statements) =
convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?;
let ssa_statements = insert_mem_ssa_statements(
typed_statements,
&mut numeric_id_defs,
&mut (*func_decl).borrow_mut(),
)?;
let mut numeric_id_defs = numeric_id_defs.finish();
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
let expanded_statements =
insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?;
let mut numeric_id_defs = numeric_id_defs.unmut();
let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs);
let (f_body, globals) =
extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?;
Ok(Function {
func_decl: func_decl,
globals: globals,
body: Some(f_body),
import_as: None,
tuning,
linkage,
})
*/
}

View file

@ -4,6 +4,8 @@ version = "0.0.0"
authors = ["Andrzej Janik <vosen@vosen.pl>"]
edition = "2021"
[lib]
[dependencies]
logos = "0.14"
winnow = { version = "0.6.18" }
@ -11,3 +13,4 @@ ptx_parser_macros = { path = "../ptx_parser_macros" }
thiserror = "1.0"
bitflags = "1.2"
rustc-hash = "2.0.0"
derive_more = { version = "1", features = ["display"] }

View file

@ -147,9 +147,9 @@ ptx_parser_macros::generate_instruction_type!(
Call {
data: CallDetails,
arguments: CallArgs<T>,
visit: arguments.visit(data, visitor),
visit_mut: arguments.visit_mut(data, visitor),
map: Instruction::Call{ arguments: arguments.map(&data, visitor), data }
visit: arguments.visit(data, visitor)?,
visit_mut: arguments.visit_mut(data, visitor)?,
map: Instruction::Call{ arguments: arguments.map(&data, visitor)?, data }
},
Cvt {
data: CvtDetails,
@ -488,93 +488,185 @@ ptx_parser_macros::generate_instruction_type!(
}
);
pub trait Visitor<T: Operand> {
fn visit(&mut self, args: &T, type_space: Option<(&Type, StateSpace)>, is_dst: bool);
fn visit_ident(&self, args: &T::Ident, type_space: Option<(&Type, StateSpace)>, is_dst: bool);
pub trait Visitor<T: Operand, Err> {
fn visit(
&mut self,
args: &T,
type_space: Option<(&Type, StateSpace)>,
is_dst: bool,
) -> Result<(), Err>;
fn visit_ident(
&mut self,
args: &T::Ident,
type_space: Option<(&Type, StateSpace)>,
is_dst: bool,
) -> Result<(), Err>;
}
pub trait VisitorMut<T: Operand> {
fn visit(&mut self, args: &mut T, type_space: Option<(&Type, StateSpace)>, is_dst: bool);
impl<T: Operand, Err, Fn: FnMut(&T, Option<(&Type, StateSpace)>, bool) -> Result<(), Err>>
Visitor<T, Err> for Fn
{
fn visit(
&mut self,
args: &T,
type_space: Option<(&Type, StateSpace)>,
is_dst: bool,
) -> Result<(), Err> {
(self)(args, type_space, is_dst)
}
fn visit_ident(
&mut self,
args: &T::Ident,
type_space: Option<(&Type, StateSpace)>,
is_dst: bool,
) -> Result<(), Err> {
(self)(&T::from_ident(*args), type_space, is_dst)
}
}
pub trait VisitorMut<T: Operand, Err> {
fn visit(
&mut self,
args: &mut T,
type_space: Option<(&Type, StateSpace)>,
is_dst: bool,
) -> Result<(), Err>;
fn visit_ident(
&mut self,
args: &mut T::Ident,
type_space: Option<(&Type, StateSpace)>,
is_dst: bool,
);
) -> Result<(), Err>;
}
pub trait VisitorMap<From: Operand, To: Operand> {
fn visit(&mut self, args: From, type_space: Option<(&Type, StateSpace)>, is_dst: bool) -> To;
pub trait VisitorMap<From: Operand, To: Operand, Err> {
fn visit(
&mut self,
args: From,
type_space: Option<(&Type, StateSpace)>,
is_dst: bool,
) -> Result<To, Err>;
fn visit_ident(
&mut self,
args: From::Ident,
type_space: Option<(&Type, StateSpace)>,
is_dst: bool,
) -> To::Ident;
) -> Result<To::Ident, Err>;
}
trait VisitOperand {
impl<
T: Operand,
U: Operand,
Err,
Fn: FnMut(T, Option<(&Type, StateSpace)>, bool) -> Result<U, Err>,
> VisitorMap<T, U, Err> for Fn
{
fn visit(
&mut self,
args: T,
type_space: Option<(&Type, StateSpace)>,
is_dst: bool,
) -> Result<U, Err> {
(self)(args, type_space, is_dst)
}
fn visit_ident(
&mut self,
args: T::Ident,
type_space: Option<(&Type, StateSpace)>,
is_dst: bool,
) -> Result<U::Ident, Err> {
let value: U = (self)(T::from_ident(args), type_space, is_dst)?;
Ok(value)
}
}
trait VisitOperand<Err> {
type Operand: Operand;
#[allow(unused)] // Used by generated code
fn visit(&self, fn_: impl FnMut(&Self::Operand));
fn visit(&self, fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err>;
#[allow(unused)] // Used by generated code
fn visit_mut(&mut self, fn_: impl FnMut(&mut Self::Operand));
fn visit_mut(
&mut self,
fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>,
) -> Result<(), Err>;
}
impl<T: Operand> VisitOperand for T {
impl<T: Operand, Err> VisitOperand<Err> for T {
type Operand = Self;
fn visit(&self, mut fn_: impl FnMut(&Self::Operand)) {
fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> {
fn_(self)
}
fn visit_mut(&mut self, mut fn_: impl FnMut(&mut Self::Operand)) {
fn visit_mut(
&mut self,
mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>,
) -> Result<(), Err> {
fn_(self)
}
}
impl<T: Operand> VisitOperand for Option<T> {
impl<T: Operand, Err> VisitOperand<Err> for Option<T> {
type Operand = T;
fn visit(&self, fn_: impl FnMut(&Self::Operand)) {
self.as_ref().map(fn_);
fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> {
if let Some(x) = self {
fn_(x)?;
}
Ok(())
}
fn visit_mut(&mut self, fn_: impl FnMut(&mut Self::Operand)) {
self.as_mut().map(fn_);
fn visit_mut(
&mut self,
mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>,
) -> Result<(), Err> {
if let Some(x) = self {
fn_(x)?;
}
Ok(())
}
}
impl<T: Operand> VisitOperand for Vec<T> {
impl<T: Operand, Err> VisitOperand<Err> for Vec<T> {
type Operand = T;
fn visit(&self, mut fn_: impl FnMut(&Self::Operand)) {
fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> {
for o in self {
fn_(o)
fn_(o)?;
}
Ok(())
}
fn visit_mut(&mut self, mut fn_: impl FnMut(&mut Self::Operand)) {
fn visit_mut(
&mut self,
mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>,
) -> Result<(), Err> {
for o in self {
fn_(o)
fn_(o)?;
}
Ok(())
}
}
trait MapOperand: Sized {
trait MapOperand<Err>: Sized {
type Input;
type Output<U>;
#[allow(unused)] // Used by generated code
fn map<U>(self, fn_: impl FnOnce(Self::Input) -> U) -> Self::Output<U>;
fn map<U>(
self,
fn_: impl FnOnce(Self::Input) -> Result<U, Err>,
) -> Result<Self::Output<U>, Err>;
}
impl<T: Operand> MapOperand for T {
impl<T: Operand, Err> MapOperand<Err> for T {
type Input = Self;
type Output<U> = U;
fn map<U>(self, fn_: impl FnOnce(T) -> U) -> U {
fn map<U>(self, fn_: impl FnOnce(T) -> Result<U, Err>) -> Result<U, Err> {
fn_(self)
}
}
impl<T: Operand> MapOperand for Option<T> {
impl<T: Operand, Err> MapOperand<Err> for Option<T> {
type Input = T;
type Output<U> = Option<U>;
fn map<U>(self, fn_: impl FnOnce(T) -> U) -> Option<U> {
self.map(|x| fn_(x))
fn map<U>(self, fn_: impl FnOnce(T) -> Result<U, Err>) -> Result<Option<U>, Err> {
self.map(|x| fn_(x)).transpose()
}
}
@ -715,10 +807,16 @@ pub enum ParsedOperand<Ident> {
impl<Ident: Copy> Operand for ParsedOperand<Ident> {
type Ident = Ident;
fn from_ident(ident: Self::Ident) -> Self {
ParsedOperand::Reg(ident)
}
}
pub trait Operand {
pub trait Operand: Sized {
type Ident: Copy;
fn from_ident(ident: Self::Ident) -> Self;
}
#[derive(Copy, Clone)]
@ -1048,67 +1146,77 @@ pub struct CallArgs<T: Operand> {
impl<T: Operand> CallArgs<T> {
#[allow(dead_code)] // Used by generated code
fn visit(&self, details: &CallDetails, visitor: &mut impl Visitor<T>) {
fn visit<Err>(
&self,
details: &CallDetails,
visitor: &mut impl Visitor<T, Err>,
) -> Result<(), Err> {
for (param, (type_, space)) in self
.return_arguments
.iter()
.zip(details.return_arguments.iter())
{
visitor.visit_ident(param, Some((type_, *space)), true);
visitor.visit_ident(param, Some((type_, *space)), true)?;
}
visitor.visit_ident(&self.func, None, false);
visitor.visit_ident(&self.func, None, false)?;
for (param, (type_, space)) in self
.input_arguments
.iter()
.zip(details.input_arguments.iter())
{
visitor.visit(param, Some((type_, *space)), true);
visitor.visit(param, Some((type_, *space)), true)?;
}
Ok(())
}
#[allow(dead_code)] // Used by generated code
fn visit_mut(&mut self, details: &CallDetails, visitor: &mut impl VisitorMut<T>) {
fn visit_mut<Err>(
&mut self,
details: &CallDetails,
visitor: &mut impl VisitorMut<T, Err>,
) -> Result<(), Err> {
for (param, (type_, space)) in self
.return_arguments
.iter_mut()
.zip(details.return_arguments.iter())
{
visitor.visit_ident(param, Some((type_, *space)), true);
visitor.visit_ident(param, Some((type_, *space)), true)?;
}
visitor.visit_ident(&mut self.func, None, false);
visitor.visit_ident(&mut self.func, None, false)?;
for (param, (type_, space)) in self
.input_arguments
.iter_mut()
.zip(details.input_arguments.iter())
{
visitor.visit(param, Some((type_, *space)), true);
visitor.visit(param, Some((type_, *space)), true)?;
}
Ok(())
}
#[allow(dead_code)] // Used by generated code
fn map<U: Operand>(
fn map<U: Operand, Err>(
self,
details: &CallDetails,
visitor: &mut impl VisitorMap<T, U>,
) -> CallArgs<U> {
visitor: &mut impl VisitorMap<T, U, Err>,
) -> Result<CallArgs<U>, Err> {
let return_arguments = self
.return_arguments
.into_iter()
.zip(details.return_arguments.iter())
.map(|(param, (type_, space))| visitor.visit_ident(param, Some((type_, *space)), true))
.collect::<Vec<_>>();
let func = visitor.visit_ident(self.func, None, false);
.collect::<Result<Vec<_>, _>>()?;
let func = visitor.visit_ident(self.func, None, false)?;
let input_arguments = self
.input_arguments
.into_iter()
.zip(details.input_arguments.iter())
.map(|(param, (type_, space))| visitor.visit(param, Some((type_, *space)), true))
.collect::<Vec<_>>();
CallArgs {
.collect::<Result<Vec<_>, _>>()?;
Ok(CallArgs {
return_arguments,
func,
input_arguments,
}
})
}
}

View file

@ -1,8 +1,8 @@
use derive_more::Display;
use logos::Logos;
use ptx_parser_macros::derive_parser;
use rustc_hash::FxHashMap;
use std::fmt::Debug;
use std::mem;
use std::num::{ParseFloatError, ParseIntError};
use winnow::ascii::dec_uint;
use winnow::combinator::*;
@ -81,16 +81,16 @@ impl VectorPrefix {
}
}
struct PtxParserState<'input> {
errors: Vec<PtxError>,
struct PtxParserState<'a, 'input> {
errors: &'a mut Vec<PtxError>,
function_declarations:
FxHashMap<&'input str, (Vec<(ast::Type, StateSpace)>, Vec<(ast::Type, StateSpace)>)>,
}
impl<'input> PtxParserState<'input> {
fn new() -> Self {
impl<'a, 'input> PtxParserState<'a, 'input> {
fn new(errors: &'a mut Vec<PtxError>) -> Self {
Self {
errors: Vec::new(),
errors,
function_declarations: FxHashMap::default(),
}
}
@ -115,7 +115,7 @@ impl<'input> PtxParserState<'input> {
}
}
impl<'input> Debug for PtxParserState<'input> {
impl<'a, 'input> Debug for PtxParserState<'a, 'input> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PtxParserState")
.field("errors", &self.errors) /* .field("function_decl", &self.function_decl) */
@ -123,7 +123,7 @@ impl<'input> Debug for PtxParserState<'input> {
}
}
type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState<'input>>;
type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState<'a, 'input>>;
fn ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> {
any.verify_map(|t| {
@ -277,6 +277,18 @@ fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<as
.parse_next(stream)
}
pub fn parse_module_unchecked<'input>(text: &'input str) -> Option<ast::Module<'input>> {
let lexer = Token::lexer(text);
let input = lexer.collect::<Result<Vec<_>, _>>().ok()?;
let mut errors = Vec::new();
let state = PtxParserState::new(&mut errors);
let parser = PtxParser {
state,
input: &input[..],
};
module.parse(parser).ok()
}
fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Module<'input>> {
(
version,
@ -818,6 +830,8 @@ pub enum PtxError {
source: ParseFloatError,
},
#[error("")]
Lexer(#[from] TokenError),
#[error("")]
Todo,
#[error("")]
SyntaxError,
@ -1042,9 +1056,15 @@ fn empty_call<'input>(
type ParsedOperandStr<'input> = ast::ParsedOperand<&'input str>;
#[derive(Clone, PartialEq, Default, Debug, Display)]
pub struct TokenError;
impl std::error::Error for TokenError {}
derive_parser!(
#[derive(Logos, PartialEq, Eq, Debug, Clone, Copy)]
#[logos(skip r"\s+")]
#[logos(error = TokenError)]
enum Token<'input> {
#[token(",")]
Comma,
@ -1134,6 +1154,7 @@ derive_parser!(
pub enum StateSpace {
Reg,
Generic,
Sreg,
}
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
@ -2825,57 +2846,6 @@ derive_parser!(
);
fn main() {
use winnow::Parser;
let lexer = Token::lexer(
"
.version 6.5
.target sm_30
.address_size 64
.const .align 8 .b32 constparams;
.visible .entry const(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .b16 temp1;
.reg .b16 temp2;
.reg .b16 temp3;
.reg .b16 temp4;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.const.b16 temp1, [constparams];
ld.const.b16 temp2, [constparams+2];
ld.const.b16 temp3, [constparams+4];
ld.const.b16 temp4, [constparams+6];
st.u16 [out_addr], temp1;
st.u16 [out_addr+2], temp2;
st.u16 [out_addr+4], temp3;
st.u16 [out_addr+6], temp4;
ret;
}
",
);
let tokens = lexer.clone().collect::<Vec<_>>();
println!("{:?}", &tokens);
let tokens = lexer.map(|t| t.unwrap()).collect::<Vec<_>>();
println!("{:?}", &tokens);
let stream = PtxParser {
input: &tokens[..],
state: PtxParserState::new(),
};
let _module = module.parse(stream).unwrap();
println!("{}", mem::size_of::<Token>());
}
#[cfg(test)]
mod tests {
use super::target;

View file

@ -1017,7 +1017,7 @@ pub fn generate_instruction_type(tokens: proc_macro::TokenStream) -> proc_macro:
input.emit_arg_types(&mut result);
input.emit_instruction_type(&mut result);
input.emit_visit(&mut result);
input.emit_visit_mut(&mut result);
//input.emit_visit_mut(&mut result);
input.emit_visit_map(&mut result);
result.into()
}

View file

@ -67,37 +67,29 @@ impl GenerateInstructionType {
let visit_ref = kind.reference();
let visitor_type = format_ident!("Visitor{}", kind.type_suffix());
let visit_fn = format_ident!("visit{}", kind.fn_suffix());
let visit_slice_fn = format_ident!("visit{}_slice", kind.fn_suffix());
let (type_parameters, visitor_parameters, return_type) = if kind == VisitKind::Map {
(
quote! { <#type_parameters, To: Operand> },
quote! { <#short_parameters, To> },
quote! { #type_name<To> },
quote! { <#type_parameters, To: Operand, Err> },
quote! { <#short_parameters, To, Err> },
quote! { std::result::Result<#type_name<To>, Err> },
)
} else {
(
quote! { <#type_parameters> },
quote! { <#short_parameters> },
quote! { () },
quote! { <#type_parameters, Err> },
quote! { <#short_parameters, Err> },
quote! { std::result::Result<(), Err> },
)
};
quote! {
fn #visit_fn #type_parameters (i: #visit_ref #type_name<#short_parameters>, visitor: &mut impl #visitor_type #visitor_parameters ) -> #return_type {
match i {
pub fn #visit_fn #type_parameters (i: #visit_ref #type_name<#short_parameters>, visitor: &mut impl #visitor_type #visitor_parameters ) -> #return_type {
Ok(match i {
#inner_tokens
}
})
}
}.to_tokens(tokens);
if kind == VisitKind::Map {
return;
}
quote! {
fn #visit_slice_fn #type_parameters (instructions: #visit_ref [#type_name<#short_parameters>], visitor: &mut impl #visitor_type #visitor_parameters) {
for i in instructions {
#visit_fn(i, visitor)
}
}
}.to_tokens(tokens);
}
}
@ -630,14 +622,14 @@ impl ArgumentField {
quote! {
{
#type_space
visitor.visit_ident(&mut arguments.#name, type_space, #is_dst);
visitor.visit_ident(&mut arguments.#name, type_space, #is_dst)?;
}
}
} else {
quote! {
{
#type_space
visitor.visit_ident(& arguments.#name, type_space, #is_dst);
visitor.visit_ident(& arguments.#name, type_space, #is_dst)?;
}
}
}
@ -663,7 +655,7 @@ impl ArgumentField {
};
quote! {{
#type_space
#operand_fn(#arguments_name, |x| visitor.visit(x, type_space, #is_dst));
#operand_fn(#arguments_name, |x| visitor.visit(x, type_space, #is_dst))?;
}}
}
}
@ -701,11 +693,11 @@ impl ArgumentField {
};
let map_call = if is_ident {
quote! {
visitor.visit_ident(arguments.#name, type_space, #is_dst)
visitor.visit_ident(arguments.#name, type_space, #is_dst)?
}
} else {
quote! {
MapOperand::map(arguments.#name, |x| visitor.visit(x, type_space, #is_dst))
MapOperand::map(arguments.#name, |x| visitor.visit(x, type_space, #is_dst))?
}
};
quote! {