mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-03 06:40:21 +00:00
Work on more passes
This commit is contained in:
parent
12ef8dbc90
commit
7ea990edb7
4 changed files with 222 additions and 15 deletions
140
ptx/src/pass/convert_to_typed.rs
Normal file
140
ptx/src/pass/convert_to_typed.rs
Normal file
|
@ -0,0 +1,140 @@
|
||||||
|
use super::*;
|
||||||
|
use ptx_parser as ast;
|
||||||
|
|
||||||
|
pub(crate) fn run(
|
||||||
|
func: Vec<UnconditionalStatement>,
|
||||||
|
fn_defs: &GlobalFnDeclResolver,
|
||||||
|
id_defs: &mut NumericIdResolver,
|
||||||
|
) -> Result<Vec<TypedStatement>, TranslateError> {
|
||||||
|
let mut result = Vec::<TypedStatement>::with_capacity(func.len());
|
||||||
|
for s in func {
|
||||||
|
match s {
|
||||||
|
Statement::Instruction(inst) => match inst {
|
||||||
|
ast::Instruction::Mov {
|
||||||
|
data,
|
||||||
|
arguments:
|
||||||
|
ast::MovArgs {
|
||||||
|
dst: ast::ParsedOperand::Reg(dst_reg),
|
||||||
|
src: ast::ParsedOperand::Reg(src_reg),
|
||||||
|
},
|
||||||
|
} if fn_defs.fns.contains_key(&src_reg) => {
|
||||||
|
if data.typ != ast::Type::Scalar(ast::ScalarType::U64) {
|
||||||
|
return Err(TranslateError::MismatchedType);
|
||||||
|
}
|
||||||
|
result.push(TypedStatement::FunctionPointer(FunctionPointerDetails {
|
||||||
|
dst: dst_reg,
|
||||||
|
src: src_reg,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
ast::Instruction::Call(call) => {
|
||||||
|
let resolver = fn_defs.get_fn_sig_resolver(call.func)?;
|
||||||
|
let resolved_call = resolver.resolve_in_spirv_repr(call)?;
|
||||||
|
let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
|
||||||
|
let reresolved_call = resolved_call.visit(&mut visitor)?;
|
||||||
|
visitor.func.push(reresolved_call);
|
||||||
|
visitor.func.extend(visitor.post_stmts);
|
||||||
|
}
|
||||||
|
inst => {
|
||||||
|
let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
|
||||||
|
let instruction = Statement::Instruction(inst.map(&mut visitor)?);
|
||||||
|
visitor.func.push(instruction);
|
||||||
|
visitor.func.extend(visitor.post_stmts);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Statement::Label(i) => result.push(Statement::Label(i)),
|
||||||
|
Statement::Variable(v) => result.push(Statement::Variable(v)),
|
||||||
|
Statement::Conditional(c) => result.push(Statement::Conditional(c)),
|
||||||
|
_ => return Err(error_unreachable()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
struct VectorRepackVisitor<'a, 'b> {
|
||||||
|
func: &'b mut Vec<TypedStatement>,
|
||||||
|
id_def: &'b mut NumericIdResolver<'a>,
|
||||||
|
post_stmts: Option<TypedStatement>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'b> VectorRepackVisitor<'a, 'b> {
|
||||||
|
fn new(func: &'b mut Vec<TypedStatement>, id_def: &'b mut NumericIdResolver<'a>) -> Self {
|
||||||
|
VectorRepackVisitor {
|
||||||
|
func,
|
||||||
|
id_def,
|
||||||
|
post_stmts: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_vector(
|
||||||
|
&mut self,
|
||||||
|
is_dst: bool,
|
||||||
|
non_default_implicit_conversion: Option<
|
||||||
|
fn(
|
||||||
|
(ast::StateSpace, &ast::Type),
|
||||||
|
(ast::StateSpace, &ast::Type),
|
||||||
|
) -> Result<Option<ConversionKind>, TranslateError>,
|
||||||
|
>,
|
||||||
|
typ: &ast::Type,
|
||||||
|
state_space: ast::StateSpace,
|
||||||
|
idx: Vec<SpirvWord>,
|
||||||
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
|
// mov.u32 foobar, {a,b};
|
||||||
|
let scalar_t = match typ {
|
||||||
|
ast::Type::Vector(scalar_t, _) => *scalar_t,
|
||||||
|
_ => return Err(TranslateError::MismatchedType),
|
||||||
|
};
|
||||||
|
let temp_vec = self
|
||||||
|
.id_def
|
||||||
|
.register_intermediate(Some((typ.clone(), state_space)));
|
||||||
|
let statement = Statement::RepackVector(RepackVectorDetails {
|
||||||
|
is_extract: is_dst,
|
||||||
|
typ: scalar_t,
|
||||||
|
packed: temp_vec,
|
||||||
|
unpacked: idx,
|
||||||
|
non_default_implicit_conversion,
|
||||||
|
});
|
||||||
|
if is_dst {
|
||||||
|
self.post_stmts = Some(statement);
|
||||||
|
} else {
|
||||||
|
self.func.push(statement);
|
||||||
|
}
|
||||||
|
Ok(temp_vec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'b> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, TypedOperand, TranslateError>
|
||||||
|
for VectorRepackVisitor<'a, 'b>
|
||||||
|
{
|
||||||
|
fn visit_ident(
|
||||||
|
&mut self,
|
||||||
|
ident: SpirvWord,
|
||||||
|
_: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||||
|
_: bool,
|
||||||
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
|
Ok(ident)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit(
|
||||||
|
&mut self,
|
||||||
|
op: ast::ParsedOperand<SpirvWord>,
|
||||||
|
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||||
|
is_dst: bool,
|
||||||
|
) -> Result<TypedOperand, TranslateError> {
|
||||||
|
Ok(match op {
|
||||||
|
ast::ParsedOperand::Reg(reg) => TypedOperand::Reg(reg),
|
||||||
|
ast::ParsedOperand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset),
|
||||||
|
ast::ParsedOperand::Imm(x) => TypedOperand::Imm(x),
|
||||||
|
ast::ParsedOperand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx),
|
||||||
|
ast::ParsedOperand::VecPack(vec) => {
|
||||||
|
let (type_, space) = type_space.ok_or(TranslateError::MismatchedType)?;
|
||||||
|
TypedOperand::Reg(self.convert_vector(
|
||||||
|
is_dst,
|
||||||
|
desc.non_default_implicit_conversion,
|
||||||
|
type_,
|
||||||
|
space,
|
||||||
|
vec,
|
||||||
|
)?)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -8,7 +8,9 @@ use std::{
|
||||||
rc::Rc,
|
rc::Rc,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub(crate) mod normalize;
|
mod convert_to_typed;
|
||||||
|
mod normalize_identifiers;
|
||||||
|
mod normalize_predicates;
|
||||||
|
|
||||||
static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv");
|
static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv");
|
||||||
static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
|
static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
|
||||||
|
@ -161,13 +163,13 @@ fn to_ssa<'input, 'b>(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let normalized_ids = normalize::run(&mut id_defs, &fn_defs, f_body)?;
|
let normalized_ids = normalize_identifiers::run(&mut id_defs, &fn_defs, f_body)?;
|
||||||
|
let mut numeric_id_defs = id_defs.finish();
|
||||||
|
let unadorned_statements = normalize_predicates::run(normalized_ids, &mut numeric_id_defs)?;
|
||||||
|
let typed_statements =
|
||||||
|
convert_to_typed::run(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
|
||||||
todo!()
|
todo!()
|
||||||
/*
|
/*
|
||||||
let mut numeric_id_defs = id_defs.finish();
|
|
||||||
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
|
|
||||||
let typed_statements =
|
|
||||||
convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
|
|
||||||
let typed_statements =
|
let typed_statements =
|
||||||
fix_special_registers2(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?;
|
fix_special_registers2(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?;
|
||||||
let (func_decl, typed_statements) =
|
let (func_decl, typed_statements) =
|
||||||
|
@ -856,4 +858,33 @@ pub(crate) struct Function<'input> {
|
||||||
linkage: ast::LinkingDirective,
|
linkage: ast::LinkingDirective,
|
||||||
}
|
}
|
||||||
|
|
||||||
type ExpandedStatement = Statement<ast::Instruction<SpirvWord>, SpirvWord>;
|
type ExpandedStatement = Statement<ast::Instruction<SpirvWord>, SpirvWord>;
|
||||||
|
|
||||||
|
type NormalizedStatement = Statement<
|
||||||
|
(
|
||||||
|
Option<ast::PredAt<SpirvWord>>,
|
||||||
|
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||||
|
),
|
||||||
|
ast::ParsedOperand<SpirvWord>,
|
||||||
|
>;
|
||||||
|
|
||||||
|
type UnconditionalStatement =
|
||||||
|
Statement<ast::Instruction<ast::ParsedOperand<SpirvWord>>, ast::ParsedOperand<SpirvWord>>;
|
||||||
|
|
||||||
|
type TypedStatement = Statement<ast::Instruction<TypedOperand>, TypedOperand>;
|
||||||
|
|
||||||
|
#[derive(Copy, Clone)]
|
||||||
|
enum TypedOperand {
|
||||||
|
Reg(SpirvWord),
|
||||||
|
RegOffset(SpirvWord, i32),
|
||||||
|
Imm(ast::ImmediateValue),
|
||||||
|
VecMember(SpirvWord, u8),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ast::Operand for TypedOperand {
|
||||||
|
type Ident = SpirvWord;
|
||||||
|
|
||||||
|
fn from_ident(ident: Self::Ident) -> Self {
|
||||||
|
TypedOperand::Reg(ident)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,14 +1,6 @@
|
||||||
use super::*;
|
use super::*;
|
||||||
use ptx_parser as ast;
|
use ptx_parser as ast;
|
||||||
|
|
||||||
type NormalizedStatement = Statement<
|
|
||||||
(
|
|
||||||
Option<ast::PredAt<SpirvWord>>,
|
|
||||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
|
||||||
),
|
|
||||||
ast::ParsedOperand<SpirvWord>,
|
|
||||||
>;
|
|
||||||
|
|
||||||
pub(crate) fn run<'input, 'b>(
|
pub(crate) fn run<'input, 'b>(
|
||||||
id_defs: &mut FnStringIdResolver<'input, 'b>,
|
id_defs: &mut FnStringIdResolver<'input, 'b>,
|
||||||
fn_defs: &GlobalFnDeclResolver<'input, 'b>,
|
fn_defs: &GlobalFnDeclResolver<'input, 'b>,
|
44
ptx/src/pass/normalize_predicates.rs
Normal file
44
ptx/src/pass/normalize_predicates.rs
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
use super::*;
|
||||||
|
use ptx_parser as ast;
|
||||||
|
|
||||||
|
pub(crate) fn run(
|
||||||
|
func: Vec<NormalizedStatement>,
|
||||||
|
id_def: &mut NumericIdResolver,
|
||||||
|
) -> Result<Vec<UnconditionalStatement>, TranslateError> {
|
||||||
|
let mut result = Vec::with_capacity(func.len());
|
||||||
|
for s in func {
|
||||||
|
match s {
|
||||||
|
Statement::Label(id) => result.push(Statement::Label(id)),
|
||||||
|
Statement::Instruction((pred, inst)) => {
|
||||||
|
if let Some(pred) = pred {
|
||||||
|
let if_true = id_def.register_intermediate(None);
|
||||||
|
let if_false = id_def.register_intermediate(None);
|
||||||
|
let folded_bra = match &inst {
|
||||||
|
ast::Instruction::Bra { arguments, .. } => Some(arguments.src),
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
let mut branch = BrachCondition {
|
||||||
|
predicate: pred.label,
|
||||||
|
if_true: folded_bra.unwrap_or(if_true),
|
||||||
|
if_false,
|
||||||
|
};
|
||||||
|
if pred.not {
|
||||||
|
std::mem::swap(&mut branch.if_true, &mut branch.if_false);
|
||||||
|
}
|
||||||
|
result.push(Statement::Conditional(branch));
|
||||||
|
if folded_bra.is_none() {
|
||||||
|
result.push(Statement::Label(if_true));
|
||||||
|
result.push(Statement::Instruction(inst));
|
||||||
|
}
|
||||||
|
result.push(Statement::Label(if_false));
|
||||||
|
} else {
|
||||||
|
result.push(Statement::Instruction(inst));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Statement::Variable(var) => result.push(Statement::Variable(var)),
|
||||||
|
// Blocks are flattened when resolving ids
|
||||||
|
_ => return Err(error_unreachable()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue