Add relaxed type check information to visitors

This commit is contained in:
Andrzej Janik 2024-08-24 02:51:46 +02:00
commit 69175d27ed
3 changed files with 76 additions and 36 deletions

View file

@ -27,7 +27,10 @@ ptx_parser_macros::generate_instruction_type!(
type: { &data.typ }, type: { &data.typ },
data: LdDetails, data: LdDetails,
arguments<T>: { arguments<T>: {
dst: T, dst: {
repr: T,
relaxed_type_check: true,
},
src: { src: {
repr: T, repr: T,
space: { data.state_space }, space: { data.state_space },
@ -51,7 +54,10 @@ ptx_parser_macros::generate_instruction_type!(
repr: T, repr: T,
space: { data.state_space }, space: { data.state_space },
}, },
src2: T, src2: {
repr: T,
relaxed_type_check: true,
}
} }
}, },
Mul { Mul {
@ -157,10 +163,13 @@ ptx_parser_macros::generate_instruction_type!(
dst: { dst: {
repr: T, repr: T,
type: { Type::Scalar(data.to) }, type: { Type::Scalar(data.to) },
// TODO: double check
relaxed_type_check: true,
}, },
src: { src: {
repr: T, repr: T,
type: { Type::Scalar(data.from) }, type: { Type::Scalar(data.from) },
relaxed_type_check: true,
}, },
} }
}, },
@ -494,16 +503,18 @@ pub trait Visitor<T: Operand, Err> {
args: &T, args: &T,
type_space: Option<(&Type, StateSpace)>, type_space: Option<(&Type, StateSpace)>,
is_dst: bool, is_dst: bool,
relaxed_type_check: bool,
) -> Result<(), Err>; ) -> Result<(), Err>;
fn visit_ident( fn visit_ident(
&mut self, &mut self,
args: &T::Ident, args: &T::Ident,
type_space: Option<(&Type, StateSpace)>, type_space: Option<(&Type, StateSpace)>,
is_dst: bool, is_dst: bool,
relaxed_type_check: bool,
) -> Result<(), Err>; ) -> Result<(), Err>;
} }
impl<T: Operand, Err, Fn: FnMut(&T, Option<(&Type, StateSpace)>, bool) -> Result<(), Err>> impl<T: Operand, Err, Fn: FnMut(&T, Option<(&Type, StateSpace)>, bool, bool) -> Result<(), Err>>
Visitor<T, Err> for Fn Visitor<T, Err> for Fn
{ {
fn visit( fn visit(
@ -511,8 +522,9 @@ impl<T: Operand, Err, Fn: FnMut(&T, Option<(&Type, StateSpace)>, bool) -> Result
args: &T, args: &T,
type_space: Option<(&Type, StateSpace)>, type_space: Option<(&Type, StateSpace)>,
is_dst: bool, is_dst: bool,
relaxed_type_check: bool,
) -> Result<(), Err> { ) -> Result<(), Err> {
(self)(args, type_space, is_dst) (self)(args, type_space, is_dst, relaxed_type_check)
} }
fn visit_ident( fn visit_ident(
@ -520,8 +532,14 @@ impl<T: Operand, Err, Fn: FnMut(&T, Option<(&Type, StateSpace)>, bool) -> Result
args: &T::Ident, args: &T::Ident,
type_space: Option<(&Type, StateSpace)>, type_space: Option<(&Type, StateSpace)>,
is_dst: bool, is_dst: bool,
relaxed_type_check: bool,
) -> Result<(), Err> { ) -> Result<(), Err> {
(self)(&T::from_ident(*args), type_space, is_dst) (self)(
&T::from_ident(*args),
type_space,
is_dst,
relaxed_type_check,
)
} }
} }
@ -531,12 +549,14 @@ pub trait VisitorMut<T: Operand, Err> {
args: &mut T, args: &mut T,
type_space: Option<(&Type, StateSpace)>, type_space: Option<(&Type, StateSpace)>,
is_dst: bool, is_dst: bool,
relaxed_type_check: bool,
) -> Result<(), Err>; ) -> Result<(), Err>;
fn visit_ident( fn visit_ident(
&mut self, &mut self,
args: &mut T::Ident, args: &mut T::Ident,
type_space: Option<(&Type, StateSpace)>, type_space: Option<(&Type, StateSpace)>,
is_dst: bool, is_dst: bool,
relaxed_type_check: bool,
) -> Result<(), Err>; ) -> Result<(), Err>;
} }
@ -546,37 +566,44 @@ pub trait VisitorMap<From: Operand, To: Operand, Err> {
args: From, args: From,
type_space: Option<(&Type, StateSpace)>, type_space: Option<(&Type, StateSpace)>,
is_dst: bool, is_dst: bool,
relaxed_type_check: bool,
) -> Result<To, Err>; ) -> Result<To, Err>;
fn visit_ident( fn visit_ident(
&mut self, &mut self,
args: From::Ident, args: From::Ident,
type_space: Option<(&Type, StateSpace)>, type_space: Option<(&Type, StateSpace)>,
is_dst: bool, is_dst: bool,
relaxed_type_check: bool,
) -> Result<To::Ident, Err>; ) -> Result<To::Ident, Err>;
} }
impl<T: Copy, U: Copy, Err, Fn> VisitorMap<ParsedOperand<T>, ParsedOperand<U>, Err> for Fn impl<T: Copy, U: Copy, Err, Fn> VisitorMap<ParsedOperand<T>, ParsedOperand<U>, Err> for Fn
where where
Fn: FnMut(T, Option<(&Type, StateSpace)>, bool) -> Result<U, Err>, Fn: FnMut(T, Option<(&Type, StateSpace)>, bool, bool) -> Result<U, Err>,
{ {
fn visit( fn visit(
&mut self, &mut self,
args: ParsedOperand<T>, args: ParsedOperand<T>,
type_space: Option<(&Type, StateSpace)>, type_space: Option<(&Type, StateSpace)>,
is_dst: bool, is_dst: bool,
relaxed_type_check: bool,
) -> Result<ParsedOperand<U>, Err> { ) -> Result<ParsedOperand<U>, Err> {
Ok(match args { Ok(match args {
ParsedOperand::Reg(ident) => ParsedOperand::Reg((self)(ident, type_space, is_dst)?), ParsedOperand::Reg(ident) => {
ParsedOperand::RegOffset(ident, imm) => { ParsedOperand::Reg((self)(ident, type_space, is_dst, relaxed_type_check)?)
ParsedOperand::RegOffset((self)(ident, type_space, is_dst)?, imm)
} }
ParsedOperand::RegOffset(ident, imm) => ParsedOperand::RegOffset(
(self)(ident, type_space, is_dst, relaxed_type_check)?,
imm,
),
ParsedOperand::Imm(imm) => ParsedOperand::Imm(imm), ParsedOperand::Imm(imm) => ParsedOperand::Imm(imm),
ParsedOperand::VecMember(ident, index) => { ParsedOperand::VecMember(ident, index) => ParsedOperand::VecMember(
ParsedOperand::VecMember((self)(ident, type_space, is_dst)?, index) (self)(ident, type_space, is_dst, relaxed_type_check)?,
} index,
),
ParsedOperand::VecPack(vec) => ParsedOperand::VecPack( ParsedOperand::VecPack(vec) => ParsedOperand::VecPack(
vec.into_iter() vec.into_iter()
.map(|ident| (self)(ident, type_space, is_dst)) .map(|ident| (self)(ident, type_space, is_dst, relaxed_type_check))
.collect::<Result<Vec<_>, _>>()?, .collect::<Result<Vec<_>, _>>()?,
), ),
}) })
@ -587,22 +614,24 @@ where
args: T, args: T,
type_space: Option<(&Type, StateSpace)>, type_space: Option<(&Type, StateSpace)>,
is_dst: bool, is_dst: bool,
relaxed_type_check: bool,
) -> Result<U, Err> { ) -> Result<U, Err> {
(self)(args, type_space, is_dst) (self)(args, type_space, is_dst, relaxed_type_check)
} }
} }
impl<T: Operand<Ident = T>, U: Operand<Ident = U>, Err, Fn> VisitorMap<T, U, Err> for Fn impl<T: Operand<Ident = T>, U: Operand<Ident = U>, Err, Fn> VisitorMap<T, U, Err> for Fn
where where
Fn: FnMut(T, Option<(&Type, StateSpace)>, bool) -> Result<U, Err>, Fn: FnMut(T, Option<(&Type, StateSpace)>, bool, bool) -> Result<U, Err>,
{ {
fn visit( fn visit(
&mut self, &mut self,
args: T, args: T,
type_space: Option<(&Type, StateSpace)>, type_space: Option<(&Type, StateSpace)>,
is_dst: bool, is_dst: bool,
relaxed_type_check: bool,
) -> Result<U, Err> { ) -> Result<U, Err> {
(self)(args, type_space, is_dst) (self)(args, type_space, is_dst, relaxed_type_check)
} }
fn visit_ident( fn visit_ident(
@ -610,8 +639,9 @@ where
args: T, args: T,
type_space: Option<(&Type, StateSpace)>, type_space: Option<(&Type, StateSpace)>,
is_dst: bool, is_dst: bool,
relaxed_type_check: bool,
) -> Result<U, Err> { ) -> Result<U, Err> {
(self)(args, type_space, is_dst) (self)(args, type_space, is_dst, relaxed_type_check)
} }
} }
@ -1198,15 +1228,15 @@ impl<T: Operand> CallArgs<T> {
.iter() .iter()
.zip(details.return_arguments.iter()) .zip(details.return_arguments.iter())
{ {
visitor.visit_ident(param, Some((type_, *space)), true)?; visitor.visit_ident(param, Some((type_, *space)), true, false)?;
} }
visitor.visit_ident(&self.func, None, false)?; visitor.visit_ident(&self.func, None, false, false)?;
for (param, (type_, space)) in self for (param, (type_, space)) in self
.input_arguments .input_arguments
.iter() .iter()
.zip(details.input_arguments.iter()) .zip(details.input_arguments.iter())
{ {
visitor.visit(param, Some((type_, *space)), true)?; visitor.visit(param, Some((type_, *space)), true, false)?;
} }
Ok(()) Ok(())
} }
@ -1222,15 +1252,15 @@ impl<T: Operand> CallArgs<T> {
.iter_mut() .iter_mut()
.zip(details.return_arguments.iter()) .zip(details.return_arguments.iter())
{ {
visitor.visit_ident(param, Some((type_, *space)), true)?; visitor.visit_ident(param, Some((type_, *space)), true, false)?;
} }
visitor.visit_ident(&mut self.func, None, false)?; visitor.visit_ident(&mut self.func, None, false, false)?;
for (param, (type_, space)) in self for (param, (type_, space)) in self
.input_arguments .input_arguments
.iter_mut() .iter_mut()
.zip(details.input_arguments.iter()) .zip(details.input_arguments.iter())
{ {
visitor.visit(param, Some((type_, *space)), true)?; visitor.visit(param, Some((type_, *space)), true, false)?;
} }
Ok(()) Ok(())
} }
@ -1245,14 +1275,14 @@ impl<T: Operand> CallArgs<T> {
.return_arguments .return_arguments
.into_iter() .into_iter()
.zip(details.return_arguments.iter()) .zip(details.return_arguments.iter())
.map(|(param, (type_, space))| visitor.visit_ident(param, Some((type_, *space)), true)) .map(|(param, (type_, space))| visitor.visit_ident(param, Some((type_, *space)), true, false))
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let func = visitor.visit_ident(self.func, None, false)?; let func = visitor.visit_ident(self.func, None, false, false)?;
let input_arguments = self let input_arguments = self
.input_arguments .input_arguments
.into_iter() .into_iter()
.zip(details.input_arguments.iter()) .zip(details.input_arguments.iter())
.map(|(param, (type_, space))| visitor.visit(param, Some((type_, *space)), true)) .map(|(param, (type_, space))| visitor.visit(param, Some((type_, *space)), true, false))
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
Ok(CallArgs { Ok(CallArgs {
return_arguments, return_arguments,

View file

@ -1017,7 +1017,7 @@ pub fn generate_instruction_type(tokens: proc_macro::TokenStream) -> proc_macro:
input.emit_arg_types(&mut result); input.emit_arg_types(&mut result);
input.emit_instruction_type(&mut result); input.emit_instruction_type(&mut result);
input.emit_visit(&mut result); input.emit_visit(&mut result);
//input.emit_visit_mut(&mut result); input.emit_visit_mut(&mut result);
input.emit_visit_map(&mut result); input.emit_visit_map(&mut result);
result.into() result.into()
} }

View file

@ -512,12 +512,13 @@ pub struct ArgumentField {
pub repr: Type, pub repr: Type,
pub space: Option<Expr>, pub space: Option<Expr>,
pub type_: Option<Expr>, pub type_: Option<Expr>,
pub relaxed_type_check: bool,
} }
impl ArgumentField { impl ArgumentField {
fn parse_block( fn parse_block(
input: syn::parse::ParseStream, input: syn::parse::ParseStream,
) -> syn::Result<(Type, Option<Expr>, Option<Expr>, Option<bool>)> { ) -> syn::Result<(Type, Option<Expr>, Option<Expr>, Option<bool>, bool)> {
let content; let content;
braced!(content in input); braced!(content in input);
let all_fields = let all_fields =
@ -531,6 +532,9 @@ impl ArgumentField {
let name_ident = content.parse::<Ident>()?; let name_ident = content.parse::<Ident>()?;
content.parse::<Token![:]>()?; content.parse::<Token![:]>()?;
match &*name_ident.to_string() { match &*name_ident.to_string() {
"relaxed_type_check" => {
ExprOrPath::RelaxedTypeCheck(content.parse::<LitBool>()?.value)
}
"repr" => ExprOrPath::Repr(content.parse::<Type>()?), "repr" => ExprOrPath::Repr(content.parse::<Type>()?),
"space" => ExprOrPath::Space(content.parse::<Expr>()?), "space" => ExprOrPath::Space(content.parse::<Expr>()?),
"dst" => { "dst" => {
@ -552,15 +556,17 @@ impl ArgumentField {
let mut type_ = None; let mut type_ = None;
let mut space = None; let mut space = None;
let mut is_dst = None; let mut is_dst = None;
let mut relaxed_type_check = false;
for exp_or_path in all_fields { for exp_or_path in all_fields {
match exp_or_path { match exp_or_path {
ExprOrPath::Repr(r) => repr = Some(r), ExprOrPath::Repr(r) => repr = Some(r),
ExprOrPath::Type(t) => type_ = Some(t), ExprOrPath::Type(t) => type_ = Some(t),
ExprOrPath::Space(s) => space = Some(s), ExprOrPath::Space(s) => space = Some(s),
ExprOrPath::Dst(x) => is_dst = Some(x), ExprOrPath::Dst(x) => is_dst = Some(x),
ExprOrPath::RelaxedTypeCheck(relaxed) => relaxed_type_check = relaxed,
} }
} }
Ok((repr.unwrap(), type_, space, is_dst)) Ok((repr.unwrap(), type_, space, is_dst, relaxed_type_check))
} }
fn parse_basic(input: &syn::parse::ParseBuffer) -> syn::Result<Type> { fn parse_basic(input: &syn::parse::ParseBuffer) -> syn::Result<Type> {
@ -605,6 +611,7 @@ impl ArgumentField {
.map(|space| quote! { #space }) .map(|space| quote! { #space })
.unwrap_or_else(|| quote! { StateSpace::Reg }); .unwrap_or_else(|| quote! { StateSpace::Reg });
let is_dst = self.is_dst; let is_dst = self.is_dst;
let relaxed_type_check = self.relaxed_type_check;
let name = &self.name; let name = &self.name;
let type_space = if is_typeless { let type_space = if is_typeless {
quote! { quote! {
@ -622,14 +629,14 @@ impl ArgumentField {
quote! { quote! {
{ {
#type_space #type_space
visitor.visit_ident(&mut arguments.#name, type_space, #is_dst)?; visitor.visit_ident(&mut arguments.#name, type_space, #is_dst, #relaxed_type_check)?;
} }
} }
} else { } else {
quote! { quote! {
{ {
#type_space #type_space
visitor.visit_ident(& arguments.#name, type_space, #is_dst)?; visitor.visit_ident(& arguments.#name, type_space, #is_dst, #relaxed_type_check)?;
} }
} }
} }
@ -655,7 +662,7 @@ impl ArgumentField {
}; };
quote! {{ quote! {{
#type_space #type_space
#operand_fn(#arguments_name, |x| visitor.visit(x, type_space, #is_dst))?; #operand_fn(#arguments_name, |x| visitor.visit(x, type_space, #is_dst, #relaxed_type_check))?;
}} }}
} }
} }
@ -679,6 +686,7 @@ impl ArgumentField {
.map(|space| quote! { #space }) .map(|space| quote! { #space })
.unwrap_or_else(|| quote! { StateSpace::Reg }); .unwrap_or_else(|| quote! { StateSpace::Reg });
let is_dst = self.is_dst; let is_dst = self.is_dst;
let relaxed_type_check = self.relaxed_type_check;
let name = &self.name; let name = &self.name;
let type_space = if is_typeless { let type_space = if is_typeless {
quote! { quote! {
@ -693,11 +701,11 @@ impl ArgumentField {
}; };
let map_call = if is_ident { let map_call = if is_ident {
quote! { quote! {
visitor.visit_ident(arguments.#name, type_space, #is_dst)? visitor.visit_ident(arguments.#name, type_space, #is_dst, #relaxed_type_check)?
} }
} else { } else {
quote! { quote! {
MapOperand::map(arguments.#name, |x| visitor.visit(x, type_space, #is_dst))? MapOperand::map(arguments.#name, |x| visitor.visit(x, type_space, #is_dst, #relaxed_type_check))?
} }
}; };
quote! { quote! {
@ -739,10 +747,10 @@ impl Parse for ArgumentField {
input.parse::<Token![:]>()?; input.parse::<Token![:]>()?;
let lookahead = input.lookahead1(); let lookahead = input.lookahead1();
let (repr, type_, space, is_dst) = if lookahead.peek(token::Brace) { let (repr, type_, space, is_dst, relaxed_type_check) = if lookahead.peek(token::Brace) {
Self::parse_block(input)? Self::parse_block(input)?
} else if lookahead.peek(syn::Ident) { } else if lookahead.peek(syn::Ident) {
(Self::parse_basic(input)?, None, None, None) (Self::parse_basic(input)?, None, None, None, false)
} else { } else {
return Err(lookahead.error()); return Err(lookahead.error());
}; };
@ -756,6 +764,7 @@ impl Parse for ArgumentField {
repr, repr,
type_, type_,
space, space,
relaxed_type_check
}) })
} }
} }
@ -765,6 +774,7 @@ enum ExprOrPath {
Type(Expr), Type(Expr),
Space(Expr), Space(Expr),
Dst(bool), Dst(bool),
RelaxedTypeCheck(bool),
} }
#[cfg(test)] #[cfg(test)]