mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-04 07:09:53 +00:00
More fixes
This commit is contained in:
parent
0c9339325e
commit
8d15499acc
9 changed files with 487 additions and 104 deletions
|
@ -467,8 +467,22 @@ fn convert_to_stateful_memory_access_postprocess(
|
||||||
Some(new_id) => {
|
Some(new_id) => {
|
||||||
let (new_operand_type, new_operand_space, _) = id_defs.get_typed(*new_id)?;
|
let (new_operand_type, new_operand_space, _) = id_defs.get_typed(*new_id)?;
|
||||||
// TODO: readd if required
|
// TODO: readd if required
|
||||||
if let Some(..) = type_space {
|
if let Some((expected_type, expected_space)) = type_space {
|
||||||
if relaxed_conversion {
|
let implicit_conversion = if relaxed_conversion {
|
||||||
|
if is_dst {
|
||||||
|
super::insert_implicit_conversions::should_convert_relaxed_dst_wrapper
|
||||||
|
} else {
|
||||||
|
super::insert_implicit_conversions::should_convert_relaxed_src_wrapper
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
super::insert_implicit_conversions::default_implicit_conversion
|
||||||
|
};
|
||||||
|
if implicit_conversion(
|
||||||
|
(new_operand_space, &new_operand_type),
|
||||||
|
(expected_space, expected_type),
|
||||||
|
)
|
||||||
|
.is_ok()
|
||||||
|
{
|
||||||
return Ok(*new_id);
|
return Ok(*new_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,7 +67,7 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> {
|
||||||
) -> Result<SpirvWord, TranslateError> {
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
// mov.u32 foobar, {a,b};
|
// mov.u32 foobar, {a,b};
|
||||||
let scalar_t = match typ {
|
let scalar_t = match typ {
|
||||||
ast::Type::Vector(scalar_t, _) => *scalar_t,
|
ast::Type::Vector(_, scalar_t) => *scalar_t,
|
||||||
_ => return Err(error_mismatched_type()),
|
_ => return Err(error_mismatched_type()),
|
||||||
};
|
};
|
||||||
let temp_vec = self
|
let temp_vec = self
|
||||||
|
|
|
@ -291,7 +291,7 @@ impl TypeWordMap {
|
||||||
| ast::ScalarType::BF16x2
|
| ast::ScalarType::BF16x2
|
||||||
| ast::ScalarType::B128 => todo!(),
|
| ast::ScalarType::B128 => todo!(),
|
||||||
},
|
},
|
||||||
ast::Type::Vector(typ, len) => {
|
ast::Type::Vector(len, typ) => {
|
||||||
let result_type =
|
let result_type =
|
||||||
self.get_or_add(b, SpirvType::Vector(SpirvScalarKey::from(*typ), *len));
|
self.get_or_add(b, SpirvType::Vector(SpirvScalarKey::from(*typ), *len));
|
||||||
let size_of_t = typ.size_of();
|
let size_of_t = typ.size_of();
|
||||||
|
@ -309,7 +309,7 @@ impl TypeWordMap {
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
SpirvWord(b.constant_composite(result_type.0, None, components.into_iter()))
|
SpirvWord(b.constant_composite(result_type.0, None, components.into_iter()))
|
||||||
}
|
}
|
||||||
ast::Type::Array(typ, dims) => match dims.as_slice() {
|
ast::Type::Array(_, typ, dims) => match dims.as_slice() {
|
||||||
[] => return Err(error_unreachable()),
|
[] => return Err(error_unreachable()),
|
||||||
[dim] => {
|
[dim] => {
|
||||||
let result_type = self
|
let result_type = self
|
||||||
|
@ -342,7 +342,7 @@ impl TypeWordMap {
|
||||||
Ok::<_, TranslateError>(
|
Ok::<_, TranslateError>(
|
||||||
self.get_or_add_constant(
|
self.get_or_add_constant(
|
||||||
b,
|
b,
|
||||||
&ast::Type::Array(*typ, rest.to_vec()),
|
&ast::Type::Array(None, *typ, rest.to_vec()),
|
||||||
&init[((size_of_t as usize) * (x as usize))..],
|
&init[((size_of_t as usize) * (x as usize))..],
|
||||||
)?
|
)?
|
||||||
.0,
|
.0,
|
||||||
|
@ -397,8 +397,8 @@ impl SpirvType {
|
||||||
fn new(t: ast::Type) -> Self {
|
fn new(t: ast::Type) -> Self {
|
||||||
match t {
|
match t {
|
||||||
ast::Type::Scalar(t) => SpirvType::Base(t.into()),
|
ast::Type::Scalar(t) => SpirvType::Base(t.into()),
|
||||||
ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len),
|
ast::Type::Vector(len, typ) => SpirvType::Vector(typ.into(), len),
|
||||||
ast::Type::Array(t, len) => SpirvType::Array(t.into(), len),
|
ast::Type::Array(_, t, len) => SpirvType::Array(t.into(), len),
|
||||||
ast::Type::Pointer(pointer_t, space) => SpirvType::Pointer(
|
ast::Type::Pointer(pointer_t, space) => SpirvType::Pointer(
|
||||||
Box::new(SpirvType::Base(pointer_t.into())),
|
Box::new(SpirvType::Base(pointer_t.into())),
|
||||||
space_to_spirv(space),
|
space_to_spirv(space),
|
||||||
|
@ -809,8 +809,8 @@ fn emit_function_header<'input>(
|
||||||
pub fn type_size_of(this: &ast::Type) -> usize {
|
pub fn type_size_of(this: &ast::Type) -> usize {
|
||||||
match this {
|
match this {
|
||||||
ast::Type::Scalar(typ) => typ.size_of() as usize,
|
ast::Type::Scalar(typ) => typ.size_of() as usize,
|
||||||
ast::Type::Vector(typ, len) => (typ.size_of() as usize) * (*len as usize),
|
ast::Type::Vector(len, typ) => (typ.size_of() as usize) * (*len as usize),
|
||||||
ast::Type::Array(typ, len) => len
|
ast::Type::Array(_, typ, len) => len
|
||||||
.iter()
|
.iter()
|
||||||
.fold(typ.size_of() as usize, |x, y| (x as usize) * (*y as usize)),
|
.fold(typ.size_of() as usize, |x, y| (x as usize) * (*y as usize)),
|
||||||
ast::Type::Pointer(..) => mem::size_of::<usize>(),
|
ast::Type::Pointer(..) => mem::size_of::<usize>(),
|
||||||
|
@ -1853,11 +1853,16 @@ fn emit_mul_int(
|
||||||
builder.i_mul(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?;
|
builder.i_mul(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?;
|
||||||
}
|
}
|
||||||
ast::MulIntControl::High => {
|
ast::MulIntControl::High => {
|
||||||
|
let opencl_inst = if type_.kind() == ast::ScalarKind::Signed {
|
||||||
|
spirv::CLOp::s_mul_hi
|
||||||
|
} else {
|
||||||
|
spirv::CLOp::u_mul_hi
|
||||||
|
};
|
||||||
builder.ext_inst(
|
builder.ext_inst(
|
||||||
inst_type.0,
|
inst_type.0,
|
||||||
Some(arg.dst.0),
|
Some(arg.dst.0),
|
||||||
opencl,
|
opencl,
|
||||||
spirv::CLOp::s_mul_hi as spirv::Word,
|
opencl_inst as spirv::Word,
|
||||||
[
|
[
|
||||||
dr::Operand::IdRef(arg.src1.0),
|
dr::Operand::IdRef(arg.src1.0),
|
||||||
dr::Operand::IdRef(arg.src2.0),
|
dr::Operand::IdRef(arg.src2.0),
|
||||||
|
@ -2646,7 +2651,7 @@ fn emit_load_var(
|
||||||
match details.member_index {
|
match details.member_index {
|
||||||
Some((index, Some(width))) => {
|
Some((index, Some(width))) => {
|
||||||
let vector_type = match details.typ {
|
let vector_type = match details.typ {
|
||||||
ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width),
|
ast::Type::Scalar(scalar_t) => ast::Type::Vector(width, scalar_t),
|
||||||
_ => return Err(error_mismatched_type()),
|
_ => return Err(error_mismatched_type()),
|
||||||
};
|
};
|
||||||
let vector_type_spirv = map.get_or_add(builder, SpirvType::new(vector_type));
|
let vector_type_spirv = map.get_or_add(builder, SpirvType::new(vector_type));
|
||||||
|
@ -2710,14 +2715,14 @@ fn to_parts(this: &ast::Type) -> TypeParts {
|
||||||
width: scalar.size_of(),
|
width: scalar.size_of(),
|
||||||
components: Vec::new(),
|
components: Vec::new(),
|
||||||
},
|
},
|
||||||
ast::Type::Vector(scalar, components) => TypeParts {
|
ast::Type::Vector(components, scalar) => TypeParts {
|
||||||
kind: TypeKind::Vector,
|
kind: TypeKind::Vector,
|
||||||
state_space: ast::StateSpace::Reg,
|
state_space: ast::StateSpace::Reg,
|
||||||
scalar_kind: scalar.kind(),
|
scalar_kind: scalar.kind(),
|
||||||
width: scalar.size_of(),
|
width: scalar.size_of(),
|
||||||
components: vec![*components as u32],
|
components: vec![*components as u32],
|
||||||
},
|
},
|
||||||
ast::Type::Array(scalar, components) => TypeParts {
|
ast::Type::Array(_, scalar, components) => TypeParts {
|
||||||
kind: TypeKind::Array,
|
kind: TypeKind::Array,
|
||||||
state_space: ast::StateSpace::Reg,
|
state_space: ast::StateSpace::Reg,
|
||||||
scalar_kind: scalar.kind(),
|
scalar_kind: scalar.kind(),
|
||||||
|
@ -2738,12 +2743,14 @@ fn type_from_parts(t: TypeParts) -> ast::Type {
|
||||||
match t.kind {
|
match t.kind {
|
||||||
TypeKind::Scalar => ast::Type::Scalar(scalar_from_parts(t.width, t.scalar_kind)),
|
TypeKind::Scalar => ast::Type::Scalar(scalar_from_parts(t.width, t.scalar_kind)),
|
||||||
TypeKind::Vector => ast::Type::Vector(
|
TypeKind::Vector => ast::Type::Vector(
|
||||||
scalar_from_parts(t.width, t.scalar_kind),
|
|
||||||
t.components[0] as u8,
|
t.components[0] as u8,
|
||||||
|
scalar_from_parts(t.width, t.scalar_kind),
|
||||||
|
),
|
||||||
|
TypeKind::Array => ast::Type::Array(
|
||||||
|
None,
|
||||||
|
scalar_from_parts(t.width, t.scalar_kind),
|
||||||
|
t.components,
|
||||||
),
|
),
|
||||||
TypeKind::Array => {
|
|
||||||
ast::Type::Array(scalar_from_parts(t.width, t.scalar_kind), t.components)
|
|
||||||
}
|
|
||||||
TypeKind::Pointer => {
|
TypeKind::Pointer => {
|
||||||
ast::Type::Pointer(scalar_from_parts(t.width, t.scalar_kind), t.state_space)
|
ast::Type::Pointer(scalar_from_parts(t.width, t.scalar_kind), t.state_space)
|
||||||
}
|
}
|
||||||
|
|
|
@ -123,13 +123,13 @@ fn insert_implicit_conversions_impl(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_implicit_conversion(
|
pub(crate) fn default_implicit_conversion(
|
||||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||||
if instruction_space == ast::StateSpace::Reg {
|
if instruction_space == ast::StateSpace::Reg {
|
||||||
if space_is_compatible(operand_space, ast::StateSpace::Reg) {
|
if space_is_compatible(operand_space, ast::StateSpace::Reg) {
|
||||||
if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) =
|
if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) =
|
||||||
(operand_type, instruction_type)
|
(operand_type, instruction_type)
|
||||||
{
|
{
|
||||||
if scalar.kind() == ast::ScalarKind::Bit
|
if scalar.kind() == ast::ScalarKind::Bit
|
||||||
|
@ -282,15 +282,15 @@ fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
|
||||||
ast::ScalarKind::Pred => false,
|
ast::ScalarKind::Pred => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
(ast::Type::Vector(inst, _), ast::Type::Vector(operand, _))
|
(ast::Type::Vector(_, inst), ast::Type::Vector(_, operand))
|
||||||
| (ast::Type::Array(inst, _), ast::Type::Array(operand, _)) => {
|
| (ast::Type::Array(_, inst, _), ast::Type::Array(_, operand, _)) => {
|
||||||
should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand))
|
should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand))
|
||||||
}
|
}
|
||||||
_ => false,
|
_ => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn should_convert_relaxed_dst_wrapper(
|
pub(crate) fn should_convert_relaxed_dst_wrapper(
|
||||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||||
|
@ -356,8 +356,8 @@ fn should_convert_relaxed_dst(
|
||||||
}
|
}
|
||||||
ast::ScalarKind::Pred => None,
|
ast::ScalarKind::Pred => None,
|
||||||
},
|
},
|
||||||
(ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
|
(ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
|
||||||
| (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {
|
| (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
|
||||||
should_convert_relaxed_dst(
|
should_convert_relaxed_dst(
|
||||||
&ast::Type::Scalar(*dst_type),
|
&ast::Type::Scalar(*dst_type),
|
||||||
&ast::Type::Scalar(*instr_type),
|
&ast::Type::Scalar(*instr_type),
|
||||||
|
@ -367,7 +367,7 @@ fn should_convert_relaxed_dst(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn should_convert_relaxed_src_wrapper(
|
pub(crate) fn should_convert_relaxed_src_wrapper(
|
||||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||||
|
@ -420,8 +420,8 @@ fn should_convert_relaxed_src(
|
||||||
}
|
}
|
||||||
ast::ScalarKind::Pred => None,
|
ast::ScalarKind::Pred => None,
|
||||||
},
|
},
|
||||||
(ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
|
(ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
|
||||||
| (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {
|
| (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
|
||||||
should_convert_relaxed_src(
|
should_convert_relaxed_src(
|
||||||
&ast::Type::Scalar(*dst_type),
|
&ast::Type::Scalar(*dst_type),
|
||||||
&ast::Type::Scalar(*instr_type),
|
&ast::Type::Scalar(*instr_type),
|
||||||
|
|
|
@ -195,7 +195,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
|
||||||
let member_index = match member_index {
|
let member_index = match member_index {
|
||||||
Some(idx) => {
|
Some(idx) => {
|
||||||
let vector_width = match var_type {
|
let vector_width = match var_type {
|
||||||
ast::Type::Vector(scalar_t, width) => {
|
ast::Type::Vector(width, scalar_t) => {
|
||||||
var_type = ast::Type::Scalar(scalar_t);
|
var_type = ast::Type::Scalar(scalar_t);
|
||||||
width
|
width
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
use ptx_parser as ast;
|
use ptx_parser as ast;
|
||||||
use rspirv::{binary::Assemble, dr};
|
use rspirv::{binary::Assemble, dr};
|
||||||
use std::hash::Hash;
|
use std::hash::Hash;
|
||||||
|
use std::num::NonZeroU8;
|
||||||
use std::{
|
use std::{
|
||||||
borrow::Cow,
|
borrow::Cow,
|
||||||
cell::RefCell,
|
cell::RefCell,
|
||||||
|
@ -360,7 +361,7 @@ impl PtxSpecialRegister {
|
||||||
PtxSpecialRegister::Tid
|
PtxSpecialRegister::Tid
|
||||||
| PtxSpecialRegister::Ntid
|
| PtxSpecialRegister::Ntid
|
||||||
| PtxSpecialRegister::Ctaid
|
| PtxSpecialRegister::Ctaid
|
||||||
| PtxSpecialRegister::Nctaid => ast::Type::Vector(self.get_function_return_type(), 4),
|
| PtxSpecialRegister::Nctaid => ast::Type::Vector(4, self.get_function_return_type()),
|
||||||
_ => ast::Type::Scalar(self.get_function_return_type()),
|
_ => ast::Type::Scalar(self.get_function_return_type()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -764,7 +765,12 @@ impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
Statement::Conditional(conditional) => {
|
Statement::Conditional(conditional) => {
|
||||||
let predicate = visitor.visit_ident(conditional.predicate, None, false, false)?;
|
let predicate = visitor.visit_ident(
|
||||||
|
conditional.predicate,
|
||||||
|
Some((&ast::ScalarType::Pred.into(), ast::StateSpace::Reg)),
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
)?;
|
||||||
let if_true = visitor.visit_ident(conditional.if_true, None, false, false)?;
|
let if_true = visitor.visit_ident(conditional.if_true, None, false, false)?;
|
||||||
let if_false = visitor.visit_ident(conditional.if_false, None, false, false)?;
|
let if_false = visitor.visit_ident(conditional.if_false, None, false, false)?;
|
||||||
Statement::Conditional(BrachCondition {
|
Statement::Conditional(BrachCondition {
|
||||||
|
@ -919,7 +925,7 @@ impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
|
||||||
let packed = visitor.visit_ident(
|
let packed = visitor.visit_ident(
|
||||||
packed,
|
packed,
|
||||||
Some((
|
Some((
|
||||||
&ast::Type::Vector(typ, unpacked.len() as u8),
|
&ast::Type::Vector(unpacked.len() as u8, typ),
|
||||||
ast::StateSpace::Reg,
|
ast::StateSpace::Reg,
|
||||||
)),
|
)),
|
||||||
false,
|
false,
|
||||||
|
@ -930,7 +936,7 @@ impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
|
||||||
let packed = visitor.visit_ident(
|
let packed = visitor.visit_ident(
|
||||||
packed,
|
packed,
|
||||||
Some((
|
Some((
|
||||||
&ast::Type::Vector(typ, unpacked.len() as u8),
|
&ast::Type::Vector(unpacked.len() as u8, typ),
|
||||||
ast::StateSpace::Reg,
|
ast::StateSpace::Reg,
|
||||||
)),
|
)),
|
||||||
true,
|
true,
|
||||||
|
|
|
@ -4,7 +4,7 @@ use super::{
|
||||||
};
|
};
|
||||||
use crate::{PtxError, PtxParserState};
|
use crate::{PtxError, PtxParserState};
|
||||||
use bitflags::bitflags;
|
use bitflags::bitflags;
|
||||||
use std::cmp::Ordering;
|
use std::{cmp::Ordering, num::NonZeroU8};
|
||||||
|
|
||||||
pub enum Statement<P: Operand> {
|
pub enum Statement<P: Operand> {
|
||||||
Label(P::Ident),
|
Label(P::Ident),
|
||||||
|
@ -760,19 +760,37 @@ pub enum Type {
|
||||||
// .param.b32 foo;
|
// .param.b32 foo;
|
||||||
Scalar(ScalarType),
|
Scalar(ScalarType),
|
||||||
// .param.v2.b32 foo;
|
// .param.v2.b32 foo;
|
||||||
Vector(ScalarType, u8),
|
Vector(u8, ScalarType),
|
||||||
// .param.b32 foo[4];
|
// .param.b32 foo[4];
|
||||||
Array(ScalarType, Vec<u32>),
|
Array(Option<NonZeroU8>, ScalarType, Vec<u32>),
|
||||||
Pointer(ScalarType, StateSpace),
|
Pointer(ScalarType, StateSpace),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Type {
|
impl Type {
|
||||||
pub(crate) fn maybe_vector(vector: Option<VectorPrefix>, scalar: ScalarType) -> Self {
|
pub(crate) fn maybe_vector(vector: Option<VectorPrefix>, scalar: ScalarType) -> Self {
|
||||||
match vector {
|
match vector {
|
||||||
Some(prefix) => Type::Vector(scalar, prefix.len()),
|
Some(prefix) => Type::Vector(prefix.len().get(), scalar),
|
||||||
None => Type::Scalar(scalar),
|
None => Type::Scalar(scalar),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn maybe_vector_parsed(prefix: Option<NonZeroU8>, scalar: ScalarType) -> Self {
|
||||||
|
match prefix {
|
||||||
|
Some(prefix) => Type::Vector(prefix.get(), scalar),
|
||||||
|
None => Type::Scalar(scalar),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn maybe_array(
|
||||||
|
prefix: Option<NonZeroU8>,
|
||||||
|
scalar: ScalarType,
|
||||||
|
array: Option<Vec<u32>>,
|
||||||
|
) -> Self {
|
||||||
|
match array {
|
||||||
|
Some(dimensions) => Type::Array(prefix, scalar, dimensions),
|
||||||
|
None => Self::maybe_vector_parsed(prefix, scalar),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ScalarType {
|
impl ScalarType {
|
||||||
|
@ -1304,7 +1322,9 @@ impl<T: Operand> CallArgs<T> {
|
||||||
.input_arguments
|
.input_arguments
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.zip(details.input_arguments.iter())
|
.zip(details.input_arguments.iter())
|
||||||
.map(|(param, (type_, space))| visitor.visit(param, Some((type_, *space)), false, false))
|
.map(|(param, (type_, space))| {
|
||||||
|
visitor.visit(param, Some((type_, *space)), false, false)
|
||||||
|
})
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
Ok(CallArgs {
|
Ok(CallArgs {
|
||||||
return_arguments,
|
return_arguments,
|
||||||
|
|
69
ptx_parser/src/check_args.py
Normal file
69
ptx_parser/src/check_args.py
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
import os, sys, subprocess
|
||||||
|
|
||||||
|
|
||||||
|
SPACE = [".reg", ".sreg", ".param", ".param::entry", ".param::func", ".local", ".global", ".const", ".shared", ".shared::cta", ".shared::cluster"]
|
||||||
|
TYPE_AND_INIT = ["", " = 1", "[1]", "[1] = {1}"]
|
||||||
|
MULTIVAR = ["", "<1>" ]
|
||||||
|
VECTOR = ["", ".v2" ]
|
||||||
|
|
||||||
|
HEADER = """
|
||||||
|
.version 8.5
|
||||||
|
.target sm_90
|
||||||
|
.address_size 64
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def directive(space, variable, multivar, vector):
|
||||||
|
return """{3}
|
||||||
|
{0} {4} .b32 variable{2} {1};
|
||||||
|
""".format(space, variable, multivar, HEADER, vector)
|
||||||
|
|
||||||
|
def entry_arg(space, variable, multivar, vector):
|
||||||
|
return """{3}
|
||||||
|
.entry foobar ( {0} {4} .b32 variable{2} {1})
|
||||||
|
{{
|
||||||
|
ret;
|
||||||
|
}}
|
||||||
|
""".format(space, variable, multivar, HEADER, vector)
|
||||||
|
|
||||||
|
|
||||||
|
def fn_arg(space, variable, multivar, vector):
|
||||||
|
return """{3}
|
||||||
|
.func foobar ( {0} {4} .b32 variable{2} {1})
|
||||||
|
{{
|
||||||
|
ret;
|
||||||
|
}}
|
||||||
|
""".format(space, variable, multivar, HEADER, vector)
|
||||||
|
|
||||||
|
|
||||||
|
def fn_body(space, variable, multivar, vector):
|
||||||
|
return """{3}
|
||||||
|
.func foobar ()
|
||||||
|
{{
|
||||||
|
{0} {4} .b32 variable{2} {1};
|
||||||
|
ret;
|
||||||
|
}}
|
||||||
|
""".format(space, variable, multivar, HEADER, vector)
|
||||||
|
|
||||||
|
|
||||||
|
def generate(generator):
|
||||||
|
legal = []
|
||||||
|
for space in SPACE:
|
||||||
|
for init in TYPE_AND_INIT:
|
||||||
|
for multi in MULTIVAR:
|
||||||
|
for vector in VECTOR:
|
||||||
|
ptx = generator(space, init, multi, vector)
|
||||||
|
if 0 == subprocess.call(["C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.6\\bin\\ptxas.exe", "-arch", "sm_90", "-ias", ptx], stdout = subprocess.DEVNULL): #
|
||||||
|
legal.append((space, vector, init, multi))
|
||||||
|
print(generator.__name__)
|
||||||
|
print(legal)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
generate(directive)
|
||||||
|
generate(entry_arg)
|
||||||
|
generate(fn_arg)
|
||||||
|
generate(fn_body)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -3,9 +3,10 @@ use logos::Logos;
|
||||||
use ptx_parser_macros::derive_parser;
|
use ptx_parser_macros::derive_parser;
|
||||||
use rustc_hash::FxHashMap;
|
use rustc_hash::FxHashMap;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::num::{ParseFloatError, ParseIntError};
|
use std::num::{NonZeroU8, ParseFloatError, ParseIntError};
|
||||||
use winnow::ascii::dec_uint;
|
use winnow::ascii::dec_uint;
|
||||||
use winnow::combinator::*;
|
use winnow::combinator::*;
|
||||||
|
use winnow::error::{ErrMode, ErrorKind};
|
||||||
use winnow::stream::Accumulate;
|
use winnow::stream::Accumulate;
|
||||||
use winnow::token::any;
|
use winnow::token::any;
|
||||||
use winnow::{
|
use winnow::{
|
||||||
|
@ -72,11 +73,13 @@ impl From<RawRoundingMode> for ast::RoundingMode {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VectorPrefix {
|
impl VectorPrefix {
|
||||||
pub(crate) fn len(self) -> u8 {
|
pub(crate) fn len(self) -> NonZeroU8 {
|
||||||
match self {
|
unsafe {
|
||||||
VectorPrefix::V2 => 2,
|
match self {
|
||||||
VectorPrefix::V4 => 4,
|
VectorPrefix::V2 => NonZeroU8::new_unchecked(2),
|
||||||
VectorPrefix::V8 => 8,
|
VectorPrefix::V4 => NonZeroU8::new_unchecked(4),
|
||||||
|
VectorPrefix::V8 => NonZeroU8::new_unchecked(8),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -386,22 +389,14 @@ fn module_variable<'a, 'input>(
|
||||||
) -> PResult<(ast::LinkingDirective, ast::Variable<&'input str>)> {
|
) -> PResult<(ast::LinkingDirective, ast::Variable<&'input str>)> {
|
||||||
(
|
(
|
||||||
linking_directives,
|
linking_directives,
|
||||||
module_variable_state_space.flat_map(variable_scalar_or_vector),
|
global_space
|
||||||
|
.flat_map(multi_variable)
|
||||||
|
// TODO: support multi var in globals
|
||||||
|
.map(|multi_var| multi_var.var),
|
||||||
)
|
)
|
||||||
.parse_next(stream)
|
.parse_next(stream)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn module_variable_state_space<'a, 'input>(
|
|
||||||
stream: &mut PtxParser<'a, 'input>,
|
|
||||||
) -> PResult<StateSpace> {
|
|
||||||
alt((
|
|
||||||
Token::DotConst.value(StateSpace::Const),
|
|
||||||
Token::DotGlobal.value(StateSpace::Global),
|
|
||||||
Token::DotShared.value(StateSpace::Shared),
|
|
||||||
))
|
|
||||||
.parse_next(stream)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn file<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> {
|
fn file<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> {
|
||||||
(
|
(
|
||||||
Token::DotFile,
|
Token::DotFile,
|
||||||
|
@ -547,17 +542,13 @@ fn kernel_arguments<'a, 'input>(
|
||||||
fn kernel_input<'a, 'input>(
|
fn kernel_input<'a, 'input>(
|
||||||
stream: &mut PtxParser<'a, 'input>,
|
stream: &mut PtxParser<'a, 'input>,
|
||||||
) -> PResult<ast::Variable<&'input str>> {
|
) -> PResult<ast::Variable<&'input str>> {
|
||||||
preceded(
|
preceded(Token::DotParam, method_parameter(StateSpace::Param)).parse_next(stream)
|
||||||
Token::DotParam,
|
|
||||||
variable_scalar_or_vector(StateSpace::Param),
|
|
||||||
)
|
|
||||||
.parse_next(stream)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn fn_input<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Variable<&'input str>> {
|
fn fn_input<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Variable<&'input str>> {
|
||||||
dispatch! { any;
|
dispatch! { any;
|
||||||
Token::DotParam => variable_scalar_or_vector(StateSpace::Param),
|
Token::DotParam => method_parameter(StateSpace::Param),
|
||||||
Token::DotReg => variable_scalar_or_vector(StateSpace::Reg),
|
Token::DotReg => method_parameter(StateSpace::Reg),
|
||||||
_ => fail
|
_ => fail
|
||||||
}
|
}
|
||||||
.parse_next(stream)
|
.parse_next(stream)
|
||||||
|
@ -596,7 +587,7 @@ fn tuple1to3_u32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(u32
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
separated::<_, _, Tuple3AccumulateU32, _, _, _, _>(1..3, u32, Token::Comma)
|
separated::<_, _, Tuple3AccumulateU32, _, _, _, _>(1..=3, u32, Token::Comma)
|
||||||
.map(|acc| acc.value)
|
.map(|acc| acc.value)
|
||||||
.parse_next(stream)
|
.parse_next(stream)
|
||||||
}
|
}
|
||||||
|
@ -618,7 +609,12 @@ fn statement<'a, 'input>(
|
||||||
alt((
|
alt((
|
||||||
label.map(Some),
|
label.map(Some),
|
||||||
debug_directive.map(|_| None),
|
debug_directive.map(|_| None),
|
||||||
multi_variable.map(Some),
|
terminated(
|
||||||
|
method_space
|
||||||
|
.flat_map(multi_variable)
|
||||||
|
.map(|var| Some(Statement::Variable(var))),
|
||||||
|
Token::Semicolon,
|
||||||
|
),
|
||||||
predicated_instruction.map(Some),
|
predicated_instruction.map(Some),
|
||||||
pragma.map(|_| None),
|
pragma.map(|_| None),
|
||||||
block_statement.map(Some),
|
block_statement.map(Some),
|
||||||
|
@ -632,59 +628,328 @@ fn pragma<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> {
|
||||||
.parse_next(stream)
|
.parse_next(stream)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn multi_variable<'a, 'input>(
|
fn method_parameter<'a, 'input: 'a>(
|
||||||
|
state_space: StateSpace,
|
||||||
|
) -> impl Parser<PtxParser<'a, 'input>, Variable<&'input str>, ContextError> {
|
||||||
|
move |stream: &mut PtxParser<'a, 'input>| {
|
||||||
|
let (align, vector, type_, name) = variable_declaration.parse_next(stream)?;
|
||||||
|
let array_dimensions = if state_space != StateSpace::Reg {
|
||||||
|
opt(array_dimensions).parse_next(stream)?
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
// TODO: push this check into array_dimensions(...)
|
||||||
|
if let Some(ref dims) = array_dimensions {
|
||||||
|
if dims[0] == 0 {
|
||||||
|
return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(Variable {
|
||||||
|
align,
|
||||||
|
v_type: Type::maybe_array(vector, type_, array_dimensions),
|
||||||
|
state_space,
|
||||||
|
name,
|
||||||
|
array_init: Vec::new(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: split to a separate type
|
||||||
|
fn variable_declaration<'a, 'input>(
|
||||||
stream: &mut PtxParser<'a, 'input>,
|
stream: &mut PtxParser<'a, 'input>,
|
||||||
) -> PResult<ast::Statement<ParsedOperandStr<'input>>> {
|
) -> PResult<(Option<u32>, Option<NonZeroU8>, ScalarType, &'input str)> {
|
||||||
(
|
(
|
||||||
variable,
|
opt(align.verify(|x| x.count_ones() == 1)),
|
||||||
opt(delimited(Token::Lt, u32, Token::Gt)),
|
vector_prefix,
|
||||||
Token::Semicolon,
|
scalar_type,
|
||||||
|
ident,
|
||||||
)
|
)
|
||||||
.map(|(var, count, _)| ast::Statement::Variable(ast::MultiVariable { var, count }))
|
|
||||||
.parse_next(stream)
|
.parse_next(stream)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn variable<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Variable<&'input str>> {
|
fn multi_variable<'a, 'input: 'a>(
|
||||||
dispatch! {any;
|
state_space: StateSpace,
|
||||||
Token::DotReg => variable_scalar_or_vector(StateSpace::Reg),
|
) -> impl Parser<PtxParser<'a, 'input>, MultiVariable<&'input str>, ContextError> {
|
||||||
Token::DotLocal => variable_scalar_or_vector(StateSpace::Local),
|
move |stream: &mut PtxParser<'a, 'input>| {
|
||||||
Token::DotParam => variable_scalar_or_vector(StateSpace::Param),
|
let ((align, vector, type_, name), count) = (
|
||||||
Token::DotShared => variable_scalar_or_vector(StateSpace::Shared),
|
variable_declaration,
|
||||||
_ => fail
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names
|
||||||
|
opt(delimited(Token::Lt, u32.verify(|x| *x != 0), Token::Gt)),
|
||||||
|
)
|
||||||
|
.parse_next(stream)?;
|
||||||
|
if count.is_some() {
|
||||||
|
return Ok(MultiVariable {
|
||||||
|
var: Variable {
|
||||||
|
align,
|
||||||
|
v_type: Type::maybe_vector_parsed(vector, type_),
|
||||||
|
state_space,
|
||||||
|
name,
|
||||||
|
array_init: Vec::new(),
|
||||||
|
},
|
||||||
|
count,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let mut array_dimensions = if state_space != StateSpace::Reg {
|
||||||
|
opt(array_dimensions).parse_next(stream)?
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let initializer = match state_space {
|
||||||
|
StateSpace::Global | StateSpace::Const => match array_dimensions {
|
||||||
|
Some(ref mut dimensions) => {
|
||||||
|
opt(array_initializer(vector, type_, dimensions)).parse_next(stream)?
|
||||||
|
}
|
||||||
|
None => opt(value_initializer(vector, type_)).parse_next(stream)?,
|
||||||
|
},
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
if let Some(ref dims) = array_dimensions {
|
||||||
|
if dims[0] == 0 {
|
||||||
|
return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(MultiVariable {
|
||||||
|
var: Variable {
|
||||||
|
align,
|
||||||
|
v_type: Type::maybe_array(vector, type_, array_dimensions),
|
||||||
|
state_space,
|
||||||
|
name,
|
||||||
|
array_init: initializer.unwrap_or(Vec::new()),
|
||||||
|
},
|
||||||
|
count,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn array_initializer<'a, 'input: 'a>(
|
||||||
|
vector: Option<NonZeroU8>,
|
||||||
|
type_: ScalarType,
|
||||||
|
array_dimensions: &mut Vec<u32>,
|
||||||
|
) -> impl Parser<PtxParser<'a, 'input>, Vec<u8>, ContextError> + '_ {
|
||||||
|
move |stream: &mut PtxParser<'a, 'input>| {
|
||||||
|
Token::Eq.parse_next(stream)?;
|
||||||
|
let mut result = Vec::new();
|
||||||
|
// TODO: vector constants and multi dim arrays
|
||||||
|
if vector.is_some() || array_dimensions[0] == 0 || array_dimensions.len() > 1 {
|
||||||
|
return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify));
|
||||||
|
}
|
||||||
|
delimited(
|
||||||
|
Token::LBracket,
|
||||||
|
separated(
|
||||||
|
array_dimensions[0] as usize..=array_dimensions[0] as usize,
|
||||||
|
single_value_append(&mut result, type_),
|
||||||
|
Token::Comma,
|
||||||
|
),
|
||||||
|
Token::RBracket,
|
||||||
|
)
|
||||||
|
.parse_next(stream)?;
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn value_initializer<'a, 'input: 'a>(
|
||||||
|
vector: Option<NonZeroU8>,
|
||||||
|
type_: ScalarType,
|
||||||
|
) -> impl Parser<PtxParser<'a, 'input>, Vec<u8>, ContextError> {
|
||||||
|
move |stream: &mut PtxParser<'a, 'input>| {
|
||||||
|
Token::Eq.parse_next(stream)?;
|
||||||
|
let mut result = Vec::new();
|
||||||
|
// TODO: vector constants
|
||||||
|
if vector.is_some() {
|
||||||
|
return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify));
|
||||||
|
}
|
||||||
|
single_value_append(&mut result, type_).parse_next(stream)?;
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn single_value_append<'a, 'input: 'a>(
|
||||||
|
accumulator: &mut Vec<u8>,
|
||||||
|
type_: ScalarType,
|
||||||
|
) -> impl Parser<PtxParser<'a, 'input>, (), ContextError> + '_ {
|
||||||
|
move |stream: &mut PtxParser<'a, 'input>| {
|
||||||
|
let value = immediate_value.parse_next(stream)?;
|
||||||
|
match (type_, value) {
|
||||||
|
(ScalarType::U8, ImmediateValue::U64(x)) => accumulator.extend_from_slice(
|
||||||
|
&u8::try_from(x)
|
||||||
|
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
|
||||||
|
.to_le_bytes(),
|
||||||
|
),
|
||||||
|
(ScalarType::U8, ImmediateValue::S64(x)) => accumulator.extend_from_slice(
|
||||||
|
&u8::try_from(x)
|
||||||
|
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
|
||||||
|
.to_le_bytes(),
|
||||||
|
),
|
||||||
|
(ScalarType::U16, ImmediateValue::U64(x)) => accumulator.extend_from_slice(
|
||||||
|
&u16::try_from(x)
|
||||||
|
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
|
||||||
|
.to_le_bytes(),
|
||||||
|
),
|
||||||
|
(ScalarType::U16, ImmediateValue::S64(x)) => accumulator.extend_from_slice(
|
||||||
|
&u16::try_from(x)
|
||||||
|
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
|
||||||
|
.to_le_bytes(),
|
||||||
|
),
|
||||||
|
(ScalarType::U32, ImmediateValue::U64(x)) => accumulator.extend_from_slice(
|
||||||
|
&u32::try_from(x)
|
||||||
|
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
|
||||||
|
.to_le_bytes(),
|
||||||
|
),
|
||||||
|
(ScalarType::U32, ImmediateValue::S64(x)) => accumulator.extend_from_slice(
|
||||||
|
&u32::try_from(x)
|
||||||
|
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
|
||||||
|
.to_le_bytes(),
|
||||||
|
),
|
||||||
|
(ScalarType::U64, ImmediateValue::U64(x)) => accumulator.extend_from_slice(
|
||||||
|
&u64::try_from(x)
|
||||||
|
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
|
||||||
|
.to_le_bytes(),
|
||||||
|
),
|
||||||
|
(ScalarType::U64, ImmediateValue::S64(x)) => accumulator.extend_from_slice(
|
||||||
|
&u64::try_from(x)
|
||||||
|
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
|
||||||
|
.to_le_bytes(),
|
||||||
|
),
|
||||||
|
(ScalarType::S8, ImmediateValue::U64(x)) => accumulator.extend_from_slice(
|
||||||
|
&i8::try_from(x)
|
||||||
|
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
|
||||||
|
.to_le_bytes(),
|
||||||
|
),
|
||||||
|
(ScalarType::S8, ImmediateValue::S64(x)) => accumulator.extend_from_slice(
|
||||||
|
&i8::try_from(x)
|
||||||
|
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
|
||||||
|
.to_le_bytes(),
|
||||||
|
),
|
||||||
|
(ScalarType::S16, ImmediateValue::U64(x)) => accumulator.extend_from_slice(
|
||||||
|
&i16::try_from(x)
|
||||||
|
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
|
||||||
|
.to_le_bytes(),
|
||||||
|
),
|
||||||
|
(ScalarType::S16, ImmediateValue::S64(x)) => accumulator.extend_from_slice(
|
||||||
|
&i16::try_from(x)
|
||||||
|
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
|
||||||
|
.to_le_bytes(),
|
||||||
|
),
|
||||||
|
(ScalarType::S32, ImmediateValue::U64(x)) => accumulator.extend_from_slice(
|
||||||
|
&i32::try_from(x)
|
||||||
|
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
|
||||||
|
.to_le_bytes(),
|
||||||
|
),
|
||||||
|
(ScalarType::S32, ImmediateValue::S64(x)) => accumulator.extend_from_slice(
|
||||||
|
&i32::try_from(x)
|
||||||
|
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
|
||||||
|
.to_le_bytes(),
|
||||||
|
),
|
||||||
|
(ScalarType::S64, ImmediateValue::U64(x)) => accumulator.extend_from_slice(
|
||||||
|
&i64::try_from(x)
|
||||||
|
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
|
||||||
|
.to_le_bytes(),
|
||||||
|
),
|
||||||
|
(ScalarType::S64, ImmediateValue::S64(x)) => accumulator.extend_from_slice(
|
||||||
|
&i64::try_from(x)
|
||||||
|
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
|
||||||
|
.to_le_bytes(),
|
||||||
|
),
|
||||||
|
(ScalarType::F32, ImmediateValue::F32(x)) => {
|
||||||
|
accumulator.extend_from_slice(&x.to_le_bytes())
|
||||||
|
}
|
||||||
|
(ScalarType::F64, ImmediateValue::F64(x)) => {
|
||||||
|
accumulator.extend_from_slice(&x.to_le_bytes())
|
||||||
|
}
|
||||||
|
_ => return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)),
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn array_dimensions<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<Vec<u32>> {
|
||||||
|
let dimension = delimited(
|
||||||
|
Token::LBracket,
|
||||||
|
opt(u32).verify(|dim| *dim != Some(0)),
|
||||||
|
Token::RBracket,
|
||||||
|
)
|
||||||
|
.parse_next(stream)?;
|
||||||
|
let result = vec![dimension.unwrap_or(0)];
|
||||||
|
repeat_fold_0_or_more(
|
||||||
|
delimited(
|
||||||
|
Token::LBracket,
|
||||||
|
u32.verify(|dim| *dim != 0),
|
||||||
|
Token::RBracket,
|
||||||
|
),
|
||||||
|
move || result,
|
||||||
|
|mut result: Vec<u32>, x| {
|
||||||
|
result.push(x);
|
||||||
|
result
|
||||||
|
},
|
||||||
|
stream,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copied and fixed from Winnow sources (fold_repeat0_)
|
||||||
|
// Winnow Repeat::fold takes FnMut() -> Result to initalize accumulator,
|
||||||
|
// this really should be FnOnce() -> Result
|
||||||
|
fn repeat_fold_0_or_more<I, O, E, F, G, H, R>(
|
||||||
|
mut f: F,
|
||||||
|
init: H,
|
||||||
|
mut g: G,
|
||||||
|
input: &mut I,
|
||||||
|
) -> PResult<R, E>
|
||||||
|
where
|
||||||
|
I: Stream,
|
||||||
|
F: Parser<I, O, E>,
|
||||||
|
G: FnMut(R, O) -> R,
|
||||||
|
H: FnOnce() -> R,
|
||||||
|
E: ParserError<I>,
|
||||||
|
{
|
||||||
|
use winnow::error::ErrMode;
|
||||||
|
let mut res = init();
|
||||||
|
loop {
|
||||||
|
let start = input.checkpoint();
|
||||||
|
match f.parse_next(input) {
|
||||||
|
Ok(o) => {
|
||||||
|
res = g(res, o);
|
||||||
|
}
|
||||||
|
Err(ErrMode::Backtrack(_)) => {
|
||||||
|
input.reset(&start);
|
||||||
|
return Ok(res);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn global_space<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<StateSpace> {
|
||||||
|
alt((
|
||||||
|
Token::DotGlobal.value(StateSpace::Global),
|
||||||
|
Token::DotConst.value(StateSpace::Const),
|
||||||
|
Token::DotShared.value(StateSpace::Shared),
|
||||||
|
))
|
||||||
.parse_next(stream)
|
.parse_next(stream)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn variable_scalar_or_vector<'a, 'input: 'a>(
|
fn method_space<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<StateSpace> {
|
||||||
state_space: StateSpace,
|
alt((
|
||||||
) -> impl Parser<PtxParser<'a, 'input>, ast::Variable<&'input str>, ContextError> {
|
Token::DotReg.value(StateSpace::Reg),
|
||||||
move |stream: &mut PtxParser<'a, 'input>| {
|
Token::DotLocal.value(StateSpace::Local),
|
||||||
(opt(align), scalar_vector_type, ident)
|
Token::DotParam.value(StateSpace::Param),
|
||||||
.map(|(align, v_type, name)| ast::Variable {
|
global_space,
|
||||||
align,
|
))
|
||||||
v_type,
|
.parse_next(stream)
|
||||||
state_space,
|
|
||||||
name,
|
|
||||||
array_init: Vec::new(),
|
|
||||||
})
|
|
||||||
.parse_next(stream)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn align<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<u32> {
|
fn align<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<u32> {
|
||||||
preceded(Token::DotAlign, u32).parse_next(stream)
|
preceded(Token::DotAlign, u32).parse_next(stream)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn scalar_vector_type<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<Type> {
|
fn vector_prefix<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<Option<NonZeroU8>> {
|
||||||
(
|
opt(alt((
|
||||||
opt(alt((
|
Token::DotV2.value(unsafe { NonZeroU8::new_unchecked(2) }),
|
||||||
Token::DotV2.value(VectorPrefix::V2),
|
Token::DotV4.value(unsafe { NonZeroU8::new_unchecked(4) }),
|
||||||
Token::DotV4.value(VectorPrefix::V4),
|
Token::DotV8.value(unsafe { NonZeroU8::new_unchecked(8) }),
|
||||||
))),
|
)))
|
||||||
scalar_type,
|
.parse_next(stream)
|
||||||
)
|
|
||||||
.map(|(prefix, scalar)| ast::Type::maybe_vector(prefix, scalar))
|
|
||||||
.parse_next(stream)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn scalar_type<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ScalarType> {
|
fn scalar_type<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ScalarType> {
|
||||||
|
@ -1157,6 +1422,8 @@ derive_parser!(
|
||||||
Minus,
|
Minus,
|
||||||
#[token("+")]
|
#[token("+")]
|
||||||
Plus,
|
Plus,
|
||||||
|
#[token("=")]
|
||||||
|
Eq,
|
||||||
#[token(".version")]
|
#[token(".version")]
|
||||||
DotVersion,
|
DotVersion,
|
||||||
#[token(".loc")]
|
#[token(".loc")]
|
||||||
|
@ -2509,7 +2776,7 @@ derive_parser!(
|
||||||
scope: scope.unwrap_or(MemScope::Gpu),
|
scope: scope.unwrap_or(MemScope::Gpu),
|
||||||
space: global.unwrap_or(StateSpace::Generic),
|
space: global.unwrap_or(StateSpace::Generic),
|
||||||
op: ast::AtomicOp::new(float_op, f32.kind()),
|
op: ast::AtomicOp::new(float_op, f32.kind()),
|
||||||
type_: ast::Type::Vector(f32, vec_32_bit.len())
|
type_: ast::Type::Vector(vec_32_bit.len().get(), f32)
|
||||||
},
|
},
|
||||||
arguments: AtomArgs { dst: d, src1: a, src2: b }
|
arguments: AtomArgs { dst: d, src1: a, src2: b }
|
||||||
}
|
}
|
||||||
|
@ -2840,7 +3107,7 @@ derive_parser!(
|
||||||
// .mode = { .f4e, .b4e, .rc8, .ecl, .ecr, .rc16 };
|
// .mode = { .f4e, .b4e, .rc8, .ecl, .ecr, .rc16 };
|
||||||
prmt.b32 d, a, b, c => {
|
prmt.b32 d, a, b, c => {
|
||||||
match c {
|
match c {
|
||||||
ast::ParsedOperand::Imm(ImmediateValue::U64(control)) => ast::Instruction::Prmt {
|
ast::ParsedOperand::Imm(ImmediateValue::S64(control)) => ast::Instruction::Prmt {
|
||||||
data: control as u16,
|
data: control as u16,
|
||||||
arguments: PrmtArgs {
|
arguments: PrmtArgs {
|
||||||
dst: d, src1: a, src2: b
|
dst: d, src1: a, src2: b
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue