mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-07-29 12:28:38 +00:00
Port expand_arguments
This commit is contained in:
parent
cccd37f6ee
commit
c088cc2171
2 changed files with 191 additions and 3 deletions
181
ptx/src/pass/expand_arguments.rs
Normal file
181
ptx/src/pass/expand_arguments.rs
Normal file
|
@ -0,0 +1,181 @@
|
||||||
|
use super::*;
|
||||||
|
use ptx_parser as ast;
|
||||||
|
|
||||||
|
pub(super) fn run<'a, 'b>(
|
||||||
|
func: Vec<TypedStatement>,
|
||||||
|
id_def: &'b mut MutableNumericIdResolver<'a>,
|
||||||
|
) -> Result<Vec<ExpandedStatement>, TranslateError> {
|
||||||
|
let mut result = Vec::with_capacity(func.len());
|
||||||
|
for s in func {
|
||||||
|
match s {
|
||||||
|
Statement::Label(id) => result.push(Statement::Label(id)),
|
||||||
|
Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
|
||||||
|
Statement::LoadVar(details) => result.push(Statement::LoadVar(details)),
|
||||||
|
Statement::StoreVar(details) => result.push(Statement::StoreVar(details)),
|
||||||
|
Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
|
||||||
|
Statement::Conversion(conv) => result.push(Statement::Conversion(conv)),
|
||||||
|
Statement::Constant(c) => result.push(Statement::Constant(c)),
|
||||||
|
Statement::FunctionPointer(d) => result.push(Statement::FunctionPointer(d)),
|
||||||
|
s => {
|
||||||
|
let (new_statement, post_stmts) = {
|
||||||
|
let mut visitor = FlattenArguments::new(&mut result, id_def);
|
||||||
|
(s.visit_map(&mut visitor)?, visitor.post_stmts)
|
||||||
|
};
|
||||||
|
result.push(new_statement);
|
||||||
|
result.extend(post_stmts);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
struct FlattenArguments<'a, 'b> {
|
||||||
|
func: &'b mut Vec<ExpandedStatement>,
|
||||||
|
id_def: &'b mut MutableNumericIdResolver<'a>,
|
||||||
|
post_stmts: Vec<ExpandedStatement>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'b> FlattenArguments<'a, 'b> {
|
||||||
|
fn new(
|
||||||
|
func: &'b mut Vec<ExpandedStatement>,
|
||||||
|
id_def: &'b mut MutableNumericIdResolver<'a>,
|
||||||
|
) -> Self {
|
||||||
|
FlattenArguments {
|
||||||
|
func,
|
||||||
|
id_def,
|
||||||
|
post_stmts: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reg(&mut self, name: SpirvWord) -> Result<SpirvWord, TranslateError> {
|
||||||
|
Ok(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reg_offset(
|
||||||
|
&mut self,
|
||||||
|
reg: SpirvWord,
|
||||||
|
offset: i32,
|
||||||
|
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||||
|
_is_dst: bool,
|
||||||
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
|
let (type_, state_space) = if let Some((type_, state_space)) = type_space {
|
||||||
|
(type_, state_space)
|
||||||
|
} else {
|
||||||
|
return Err(TranslateError::UntypedSymbol);
|
||||||
|
};
|
||||||
|
if state_space != ast::StateSpace::Reg && state_space != ast::StateSpace::Sreg {
|
||||||
|
let (reg_type, reg_space) = self.id_def.get_typed(reg)?;
|
||||||
|
if !state_is_compatible(reg_space, ast::StateSpace::Reg) {
|
||||||
|
return Err(TranslateError::MismatchedType);
|
||||||
|
}
|
||||||
|
let reg_scalar_type = match reg_type {
|
||||||
|
ast::Type::Scalar(underlying_type) => underlying_type,
|
||||||
|
_ => return Err(TranslateError::MismatchedType),
|
||||||
|
};
|
||||||
|
let id_constant_stmt = self
|
||||||
|
.id_def
|
||||||
|
.register_intermediate(reg_type.clone(), ast::StateSpace::Reg);
|
||||||
|
self.func.push(Statement::Constant(ConstantDefinition {
|
||||||
|
dst: id_constant_stmt,
|
||||||
|
typ: reg_scalar_type,
|
||||||
|
value: ast::ImmediateValue::S64(offset as i64),
|
||||||
|
}));
|
||||||
|
let arith_details = match reg_scalar_type.kind() {
|
||||||
|
ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger {
|
||||||
|
type_: reg_scalar_type,
|
||||||
|
saturate: false,
|
||||||
|
}),
|
||||||
|
ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => {
|
||||||
|
ast::ArithDetails::Integer(ast::ArithInteger {
|
||||||
|
type_: reg_scalar_type,
|
||||||
|
saturate: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
_ => return Err(error_unreachable()),
|
||||||
|
};
|
||||||
|
let id_add_result = self.id_def.register_intermediate(reg_type, state_space);
|
||||||
|
self.func
|
||||||
|
.push(Statement::Instruction(ast::Instruction::Add {
|
||||||
|
data: arith_details,
|
||||||
|
arguments: ast::AddArgs {
|
||||||
|
dst: id_add_result,
|
||||||
|
src1: reg,
|
||||||
|
src2: id_constant_stmt,
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
Ok(id_add_result)
|
||||||
|
} else {
|
||||||
|
let id_constant_stmt = self.id_def.register_intermediate(
|
||||||
|
ast::Type::Scalar(ast::ScalarType::S64),
|
||||||
|
ast::StateSpace::Reg,
|
||||||
|
);
|
||||||
|
self.func.push(Statement::Constant(ConstantDefinition {
|
||||||
|
dst: id_constant_stmt,
|
||||||
|
typ: ast::ScalarType::S64,
|
||||||
|
value: ast::ImmediateValue::S64(offset as i64),
|
||||||
|
}));
|
||||||
|
let dst = self
|
||||||
|
.id_def
|
||||||
|
.register_intermediate(type_.clone(), state_space);
|
||||||
|
self.func.push(Statement::PtrAccess(PtrAccess {
|
||||||
|
underlying_type: type_.clone(),
|
||||||
|
state_space: state_space,
|
||||||
|
dst,
|
||||||
|
ptr_src: reg,
|
||||||
|
offset_src: id_constant_stmt,
|
||||||
|
}));
|
||||||
|
Ok(dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn immediate(
|
||||||
|
&mut self,
|
||||||
|
value: ast::ImmediateValue,
|
||||||
|
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||||
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
|
let (scalar_t, state_space) =
|
||||||
|
if let Some((ast::Type::Scalar(scalar), state_space)) = type_space {
|
||||||
|
(*scalar, state_space)
|
||||||
|
} else {
|
||||||
|
return Err(TranslateError::UntypedSymbol);
|
||||||
|
};
|
||||||
|
let id = self
|
||||||
|
.id_def
|
||||||
|
.register_intermediate(ast::Type::Scalar(scalar_t), state_space);
|
||||||
|
self.func.push(Statement::Constant(ConstantDefinition {
|
||||||
|
dst: id,
|
||||||
|
typ: scalar_t,
|
||||||
|
value,
|
||||||
|
}));
|
||||||
|
Ok(id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'b> ast::VisitorMap<TypedOperand, SpirvWord, TranslateError> for FlattenArguments<'a, 'b> {
|
||||||
|
fn visit(
|
||||||
|
&mut self,
|
||||||
|
args: TypedOperand,
|
||||||
|
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||||
|
is_dst: bool,
|
||||||
|
_relaxed_type_check: bool,
|
||||||
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
|
match args {
|
||||||
|
TypedOperand::Reg(r) => self.reg(r),
|
||||||
|
TypedOperand::Imm(x) => self.immediate(x, type_space),
|
||||||
|
TypedOperand::RegOffset(reg, offset) => {
|
||||||
|
self.reg_offset(reg, offset, type_space, is_dst)
|
||||||
|
}
|
||||||
|
TypedOperand::VecMember(..) => Err(error_unreachable()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_ident(
|
||||||
|
&mut self,
|
||||||
|
name: <TypedOperand as ptx_parser::Operand>::Ident,
|
||||||
|
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||||
|
_is_dst: bool,
|
||||||
|
_relaxed_type_check: bool,
|
||||||
|
) -> Result<<SpirvWord as ptx_parser::Operand>::Ident, TranslateError> {
|
||||||
|
self.reg(name)
|
||||||
|
}
|
||||||
|
}
|
|
@ -11,6 +11,7 @@ use std::{
|
||||||
|
|
||||||
mod convert_to_stateful_memory_access;
|
mod convert_to_stateful_memory_access;
|
||||||
mod convert_to_typed;
|
mod convert_to_typed;
|
||||||
|
mod expand_arguments;
|
||||||
mod fix_special_registers;
|
mod fix_special_registers;
|
||||||
mod insert_mem_ssa_statements;
|
mod insert_mem_ssa_statements;
|
||||||
mod normalize_identifiers;
|
mod normalize_identifiers;
|
||||||
|
@ -181,10 +182,10 @@ fn to_ssa<'input, 'b>(
|
||||||
&mut numeric_id_defs,
|
&mut numeric_id_defs,
|
||||||
&mut (*func_decl).borrow_mut(),
|
&mut (*func_decl).borrow_mut(),
|
||||||
)?;
|
)?;
|
||||||
|
let mut numeric_id_defs = numeric_id_defs.finish();
|
||||||
|
let expanded_statements = expand_arguments::run(ssa_statements, &mut numeric_id_defs)?;
|
||||||
todo!()
|
todo!()
|
||||||
/*
|
/*
|
||||||
let mut numeric_id_defs = numeric_id_defs.finish();
|
|
||||||
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
|
|
||||||
let expanded_statements =
|
let expanded_statements =
|
||||||
insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?;
|
insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?;
|
||||||
let mut numeric_id_defs = numeric_id_defs.unmut();
|
let mut numeric_id_defs = numeric_id_defs.unmut();
|
||||||
|
@ -743,7 +744,7 @@ impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
|
||||||
fn visit_map<To: ast::Operand<Ident = SpirvWord>, Err>(
|
fn visit_map<To: ast::Operand<Ident = SpirvWord>, Err>(
|
||||||
self,
|
self,
|
||||||
visitor: &mut impl ast::VisitorMap<T, To, Err>,
|
visitor: &mut impl ast::VisitorMap<T, To, Err>,
|
||||||
) -> std::result::Result<Statement<ast::Instruction<To>, T>, Err> {
|
) -> std::result::Result<Statement<ast::Instruction<To>, To>, Err> {
|
||||||
Ok(match self {
|
Ok(match self {
|
||||||
Statement::Instruction(i) => {
|
Statement::Instruction(i) => {
|
||||||
return ast::visit_map(i, visitor).map(Statement::Instruction)
|
return ast::visit_map(i, visitor).map(Statement::Instruction)
|
||||||
|
@ -883,6 +884,12 @@ impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
|
||||||
false,
|
false,
|
||||||
false,
|
false,
|
||||||
)?;
|
)?;
|
||||||
|
let offset_src = visitor.visit(
|
||||||
|
offset_src,
|
||||||
|
Some((&underlying_type, state_space)),
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
)?;
|
||||||
Statement::PtrAccess(PtrAccess {
|
Statement::PtrAccess(PtrAccess {
|
||||||
underlying_type,
|
underlying_type,
|
||||||
state_space,
|
state_space,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue