mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Attempt #2
This commit is contained in:
parent
71e025845c
commit
1ec1ca0c30
11 changed files with 903 additions and 148 deletions
|
@ -7,7 +7,7 @@ edition = "2018"
|
|||
[lib]
|
||||
|
||||
[dependencies]
|
||||
lalrpop-util = "0.19"
|
||||
ptx_parser = { path = "../ptx_parser" }
|
||||
regex = "1"
|
||||
rspirv = "0.7"
|
||||
spirv_headers = "1.5"
|
||||
|
@ -17,8 +17,12 @@ bit-vec = "0.6"
|
|||
half ="1.6"
|
||||
bitflags = "1.2"
|
||||
|
||||
[dependencies.lalrpop-util]
|
||||
version = "0.19.12"
|
||||
features = ["lexer"]
|
||||
|
||||
[build-dependencies.lalrpop]
|
||||
version = "0.19"
|
||||
version = "0.19.12"
|
||||
features = ["lexer"]
|
||||
|
||||
[dev-dependencies]
|
||||
|
|
|
@ -34,15 +34,9 @@ pub enum PtxError {
|
|||
#[error("")]
|
||||
NonExternPointer,
|
||||
#[error("{start}:{end}")]
|
||||
UnrecognizedStatement {
|
||||
start: usize,
|
||||
end: usize,
|
||||
},
|
||||
UnrecognizedStatement { start: usize, end: usize },
|
||||
#[error("{start}:{end}")]
|
||||
UnrecognizedDirective {
|
||||
start: usize,
|
||||
end: usize,
|
||||
},
|
||||
UnrecognizedDirective { start: usize, end: usize },
|
||||
}
|
||||
|
||||
// For some weird reson this is illegal:
|
||||
|
@ -578,11 +572,15 @@ impl CvtDetails {
|
|||
if saturate {
|
||||
if src.kind() == ScalarKind::Signed {
|
||||
if dst.kind() == ScalarKind::Signed && dst.size_of() >= src.size_of() {
|
||||
err.push(ParseError::from(PtxError::SyntaxError));
|
||||
err.push(ParseError::User {
|
||||
error: PtxError::SyntaxError,
|
||||
});
|
||||
}
|
||||
} else {
|
||||
if dst == src || dst.size_of() >= src.size_of() {
|
||||
err.push(ParseError::from(PtxError::SyntaxError));
|
||||
err.push(ParseError::User {
|
||||
error: PtxError::SyntaxError,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -598,7 +596,9 @@ impl CvtDetails {
|
|||
err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>,
|
||||
) -> Self {
|
||||
if flush_to_zero && dst != ScalarType::F32 {
|
||||
err.push(ParseError::from(PtxError::NonF32Ftz));
|
||||
err.push(ParseError::from(lalrpop_util::ParseError::User {
|
||||
error: PtxError::NonF32Ftz,
|
||||
}));
|
||||
}
|
||||
CvtDetails::FloatFromInt(CvtDesc {
|
||||
dst,
|
||||
|
@ -618,7 +618,9 @@ impl CvtDetails {
|
|||
err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>,
|
||||
) -> Self {
|
||||
if flush_to_zero && src != ScalarType::F32 {
|
||||
err.push(ParseError::from(PtxError::NonF32Ftz));
|
||||
err.push(ParseError::from(lalrpop_util::ParseError::User {
|
||||
error: PtxError::NonF32Ftz,
|
||||
}));
|
||||
}
|
||||
CvtDetails::IntFromFloat(CvtDesc {
|
||||
dst,
|
||||
|
|
|
@ -24,9 +24,11 @@ lalrpop_mod!(
|
|||
);
|
||||
|
||||
pub mod ast;
|
||||
mod pass;
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
mod translate;
|
||||
mod translate2;
|
||||
|
||||
use std::fmt;
|
||||
|
||||
|
|
531
ptx/src/pass/mod.rs
Normal file
531
ptx/src/pass/mod.rs
Normal file
|
@ -0,0 +1,531 @@
|
|||
use ptx_parser as ast;
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
cell::RefCell,
|
||||
collections::{hash_map, HashMap},
|
||||
rc::Rc,
|
||||
};
|
||||
|
||||
mod normalize;
|
||||
|
||||
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
|
||||
enum PtxSpecialRegister {
|
||||
Tid,
|
||||
Ntid,
|
||||
Ctaid,
|
||||
Nctaid,
|
||||
Clock,
|
||||
LanemaskLt,
|
||||
}
|
||||
|
||||
impl PtxSpecialRegister {
|
||||
fn try_parse(s: &str) -> Option<Self> {
|
||||
match s {
|
||||
"%tid" => Some(Self::Tid),
|
||||
"%ntid" => Some(Self::Ntid),
|
||||
"%ctaid" => Some(Self::Ctaid),
|
||||
"%nctaid" => Some(Self::Nctaid),
|
||||
"%clock" => Some(Self::Clock),
|
||||
"%lanemask_lt" => Some(Self::LanemaskLt),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_type(self) -> ast::Type {
|
||||
match self {
|
||||
PtxSpecialRegister::Tid
|
||||
| PtxSpecialRegister::Ntid
|
||||
| PtxSpecialRegister::Ctaid
|
||||
| PtxSpecialRegister::Nctaid => ast::Type::Vector(self.get_function_return_type(), 4),
|
||||
_ => ast::Type::Scalar(self.get_function_return_type()),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_function_return_type(self) -> ast::ScalarType {
|
||||
match self {
|
||||
PtxSpecialRegister::Tid => ast::ScalarType::U32,
|
||||
PtxSpecialRegister::Ntid => ast::ScalarType::U32,
|
||||
PtxSpecialRegister::Ctaid => ast::ScalarType::U32,
|
||||
PtxSpecialRegister::Nctaid => ast::ScalarType::U32,
|
||||
PtxSpecialRegister::Clock => ast::ScalarType::U32,
|
||||
PtxSpecialRegister::LanemaskLt => ast::ScalarType::U32,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_function_input_type(self) -> Option<ast::ScalarType> {
|
||||
match self {
|
||||
PtxSpecialRegister::Tid
|
||||
| PtxSpecialRegister::Ntid
|
||||
| PtxSpecialRegister::Ctaid
|
||||
| PtxSpecialRegister::Nctaid => Some(ast::ScalarType::U8),
|
||||
PtxSpecialRegister::Clock | PtxSpecialRegister::LanemaskLt => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_unprefixed_function_name(self) -> &'static str {
|
||||
match self {
|
||||
PtxSpecialRegister::Tid => "sreg_tid",
|
||||
PtxSpecialRegister::Ntid => "sreg_ntid",
|
||||
PtxSpecialRegister::Ctaid => "sreg_ctaid",
|
||||
PtxSpecialRegister::Nctaid => "sreg_nctaid",
|
||||
PtxSpecialRegister::Clock => "sreg_clock",
|
||||
PtxSpecialRegister::LanemaskLt => "sreg_lanemask_lt",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct SpecialRegistersMap {
|
||||
reg_to_id: HashMap<PtxSpecialRegister, SpirvWord>,
|
||||
id_to_reg: HashMap<SpirvWord, PtxSpecialRegister>,
|
||||
}
|
||||
|
||||
impl SpecialRegistersMap {
|
||||
fn new() -> Self {
|
||||
SpecialRegistersMap {
|
||||
reg_to_id: HashMap::new(),
|
||||
id_to_reg: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get(&self, id: SpirvWord) -> Option<PtxSpecialRegister> {
|
||||
self.id_to_reg.get(&id).copied()
|
||||
}
|
||||
|
||||
fn get_or_add(&mut self, current_id: &mut SpirvWord, reg: PtxSpecialRegister) -> SpirvWord {
|
||||
match self.reg_to_id.entry(reg) {
|
||||
hash_map::Entry::Occupied(e) => *e.get(),
|
||||
hash_map::Entry::Vacant(e) => {
|
||||
let numeric_id = SpirvWord(current_id.0);
|
||||
current_id.0 += 1;
|
||||
e.insert(numeric_id);
|
||||
self.id_to_reg.insert(numeric_id, reg);
|
||||
numeric_id
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct FnStringIdResolver<'input, 'b> {
|
||||
current_id: &'b mut SpirvWord,
|
||||
global_variables: &'b HashMap<Cow<'input, str>, SpirvWord>,
|
||||
global_type_check: &'b HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
|
||||
special_registers: &'b mut SpecialRegistersMap,
|
||||
variables: Vec<HashMap<Cow<'input, str>, SpirvWord>>,
|
||||
type_check: HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
|
||||
}
|
||||
|
||||
impl<'a, 'b> FnStringIdResolver<'a, 'b> {
|
||||
fn finish(self) -> NumericIdResolver<'b> {
|
||||
NumericIdResolver {
|
||||
current_id: self.current_id,
|
||||
global_type_check: self.global_type_check,
|
||||
type_check: self.type_check,
|
||||
special_registers: self.special_registers,
|
||||
}
|
||||
}
|
||||
|
||||
fn start_block(&mut self) {
|
||||
self.variables.push(HashMap::new())
|
||||
}
|
||||
|
||||
fn end_block(&mut self) {
|
||||
self.variables.pop();
|
||||
}
|
||||
|
||||
fn get_id(&mut self, id: &str) -> Result<SpirvWord, TranslateError> {
|
||||
for scope in self.variables.iter().rev() {
|
||||
match scope.get(id) {
|
||||
Some(id) => return Ok(*id),
|
||||
None => continue,
|
||||
}
|
||||
}
|
||||
match self.global_variables.get(id) {
|
||||
Some(id) => Ok(*id),
|
||||
None => {
|
||||
let sreg = PtxSpecialRegister::try_parse(id).ok_or_else(error_unknown_symbol)?;
|
||||
Ok(self.special_registers.get_or_add(self.current_id, sreg))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn add_def(
|
||||
&mut self,
|
||||
id: &'a str,
|
||||
typ: Option<(ast::Type, ast::StateSpace)>,
|
||||
is_variable: bool,
|
||||
) -> SpirvWord {
|
||||
let numeric_id = *self.current_id;
|
||||
self.variables
|
||||
.last_mut()
|
||||
.unwrap()
|
||||
.insert(Cow::Borrowed(id), numeric_id);
|
||||
self.type_check.insert(
|
||||
numeric_id.0,
|
||||
typ.map(|(typ, space)| (typ, space, is_variable)),
|
||||
);
|
||||
self.current_id.0 += 1;
|
||||
numeric_id
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
fn add_defs(
|
||||
&mut self,
|
||||
base_id: &'a str,
|
||||
count: u32,
|
||||
typ: ast::Type,
|
||||
state_space: ast::StateSpace,
|
||||
is_variable: bool,
|
||||
) -> impl Iterator<Item = SpirvWord> {
|
||||
let numeric_id = *self.current_id;
|
||||
for i in 0..count {
|
||||
self.variables.last_mut().unwrap().insert(
|
||||
Cow::Owned(format!("{}{}", base_id, i)),
|
||||
SpirvWord(numeric_id.0 + i),
|
||||
);
|
||||
self.type_check.insert(
|
||||
numeric_id.0 + i,
|
||||
Some((typ.clone(), state_space, is_variable)),
|
||||
);
|
||||
}
|
||||
self.current_id.0 += count;
|
||||
(0..count)
|
||||
.into_iter()
|
||||
.map(move |i| SpirvWord(i + numeric_id.0))
|
||||
}
|
||||
}
|
||||
|
||||
struct NumericIdResolver<'b> {
|
||||
current_id: &'b mut SpirvWord,
|
||||
global_type_check: &'b HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
|
||||
type_check: HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
|
||||
special_registers: &'b mut SpecialRegistersMap,
|
||||
}
|
||||
|
||||
impl<'b> NumericIdResolver<'b> {
|
||||
fn finish(self) -> MutableNumericIdResolver<'b> {
|
||||
MutableNumericIdResolver { base: self }
|
||||
}
|
||||
|
||||
fn get_typed(
|
||||
&self,
|
||||
id: SpirvWord,
|
||||
) -> Result<(ast::Type, ast::StateSpace, bool), TranslateError> {
|
||||
match self.type_check.get(&id.0) {
|
||||
Some(Some(x)) => Ok(x.clone()),
|
||||
Some(None) => Err(TranslateError::UntypedSymbol),
|
||||
None => match self.special_registers.get(id) {
|
||||
Some(x) => Ok((x.get_type(), ast::StateSpace::Sreg, true)),
|
||||
None => match self.global_type_check.get(&id.0) {
|
||||
Some(Some(result)) => Ok(result.clone()),
|
||||
Some(None) | None => Err(TranslateError::UntypedSymbol),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// This is for identifiers which will be emitted later as OpVariable
|
||||
// They are candidates for insertion of LoadVar/StoreVar
|
||||
fn register_variable(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> SpirvWord {
|
||||
let new_id = *self.current_id;
|
||||
self.type_check
|
||||
.insert(new_id.0, Some((typ, state_space, true)));
|
||||
self.current_id.0 += 1;
|
||||
new_id
|
||||
}
|
||||
|
||||
fn register_intermediate(&mut self, typ: Option<(ast::Type, ast::StateSpace)>) -> SpirvWord {
|
||||
let new_id = *self.current_id;
|
||||
self.type_check
|
||||
.insert(new_id.0, typ.map(|(t, space)| (t, space, false)));
|
||||
self.current_id.0 += 1;
|
||||
new_id
|
||||
}
|
||||
}
|
||||
|
||||
struct MutableNumericIdResolver<'b> {
|
||||
base: NumericIdResolver<'b>,
|
||||
}
|
||||
|
||||
impl<'b> MutableNumericIdResolver<'b> {
|
||||
fn unmut(self) -> NumericIdResolver<'b> {
|
||||
self.base
|
||||
}
|
||||
|
||||
fn get_typed(&self, id: SpirvWord) -> Result<(ast::Type, ast::StateSpace), TranslateError> {
|
||||
self.base.get_typed(id).map(|(t, space, _)| (t, space))
|
||||
}
|
||||
|
||||
fn register_intermediate(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> SpirvWord {
|
||||
self.base.register_intermediate(Some((typ, state_space)))
|
||||
}
|
||||
}
|
||||
|
||||
quick_error! {
|
||||
#[derive(Debug)]
|
||||
pub enum TranslateError {
|
||||
UnknownSymbol {}
|
||||
UntypedSymbol {}
|
||||
MismatchedType {}
|
||||
Spirv(err: rspirv::dr::Error) {
|
||||
from()
|
||||
display("{}", err)
|
||||
cause(err)
|
||||
}
|
||||
Unreachable {}
|
||||
Todo {}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
fn error_unreachable() -> TranslateError {
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
fn error_unreachable() -> TranslateError {
|
||||
TranslateError::Unreachable
|
||||
}
|
||||
|
||||
fn error_unknown_symbol() -> TranslateError {
|
||||
TranslateError::UnknownSymbol
|
||||
}
|
||||
|
||||
pub struct GlobalFnDeclResolver<'input, 'a> {
|
||||
fns: &'a HashMap<SpirvWord, FnSigMapper<'input>>,
|
||||
}
|
||||
|
||||
impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> {
|
||||
fn get_fn_sig_resolver(&self, id: SpirvWord) -> Result<&FnSigMapper<'input>, TranslateError> {
|
||||
self.fns.get(&id).ok_or_else(error_unknown_symbol)
|
||||
}
|
||||
}
|
||||
|
||||
struct FnSigMapper<'input> {
|
||||
// true - stays as return argument
|
||||
// false - is moved to input argument
|
||||
return_param_args: Vec<bool>,
|
||||
func_decl: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
|
||||
}
|
||||
|
||||
impl<'input> FnSigMapper<'input> {
|
||||
fn remap_to_spirv_repr(mut method: ast::MethodDeclaration<'input, SpirvWord>) -> Self {
|
||||
let return_param_args = method
|
||||
.return_arguments
|
||||
.iter()
|
||||
.map(|a| a.state_space != ast::StateSpace::Param)
|
||||
.collect::<Vec<_>>();
|
||||
let mut new_return_arguments = Vec::new();
|
||||
for arg in method.return_arguments.into_iter() {
|
||||
if arg.state_space == ast::StateSpace::Param {
|
||||
method.input_arguments.push(arg);
|
||||
} else {
|
||||
new_return_arguments.push(arg);
|
||||
}
|
||||
}
|
||||
method.return_arguments = new_return_arguments;
|
||||
FnSigMapper {
|
||||
return_param_args,
|
||||
func_decl: Rc::new(RefCell::new(method)),
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
fn resolve_in_spirv_repr(
|
||||
&self,
|
||||
call_inst: ast::CallInst<NormalizedArgParams>,
|
||||
) -> Result<ResolvedCall<NormalizedArgParams>, TranslateError> {
|
||||
let func_decl = (*self.func_decl).borrow();
|
||||
let mut return_arguments = Vec::new();
|
||||
let mut input_arguments = call_inst
|
||||
.param_list
|
||||
.into_iter()
|
||||
.zip(func_decl.input_arguments.iter())
|
||||
.map(|(id, var)| (id, var.v_type.clone(), var.state_space))
|
||||
.collect::<Vec<_>>();
|
||||
let mut func_decl_return_iter = func_decl.return_arguments.iter();
|
||||
let mut func_decl_input_iter = func_decl.input_arguments[input_arguments.len()..].iter();
|
||||
for (idx, id) in call_inst.ret_params.iter().enumerate() {
|
||||
let stays_as_return = match self.return_param_args.get(idx) {
|
||||
Some(x) => *x,
|
||||
None => return Err(TranslateError::MismatchedType),
|
||||
};
|
||||
if stays_as_return {
|
||||
if let Some(var) = func_decl_return_iter.next() {
|
||||
return_arguments.push((*id, var.v_type.clone(), var.state_space));
|
||||
} else {
|
||||
return Err(TranslateError::MismatchedType);
|
||||
}
|
||||
} else {
|
||||
if let Some(var) = func_decl_input_iter.next() {
|
||||
input_arguments.push((
|
||||
ast::Operand::Reg(*id),
|
||||
var.v_type.clone(),
|
||||
var.state_space,
|
||||
));
|
||||
} else {
|
||||
return Err(TranslateError::MismatchedType);
|
||||
}
|
||||
}
|
||||
}
|
||||
if return_arguments.len() != func_decl.return_arguments.len()
|
||||
|| input_arguments.len() != func_decl.input_arguments.len()
|
||||
{
|
||||
return Err(TranslateError::MismatchedType);
|
||||
}
|
||||
Ok(ResolvedCall {
|
||||
return_arguments,
|
||||
input_arguments,
|
||||
uniform: call_inst.uniform,
|
||||
name: call_inst.func,
|
||||
})
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
enum Statement<I, P: ast::Operand> {
|
||||
Label(SpirvWord),
|
||||
Variable(ast::Variable<P::Ident>),
|
||||
Instruction(I),
|
||||
// SPIR-V compatible replacement for PTX predicates
|
||||
Conditional(BrachCondition),
|
||||
LoadVar(LoadVarDetails),
|
||||
StoreVar(StoreVarDetails),
|
||||
Conversion(ImplicitConversion),
|
||||
Constant(ConstantDefinition),
|
||||
RetValue(ast::RetData, SpirvWord),
|
||||
PtrAccess(PtrAccess<P>),
|
||||
RepackVector(RepackVectorDetails),
|
||||
FunctionPointer(FunctionPointerDetails),
|
||||
}
|
||||
|
||||
struct BrachCondition {
|
||||
predicate: SpirvWord,
|
||||
if_true: SpirvWord,
|
||||
if_false: SpirvWord,
|
||||
}
|
||||
struct LoadVarDetails {
|
||||
arg: ast::LdArgs<SpirvWord>,
|
||||
typ: ast::Type,
|
||||
state_space: ast::StateSpace,
|
||||
// (index, vector_width)
|
||||
// HACK ALERT
|
||||
// For some reason IGC explodes when you try to load from builtin vectors
|
||||
// using OpInBoundsAccessChain, the one true way to do it is to
|
||||
// OpLoad+OpCompositeExtract
|
||||
member_index: Option<(u8, Option<u8>)>,
|
||||
}
|
||||
|
||||
struct StoreVarDetails {
|
||||
arg: ast::StArgs<SpirvWord>,
|
||||
typ: ast::Type,
|
||||
member_index: Option<u8>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ImplicitConversion {
|
||||
src: SpirvWord,
|
||||
dst: SpirvWord,
|
||||
from_type: ast::Type,
|
||||
to_type: ast::Type,
|
||||
from_space: ast::StateSpace,
|
||||
to_space: ast::StateSpace,
|
||||
kind: ConversionKind,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Clone)]
|
||||
enum ConversionKind {
|
||||
Default,
|
||||
// zero-extend/chop/bitcast depending on types
|
||||
SignExtend,
|
||||
BitToPtr,
|
||||
PtrToPtr,
|
||||
AddressOf,
|
||||
}
|
||||
|
||||
struct ConstantDefinition {
|
||||
pub dst: SpirvWord,
|
||||
pub typ: ast::ScalarType,
|
||||
pub value: ast::ImmediateValue,
|
||||
}
|
||||
|
||||
pub struct PtrAccess<T> {
|
||||
underlying_type: ast::Type,
|
||||
state_space: ast::StateSpace,
|
||||
dst: SpirvWord,
|
||||
ptr_src: SpirvWord,
|
||||
offset_src: T,
|
||||
}
|
||||
|
||||
struct RepackVectorDetails {
|
||||
is_extract: bool,
|
||||
typ: ast::ScalarType,
|
||||
packed: SpirvWord,
|
||||
unpacked: Vec<SpirvWord>,
|
||||
non_default_implicit_conversion: Option<
|
||||
fn(
|
||||
(ast::StateSpace, &ast::Type),
|
||||
(ast::StateSpace, &ast::Type),
|
||||
) -> Result<Option<ConversionKind>, TranslateError>,
|
||||
>,
|
||||
}
|
||||
|
||||
struct FunctionPointerDetails {
|
||||
dst: SpirvWord,
|
||||
src: SpirvWord,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
|
||||
struct SpirvWord(spirv::Word);
|
||||
|
||||
impl From<spirv::Word> for SpirvWord {
|
||||
fn from(value: spirv::Word) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
impl From<SpirvWord> for spirv::Word {
|
||||
fn from(value: SpirvWord) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
|
||||
impl ast::Operand for SpirvWord {
|
||||
type Ident = Self;
|
||||
}
|
||||
|
||||
fn pred_map_variable<U, T, F: FnMut(T) -> Result<U, TranslateError>>(
|
||||
this: ast::PredAt<T>,
|
||||
f: &mut F,
|
||||
) -> Result<ast::PredAt<U>, TranslateError> {
|
||||
let new_label = f(this.label)?;
|
||||
Ok(ast::PredAt {
|
||||
not: this.not,
|
||||
label: new_label,
|
||||
})
|
||||
}
|
||||
|
||||
impl<T: ast::Operand, U: ast::Operand, X: FnMut(&str) -> Result<SpirvWord, Err>, Err> ast::VisitorMap<T, U, Err> for X {
|
||||
fn visit(
|
||||
&mut self,
|
||||
args: T,
|
||||
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
is_dst: bool,
|
||||
) -> U {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn visit_ident(
|
||||
&mut self,
|
||||
args: <T as ptx_parser::Operand>::Ident,
|
||||
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
is_dst: bool,
|
||||
) -> <U as ptx_parser::Operand>::Ident {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
fn op_map_variable<'a, F: FnMut(&str) -> Result<SpirvWord, TranslateError>>(
|
||||
this: ast::Instruction<ast::ParsedOperand<&'a str>>,
|
||||
f: &mut F,
|
||||
) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
|
||||
ast::visit_map(this , f)
|
||||
}
|
83
ptx/src/pass/normalize.rs
Normal file
83
ptx/src/pass/normalize.rs
Normal file
|
@ -0,0 +1,83 @@
|
|||
use super::*;
|
||||
use ptx_parser as ast;
|
||||
|
||||
type NormalizedStatement = Statement<
|
||||
(
|
||||
Option<ast::PredAt<SpirvWord>>,
|
||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||
),
|
||||
ast::ParsedOperand<SpirvWord>,
|
||||
>;
|
||||
|
||||
fn run<'input, 'b>(
|
||||
id_defs: &mut FnStringIdResolver<'input, 'b>,
|
||||
fn_defs: &GlobalFnDeclResolver<'input, 'b>,
|
||||
func: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
|
||||
) -> Result<Vec<NormalizedStatement>, TranslateError> {
|
||||
for s in func.iter() {
|
||||
match s {
|
||||
ast::Statement::Label(id) => {
|
||||
id_defs.add_def(*id, None, false);
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
let mut result = Vec::new();
|
||||
for s in func {
|
||||
expand_map_variables(id_defs, fn_defs, &mut result, s)?;
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn expand_map_variables<'a, 'b>(
|
||||
id_defs: &mut FnStringIdResolver<'a, 'b>,
|
||||
fn_defs: &GlobalFnDeclResolver<'a, 'b>,
|
||||
result: &mut Vec<NormalizedStatement>,
|
||||
s: ast::Statement<ast::ParsedOperand<&'a str>>,
|
||||
) -> Result<(), TranslateError> {
|
||||
match s {
|
||||
ast::Statement::Block(block) => {
|
||||
id_defs.start_block();
|
||||
for s in block {
|
||||
expand_map_variables(id_defs, fn_defs, result, s)?;
|
||||
}
|
||||
id_defs.end_block();
|
||||
}
|
||||
ast::Statement::Label(name) => result.push(Statement::Label(id_defs.get_id(name)?)),
|
||||
ast::Statement::Instruction(p, i) => result.push(Statement::Instruction((
|
||||
p.map(|p| pred_map_variable(p, &mut |id| id_defs.get_id(id)))
|
||||
.transpose()?,
|
||||
op_map_variable(i, &mut |id| id_defs.get_id(id))?,
|
||||
))),
|
||||
ast::Statement::Variable(var) => {
|
||||
let var_type = var.var.v_type.clone();
|
||||
match var.count {
|
||||
Some(count) => {
|
||||
for new_id in
|
||||
id_defs.add_defs(var.var.name, count, var_type, var.var.state_space, true)
|
||||
{
|
||||
result.push(Statement::Variable(ast::Variable {
|
||||
align: var.var.align,
|
||||
v_type: var.var.v_type.clone(),
|
||||
state_space: var.var.state_space,
|
||||
name: new_id,
|
||||
array_init: var.var.array_init.clone(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let new_id =
|
||||
id_defs.add_def(var.var.name, Some((var_type, var.var.state_space)), true);
|
||||
result.push(Statement::Variable(ast::Variable {
|
||||
align: var.var.align,
|
||||
v_type: var.var.v_type.clone(),
|
||||
state_space: var.var.state_space,
|
||||
name: new_id,
|
||||
array_init: var.var.array_init,
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
60
ptx/src/translate2.rs
Normal file
60
ptx/src/translate2.rs
Normal file
|
@ -0,0 +1,60 @@
|
|||
use std::collections::HashMap;
|
||||
use half::f16;
|
||||
use ptx_parser as ast;
|
||||
|
||||
fn to_ssa<'input, 'b>(
|
||||
ptx_impl_imports: &'b mut HashMap<String, Directive<'input>>,
|
||||
mut id_defs: FnStringIdResolver<'input, 'b>,
|
||||
fn_defs: GlobalFnDeclResolver<'input, 'b>,
|
||||
func_decl: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
|
||||
f_body: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>>,
|
||||
tuning: Vec<ast::TuningDirective>,
|
||||
linkage: ast::LinkingDirective,
|
||||
) -> Result<Function<'input>, TranslateError> {
|
||||
//deparamize_function_decl(&func_decl)?;
|
||||
let f_body = match f_body {
|
||||
Some(vec) => vec,
|
||||
None => {
|
||||
return Ok(Function {
|
||||
func_decl: func_decl,
|
||||
body: None,
|
||||
globals: Vec::new(),
|
||||
import_as: None,
|
||||
tuning,
|
||||
linkage,
|
||||
})
|
||||
}
|
||||
};
|
||||
let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body)?;
|
||||
/*
|
||||
let mut numeric_id_defs = id_defs.finish();
|
||||
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
|
||||
let typed_statements =
|
||||
convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
|
||||
let typed_statements =
|
||||
fix_special_registers2(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?;
|
||||
let (func_decl, typed_statements) =
|
||||
convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?;
|
||||
let ssa_statements = insert_mem_ssa_statements(
|
||||
typed_statements,
|
||||
&mut numeric_id_defs,
|
||||
&mut (*func_decl).borrow_mut(),
|
||||
)?;
|
||||
let mut numeric_id_defs = numeric_id_defs.finish();
|
||||
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
|
||||
let expanded_statements =
|
||||
insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?;
|
||||
let mut numeric_id_defs = numeric_id_defs.unmut();
|
||||
let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs);
|
||||
let (f_body, globals) =
|
||||
extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?;
|
||||
Ok(Function {
|
||||
func_decl: func_decl,
|
||||
globals: globals,
|
||||
body: Some(f_body),
|
||||
import_as: None,
|
||||
tuning,
|
||||
linkage,
|
||||
})
|
||||
*/
|
||||
}
|
|
@ -4,6 +4,8 @@ version = "0.0.0"
|
|||
authors = ["Andrzej Janik <vosen@vosen.pl>"]
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
|
||||
[dependencies]
|
||||
logos = "0.14"
|
||||
winnow = { version = "0.6.18" }
|
||||
|
@ -11,3 +13,4 @@ ptx_parser_macros = { path = "../ptx_parser_macros" }
|
|||
thiserror = "1.0"
|
||||
bitflags = "1.2"
|
||||
rustc-hash = "2.0.0"
|
||||
derive_more = { version = "1", features = ["display"] }
|
||||
|
|
|
@ -147,9 +147,9 @@ ptx_parser_macros::generate_instruction_type!(
|
|||
Call {
|
||||
data: CallDetails,
|
||||
arguments: CallArgs<T>,
|
||||
visit: arguments.visit(data, visitor),
|
||||
visit_mut: arguments.visit_mut(data, visitor),
|
||||
map: Instruction::Call{ arguments: arguments.map(&data, visitor), data }
|
||||
visit: arguments.visit(data, visitor)?,
|
||||
visit_mut: arguments.visit_mut(data, visitor)?,
|
||||
map: Instruction::Call{ arguments: arguments.map(&data, visitor)?, data }
|
||||
},
|
||||
Cvt {
|
||||
data: CvtDetails,
|
||||
|
@ -488,93 +488,185 @@ ptx_parser_macros::generate_instruction_type!(
|
|||
}
|
||||
);
|
||||
|
||||
pub trait Visitor<T: Operand> {
|
||||
fn visit(&mut self, args: &T, type_space: Option<(&Type, StateSpace)>, is_dst: bool);
|
||||
fn visit_ident(&self, args: &T::Ident, type_space: Option<(&Type, StateSpace)>, is_dst: bool);
|
||||
pub trait Visitor<T: Operand, Err> {
|
||||
fn visit(
|
||||
&mut self,
|
||||
args: &T,
|
||||
type_space: Option<(&Type, StateSpace)>,
|
||||
is_dst: bool,
|
||||
) -> Result<(), Err>;
|
||||
fn visit_ident(
|
||||
&mut self,
|
||||
args: &T::Ident,
|
||||
type_space: Option<(&Type, StateSpace)>,
|
||||
is_dst: bool,
|
||||
) -> Result<(), Err>;
|
||||
}
|
||||
|
||||
pub trait VisitorMut<T: Operand> {
|
||||
fn visit(&mut self, args: &mut T, type_space: Option<(&Type, StateSpace)>, is_dst: bool);
|
||||
impl<T: Operand, Err, Fn: FnMut(&T, Option<(&Type, StateSpace)>, bool) -> Result<(), Err>>
|
||||
Visitor<T, Err> for Fn
|
||||
{
|
||||
fn visit(
|
||||
&mut self,
|
||||
args: &T,
|
||||
type_space: Option<(&Type, StateSpace)>,
|
||||
is_dst: bool,
|
||||
) -> Result<(), Err> {
|
||||
(self)(args, type_space, is_dst)
|
||||
}
|
||||
|
||||
fn visit_ident(
|
||||
&mut self,
|
||||
args: &T::Ident,
|
||||
type_space: Option<(&Type, StateSpace)>,
|
||||
is_dst: bool,
|
||||
) -> Result<(), Err> {
|
||||
(self)(&T::from_ident(*args), type_space, is_dst)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VisitorMut<T: Operand, Err> {
|
||||
fn visit(
|
||||
&mut self,
|
||||
args: &mut T,
|
||||
type_space: Option<(&Type, StateSpace)>,
|
||||
is_dst: bool,
|
||||
) -> Result<(), Err>;
|
||||
fn visit_ident(
|
||||
&mut self,
|
||||
args: &mut T::Ident,
|
||||
type_space: Option<(&Type, StateSpace)>,
|
||||
is_dst: bool,
|
||||
);
|
||||
) -> Result<(), Err>;
|
||||
}
|
||||
|
||||
pub trait VisitorMap<From: Operand, To: Operand> {
|
||||
fn visit(&mut self, args: From, type_space: Option<(&Type, StateSpace)>, is_dst: bool) -> To;
|
||||
pub trait VisitorMap<From: Operand, To: Operand, Err> {
|
||||
fn visit(
|
||||
&mut self,
|
||||
args: From,
|
||||
type_space: Option<(&Type, StateSpace)>,
|
||||
is_dst: bool,
|
||||
) -> Result<To, Err>;
|
||||
fn visit_ident(
|
||||
&mut self,
|
||||
args: From::Ident,
|
||||
type_space: Option<(&Type, StateSpace)>,
|
||||
is_dst: bool,
|
||||
) -> To::Ident;
|
||||
) -> Result<To::Ident, Err>;
|
||||
}
|
||||
|
||||
trait VisitOperand {
|
||||
impl<
|
||||
T: Operand,
|
||||
U: Operand,
|
||||
Err,
|
||||
Fn: FnMut(T, Option<(&Type, StateSpace)>, bool) -> Result<U, Err>,
|
||||
> VisitorMap<T, U, Err> for Fn
|
||||
{
|
||||
fn visit(
|
||||
&mut self,
|
||||
args: T,
|
||||
type_space: Option<(&Type, StateSpace)>,
|
||||
is_dst: bool,
|
||||
) -> Result<U, Err> {
|
||||
(self)(args, type_space, is_dst)
|
||||
}
|
||||
|
||||
fn visit_ident(
|
||||
&mut self,
|
||||
args: T::Ident,
|
||||
type_space: Option<(&Type, StateSpace)>,
|
||||
is_dst: bool,
|
||||
) -> Result<U::Ident, Err> {
|
||||
let value: U = (self)(T::from_ident(args), type_space, is_dst)?;
|
||||
Ok(value)
|
||||
}
|
||||
}
|
||||
|
||||
trait VisitOperand<Err> {
|
||||
type Operand: Operand;
|
||||
#[allow(unused)] // Used by generated code
|
||||
fn visit(&self, fn_: impl FnMut(&Self::Operand));
|
||||
fn visit(&self, fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err>;
|
||||
#[allow(unused)] // Used by generated code
|
||||
fn visit_mut(&mut self, fn_: impl FnMut(&mut Self::Operand));
|
||||
fn visit_mut(
|
||||
&mut self,
|
||||
fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>,
|
||||
) -> Result<(), Err>;
|
||||
}
|
||||
|
||||
impl<T: Operand> VisitOperand for T {
|
||||
impl<T: Operand, Err> VisitOperand<Err> for T {
|
||||
type Operand = Self;
|
||||
fn visit(&self, mut fn_: impl FnMut(&Self::Operand)) {
|
||||
fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> {
|
||||
fn_(self)
|
||||
}
|
||||
fn visit_mut(&mut self, mut fn_: impl FnMut(&mut Self::Operand)) {
|
||||
fn visit_mut(
|
||||
&mut self,
|
||||
mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>,
|
||||
) -> Result<(), Err> {
|
||||
fn_(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Operand> VisitOperand for Option<T> {
|
||||
impl<T: Operand, Err> VisitOperand<Err> for Option<T> {
|
||||
type Operand = T;
|
||||
fn visit(&self, fn_: impl FnMut(&Self::Operand)) {
|
||||
self.as_ref().map(fn_);
|
||||
fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> {
|
||||
if let Some(x) = self {
|
||||
fn_(x)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
fn visit_mut(&mut self, fn_: impl FnMut(&mut Self::Operand)) {
|
||||
self.as_mut().map(fn_);
|
||||
fn visit_mut(
|
||||
&mut self,
|
||||
mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>,
|
||||
) -> Result<(), Err> {
|
||||
if let Some(x) = self {
|
||||
fn_(x)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Operand> VisitOperand for Vec<T> {
|
||||
impl<T: Operand, Err> VisitOperand<Err> for Vec<T> {
|
||||
type Operand = T;
|
||||
fn visit(&self, mut fn_: impl FnMut(&Self::Operand)) {
|
||||
fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> {
|
||||
for o in self {
|
||||
fn_(o)
|
||||
fn_(o)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
fn visit_mut(&mut self, mut fn_: impl FnMut(&mut Self::Operand)) {
|
||||
fn visit_mut(
|
||||
&mut self,
|
||||
mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>,
|
||||
) -> Result<(), Err> {
|
||||
for o in self {
|
||||
fn_(o)
|
||||
fn_(o)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
trait MapOperand: Sized {
|
||||
trait MapOperand<Err>: Sized {
|
||||
type Input;
|
||||
type Output<U>;
|
||||
#[allow(unused)] // Used by generated code
|
||||
fn map<U>(self, fn_: impl FnOnce(Self::Input) -> U) -> Self::Output<U>;
|
||||
fn map<U>(
|
||||
self,
|
||||
fn_: impl FnOnce(Self::Input) -> Result<U, Err>,
|
||||
) -> Result<Self::Output<U>, Err>;
|
||||
}
|
||||
|
||||
impl<T: Operand> MapOperand for T {
|
||||
impl<T: Operand, Err> MapOperand<Err> for T {
|
||||
type Input = Self;
|
||||
type Output<U> = U;
|
||||
fn map<U>(self, fn_: impl FnOnce(T) -> U) -> U {
|
||||
fn map<U>(self, fn_: impl FnOnce(T) -> Result<U, Err>) -> Result<U, Err> {
|
||||
fn_(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Operand> MapOperand for Option<T> {
|
||||
impl<T: Operand, Err> MapOperand<Err> for Option<T> {
|
||||
type Input = T;
|
||||
type Output<U> = Option<U>;
|
||||
fn map<U>(self, fn_: impl FnOnce(T) -> U) -> Option<U> {
|
||||
self.map(|x| fn_(x))
|
||||
fn map<U>(self, fn_: impl FnOnce(T) -> Result<U, Err>) -> Result<Option<U>, Err> {
|
||||
self.map(|x| fn_(x)).transpose()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -715,10 +807,16 @@ pub enum ParsedOperand<Ident> {
|
|||
|
||||
impl<Ident: Copy> Operand for ParsedOperand<Ident> {
|
||||
type Ident = Ident;
|
||||
|
||||
fn from_ident(ident: Self::Ident) -> Self {
|
||||
ParsedOperand::Reg(ident)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Operand {
|
||||
pub trait Operand: Sized {
|
||||
type Ident: Copy;
|
||||
|
||||
fn from_ident(ident: Self::Ident) -> Self;
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
|
@ -1048,67 +1146,77 @@ pub struct CallArgs<T: Operand> {
|
|||
|
||||
impl<T: Operand> CallArgs<T> {
|
||||
#[allow(dead_code)] // Used by generated code
|
||||
fn visit(&self, details: &CallDetails, visitor: &mut impl Visitor<T>) {
|
||||
fn visit<Err>(
|
||||
&self,
|
||||
details: &CallDetails,
|
||||
visitor: &mut impl Visitor<T, Err>,
|
||||
) -> Result<(), Err> {
|
||||
for (param, (type_, space)) in self
|
||||
.return_arguments
|
||||
.iter()
|
||||
.zip(details.return_arguments.iter())
|
||||
{
|
||||
visitor.visit_ident(param, Some((type_, *space)), true);
|
||||
visitor.visit_ident(param, Some((type_, *space)), true)?;
|
||||
}
|
||||
visitor.visit_ident(&self.func, None, false);
|
||||
visitor.visit_ident(&self.func, None, false)?;
|
||||
for (param, (type_, space)) in self
|
||||
.input_arguments
|
||||
.iter()
|
||||
.zip(details.input_arguments.iter())
|
||||
{
|
||||
visitor.visit(param, Some((type_, *space)), true);
|
||||
visitor.visit(param, Some((type_, *space)), true)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(dead_code)] // Used by generated code
|
||||
fn visit_mut(&mut self, details: &CallDetails, visitor: &mut impl VisitorMut<T>) {
|
||||
fn visit_mut<Err>(
|
||||
&mut self,
|
||||
details: &CallDetails,
|
||||
visitor: &mut impl VisitorMut<T, Err>,
|
||||
) -> Result<(), Err> {
|
||||
for (param, (type_, space)) in self
|
||||
.return_arguments
|
||||
.iter_mut()
|
||||
.zip(details.return_arguments.iter())
|
||||
{
|
||||
visitor.visit_ident(param, Some((type_, *space)), true);
|
||||
visitor.visit_ident(param, Some((type_, *space)), true)?;
|
||||
}
|
||||
visitor.visit_ident(&mut self.func, None, false);
|
||||
visitor.visit_ident(&mut self.func, None, false)?;
|
||||
for (param, (type_, space)) in self
|
||||
.input_arguments
|
||||
.iter_mut()
|
||||
.zip(details.input_arguments.iter())
|
||||
{
|
||||
visitor.visit(param, Some((type_, *space)), true);
|
||||
visitor.visit(param, Some((type_, *space)), true)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(dead_code)] // Used by generated code
|
||||
fn map<U: Operand>(
|
||||
fn map<U: Operand, Err>(
|
||||
self,
|
||||
details: &CallDetails,
|
||||
visitor: &mut impl VisitorMap<T, U>,
|
||||
) -> CallArgs<U> {
|
||||
visitor: &mut impl VisitorMap<T, U, Err>,
|
||||
) -> Result<CallArgs<U>, Err> {
|
||||
let return_arguments = self
|
||||
.return_arguments
|
||||
.into_iter()
|
||||
.zip(details.return_arguments.iter())
|
||||
.map(|(param, (type_, space))| visitor.visit_ident(param, Some((type_, *space)), true))
|
||||
.collect::<Vec<_>>();
|
||||
let func = visitor.visit_ident(self.func, None, false);
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let func = visitor.visit_ident(self.func, None, false)?;
|
||||
let input_arguments = self
|
||||
.input_arguments
|
||||
.into_iter()
|
||||
.zip(details.input_arguments.iter())
|
||||
.map(|(param, (type_, space))| visitor.visit(param, Some((type_, *space)), true))
|
||||
.collect::<Vec<_>>();
|
||||
CallArgs {
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
Ok(CallArgs {
|
||||
return_arguments,
|
||||
func,
|
||||
input_arguments,
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use derive_more::Display;
|
||||
use logos::Logos;
|
||||
use ptx_parser_macros::derive_parser;
|
||||
use rustc_hash::FxHashMap;
|
||||
use std::fmt::Debug;
|
||||
use std::mem;
|
||||
use std::num::{ParseFloatError, ParseIntError};
|
||||
use winnow::ascii::dec_uint;
|
||||
use winnow::combinator::*;
|
||||
|
@ -81,16 +81,16 @@ impl VectorPrefix {
|
|||
}
|
||||
}
|
||||
|
||||
struct PtxParserState<'input> {
|
||||
errors: Vec<PtxError>,
|
||||
struct PtxParserState<'a, 'input> {
|
||||
errors: &'a mut Vec<PtxError>,
|
||||
function_declarations:
|
||||
FxHashMap<&'input str, (Vec<(ast::Type, StateSpace)>, Vec<(ast::Type, StateSpace)>)>,
|
||||
}
|
||||
|
||||
impl<'input> PtxParserState<'input> {
|
||||
fn new() -> Self {
|
||||
impl<'a, 'input> PtxParserState<'a, 'input> {
|
||||
fn new(errors: &'a mut Vec<PtxError>) -> Self {
|
||||
Self {
|
||||
errors: Vec::new(),
|
||||
errors,
|
||||
function_declarations: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
|
@ -115,7 +115,7 @@ impl<'input> PtxParserState<'input> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'input> Debug for PtxParserState<'input> {
|
||||
impl<'a, 'input> Debug for PtxParserState<'a, 'input> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("PtxParserState")
|
||||
.field("errors", &self.errors) /* .field("function_decl", &self.function_decl) */
|
||||
|
@ -123,7 +123,7 @@ impl<'input> Debug for PtxParserState<'input> {
|
|||
}
|
||||
}
|
||||
|
||||
type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState<'input>>;
|
||||
type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState<'a, 'input>>;
|
||||
|
||||
fn ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> {
|
||||
any.verify_map(|t| {
|
||||
|
@ -277,6 +277,18 @@ fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<as
|
|||
.parse_next(stream)
|
||||
}
|
||||
|
||||
pub fn parse_module_unchecked<'input>(text: &'input str) -> Option<ast::Module<'input>> {
|
||||
let lexer = Token::lexer(text);
|
||||
let input = lexer.collect::<Result<Vec<_>, _>>().ok()?;
|
||||
let mut errors = Vec::new();
|
||||
let state = PtxParserState::new(&mut errors);
|
||||
let parser = PtxParser {
|
||||
state,
|
||||
input: &input[..],
|
||||
};
|
||||
module.parse(parser).ok()
|
||||
}
|
||||
|
||||
fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Module<'input>> {
|
||||
(
|
||||
version,
|
||||
|
@ -818,6 +830,8 @@ pub enum PtxError {
|
|||
source: ParseFloatError,
|
||||
},
|
||||
#[error("")]
|
||||
Lexer(#[from] TokenError),
|
||||
#[error("")]
|
||||
Todo,
|
||||
#[error("")]
|
||||
SyntaxError,
|
||||
|
@ -1042,9 +1056,15 @@ fn empty_call<'input>(
|
|||
|
||||
type ParsedOperandStr<'input> = ast::ParsedOperand<&'input str>;
|
||||
|
||||
#[derive(Clone, PartialEq, Default, Debug, Display)]
|
||||
pub struct TokenError;
|
||||
|
||||
impl std::error::Error for TokenError {}
|
||||
|
||||
derive_parser!(
|
||||
#[derive(Logos, PartialEq, Eq, Debug, Clone, Copy)]
|
||||
#[logos(skip r"\s+")]
|
||||
#[logos(error = TokenError)]
|
||||
enum Token<'input> {
|
||||
#[token(",")]
|
||||
Comma,
|
||||
|
@ -1134,6 +1154,7 @@ derive_parser!(
|
|||
pub enum StateSpace {
|
||||
Reg,
|
||||
Generic,
|
||||
Sreg,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
|
||||
|
@ -2825,57 +2846,6 @@ derive_parser!(
|
|||
|
||||
);
|
||||
|
||||
fn main() {
|
||||
use winnow::Parser;
|
||||
|
||||
let lexer = Token::lexer(
|
||||
"
|
||||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.const .align 8 .b32 constparams;
|
||||
|
||||
.visible .entry const(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .b16 temp1;
|
||||
.reg .b16 temp2;
|
||||
.reg .b16 temp3;
|
||||
.reg .b16 temp4;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.const.b16 temp1, [constparams];
|
||||
ld.const.b16 temp2, [constparams+2];
|
||||
ld.const.b16 temp3, [constparams+4];
|
||||
ld.const.b16 temp4, [constparams+6];
|
||||
st.u16 [out_addr], temp1;
|
||||
st.u16 [out_addr+2], temp2;
|
||||
st.u16 [out_addr+4], temp3;
|
||||
st.u16 [out_addr+6], temp4;
|
||||
ret;
|
||||
}
|
||||
|
||||
",
|
||||
);
|
||||
let tokens = lexer.clone().collect::<Vec<_>>();
|
||||
println!("{:?}", &tokens);
|
||||
let tokens = lexer.map(|t| t.unwrap()).collect::<Vec<_>>();
|
||||
println!("{:?}", &tokens);
|
||||
let stream = PtxParser {
|
||||
input: &tokens[..],
|
||||
state: PtxParserState::new(),
|
||||
};
|
||||
let _module = module.parse(stream).unwrap();
|
||||
println!("{}", mem::size_of::<Token>());
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::target;
|
|
@ -1017,7 +1017,7 @@ pub fn generate_instruction_type(tokens: proc_macro::TokenStream) -> proc_macro:
|
|||
input.emit_arg_types(&mut result);
|
||||
input.emit_instruction_type(&mut result);
|
||||
input.emit_visit(&mut result);
|
||||
input.emit_visit_mut(&mut result);
|
||||
//input.emit_visit_mut(&mut result);
|
||||
input.emit_visit_map(&mut result);
|
||||
result.into()
|
||||
}
|
||||
|
|
|
@ -67,37 +67,29 @@ impl GenerateInstructionType {
|
|||
let visit_ref = kind.reference();
|
||||
let visitor_type = format_ident!("Visitor{}", kind.type_suffix());
|
||||
let visit_fn = format_ident!("visit{}", kind.fn_suffix());
|
||||
let visit_slice_fn = format_ident!("visit{}_slice", kind.fn_suffix());
|
||||
let (type_parameters, visitor_parameters, return_type) = if kind == VisitKind::Map {
|
||||
(
|
||||
quote! { <#type_parameters, To: Operand> },
|
||||
quote! { <#short_parameters, To> },
|
||||
quote! { #type_name<To> },
|
||||
quote! { <#type_parameters, To: Operand, Err> },
|
||||
quote! { <#short_parameters, To, Err> },
|
||||
quote! { std::result::Result<#type_name<To>, Err> },
|
||||
)
|
||||
} else {
|
||||
(
|
||||
quote! { <#type_parameters> },
|
||||
quote! { <#short_parameters> },
|
||||
quote! { () },
|
||||
quote! { <#type_parameters, Err> },
|
||||
quote! { <#short_parameters, Err> },
|
||||
quote! { std::result::Result<(), Err> },
|
||||
)
|
||||
};
|
||||
quote! {
|
||||
fn #visit_fn #type_parameters (i: #visit_ref #type_name<#short_parameters>, visitor: &mut impl #visitor_type #visitor_parameters ) -> #return_type {
|
||||
match i {
|
||||
pub fn #visit_fn #type_parameters (i: #visit_ref #type_name<#short_parameters>, visitor: &mut impl #visitor_type #visitor_parameters ) -> #return_type {
|
||||
Ok(match i {
|
||||
#inner_tokens
|
||||
}
|
||||
})
|
||||
}
|
||||
}.to_tokens(tokens);
|
||||
if kind == VisitKind::Map {
|
||||
return;
|
||||
}
|
||||
quote! {
|
||||
fn #visit_slice_fn #type_parameters (instructions: #visit_ref [#type_name<#short_parameters>], visitor: &mut impl #visitor_type #visitor_parameters) {
|
||||
for i in instructions {
|
||||
#visit_fn(i, visitor)
|
||||
}
|
||||
}
|
||||
}.to_tokens(tokens);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -630,14 +622,14 @@ impl ArgumentField {
|
|||
quote! {
|
||||
{
|
||||
#type_space
|
||||
visitor.visit_ident(&mut arguments.#name, type_space, #is_dst);
|
||||
visitor.visit_ident(&mut arguments.#name, type_space, #is_dst)?;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
{
|
||||
#type_space
|
||||
visitor.visit_ident(& arguments.#name, type_space, #is_dst);
|
||||
visitor.visit_ident(& arguments.#name, type_space, #is_dst)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -663,7 +655,7 @@ impl ArgumentField {
|
|||
};
|
||||
quote! {{
|
||||
#type_space
|
||||
#operand_fn(#arguments_name, |x| visitor.visit(x, type_space, #is_dst));
|
||||
#operand_fn(#arguments_name, |x| visitor.visit(x, type_space, #is_dst))?;
|
||||
}}
|
||||
}
|
||||
}
|
||||
|
@ -701,11 +693,11 @@ impl ArgumentField {
|
|||
};
|
||||
let map_call = if is_ident {
|
||||
quote! {
|
||||
visitor.visit_ident(arguments.#name, type_space, #is_dst)
|
||||
visitor.visit_ident(arguments.#name, type_space, #is_dst)?
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
MapOperand::map(arguments.#name, |x| visitor.visit(x, type_space, #is_dst))
|
||||
MapOperand::map(arguments.#name, |x| visitor.visit(x, type_space, #is_dst))?
|
||||
}
|
||||
};
|
||||
quote! {
|
||||
|
|
Loading…
Add table
Reference in a new issue