mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-01 13:49:25 +00:00
Port sreg fix pass
This commit is contained in:
parent
4e6dc07a52
commit
107f1eb17f
2 changed files with 428 additions and 2 deletions
183
ptx/src/pass/fix_special_registers.rs
Normal file
183
ptx/src/pass/fix_special_registers.rs
Normal file
|
@ -0,0 +1,183 @@
|
||||||
|
use super::*;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
fn run<'a, 'b, 'input>(
|
||||||
|
ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
|
||||||
|
typed_statements: Vec<TypedStatement>,
|
||||||
|
numeric_id_defs: &'a mut NumericIdResolver<'b>,
|
||||||
|
) -> Result<Vec<TypedStatement>, TranslateError> {
|
||||||
|
let result = Vec::with_capacity(typed_statements.len());
|
||||||
|
let mut sreg_sresolver = SpecialRegisterResolver {
|
||||||
|
ptx_impl_imports,
|
||||||
|
numeric_id_defs,
|
||||||
|
result,
|
||||||
|
};
|
||||||
|
for statement in typed_statements {
|
||||||
|
let statement = statement.visit_map(&mut sreg_sresolver)?;
|
||||||
|
sreg_sresolver.result.push(statement);
|
||||||
|
}
|
||||||
|
Ok(sreg_sresolver.result)
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SpecialRegisterResolver<'a, 'b, 'input> {
|
||||||
|
ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
|
||||||
|
numeric_id_defs: &'a mut NumericIdResolver<'b>,
|
||||||
|
result: Vec<TypedStatement>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'b, 'input> ast::VisitorMap<TypedOperand, TypedOperand, TranslateError>
|
||||||
|
for SpecialRegisterResolver<'a, 'b, 'input>
|
||||||
|
{
|
||||||
|
fn visit(
|
||||||
|
&mut self,
|
||||||
|
operand: TypedOperand,
|
||||||
|
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||||
|
is_dst: bool,
|
||||||
|
_relaxed_type_check: bool,
|
||||||
|
) -> Result<TypedOperand, TranslateError> {
|
||||||
|
operand.map(|name, vector_index| self.replace_sreg(name, is_dst, vector_index))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_ident(
|
||||||
|
&mut self,
|
||||||
|
args: SpirvWord,
|
||||||
|
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||||
|
is_dst: bool,
|
||||||
|
_relaxed_type_check: bool,
|
||||||
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
|
self.replace_sreg(args, is_dst, None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> {
|
||||||
|
fn replace_sreg(
|
||||||
|
&mut self,
|
||||||
|
name: SpirvWord,
|
||||||
|
is_dst: bool,
|
||||||
|
vector_index: Option<u8>,
|
||||||
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
|
if let Some(sreg) = self.numeric_id_defs.special_registers.get(name) {
|
||||||
|
if is_dst {
|
||||||
|
return Err(TranslateError::MismatchedType);
|
||||||
|
}
|
||||||
|
let input_arguments = match (vector_index, sreg.get_function_input_type()) {
|
||||||
|
(Some(idx), Some(inp_type)) => {
|
||||||
|
if inp_type != ast::ScalarType::U8 {
|
||||||
|
return Err(TranslateError::Unreachable);
|
||||||
|
}
|
||||||
|
let constant = self.numeric_id_defs.register_intermediate(Some((
|
||||||
|
ast::Type::Scalar(inp_type),
|
||||||
|
ast::StateSpace::Reg,
|
||||||
|
)));
|
||||||
|
self.result.push(Statement::Constant(ConstantDefinition {
|
||||||
|
dst: constant,
|
||||||
|
typ: inp_type,
|
||||||
|
value: ast::ImmediateValue::U64(idx as u64),
|
||||||
|
}));
|
||||||
|
vec![(
|
||||||
|
TypedOperand::Reg(constant),
|
||||||
|
ast::Type::Scalar(inp_type),
|
||||||
|
ast::StateSpace::Reg,
|
||||||
|
)]
|
||||||
|
}
|
||||||
|
(None, None) => Vec::new(),
|
||||||
|
_ => return Err(TranslateError::MismatchedType),
|
||||||
|
};
|
||||||
|
let ocl_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat();
|
||||||
|
let return_type = sreg.get_function_return_type();
|
||||||
|
let fn_result = self.numeric_id_defs.register_intermediate(Some((
|
||||||
|
ast::Type::Scalar(return_type),
|
||||||
|
ast::StateSpace::Reg,
|
||||||
|
)));
|
||||||
|
let return_arguments = vec![(
|
||||||
|
fn_result,
|
||||||
|
ast::Type::Scalar(return_type),
|
||||||
|
ast::StateSpace::Reg,
|
||||||
|
)];
|
||||||
|
let fn_call = register_external_fn_call(
|
||||||
|
self.numeric_id_defs,
|
||||||
|
self.ptx_impl_imports,
|
||||||
|
ocl_fn_name.to_string(),
|
||||||
|
return_arguments.iter().map(|(_, typ, space)| (typ, *space)),
|
||||||
|
input_arguments.iter().map(|(_, typ, space)| (typ, *space)),
|
||||||
|
)?;
|
||||||
|
let data = ast::CallDetails {
|
||||||
|
uniform: false,
|
||||||
|
return_arguments: return_arguments
|
||||||
|
.iter()
|
||||||
|
.map(|(_, typ, space)| (typ.clone(), *space))
|
||||||
|
.collect(),
|
||||||
|
input_arguments: input_arguments
|
||||||
|
.iter()
|
||||||
|
.map(|(_, typ, space)| (typ.clone(), *space))
|
||||||
|
.collect(),
|
||||||
|
};
|
||||||
|
let arguments = ast::CallArgs {
|
||||||
|
return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(),
|
||||||
|
func: fn_call,
|
||||||
|
input_arguments: input_arguments.iter().map(|(name, _, _)| *name).collect(),
|
||||||
|
};
|
||||||
|
self.result
|
||||||
|
.push(Statement::Instruction(ast::Instruction::Call {
|
||||||
|
data,
|
||||||
|
arguments,
|
||||||
|
}));
|
||||||
|
Ok(fn_result)
|
||||||
|
} else {
|
||||||
|
Ok(name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn register_external_fn_call<'a>(
|
||||||
|
id_defs: &mut NumericIdResolver,
|
||||||
|
ptx_impl_imports: &mut HashMap<String, Directive>,
|
||||||
|
name: String,
|
||||||
|
return_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
|
||||||
|
input_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
|
||||||
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
|
match ptx_impl_imports.entry(name) {
|
||||||
|
hash_map::Entry::Vacant(entry) => {
|
||||||
|
let fn_id = id_defs.register_intermediate(None);
|
||||||
|
let return_arguments = fn_arguments_to_variables(id_defs, return_arguments);
|
||||||
|
let input_arguments = fn_arguments_to_variables(id_defs, input_arguments);
|
||||||
|
let func_decl = ast::MethodDeclaration::<SpirvWord> {
|
||||||
|
return_arguments,
|
||||||
|
name: ast::MethodName::Func(fn_id),
|
||||||
|
input_arguments,
|
||||||
|
shared_mem: None,
|
||||||
|
};
|
||||||
|
let func = Function {
|
||||||
|
func_decl: Rc::new(RefCell::new(func_decl)),
|
||||||
|
globals: Vec::new(),
|
||||||
|
body: None,
|
||||||
|
import_as: Some(entry.key().clone()),
|
||||||
|
tuning: Vec::new(),
|
||||||
|
linkage: ast::LinkingDirective::EXTERN,
|
||||||
|
};
|
||||||
|
entry.insert(Directive::Method(func));
|
||||||
|
Ok(fn_id)
|
||||||
|
}
|
||||||
|
hash_map::Entry::Occupied(entry) => match entry.get() {
|
||||||
|
Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
|
||||||
|
ast::MethodName::Func(fn_id) => Ok(fn_id),
|
||||||
|
ast::MethodName::Kernel(_) => Err(error_unreachable()),
|
||||||
|
},
|
||||||
|
_ => Err(error_unreachable()),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fn_arguments_to_variables<'a>(
|
||||||
|
id_defs: &mut NumericIdResolver,
|
||||||
|
args: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
|
||||||
|
) -> Vec<ast::Variable<SpirvWord>> {
|
||||||
|
args.map(|(typ, space)| ast::Variable {
|
||||||
|
align: None,
|
||||||
|
v_type: typ.clone(),
|
||||||
|
state_space: space,
|
||||||
|
name: id_defs.register_intermediate(None),
|
||||||
|
array_init: Vec::new(),
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
}
|
|
@ -9,6 +9,7 @@ use std::{
|
||||||
};
|
};
|
||||||
|
|
||||||
mod convert_to_typed;
|
mod convert_to_typed;
|
||||||
|
mod fix_special_registers;
|
||||||
mod normalize_identifiers;
|
mod normalize_identifiers;
|
||||||
mod normalize_predicates;
|
mod normalize_predicates;
|
||||||
|
|
||||||
|
@ -735,6 +736,235 @@ enum Statement<I, P: ast::Operand> {
|
||||||
FunctionPointer(FunctionPointerDetails),
|
FunctionPointer(FunctionPointerDetails),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
|
||||||
|
fn visit_map<To: ast::Operand<Ident = SpirvWord>, Err>(
|
||||||
|
self,
|
||||||
|
visitor: &mut impl ast::VisitorMap<T, To, Err>,
|
||||||
|
) -> std::result::Result<Statement<ast::Instruction<To>, T>, Err> {
|
||||||
|
Ok(match self {
|
||||||
|
Statement::Instruction(i) => {
|
||||||
|
return ast::visit_map(i, visitor).map(Statement::Instruction)
|
||||||
|
}
|
||||||
|
Statement::Label(label) => {
|
||||||
|
Statement::Label(visitor.visit_ident(label, None, false, false)?)
|
||||||
|
}
|
||||||
|
Statement::Variable(var) => {
|
||||||
|
let name = visitor.visit_ident(
|
||||||
|
var.name,
|
||||||
|
Some((&var.v_type, var.state_space)),
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
)?;
|
||||||
|
Statement::Variable(ast::Variable {
|
||||||
|
align: var.align,
|
||||||
|
v_type: var.v_type,
|
||||||
|
state_space: var.state_space,
|
||||||
|
name,
|
||||||
|
array_init: var.array_init,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Statement::Conditional(conditional) => {
|
||||||
|
let predicate = visitor.visit_ident(conditional.predicate, None, false, false)?;
|
||||||
|
let if_true = visitor.visit_ident(conditional.if_true, None, false, false)?;
|
||||||
|
let if_false = visitor.visit_ident(conditional.if_false, None, false, false)?;
|
||||||
|
Statement::Conditional(BrachCondition {
|
||||||
|
predicate,
|
||||||
|
if_true,
|
||||||
|
if_false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Statement::LoadVar(LoadVarDetails {
|
||||||
|
arg,
|
||||||
|
typ,
|
||||||
|
member_index,
|
||||||
|
}) => {
|
||||||
|
let dst = visitor.visit_ident(
|
||||||
|
arg.dst,
|
||||||
|
Some((&typ, ast::StateSpace::Reg)),
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
)?;
|
||||||
|
let src = visitor.visit_ident(
|
||||||
|
arg.src,
|
||||||
|
Some((&typ, ast::StateSpace::Local)),
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
)?;
|
||||||
|
Statement::LoadVar(LoadVarDetails {
|
||||||
|
arg: ast::LdArgs { dst, src },
|
||||||
|
typ,
|
||||||
|
member_index,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Statement::StoreVar(StoreVarDetails {
|
||||||
|
arg,
|
||||||
|
typ,
|
||||||
|
member_index,
|
||||||
|
}) => {
|
||||||
|
let src1 = visitor.visit_ident(
|
||||||
|
arg.src1,
|
||||||
|
Some((&typ, ast::StateSpace::Local)),
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
)?;
|
||||||
|
let src2 = visitor.visit_ident(
|
||||||
|
arg.src2,
|
||||||
|
Some((&typ, ast::StateSpace::Reg)),
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
)?;
|
||||||
|
Statement::StoreVar(StoreVarDetails {
|
||||||
|
arg: ast::StArgs { src1, src2 },
|
||||||
|
typ,
|
||||||
|
member_index,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Statement::Conversion(ImplicitConversion {
|
||||||
|
src,
|
||||||
|
dst,
|
||||||
|
from_type,
|
||||||
|
to_type,
|
||||||
|
from_space,
|
||||||
|
to_space,
|
||||||
|
kind,
|
||||||
|
}) => {
|
||||||
|
let dst = visitor.visit_ident(
|
||||||
|
dst,
|
||||||
|
Some((&to_type, ast::StateSpace::Reg)),
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
)?;
|
||||||
|
let src = visitor.visit_ident(
|
||||||
|
src,
|
||||||
|
Some((&from_type, ast::StateSpace::Reg)),
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
)?;
|
||||||
|
Statement::Conversion(ImplicitConversion {
|
||||||
|
src,
|
||||||
|
dst,
|
||||||
|
from_type,
|
||||||
|
to_type,
|
||||||
|
from_space,
|
||||||
|
to_space,
|
||||||
|
kind,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Statement::Constant(ConstantDefinition { dst, typ, value }) => {
|
||||||
|
let dst = visitor.visit_ident(
|
||||||
|
dst,
|
||||||
|
Some((&typ.into(), ast::StateSpace::Reg)),
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
)?;
|
||||||
|
Statement::Constant(ConstantDefinition { dst, typ, value })
|
||||||
|
}
|
||||||
|
Statement::RetValue(data, value) => {
|
||||||
|
// TODO:
|
||||||
|
// We should report type here
|
||||||
|
let value = visitor.visit_ident(value, None, false, false)?;
|
||||||
|
Statement::RetValue(data, value)
|
||||||
|
}
|
||||||
|
Statement::PtrAccess(PtrAccess {
|
||||||
|
underlying_type,
|
||||||
|
state_space,
|
||||||
|
dst,
|
||||||
|
ptr_src,
|
||||||
|
offset_src,
|
||||||
|
}) => {
|
||||||
|
let dst =
|
||||||
|
visitor.visit_ident(dst, Some((&underlying_type, state_space)), true, false)?;
|
||||||
|
let ptr_src = visitor.visit_ident(
|
||||||
|
ptr_src,
|
||||||
|
Some((&underlying_type, state_space)),
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
)?;
|
||||||
|
Statement::PtrAccess(PtrAccess {
|
||||||
|
underlying_type,
|
||||||
|
state_space,
|
||||||
|
dst,
|
||||||
|
ptr_src,
|
||||||
|
offset_src,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Statement::RepackVector(RepackVectorDetails {
|
||||||
|
is_extract,
|
||||||
|
typ,
|
||||||
|
packed,
|
||||||
|
unpacked,
|
||||||
|
relaxed_type_check,
|
||||||
|
}) => {
|
||||||
|
let (packed, unpacked) = if is_extract {
|
||||||
|
let unpacked = unpacked
|
||||||
|
.into_iter()
|
||||||
|
.map(|ident| {
|
||||||
|
visitor.visit_ident(
|
||||||
|
ident,
|
||||||
|
Some((&typ.into(), ast::StateSpace::Reg)),
|
||||||
|
true,
|
||||||
|
relaxed_type_check,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
let packed = visitor.visit_ident(
|
||||||
|
packed,
|
||||||
|
Some((
|
||||||
|
&ast::Type::Vector(typ, unpacked.len() as u8),
|
||||||
|
ast::StateSpace::Reg,
|
||||||
|
)),
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
)?;
|
||||||
|
(packed, unpacked)
|
||||||
|
} else {
|
||||||
|
let packed = visitor.visit_ident(
|
||||||
|
packed,
|
||||||
|
Some((
|
||||||
|
&ast::Type::Vector(typ, unpacked.len() as u8),
|
||||||
|
ast::StateSpace::Reg,
|
||||||
|
)),
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
)?;
|
||||||
|
let unpacked = unpacked
|
||||||
|
.into_iter()
|
||||||
|
.map(|ident| {
|
||||||
|
visitor.visit_ident(
|
||||||
|
ident,
|
||||||
|
Some((&typ.into(), ast::StateSpace::Reg)),
|
||||||
|
false,
|
||||||
|
relaxed_type_check,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
(packed, unpacked)
|
||||||
|
};
|
||||||
|
Statement::RepackVector(RepackVectorDetails {
|
||||||
|
is_extract,
|
||||||
|
typ,
|
||||||
|
packed,
|
||||||
|
unpacked,
|
||||||
|
relaxed_type_check,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Statement::FunctionPointer(FunctionPointerDetails { dst, src }) => {
|
||||||
|
let dst = visitor.visit_ident(
|
||||||
|
dst,
|
||||||
|
Some((
|
||||||
|
&ast::Type::Scalar(ast::ScalarType::U64),
|
||||||
|
ast::StateSpace::Reg,
|
||||||
|
)),
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
)?;
|
||||||
|
let src = visitor.visit_ident(src, None, false, false)?;
|
||||||
|
Statement::FunctionPointer(FunctionPointerDetails { dst, src })
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct BrachCondition {
|
struct BrachCondition {
|
||||||
predicate: SpirvWord,
|
predicate: SpirvWord,
|
||||||
if_true: SpirvWord,
|
if_true: SpirvWord,
|
||||||
|
@ -743,7 +973,6 @@ struct BrachCondition {
|
||||||
struct LoadVarDetails {
|
struct LoadVarDetails {
|
||||||
arg: ast::LdArgs<SpirvWord>,
|
arg: ast::LdArgs<SpirvWord>,
|
||||||
typ: ast::Type,
|
typ: ast::Type,
|
||||||
state_space: ast::StateSpace,
|
|
||||||
// (index, vector_width)
|
// (index, vector_width)
|
||||||
// HACK ALERT
|
// HACK ALERT
|
||||||
// For some reason IGC explodes when you try to load from builtin vectors
|
// For some reason IGC explodes when you try to load from builtin vectors
|
||||||
|
@ -798,7 +1027,7 @@ struct RepackVectorDetails {
|
||||||
typ: ast::ScalarType,
|
typ: ast::ScalarType,
|
||||||
packed: SpirvWord,
|
packed: SpirvWord,
|
||||||
unpacked: Vec<SpirvWord>,
|
unpacked: Vec<SpirvWord>,
|
||||||
relaxed_type_check: bool
|
relaxed_type_check: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct FunctionPointerDetails {
|
struct FunctionPointerDetails {
|
||||||
|
@ -876,6 +1105,20 @@ enum TypedOperand {
|
||||||
VecMember(SpirvWord, u8),
|
VecMember(SpirvWord, u8),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl TypedOperand {
|
||||||
|
fn map<Err>(
|
||||||
|
self,
|
||||||
|
fn_: impl FnOnce(SpirvWord, Option<u8>) -> Result<SpirvWord, Err>,
|
||||||
|
) -> Result<Self, Err> {
|
||||||
|
Ok(match self {
|
||||||
|
TypedOperand::Reg(reg) => TypedOperand::Reg(fn_(reg, None)?),
|
||||||
|
TypedOperand::RegOffset(reg, off) => TypedOperand::RegOffset(fn_(reg, None)?, off),
|
||||||
|
TypedOperand::Imm(imm) => TypedOperand::Imm(imm),
|
||||||
|
TypedOperand::VecMember(reg, idx) => TypedOperand::VecMember(fn_(reg, Some(idx))?, idx),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ast::Operand for TypedOperand {
|
impl ast::Operand for TypedOperand {
|
||||||
type Ident = SpirvWord;
|
type Ident = SpirvWord;
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue