Add more passes

This commit is contained in:
Andrzej Janik 2024-09-22 19:47:08 +02:00
parent c84d257bb7
commit 7bd4179d1d
13 changed files with 1208 additions and 105 deletions

View file

@ -18,6 +18,8 @@ bit-vec = "0.6"
half ="1.6"
bitflags = "1.2"
rustc-hash = "2.0.0"
strum = "0.26"
strum_macros = "0.26"
[dependencies.lalrpop-util]
version = "0.19.12"

View file

@ -0,0 +1,141 @@
use std::collections::BTreeMap;
use super::*;
pub(super) fn run<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives
.into_iter()
.map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>()
}
fn run_directive<'input>(
resolver: &mut GlobalStringIdentResolver2,
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive {
var @ Directive2::Variable(..) => var,
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
})
}
fn run_method<'input>(
resolver: &mut GlobalStringIdentResolver2,
mut method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
if method.func_decl.name.is_kernel() {
return Ok(method);
}
let is_declaration = method.body.is_none();
let mut body = Vec::new();
let mut remap_returns = Vec::new();
for arg in method.func_decl.return_arguments.iter_mut() {
match arg.state_space {
ptx_parser::StateSpace::Param => {
arg.state_space = ptx_parser::StateSpace::Reg;
let old_name = arg.name;
arg.name = resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
if is_declaration {
continue;
}
remap_returns.push((old_name, arg.name, arg.v_type.clone()));
body.push(Statement::Variable(ast::Variable {
align: None,
name: old_name,
v_type: arg.v_type.clone(),
state_space: ptx_parser::StateSpace::Param,
array_init: Vec::new(),
}));
}
ptx_parser::StateSpace::Reg => {}
_ => return Err(error_unreachable()),
}
}
for arg in method.func_decl.input_arguments.iter_mut() {
match arg.state_space {
ptx_parser::StateSpace::Param => {
arg.state_space = ptx_parser::StateSpace::Reg;
let old_name = arg.name;
arg.name = resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
if is_declaration {
continue;
}
body.push(Statement::Variable(ast::Variable {
align: None,
name: old_name,
v_type: arg.v_type.clone(),
state_space: ptx_parser::StateSpace::Param,
array_init: Vec::new(),
}));
body.push(Statement::Instruction(ast::Instruction::St {
data: ast::StData {
qualifier: ast::LdStQualifier::Weak,
state_space: ast::StateSpace::Param,
caching: ast::StCacheOperator::Writethrough,
typ: arg.v_type.clone(),
},
arguments: ast::StArgs {
src1: old_name,
src2: arg.name,
},
}));
}
ptx_parser::StateSpace::Reg => {}
_ => return Err(error_unreachable()),
}
}
if remap_returns.is_empty() {
return Ok(method);
}
let body = method
.body
.map(|statements| {
for statement in statements {
run_statement(&remap_returns, &mut body, statement)?;
}
Ok::<_, TranslateError>(body)
})
.transpose()?;
Ok(Function2 {
func_decl: method.func_decl,
globals: method.globals,
body,
import_as: method.import_as,
tuning: method.tuning,
linkage: method.linkage,
})
}
fn run_statement<'input>(
remap_returns: &Vec<(SpirvWord, SpirvWord, ast::Type)>,
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
statement: Statement<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<(), TranslateError> {
match statement {
Statement::Instruction(ast::Instruction::Ret { .. }) => {
for (old_name, new_name, type_) in remap_returns.iter().cloned() {
result.push(Statement::Instruction(ast::Instruction::Ld {
data: ast::LdDetails {
qualifier: ast::LdStQualifier::Weak,
state_space: ast::StateSpace::Reg,
caching: ast::LdCacheOperator::Cached,
typ: type_,
non_coherent: false,
},
arguments: ast::LdArgs {
dst: new_name,
src: old_name,
},
}));
}
result.push(statement);
}
statement => {
result.push(statement);
}
}
Ok(())
}

View file

@ -308,6 +308,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
Statement::PtrAccess(_) => todo!(),
Statement::RepackVector(_) => todo!(),
Statement::FunctionPointer(_) => todo!(),
Statement::VectorAccess(_) => todo!(),
})
}

View file

@ -1561,6 +1561,7 @@ fn emit_function_body_ops<'input>(
builder.copy_object(vector_type.0, Some(repack.packed.0), temp_vec)?;
}
}
Statement::VectorAccess(vector_access) => todo!(),
}
}
Ok(())

View file

@ -0,0 +1,289 @@
use super::*;
pub(super) fn run<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<UnconditionalDirective<'input>>,
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives
.into_iter()
.map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>()
}
fn run_directive<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directive: Directive2<
'input,
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
ast::ParsedOperand<SpirvWord>,
>,
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive {
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
})
}
fn run_method<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
method: Function2<
'input,
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
ast::ParsedOperand<SpirvWord>,
>,
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
let body = method
.body
.map(|statements| {
let mut result = Vec::with_capacity(statements.len());
for statement in statements {
run_statement(resolver, &mut result, statement)?;
}
Ok::<_, TranslateError>(result)
})
.transpose()?;
Ok(Function2 {
func_decl: method.func_decl,
globals: method.globals,
body,
import_as: method.import_as,
tuning: method.tuning,
linkage: method.linkage,
})
}
fn run_statement<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
statement: UnconditionalStatement,
) -> Result<(), TranslateError> {
let mut visitor = FlattenArguments::new(resolver, result);
let new_statement = statement.visit_map(&mut visitor)?;
visitor.result.push(new_statement);
Ok(())
}
struct FlattenArguments<'a, 'input> {
result: &'a mut Vec<ExpandedStatement>,
resolver: &'a mut GlobalStringIdentResolver2<'input>,
post_stmts: Vec<ExpandedStatement>,
}
impl<'a, 'input> FlattenArguments<'a, 'input> {
fn new(
resolver: &'a mut GlobalStringIdentResolver2<'input>,
result: &'a mut Vec<ExpandedStatement>,
) -> Self {
FlattenArguments {
result,
resolver,
post_stmts: Vec::new(),
}
}
fn reg(&mut self, name: SpirvWord) -> Result<SpirvWord, TranslateError> {
Ok(name)
}
fn reg_offset(
&mut self,
reg: SpirvWord,
offset: i32,
type_space: Option<(&ast::Type, ast::StateSpace)>,
_is_dst: bool,
) -> Result<SpirvWord, TranslateError> {
let (type_, state_space) = if let Some((type_, state_space)) = type_space {
(type_, state_space)
} else {
return Err(TranslateError::UntypedSymbol);
};
if state_space == ast::StateSpace::Reg {
let (reg_type, reg_space) = self.resolver.get_typed(reg)?;
if *reg_space != ast::StateSpace::Reg {
return Err(error_mismatched_type());
}
let reg_scalar_type = match reg_type {
ast::Type::Scalar(underlying_type) => *underlying_type,
_ => return Err(error_mismatched_type()),
};
let reg_type = reg_type.clone();
let id_constant_stmt = self
.resolver
.register_unnamed(Some((reg_type.clone(), ast::StateSpace::Reg)));
self.result.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: reg_scalar_type,
value: ast::ImmediateValue::S64(offset as i64),
}));
let arith_details = match reg_scalar_type.kind() {
ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger {
type_: reg_scalar_type,
saturate: false,
}),
ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => {
ast::ArithDetails::Integer(ast::ArithInteger {
type_: reg_scalar_type,
saturate: false,
})
}
_ => return Err(error_unreachable()),
};
let id_add_result = self
.resolver
.register_unnamed(Some((reg_type, state_space)));
self.result
.push(Statement::Instruction(ast::Instruction::Add {
data: arith_details,
arguments: ast::AddArgs {
dst: id_add_result,
src1: reg,
src2: id_constant_stmt,
},
}));
Ok(id_add_result)
} else {
let id_constant_stmt = self.resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::S64),
ast::StateSpace::Reg,
)));
self.result.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: ast::ScalarType::S64,
value: ast::ImmediateValue::S64(offset as i64),
}));
let dst = self
.resolver
.register_unnamed(Some((type_.clone(), state_space)));
self.result.push(Statement::PtrAccess(PtrAccess {
underlying_type: type_.clone(),
state_space: state_space,
dst,
ptr_src: reg,
offset_src: id_constant_stmt,
}));
Ok(dst)
}
}
fn immediate(
&mut self,
value: ast::ImmediateValue,
type_space: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<SpirvWord, TranslateError> {
let (scalar_t, state_space) =
if let Some((ast::Type::Scalar(scalar), state_space)) = type_space {
(*scalar, state_space)
} else {
return Err(TranslateError::UntypedSymbol);
};
let id = self
.resolver
.register_unnamed(Some((ast::Type::Scalar(scalar_t), state_space)));
self.result.push(Statement::Constant(ConstantDefinition {
dst: id,
typ: scalar_t,
value,
}));
Ok(id)
}
fn vec_member(
&mut self,
vector_src: SpirvWord,
member: u8,
_type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
) -> Result<SpirvWord, TranslateError> {
if is_dst {
return Err(error_mismatched_type());
}
let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_src)? {
(ast::Type::Vector(vector_width, scalar_t), space) => {
(*vector_width, *scalar_t, *space)
}
_ => return Err(error_mismatched_type()),
};
let temporary = self
.resolver
.register_unnamed(Some((scalar_type.into(), space)));
self.result.push(Statement::VectorAccess(VectorAccess {
scalar_type,
vector_width,
dst: temporary,
src: vector_src,
member: member,
}));
Ok(temporary)
}
fn vec_pack(
&mut self,
vecs: Vec<SpirvWord>,
type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
let (scalar_t, state_space) = match type_space {
Some((ast::Type::Vector(_, scalar_t), space)) => (*scalar_t, space),
_ => return Err(error_mismatched_type()),
};
let temp_vec = self
.resolver
.register_unnamed(Some((scalar_t.into(), state_space)));
let statement = Statement::RepackVector(RepackVectorDetails {
is_extract: is_dst,
typ: scalar_t,
packed: temp_vec,
unpacked: vecs,
relaxed_type_check,
});
if is_dst {
self.post_stmts.push(statement);
} else {
self.result.push(statement);
}
Ok(temp_vec)
}
}
impl<'a, 'b> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, SpirvWord, TranslateError>
for FlattenArguments<'a, 'b>
{
fn visit(
&mut self,
args: ast::ParsedOperand<SpirvWord>,
type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
match args {
ast::ParsedOperand::Reg(r) => self.reg(r),
ast::ParsedOperand::Imm(x) => self.immediate(x, type_space),
ast::ParsedOperand::RegOffset(reg, offset) => {
self.reg_offset(reg, offset, type_space, is_dst)
}
ast::ParsedOperand::VecMember(vec, member) => {
self.vec_member(vec, member, type_space, is_dst)
}
ast::ParsedOperand::VecPack(vecs) => {
self.vec_pack(vecs, type_space, is_dst, relaxed_type_check)
}
}
}
fn visit_ident(
&mut self,
name: <TypedOperand as ast::Operand>::Ident,
_type_space: Option<(&ast::Type, ast::StateSpace)>,
_is_dst: bool,
_relaxed_type_check: bool,
) -> Result<<SpirvWord as ast::Operand>::Ident, TranslateError> {
self.reg(name)
}
}
impl Drop for FlattenArguments<'_, '_> {
fn drop(&mut self) {
self.result.extend(self.post_stmts.drain(..));
}
}

View file

@ -0,0 +1,209 @@
use super::*;
pub(super) fn run<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
special_registers: &'a SpecialRegistersMap2,
directives: Vec<UnconditionalDirective<'input>>,
) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
let declarations = SpecialRegistersMap2::generate_declarations(resolver);
let mut result = Vec::with_capacity(declarations.len() + directives.len());
let mut sreg_to_function =
FxHashMap::with_capacity_and_hasher(declarations.len(), Default::default());
for (sreg, declaration) in declarations {
let name = if let ast::MethodName::Func(name) = declaration.name {
name
} else {
return Err(error_unreachable());
};
result.push(UnconditionalDirective::Method(UnconditionalFunction {
func_decl: declaration,
globals: Vec::new(),
body: None,
import_as: None,
tuning: Vec::new(),
linkage: ast::LinkingDirective::EXTERN,
}));
sreg_to_function.insert(sreg, name);
}
let mut visitor = SpecialRegisterResolver {
resolver,
special_registers,
sreg_to_function,
result: Vec::new(),
};
directives
.into_iter()
.map(|directive| run_directive(&mut visitor, directive))
.collect::<Result<Vec<_>, _>>()
}
fn run_directive<'a, 'input>(
visitor: &mut SpecialRegisterResolver<'a, 'input>,
directive: UnconditionalDirective<'input>,
) -> Result<UnconditionalDirective<'input>, TranslateError> {
Ok(match directive {
var @ Directive2::Variable(..) => var,
Directive2::Method(method) => Directive2::Method(run_method(visitor, method)?),
})
}
fn run_method<'a, 'input>(
visitor: &mut SpecialRegisterResolver<'a, 'input>,
method: UnconditionalFunction<'input>,
) -> Result<UnconditionalFunction<'input>, TranslateError> {
let body = method
.body
.map(|statements| {
let mut result = Vec::with_capacity(statements.len());
for statement in statements {
run_statement(visitor, &mut result, statement)?;
}
Ok::<_, TranslateError>(result)
})
.transpose()?;
Ok(Function2 {
func_decl: method.func_decl,
globals: method.globals,
body,
import_as: method.import_as,
tuning: method.tuning,
linkage: method.linkage,
})
}
fn run_statement<'a, 'input>(
visitor: &mut SpecialRegisterResolver<'a, 'input>,
result: &mut Vec<UnconditionalStatement>,
statement: UnconditionalStatement,
) -> Result<(), TranslateError> {
let converted_statement = statement.visit_map(visitor)?;
result.extend(visitor.result.drain(..));
result.push(converted_statement);
Ok(())
}
struct SpecialRegisterResolver<'a, 'input> {
resolver: &'a mut GlobalStringIdentResolver2<'input>,
special_registers: &'a SpecialRegistersMap2,
sreg_to_function: FxHashMap<PtxSpecialRegister, SpirvWord>,
result: Vec<UnconditionalStatement>,
}
impl<'a, 'b, 'input>
ast::VisitorMap<ast::ParsedOperand<SpirvWord>, ast::ParsedOperand<SpirvWord>, TranslateError>
for SpecialRegisterResolver<'a, 'input>
{
fn visit(
&mut self,
operand: ast::ParsedOperand<SpirvWord>,
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
is_dst: bool,
_relaxed_type_check: bool,
) -> Result<ast::ParsedOperand<SpirvWord>, TranslateError> {
map_operand(operand, &mut |ident, vector_index| {
self.replace_sreg(ident, vector_index, is_dst)
})
}
fn visit_ident(
&mut self,
args: SpirvWord,
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
is_dst: bool,
_relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
self.replace_sreg(args, None, is_dst)
}
}
impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> {
fn replace_sreg(
&mut self,
name: SpirvWord,
vector_index: Option<u8>,
is_dst: bool,
) -> Result<SpirvWord, TranslateError> {
if let Some(sreg) = self.special_registers.get(name) {
if is_dst {
return Err(error_mismatched_type());
}
let input_arguments = match (vector_index, sreg.get_function_input_type()) {
(Some(idx), Some(inp_type)) => {
if inp_type != ast::ScalarType::U8 {
return Err(TranslateError::Unreachable);
}
let constant = self.resolver.register_unnamed(Some((
ast::Type::Scalar(inp_type),
ast::StateSpace::Reg,
)));
self.result.push(Statement::Constant(ConstantDefinition {
dst: constant,
typ: inp_type,
value: ast::ImmediateValue::U64(idx as u64),
}));
vec![(constant, ast::Type::Scalar(inp_type), ast::StateSpace::Reg)]
}
(None, None) => Vec::new(),
_ => return Err(error_mismatched_type()),
};
let return_type = sreg.get_function_return_type();
let fn_result = self
.resolver
.register_unnamed(Some((ast::Type::Scalar(return_type), ast::StateSpace::Reg)));
let return_arguments = vec![(
fn_result,
ast::Type::Scalar(return_type),
ast::StateSpace::Reg,
)];
let data = ast::CallDetails {
uniform: false,
return_arguments: return_arguments
.iter()
.map(|(_, typ, space)| (typ.clone(), *space))
.collect(),
input_arguments: input_arguments
.iter()
.map(|(_, typ, space)| (typ.clone(), *space))
.collect(),
};
let arguments = ast::CallArgs::<ast::ParsedOperand<SpirvWord>> {
return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(),
func: self.sreg_to_function[&sreg],
input_arguments: input_arguments
.iter()
.map(|(name, _, _)| ast::ParsedOperand::Reg(*name))
.collect(),
};
self.result
.push(Statement::Instruction(ast::Instruction::Call {
data,
arguments,
}));
Ok(fn_result)
} else {
Ok(name)
}
}
}
pub fn map_operand<T, U, Err>(
this: ast::ParsedOperand<T>,
fn_: &mut impl FnMut(T, Option<u8>) -> Result<U, Err>,
) -> Result<ast::ParsedOperand<U>, Err> {
Ok(match this {
ast::ParsedOperand::Reg(ident) => ast::ParsedOperand::Reg(fn_(ident, None)?),
ast::ParsedOperand::RegOffset(ident, offset) => {
ast::ParsedOperand::RegOffset(fn_(ident, None)?, offset)
}
ast::ParsedOperand::Imm(imm) => ast::ParsedOperand::Imm(imm),
ast::ParsedOperand::VecMember(ident, member) => {
ast::ParsedOperand::Reg(fn_(ident, Some(member))?)
}
ast::ParsedOperand::VecPack(idents) => ast::ParsedOperand::VecPack(
idents
.into_iter()
.map(|ident| fn_(ident, None))
.collect::<Result<Vec<_>, _>>()?,
),
})
}

View file

@ -0,0 +1,273 @@
use super::*;
use ptx_parser::VisitorMap;
use rustc_hash::FxHashSet;
// This pass:
// * Turns all .local, .param and .reg in-body variables into .local variables
// (if _not_ an input method argument)
// * Inserts explicit `ld`/`st` for newly converted .reg variables
// * Fixup state space of all existing `ld`/`st` instructions into newly
// converted variables
// * Turns `.entry` input arguments into param::entry and all related `.param`
// loads into `param::entry` loads
// * All `.func` input arguments are turned into `.reg` arguments by another
// pass, so we do nothing there
pub(super) fn run<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives
.into_iter()
.map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>()
}
fn run_directive<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive {
var @ Directive2::Variable(..) => var,
Directive2::Method(method) => {
let visitor = InsertMemSSAVisitor::new(resolver);
Directive2::Method(run_method(visitor, method)?)
}
})
}
fn run_method<'a, 'input>(
mut visitor: InsertMemSSAVisitor<'a, 'input>,
method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
let mut func_decl = method.func_decl;
for arg in func_decl.return_arguments.iter_mut() {
visitor.visit_variable(arg);
}
let is_kernel = func_decl.name.is_kernel();
// let mut prelude = Vec::with_capacity(method.body.as_ref().map(Vec::len).unwrap_or(0));
if is_kernel {
for arg in func_decl.input_arguments.iter_mut() {
let old_name = arg.name;
let old_space = arg.state_space;
let new_space = ast::StateSpace::ParamEntry;
let new_name = visitor
.resolver
.register_unnamed(Some((arg.v_type.clone(), new_space)));
visitor.input_argument(old_name, new_name, old_space);
arg.name = new_name;
arg.state_space = new_space;
}
};
let body = method
.body
.map(move |statements| {
let mut result = Vec::with_capacity(statements.len());
for statement in statements {
run_statement(&mut visitor, &mut result, statement)?;
}
Ok::<_, TranslateError>(result)
})
.transpose()?;
Ok(Function2 {
func_decl: func_decl,
globals: method.globals,
body,
import_as: method.import_as,
tuning: method.tuning,
linkage: method.linkage,
})
}
fn run_statement<'a, 'input>(
visitor: &mut InsertMemSSAVisitor<'a, 'input>,
result: &mut Vec<ExpandedStatement>,
statement: ExpandedStatement,
) -> Result<(), TranslateError> {
match statement {
Statement::Variable(mut var) => {
visitor.visit_variable(&mut var);
result.push(Statement::Variable(var));
}
Statement::Instruction(ast::Instruction::Ld { data, arguments }) => {
let instruction = visitor.visit_ld(data, arguments)?;
let instruction = ast::visit_map(instruction, visitor)?;
result.push(Statement::Instruction(instruction));
}
Statement::Instruction(ast::Instruction::St {
data,
mut arguments,
}) => {
let instruction = visitor.visit_st(data, arguments)?;
let instruction = ast::visit_map(instruction, visitor)?;
result.push(Statement::Instruction(instruction));
}
s => result.push(s.visit_map(visitor)?),
}
Ok(())
}
struct InsertMemSSAVisitor<'a, 'input> {
resolver: &'a mut GlobalStringIdentResolver2<'input>,
variables: FxHashMap<SpirvWord, RemapAction>,
}
impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
fn new(resolver: &'a mut GlobalStringIdentResolver2<'input>) -> Self {
Self {
resolver,
variables: FxHashMap::default(),
}
}
fn input_argument(
&mut self,
old_name: SpirvWord,
new_name: SpirvWord,
old_space: ast::StateSpace,
) -> Result<(), TranslateError> {
if old_space != ast::StateSpace::Param {
return Err(error_unreachable());
}
self.variables.insert(
old_name,
RemapAction::LDStSpaceChange {
name: new_name,
old_space,
new_space: ast::StateSpace::ParamEntry,
},
);
Ok(())
}
fn variable(
&mut self,
old_name: SpirvWord,
new_name: SpirvWord,
old_space: ast::StateSpace,
) -> Result<(), TranslateError> {
Ok(match old_space {
ast::StateSpace::Reg => {
self.variables
.insert(old_name, RemapAction::PreLdPostSt(new_name));
}
ast::StateSpace::Param => {
self.variables.insert(
old_name,
RemapAction::LDStSpaceChange {
old_space,
new_space: ast::StateSpace::Local,
name: new_name,
},
);
}
// Good as-is
ast::StateSpace::Local => {}
// Will be pulled into global scope later
ast::StateSpace::Generic
| ast::StateSpace::SharedCluster
| ast::StateSpace::Global
| ast::StateSpace::Const
| ast::StateSpace::SharedCta
| ast::StateSpace::Shared => {}
ast::StateSpace::ParamEntry | ast::StateSpace::ParamFunc => {
return Err(error_unreachable())
}
})
}
fn visit_st(
&self,
mut data: ast::StData,
mut arguments: ast::StArgs<SpirvWord>,
) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
if let Some(remap) = self.variables.get(&arguments.src1) {
match remap {
RemapAction::PreLdPostSt(_) => return Err(error_mismatched_type()),
RemapAction::LDStSpaceChange {
old_space,
new_space,
name,
} => {
if data.state_space != *old_space {
return Err(error_mismatched_type());
}
data.state_space = *new_space;
arguments.src1 = *name;
}
}
}
Ok(ast::Instruction::St { data, arguments })
}
fn visit_ld(
&self,
mut data: ast::LdDetails,
mut arguments: ast::LdArgs<SpirvWord>,
) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
if let Some(remap) = self.variables.get(&arguments.src) {
match remap {
RemapAction::PreLdPostSt(_) => return Err(error_mismatched_type()),
RemapAction::LDStSpaceChange {
old_space,
new_space,
name,
} => {
if data.state_space != *old_space {
return Err(error_mismatched_type());
}
data.state_space = *new_space;
arguments.src = *name;
}
}
}
Ok(ast::Instruction::Ld { data, arguments })
}
fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) {
if var.state_space != ast::StateSpace::Local {
let old_name = var.name;
let old_space = var.state_space;
let new_space = ast::StateSpace::Local;
let new_name = self
.resolver
.register_unnamed(Some((var.v_type.clone(), new_space)));
self.variable(old_name, new_name, old_space);
var.name = new_name;
var.state_space = new_space;
}
}
}
impl<'a, 'input> ast::VisitorMap<SpirvWord, SpirvWord, TranslateError>
for InsertMemSSAVisitor<'a, 'input>
{
fn visit(
&mut self,
args: SpirvWord,
type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
todo!()
}
fn visit_ident(
&mut self,
args: SpirvWord,
type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
self.visit(args, type_space, is_dst, relaxed_type_check)
}
}
#[derive(Clone, Copy)]
enum RemapAction {
PreLdPostSt(SpirvWord),
LDStSpaceChange {
old_space: ast::StateSpace,
new_space: ast::StateSpace,
name: SpirvWord,
},
}

View file

@ -45,6 +45,13 @@ pub(super) fn run(
Statement::RepackVector(repack),
)?;
}
Statement::VectorAccess(vector_access) => {
insert_implicit_conversions_impl(
&mut result,
id_def,
Statement::VectorAccess(vector_access),
)?;
}
s @ Statement::Conditional(_)
| s @ Statement::Conversion(_)
| s @ Statement::Label(_)

View file

@ -13,15 +13,21 @@ use std::{
mem,
rc::Rc,
};
use strum::IntoEnumIterator;
use strum_macros::EnumIter;
mod convert_dynamic_shared_memory_usage;
mod convert_to_stateful_memory_access;
mod convert_to_typed;
mod deparamize_functions;
pub(crate) mod emit_llvm;
mod emit_spirv;
mod expand_arguments;
mod expand_operands;
mod extract_globals;
mod fix_special_registers;
mod fix_special_registers2;
mod insert_explicit_load_store;
mod insert_implicit_conversions;
mod insert_mem_ssa_statements;
mod normalize_identifiers;
@ -68,6 +74,20 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl
})
}
pub fn to_llvm_module2<'input>(ast: ast::Module<'input>) -> Result<Module, TranslateError> {
let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1));
let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
let sreg_map = SpecialRegistersMap2::new(&mut scoped_resolver)?;
let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?;
let directives = normalize_predicates2::run(&mut flat_resolver, directives)?;
let directives = resolve_function_pointers::run(directives)?;
let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?;
let directives = expand_operands::run(&mut flat_resolver, directives)?;
let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
todo!()
}
fn translate_directive<'input, 'a>(
id_defs: &'a mut GlobalStringIdResolver<'input>,
ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
@ -323,7 +343,7 @@ pub struct KernelInfo {
pub uses_shared_mem: bool,
}
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone, EnumIter)]
enum PtxSpecialRegister {
Tid,
Ntid,
@ -346,6 +366,17 @@ impl PtxSpecialRegister {
}
}
fn as_str(self) -> &'static str {
match self {
Self::Tid => "%tid",
Self::Ntid => "%ntid",
Self::Ctaid => "%ctaid",
Self::Nctaid => "%nctaid",
Self::Clock => "%clock",
Self::LanemaskLt => "%lanemask_lt",
}
}
fn get_type(self) -> ast::Type {
match self {
PtxSpecialRegister::Tid
@ -726,6 +757,7 @@ enum Statement<I, P: ast::Operand> {
PtrAccess(PtrAccess<P>),
RepackVector(RepackVectorDetails),
FunctionPointer(FunctionPointerDetails),
VectorAccess(VectorAccess),
}
impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
@ -894,6 +926,36 @@ impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
offset_src,
})
}
Statement::VectorAccess(VectorAccess {
scalar_type,
vector_width,
dst,
src: vector_src,
member,
}) => {
let dst: SpirvWord = visitor.visit_ident(
dst,
Some((&scalar_type.into(), ast::StateSpace::Reg)),
true,
false,
)?;
let src = visitor.visit_ident(
vector_src,
Some((
&ast::Type::Vector(vector_width, scalar_type),
ast::StateSpace::Reg,
)),
false,
false,
)?;
Statement::VectorAccess(VectorAccess {
scalar_type,
vector_width,
dst,
src,
member,
})
}
Statement::RepackVector(RepackVectorDetails {
is_extract,
typ,
@ -1448,6 +1510,7 @@ fn compute_denorm_information<'input>(
Statement::Label(_) => {}
Statement::Variable(_) => {}
Statement::PtrAccess { .. } => {}
Statement::VectorAccess { .. } => {}
Statement::RepackVector(_) => {}
Statement::FunctionPointer(_) => {}
}
@ -1668,7 +1731,7 @@ pub(crate) enum Directive2<'input, Instruction, Operand: ast::Operand> {
}
pub(crate) struct Function2<'input, Instruction, Operand: ast::Operand> {
pub func_decl: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
pub func_decl: ast::MethodDeclaration<'input, SpirvWord>,
pub globals: Vec<ast::Variable<SpirvWord>>,
pub body: Option<Vec<Statement<Instruction, Operand>>>,
import_as: Option<String>,
@ -1712,10 +1775,31 @@ struct GlobalStringIdentResolver2<'input> {
}
impl<'input> GlobalStringIdentResolver2<'input> {
fn register_intermediate(
fn new(spirv_word: SpirvWord) -> Self {
Self {
current_id: spirv_word,
ident_map: FxHashMap::default(),
}
}
fn register_named(
&mut self,
name: Cow<'input, str>,
type_space: Option<(ast::Type, ast::StateSpace)>,
) -> SpirvWord {
let new_id = self.current_id;
self.ident_map.insert(
new_id,
IdentEntry {
name: Some(name),
type_space,
},
);
self.current_id.0 += 1;
new_id
}
fn register_unnamed(&mut self, type_space: Option<(ast::Type, ast::StateSpace)>) -> SpirvWord {
let new_id = self.current_id;
self.ident_map.insert(
new_id,
@ -1727,9 +1811,191 @@ impl<'input> GlobalStringIdentResolver2<'input> {
self.current_id.0 += 1;
new_id
}
fn get_typed(&self, id: SpirvWord) -> Result<&(ast::Type, ast::StateSpace), TranslateError> {
match self.ident_map.get(&id) {
Some(IdentEntry {
type_space: Some(type_space),
..
}) => Ok(type_space),
_ => Err(error_unknown_symbol()),
}
}
}
struct IdentEntry<'input> {
name: Option<Cow<'input, str>>,
type_space: Option<(ast::Type, ast::StateSpace)>,
}
struct ScopedResolver<'input, 'b> {
flat_resolver: &'b mut GlobalStringIdentResolver2<'input>,
scopes: Vec<ScopeMarker<'input>>,
}
impl<'input, 'b> ScopedResolver<'input, 'b> {
fn new(flat_resolver: &'b mut GlobalStringIdentResolver2<'input>) -> Self {
Self {
flat_resolver,
scopes: vec![ScopeMarker::new()],
}
}
fn start_scope(&mut self) {
self.scopes.push(ScopeMarker::new());
}
fn end_scope(&mut self) {
let scope = self.scopes.pop().unwrap();
scope.flush(self.flat_resolver);
}
fn add(
&mut self,
name: Cow<'input, str>,
type_space: Option<(ast::Type, ast::StateSpace)>,
) -> Result<SpirvWord, TranslateError> {
let result = self.flat_resolver.current_id;
self.flat_resolver.current_id.0 += 1;
let current_scope = self.scopes.last_mut().unwrap();
if current_scope
.name_to_ident
.insert(name.clone(), result)
.is_some()
{
return Err(error_unknown_symbol());
}
current_scope.ident_map.insert(
result,
IdentEntry {
name: Some(name),
type_space,
},
);
Ok(result)
}
fn get(&mut self, name: &str) -> Result<SpirvWord, TranslateError> {
self.scopes
.iter()
.rev()
.find_map(|resolver| resolver.name_to_ident.get(name).copied())
.ok_or_else(|| error_unreachable())
}
fn get_in_current_scope(&self, label: &'input str) -> Result<SpirvWord, TranslateError> {
let current_scope = self.scopes.last().unwrap();
current_scope
.name_to_ident
.get(label)
.copied()
.ok_or_else(|| error_unreachable())
}
}
struct ScopeMarker<'input> {
ident_map: FxHashMap<SpirvWord, IdentEntry<'input>>,
name_to_ident: FxHashMap<Cow<'input, str>, SpirvWord>,
}
impl<'input> ScopeMarker<'input> {
fn new() -> Self {
Self {
ident_map: FxHashMap::default(),
name_to_ident: FxHashMap::default(),
}
}
fn flush(self, resolver: &mut GlobalStringIdentResolver2<'input>) {
resolver.ident_map.extend(self.ident_map);
}
}
struct SpecialRegistersMap2 {
reg_to_id: FxHashMap<PtxSpecialRegister, SpirvWord>,
id_to_reg: FxHashMap<SpirvWord, PtxSpecialRegister>,
}
impl SpecialRegistersMap2 {
fn new(resolver: &mut ScopedResolver) -> Result<Self, TranslateError> {
let mut result = SpecialRegistersMap2 {
reg_to_id: FxHashMap::default(),
id_to_reg: FxHashMap::default(),
};
for sreg in PtxSpecialRegister::iter() {
let text = sreg.as_str();
let id = resolver.add(
Cow::Borrowed(text),
Some((sreg.get_type(), ast::StateSpace::Reg)),
)?;
result.reg_to_id.insert(sreg, id);
result.id_to_reg.insert(id, sreg);
}
Ok(result)
}
fn get(&self, id: SpirvWord) -> Option<PtxSpecialRegister> {
self.id_to_reg.get(&id).copied()
}
fn get_or_add(&mut self, current_id: &mut SpirvWord, reg: PtxSpecialRegister) -> SpirvWord {
match self.reg_to_id.entry(reg) {
hash_map::Entry::Occupied(e) => *e.get(),
hash_map::Entry::Vacant(e) => {
let numeric_id = SpirvWord(current_id.0);
current_id.0 += 1;
e.insert(numeric_id);
self.id_to_reg.insert(numeric_id, reg);
numeric_id
}
}
}
fn generate_declarations<'a, 'input>(
resolver: &'a mut GlobalStringIdentResolver2<'input>,
) -> impl ExactSizeIterator<
Item = (
PtxSpecialRegister,
ast::MethodDeclaration<'input, SpirvWord>,
),
> + 'a {
PtxSpecialRegister::iter().map(|sreg| {
let external_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat();
let name =
ast::MethodName::Func(resolver.register_named(Cow::Owned(external_fn_name), None));
let return_type = sreg.get_function_return_type();
let input_type = sreg.get_function_return_type();
(
sreg,
ast::MethodDeclaration {
return_arguments: vec![ast::Variable {
align: None,
v_type: return_type.into(),
state_space: ast::StateSpace::Reg,
name: resolver
.register_unnamed(Some((return_type.into(), ast::StateSpace::Reg))),
array_init: Vec::new(),
}],
name: name,
input_arguments: vec![ast::Variable {
align: None,
v_type: input_type.into(),
state_space: ast::StateSpace::Reg,
name: resolver
.register_unnamed(Some((input_type.into(), ast::StateSpace::Reg))),
array_init: Vec::new(),
}],
shared_mem: None,
},
)
})
}
}
pub struct VectorAccess {
scalar_type: ast::ScalarType,
vector_width: u8,
dst: SpirvWord,
src: SpirvWord,
member: u8,
}

