mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Work on more passes
This commit is contained in:
parent
12ef8dbc90
commit
7ea990edb7
4 changed files with 222 additions and 15 deletions
140
ptx/src/pass/convert_to_typed.rs
Normal file
140
ptx/src/pass/convert_to_typed.rs
Normal file
|
@ -0,0 +1,140 @@
|
|||
use super::*;
|
||||
use ptx_parser as ast;
|
||||
|
||||
pub(crate) fn run(
|
||||
func: Vec<UnconditionalStatement>,
|
||||
fn_defs: &GlobalFnDeclResolver,
|
||||
id_defs: &mut NumericIdResolver,
|
||||
) -> Result<Vec<TypedStatement>, TranslateError> {
|
||||
let mut result = Vec::<TypedStatement>::with_capacity(func.len());
|
||||
for s in func {
|
||||
match s {
|
||||
Statement::Instruction(inst) => match inst {
|
||||
ast::Instruction::Mov {
|
||||
data,
|
||||
arguments:
|
||||
ast::MovArgs {
|
||||
dst: ast::ParsedOperand::Reg(dst_reg),
|
||||
src: ast::ParsedOperand::Reg(src_reg),
|
||||
},
|
||||
} if fn_defs.fns.contains_key(&src_reg) => {
|
||||
if data.typ != ast::Type::Scalar(ast::ScalarType::U64) {
|
||||
return Err(TranslateError::MismatchedType);
|
||||
}
|
||||
result.push(TypedStatement::FunctionPointer(FunctionPointerDetails {
|
||||
dst: dst_reg,
|
||||
src: src_reg,
|
||||
}));
|
||||
}
|
||||
ast::Instruction::Call(call) => {
|
||||
let resolver = fn_defs.get_fn_sig_resolver(call.func)?;
|
||||
let resolved_call = resolver.resolve_in_spirv_repr(call)?;
|
||||
let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
|
||||
let reresolved_call = resolved_call.visit(&mut visitor)?;
|
||||
visitor.func.push(reresolved_call);
|
||||
visitor.func.extend(visitor.post_stmts);
|
||||
}
|
||||
inst => {
|
||||
let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
|
||||
let instruction = Statement::Instruction(inst.map(&mut visitor)?);
|
||||
visitor.func.push(instruction);
|
||||
visitor.func.extend(visitor.post_stmts);
|
||||
}
|
||||
},
|
||||
Statement::Label(i) => result.push(Statement::Label(i)),
|
||||
Statement::Variable(v) => result.push(Statement::Variable(v)),
|
||||
Statement::Conditional(c) => result.push(Statement::Conditional(c)),
|
||||
_ => return Err(error_unreachable()),
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
struct VectorRepackVisitor<'a, 'b> {
|
||||
func: &'b mut Vec<TypedStatement>,
|
||||
id_def: &'b mut NumericIdResolver<'a>,
|
||||
post_stmts: Option<TypedStatement>,
|
||||
}
|
||||
|
||||
impl<'a, 'b> VectorRepackVisitor<'a, 'b> {
|
||||
fn new(func: &'b mut Vec<TypedStatement>, id_def: &'b mut NumericIdResolver<'a>) -> Self {
|
||||
VectorRepackVisitor {
|
||||
func,
|
||||
id_def,
|
||||
post_stmts: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_vector(
|
||||
&mut self,
|
||||
is_dst: bool,
|
||||
non_default_implicit_conversion: Option<
|
||||
fn(
|
||||
(ast::StateSpace, &ast::Type),
|
||||
(ast::StateSpace, &ast::Type),
|
||||
) -> Result<Option<ConversionKind>, TranslateError>,
|
||||
>,
|
||||
typ: &ast::Type,
|
||||
state_space: ast::StateSpace,
|
||||
idx: Vec<SpirvWord>,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
// mov.u32 foobar, {a,b};
|
||||
let scalar_t = match typ {
|
||||
ast::Type::Vector(scalar_t, _) => *scalar_t,
|
||||
_ => return Err(TranslateError::MismatchedType),
|
||||
};
|
||||
let temp_vec = self
|
||||
.id_def
|
||||
.register_intermediate(Some((typ.clone(), state_space)));
|
||||
let statement = Statement::RepackVector(RepackVectorDetails {
|
||||
is_extract: is_dst,
|
||||
typ: scalar_t,
|
||||
packed: temp_vec,
|
||||
unpacked: idx,
|
||||
non_default_implicit_conversion,
|
||||
});
|
||||
if is_dst {
|
||||
self.post_stmts = Some(statement);
|
||||
} else {
|
||||
self.func.push(statement);
|
||||
}
|
||||
Ok(temp_vec)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, TypedOperand, TranslateError>
|
||||
for VectorRepackVisitor<'a, 'b>
|
||||
{
|
||||
fn visit_ident(
|
||||
&mut self,
|
||||
ident: SpirvWord,
|
||||
_: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
_: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
Ok(ident)
|
||||
}
|
||||
|
||||
fn visit(
|
||||
&mut self,
|
||||
op: ast::ParsedOperand<SpirvWord>,
|
||||
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
is_dst: bool,
|
||||
) -> Result<TypedOperand, TranslateError> {
|
||||
Ok(match op {
|
||||
ast::ParsedOperand::Reg(reg) => TypedOperand::Reg(reg),
|
||||
ast::ParsedOperand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset),
|
||||
ast::ParsedOperand::Imm(x) => TypedOperand::Imm(x),
|
||||
ast::ParsedOperand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx),
|
||||
ast::ParsedOperand::VecPack(vec) => {
|
||||
let (type_, space) = type_space.ok_or(TranslateError::MismatchedType)?;
|
||||
TypedOperand::Reg(self.convert_vector(
|
||||
is_dst,
|
||||
desc.non_default_implicit_conversion,
|
||||
type_,
|
||||
space,
|
||||
vec,
|
||||
)?)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -8,7 +8,9 @@ use std::{
|
|||
rc::Rc,
|
||||
};
|
||||
|
||||
pub(crate) mod normalize;
|
||||
mod convert_to_typed;
|
||||
mod normalize_identifiers;
|
||||
mod normalize_predicates;
|
||||
|
||||
static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv");
|
||||
static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
|
||||
|
@ -161,13 +163,13 @@ fn to_ssa<'input, 'b>(
|
|||
})
|
||||
}
|
||||
};
|
||||
let normalized_ids = normalize::run(&mut id_defs, &fn_defs, f_body)?;
|
||||
let normalized_ids = normalize_identifiers::run(&mut id_defs, &fn_defs, f_body)?;
|
||||
let mut numeric_id_defs = id_defs.finish();
|
||||
let unadorned_statements = normalize_predicates::run(normalized_ids, &mut numeric_id_defs)?;
|
||||
let typed_statements =
|
||||
convert_to_typed::run(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
|
||||
todo!()
|
||||
/*
|
||||
let mut numeric_id_defs = id_defs.finish();
|
||||
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
|
||||
let typed_statements =
|
||||
convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
|
||||
let typed_statements =
|
||||
fix_special_registers2(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?;
|
||||
let (func_decl, typed_statements) =
|
||||
|
@ -856,4 +858,33 @@ pub(crate) struct Function<'input> {
|
|||
linkage: ast::LinkingDirective,
|
||||
}
|
||||
|
||||
type ExpandedStatement = Statement<ast::Instruction<SpirvWord>, SpirvWord>;
|
||||
type ExpandedStatement = Statement<ast::Instruction<SpirvWord>, SpirvWord>;
|
||||
|
||||
type NormalizedStatement = Statement<
|
||||
(
|
||||
Option<ast::PredAt<SpirvWord>>,
|
||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||
),
|
||||
ast::ParsedOperand<SpirvWord>,
|
||||
>;
|
||||
|
||||
type UnconditionalStatement =
|
||||
Statement<ast::Instruction<ast::ParsedOperand<SpirvWord>>, ast::ParsedOperand<SpirvWord>>;
|
||||
|
||||
type TypedStatement = Statement<ast::Instruction<TypedOperand>, TypedOperand>;
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
enum TypedOperand {
|
||||
Reg(SpirvWord),
|
||||
RegOffset(SpirvWord, i32),
|
||||
Imm(ast::ImmediateValue),
|
||||
VecMember(SpirvWord, u8),
|
||||
}
|
||||
|
||||
impl ast::Operand for TypedOperand {
|
||||
type Ident = SpirvWord;
|
||||
|
||||
fn from_ident(ident: Self::Ident) -> Self {
|
||||
TypedOperand::Reg(ident)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,14 +1,6 @@
|
|||
use super::*;
|
||||
use ptx_parser as ast;
|
||||
|
||||
type NormalizedStatement = Statement<
|
||||
(
|
||||
Option<ast::PredAt<SpirvWord>>,
|
||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||
),
|
||||
ast::ParsedOperand<SpirvWord>,
|
||||
>;
|
||||
|
||||
pub(crate) fn run<'input, 'b>(
|
||||
id_defs: &mut FnStringIdResolver<'input, 'b>,
|
||||
fn_defs: &GlobalFnDeclResolver<'input, 'b>,
|
44
ptx/src/pass/normalize_predicates.rs
Normal file
44
ptx/src/pass/normalize_predicates.rs
Normal file
|
@ -0,0 +1,44 @@
|
|||
use super::*;
|
||||
use ptx_parser as ast;
|
||||
|
||||
pub(crate) fn run(
|
||||
func: Vec<NormalizedStatement>,
|
||||
id_def: &mut NumericIdResolver,
|
||||
) -> Result<Vec<UnconditionalStatement>, TranslateError> {
|
||||
let mut result = Vec::with_capacity(func.len());
|
||||
for s in func {
|
||||
match s {
|
||||
Statement::Label(id) => result.push(Statement::Label(id)),
|
||||
Statement::Instruction((pred, inst)) => {
|
||||
if let Some(pred) = pred {
|
||||
let if_true = id_def.register_intermediate(None);
|
||||
let if_false = id_def.register_intermediate(None);
|
||||
let folded_bra = match &inst {
|
||||
ast::Instruction::Bra { arguments, .. } => Some(arguments.src),
|
||||
_ => None,
|
||||
};
|
||||
let mut branch = BrachCondition {
|
||||
predicate: pred.label,
|
||||
if_true: folded_bra.unwrap_or(if_true),
|
||||
if_false,
|
||||
};
|
||||
if pred.not {
|
||||
std::mem::swap(&mut branch.if_true, &mut branch.if_false);
|
||||
}
|
||||
result.push(Statement::Conditional(branch));
|
||||
if folded_bra.is_none() {
|
||||
result.push(Statement::Label(if_true));
|
||||
result.push(Statement::Instruction(inst));
|
||||
}
|
||||
result.push(Statement::Label(if_false));
|
||||
} else {
|
||||
result.push(Statement::Instruction(inst));
|
||||
}
|
||||
}
|
||||
Statement::Variable(var) => result.push(Statement::Variable(var)),
|
||||
// Blocks are flattened when resolving ids
|
||||
_ => return Err(error_unreachable()),
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
Loading…
Add table
Reference in a new issue