View file

@ -2,21 +2,21 @@ use super::*;
use ptx_parser as ast;
use rustc_hash::FxHashMap;
pub(crate) fn run<'input>(
fn_defs: &mut GlobalStringIdentResolver2<'input>,
pub(crate) fn run<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
directives: Vec<ast::Directive<'input, ast::ParsedOperand<&'input str>>>,
) -> Result<Vec<NormalizedDirective2<'input>>, TranslateError> {
let mut resolver = NameResolver::new(fn_defs);
resolver.start_scope();
let result = directives
.into_iter()
.map(|directive| run_directive(&mut resolver, directive))
.map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>()?;
resolver.end_scope();
Ok(result)
}
fn run_directive<'input, 'b>(
resolver: &mut NameResolver<'input, 'b>,
resolver: &mut ScopedResolver<'input, 'b>,
directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
) -> Result<NormalizedDirective2<'input>, TranslateError> {
Ok(match directive {
@ -30,7 +30,7 @@ fn run_directive<'input, 'b>(
}
fn run_method<'input, 'b>(
resolver: &mut NameResolver<'input, 'b>,
resolver: &mut ScopedResolver<'input, 'b>,
linkage: ast::LinkingDirective,
method: ast::Function<'input, &'input str, ast::Statement<ast::ParsedOperand<&'input str>>>,
) -> Result<NormalizedFunction2<'input>, TranslateError> {
@ -41,11 +41,7 @@ fn run_method<'input, 'b>(
}
};
resolver.start_scope();
let func_decl = Rc::new(RefCell::new(run_function_decl(
resolver,
method.func_directive,
name,
)?));
let func_decl = run_function_decl(resolver, method.func_directive, name)?;
let body = method
.body
.map(|statements| {
@ -66,7 +62,7 @@ fn run_method<'input, 'b>(
}
fn run_function_decl<'input, 'b>(
resolver: &mut NameResolver<'input, 'b>,
resolver: &mut ScopedResolver<'input, 'b>,
func_directive: ast::MethodDeclaration<'input, &'input str>,
name: ast::MethodName<'input, SpirvWord>,
) -> Result<ast::MethodDeclaration<'input, SpirvWord>, TranslateError> {
@ -90,7 +86,7 @@ fn run_function_decl<'input, 'b>(
}
fn run_variable<'input, 'b>(
resolver: &mut NameResolver<'input, 'b>,
resolver: &mut ScopedResolver<'input, 'b>,
variable: ast::Variable<&'input str>,
) -> Result<ast::Variable<SpirvWord>, TranslateError> {
Ok(ast::Variable {
@ -106,7 +102,7 @@ fn run_variable<'input, 'b>(
}
fn run_statements<'input, 'b>(
resolver: &mut NameResolver<'input, 'b>,
resolver: &mut ScopedResolver<'input, 'b>,
result: &mut Vec<NormalizedStatement>,
statements: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
) -> Result<(), TranslateError> {
@ -148,7 +144,7 @@ fn run_statements<'input, 'b>(
}
fn run_instruction<'input, 'b>(
resolver: &mut NameResolver<'input, 'b>,
resolver: &mut ScopedResolver<'input, 'b>,
instruction: ast::Instruction<ast::ParsedOperand<&'input str>>,
) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
ast::visit_map(instruction, &mut |name: &'input str,
@ -163,7 +159,7 @@ fn run_instruction<'input, 'b>(
}
fn run_multivariable<'input, 'b>(
resolver: &mut NameResolver<'input, 'b>,
resolver: &mut ScopedResolver<'input, 'b>,
result: &mut Vec<NormalizedStatement>,
variable: ast::MultiVariable<&'input str>,
) -> Result<(), TranslateError> {
@ -201,86 +197,3 @@ fn run_multivariable<'input, 'b>(
}
Ok(())
}
struct NameResolver<'input, 'b> {
flat_resolver: &'b mut GlobalStringIdentResolver2<'input>,
scopes: Vec<ScopeStringIdentResolver<'input>>,
}
impl<'input, 'b> NameResolver<'input, 'b> {
fn new(flat_resolver: &'b mut GlobalStringIdentResolver2<'input>) -> Self {
Self {
flat_resolver,
scopes: vec![ScopeStringIdentResolver::new()],
}
}
fn start_scope(&mut self) {
self.scopes.push(ScopeStringIdentResolver::new());
}
fn end_scope(&mut self) {
let scope = self.scopes.pop().unwrap();
scope.flush(self.flat_resolver);
}
fn add(
&mut self,
name: Cow<'input, str>,
type_space: Option<(ast::Type, ast::StateSpace)>,
) -> Result<SpirvWord, TranslateError> {
let result = self.flat_resolver.current_id;
self.flat_resolver.current_id.0 += 1;
let current_scope = self.scopes.last_mut().unwrap();
if current_scope
.name_to_ident
.insert(name.clone(), result)
.is_some()
{
return Err(error_unknown_symbol());
}
current_scope.ident_map.insert(
result,
IdentEntry {
name: Some(name),
type_space,
},
);
Ok(result)
}
fn get(&mut self, name: &str) -> Result<SpirvWord, TranslateError> {
self.scopes
.iter()
.rev()
.find_map(|resolver| resolver.name_to_ident.get(name).copied())
.ok_or_else(|| error_unreachable())
}
fn get_in_current_scope(&self, label: &'input str) -> Result<SpirvWord, TranslateError> {
let current_scope = self.scopes.last().unwrap();
current_scope
.name_to_ident
.get(label)
.copied()
.ok_or_else(|| error_unreachable())
}
}
struct ScopeStringIdentResolver<'input> {
ident_map: FxHashMap<SpirvWord, IdentEntry<'input>>,
name_to_ident: FxHashMap<Cow<'input, str>, SpirvWord>,
}
impl<'input> ScopeStringIdentResolver<'input> {
fn new() -> Self {
Self {
ident_map: FxHashMap::default(),
name_to_ident: FxHashMap::default(),
}
}
fn flush(self, resolver: &mut GlobalStringIdentResolver2<'input>) {
resolver.ident_map.extend(self.ident_map);
}
}

View file

@ -26,6 +26,7 @@ pub(super) fn run(
| Statement::Constant(..)
| Statement::Label(..)
| Statement::PtrAccess { .. }
| Statement::VectorAccess { .. }
| Statement::RepackVector(..)
| Statement::FunctionPointer(..) => {}
}

View file

@ -55,8 +55,8 @@ fn run_statement<'input>(
Statement::Variable(var) => result.push(Statement::Variable(var)),
Statement::Instruction((predicate, instruction)) => {
if let Some(pred) = predicate {
let if_true = resolver.register_intermediate(None);
let if_false = resolver.register_intermediate(None);
let if_true = resolver.register_unnamed(None);
let if_false = resolver.register_unnamed(None);
let folded_bra = match &instruction {
ast::Instruction::Bra { arguments, .. } => Some(arguments.src),
_ => None,

View file

@ -20,7 +20,7 @@ fn run_directive<'input>(
var @ Directive2::Variable(..) => var,
Directive2::Method(method) => {
{
let func_decl = method.func_decl.borrow();
let func_decl = &method.func_decl;
match func_decl.name {
ptx_parser::MethodName::Kernel(_) => {}
ptx_parser::MethodName::Func(name) => {