Port remaining two passes

This commit is contained in:
Andrzej Janik 2024-08-28 01:52:54 +02:00
parent c088cc2171
commit 144f8bd5ed
5 changed files with 791 additions and 59 deletions

View file

@ -0,0 +1,282 @@
use super::*;
pub(super) fn run<'input, 'b>(
sorted_statements: Vec<ExpandedStatement>,
ptx_impl_imports: &mut HashMap<String, Directive>,
id_def: &mut NumericIdResolver,
) -> Result<(Vec<ExpandedStatement>, Vec<ast::Variable<SpirvWord>>), TranslateError> {
let mut local = Vec::with_capacity(sorted_statements.len());
let mut global = Vec::new();
for statement in sorted_statements {
match statement {
Statement::Variable(
var @ ast::Variable {
state_space: ast::StateSpace::Shared,
..
},
)
| Statement::Variable(
var @ ast::Variable {
state_space: ast::StateSpace::Global,
..
},
) => global.push(var),
Statement::Instruction(ast::Instruction::Bfe { data, arguments }) => {
let fn_name = [ZLUDA_PTX_PREFIX, "bfe_", scalar_to_ptx_name(data)].concat();
local.push(instruction_to_fn_call(
id_def,
ptx_impl_imports,
ast::Instruction::Bfe { data, arguments },
fn_name,
)?);
}
Statement::Instruction(ast::Instruction::Bfi { data, arguments }) => {
let fn_name = [ZLUDA_PTX_PREFIX, "bfi_", scalar_to_ptx_name(data)].concat();
local.push(instruction_to_fn_call(
id_def,
ptx_impl_imports,
ast::Instruction::Bfi { data, arguments },
fn_name,
)?);
}
Statement::Instruction(ast::Instruction::Brev { data, arguments }) => {
let fn_name: String =
[ZLUDA_PTX_PREFIX, "brev_", scalar_to_ptx_name(data)].concat();
local.push(instruction_to_fn_call(
id_def,
ptx_impl_imports,
ast::Instruction::Brev { data, arguments },
fn_name,
)?);
}
Statement::Instruction(ast::Instruction::Activemask { arguments }) => {
let fn_name = [ZLUDA_PTX_PREFIX, "activemask"].concat();
local.push(instruction_to_fn_call(
id_def,
ptx_impl_imports,
ast::Instruction::Activemask { arguments },
fn_name,
)?);
}
Statement::Instruction(ast::Instruction::Atom {
data:
data @ ast::AtomDetails {
op: ast::AtomicOp::IncrementWrap,
semantics,
scope,
space,
..
},
arguments,
}) => {
let fn_name = [
ZLUDA_PTX_PREFIX,
"atom_",
semantics_to_ptx_name(semantics),
"_",
scope_to_ptx_name(scope),
"_",
space_to_ptx_name(space),
"_inc",
]
.concat();
local.push(instruction_to_fn_call(
id_def,
ptx_impl_imports,
ast::Instruction::Atom { data, arguments },
fn_name,
)?);
}
Statement::Instruction(ast::Instruction::Atom {
data:
data @ ast::AtomDetails {
op: ast::AtomicOp::DecrementWrap,
semantics,
scope,
space,
..
},
arguments,
}) => {
let fn_name = [
ZLUDA_PTX_PREFIX,
"atom_",
semantics_to_ptx_name(semantics),
"_",
scope_to_ptx_name(scope),
"_",
space_to_ptx_name(space),
"_dec",
]
.concat();
local.push(instruction_to_fn_call(
id_def,
ptx_impl_imports,
ast::Instruction::Atom { data, arguments },
fn_name,
)?);
}
Statement::Instruction(ast::Instruction::Atom {
data:
data @ ast::AtomDetails {
op: ast::AtomicOp::FloatAdd,
semantics,
scope,
space,
..
},
arguments,
}) => {
let scalar_type = match data.type_ {
ptx_parser::Type::Scalar(scalar) => scalar,
_ => return Err(error_unreachable()),
};
let fn_name = [
ZLUDA_PTX_PREFIX,
"atom_",
semantics_to_ptx_name(semantics),
"_",
scope_to_ptx_name(scope),
"_",
space_to_ptx_name(space),
"_add_",
scalar_to_ptx_name(scalar_type),
]
.concat();
local.push(instruction_to_fn_call(
id_def,
ptx_impl_imports,
ast::Instruction::Atom { data, arguments },
fn_name,
)?);
}
s => local.push(s),
}
}
Ok((local, global))
}
fn instruction_to_fn_call(
id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>,
inst: ast::Instruction<SpirvWord>,
fn_name: String,
) -> Result<ExpandedStatement, TranslateError> {
let mut arguments = Vec::new();
ast::visit_map(inst, &mut |operand,
type_space: Option<(
&ast::Type,
ast::StateSpace,
)>,
is_dst,
_| {
let (typ, space) = match type_space {
Some((typ, space)) => (typ.clone(), space),
None => return Err(error_unreachable()),
};
arguments.push((operand, is_dst, typ, space));
Ok(SpirvWord(0))
})?;
let return_arguments_count = arguments
.iter()
.position(|(desc, is_dst, _, _)| !is_dst)
.unwrap_or(arguments.len());
let (return_arguments, input_arguments) = arguments.split_at(return_arguments_count);
let fn_id = register_external_fn_call(
id_defs,
ptx_impl_imports,
fn_name,
return_arguments
.iter()
.map(|(_, _, typ, state)| (typ, *state)),
input_arguments
.iter()
.map(|(_, _, typ, state)| (typ, *state)),
)?;
Ok(Statement::Instruction(ast::Instruction::Call {
data: ast::CallDetails {
uniform: false,
return_arguments: return_arguments
.iter()
.map(|(_, _, typ, state)| (typ.clone(), *state))
.collect::<Vec<_>>(),
input_arguments: input_arguments
.iter()
.map(|(_, _, typ, state)| (typ.clone(), *state))
.collect::<Vec<_>>(),
},
arguments: ast::CallArgs {
return_arguments: return_arguments
.iter()
.map(|(name, _, _, _)| *name)
.collect::<Vec<_>>(),
func: fn_id,
input_arguments: input_arguments
.iter()
.map(|(name, _, _, _)| *name)
.collect::<Vec<_>>(),
},
}))
}
fn scalar_to_ptx_name(this: ast::ScalarType) -> &'static str {
match this {
ast::ScalarType::B8 => "b8",
ast::ScalarType::B16 => "b16",
ast::ScalarType::B32 => "b32",
ast::ScalarType::B64 => "b64",
ast::ScalarType::B128 => "b128",
ast::ScalarType::U8 => "u8",
ast::ScalarType::U16 => "u16",
ast::ScalarType::U16x2 => "u16x2",
ast::ScalarType::U32 => "u32",
ast::ScalarType::U64 => "u64",
ast::ScalarType::S8 => "s8",
ast::ScalarType::S16 => "s16",
ast::ScalarType::S16x2 => "s16x2",
ast::ScalarType::S32 => "s32",
ast::ScalarType::S64 => "s64",
ast::ScalarType::F16 => "f16",
ast::ScalarType::F16x2 => "f16x2",
ast::ScalarType::F32 => "f32",
ast::ScalarType::F64 => "f64",
ast::ScalarType::BF16 => "bf16",
ast::ScalarType::BF16x2 => "bf16x2",
ast::ScalarType::Pred => "pred",
}
}
fn semantics_to_ptx_name(this: ast::AtomSemantics) -> &'static str {
match this {
ast::AtomSemantics::Relaxed => "relaxed",
ast::AtomSemantics::Acquire => "acquire",
ast::AtomSemantics::Release => "release",
ast::AtomSemantics::AcqRel => "acq_rel",
}
}
fn scope_to_ptx_name(this: ast::MemScope) -> &'static str {
match this {
ast::MemScope::Cta => "cta",
ast::MemScope::Gpu => "gpu",
ast::MemScope::Sys => "sys",
ast::MemScope::Cluster => "cluster",
}
}
fn space_to_ptx_name(this: ast::StateSpace) -> &'static str {
match this {
ast::StateSpace::Generic => "generic",
ast::StateSpace::Global => "global",
ast::StateSpace::Shared => "shared",
ast::StateSpace::Reg => "reg",
ast::StateSpace::Const => "const",
ast::StateSpace::Local => "local",
ast::StateSpace::Param => "param",
ast::StateSpace::Sreg => "sreg",
ast::StateSpace::SharedCluster => "shared_cluster",
ast::StateSpace::ParamEntry => "param_entry",
ast::StateSpace::SharedCta => "shared_cta",
ast::StateSpace::ParamFunc => "param_func",
}
}

View file

@ -128,56 +128,3 @@ impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> {
}
}
}
fn register_external_fn_call<'a>(
id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>,
name: String,
return_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
input_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
) -> Result<SpirvWord, TranslateError> {
match ptx_impl_imports.entry(name) {
hash_map::Entry::Vacant(entry) => {
let fn_id = id_defs.register_intermediate(None);
let return_arguments = fn_arguments_to_variables(id_defs, return_arguments);
let input_arguments = fn_arguments_to_variables(id_defs, input_arguments);
let func_decl = ast::MethodDeclaration::<SpirvWord> {
return_arguments,
name: ast::MethodName::Func(fn_id),
input_arguments,
shared_mem: None,
};
let func = Function {
func_decl: Rc::new(RefCell::new(func_decl)),
globals: Vec::new(),
body: None,
import_as: Some(entry.key().clone()),
tuning: Vec::new(),
linkage: ast::LinkingDirective::EXTERN,
};
entry.insert(Directive::Method(func));
Ok(fn_id)
}
hash_map::Entry::Occupied(entry) => match entry.get() {
Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
ast::MethodName::Func(fn_id) => Ok(fn_id),
ast::MethodName::Kernel(_) => Err(error_unreachable()),
},
_ => Err(error_unreachable()),
},
}
}
fn fn_arguments_to_variables<'a>(
id_defs: &mut NumericIdResolver,
args: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
) -> Vec<ast::Variable<SpirvWord>> {
args.map(|(typ, space)| ast::Variable {
align: None,
v_type: typ.clone(),
state_space: space,
name: id_defs.register_intermediate(None),
array_init: Vec::new(),
})
.collect::<Vec<_>>()
}

View file

@ -0,0 +1,402 @@
use std::mem;
use super::*;
use ptx_parser as ast;
/*
There are several kinds of implicit conversions in PTX:
* auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands
* special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size
- ld.param: not documented, but for instruction `ld.param.<type> x, [y]`,
semantics are to first zext/chop/bitcast `y` as needed and then do
documented special ld/st/cvt conversion rules for destination operands
- st.param [x] y (used as function return arguments) same rule as above applies
- generic/global ld: for instruction `ld x, [y]`, y must be of type
b64/u64/s64, which is bitcast to a pointer, dereferenced and then
documented special ld/st/cvt conversion rules are applied to dst
- generic/global st: for instruction `st [x], y`, x must be of type
b64/u64/s64, which is bitcast to a pointer
*/
pub(super) fn run(
func: Vec<ExpandedStatement>,
id_def: &mut MutableNumericIdResolver,
) -> Result<Vec<ExpandedStatement>, TranslateError> {
let mut result = Vec::with_capacity(func.len());
for s in func.into_iter() {
match s {
Statement::Instruction(inst) => {
insert_implicit_conversions_impl(
&mut result,
id_def,
Statement::Instruction(inst),
)?;
}
Statement::PtrAccess(access) => {
insert_implicit_conversions_impl(
&mut result,
id_def,
Statement::PtrAccess(access),
)?;
}
Statement::RepackVector(repack) => {
insert_implicit_conversions_impl(
&mut result,
id_def,
Statement::RepackVector(repack),
)?;
}
s @ Statement::Conditional(_)
| s @ Statement::Conversion(_)
| s @ Statement::Label(_)
| s @ Statement::Constant(_)
| s @ Statement::Variable(_)
| s @ Statement::LoadVar(..)
| s @ Statement::StoreVar(..)
| s @ Statement::RetValue(..)
| s @ Statement::FunctionPointer(..) => result.push(s),
}
}
Ok(result)
}
fn insert_implicit_conversions_impl(
func: &mut Vec<ExpandedStatement>,
id_def: &mut MutableNumericIdResolver,
stmt: ExpandedStatement,
) -> Result<(), TranslateError> {
let mut post_conv = Vec::new();
let statement = stmt.visit_map::<SpirvWord, TranslateError>(
&mut |operand,
type_state: Option<(&ast::Type, ast::StateSpace)>,
is_dst,
relaxed_type_check| {
let (instr_type, instruction_space) = match type_state {
None => return Ok(operand),
Some(t) => t,
};
let (operand_type, operand_space) = id_def.get_typed(operand)?;
let conversion_fn = if relaxed_type_check {
if is_dst {
should_convert_relaxed_dst_wrapper
} else {
should_convert_relaxed_src_wrapper
}
} else {
default_implicit_conversion
};
match conversion_fn(
(operand_space, &operand_type),
(instruction_space, instr_type),
)? {
Some(conv_kind) => {
let conv_output = if is_dst { &mut post_conv } else { &mut *func };
let mut from_type = instr_type.clone();
let mut from_space = instruction_space;
let mut to_type = operand_type;
let mut to_space = operand_space;
let mut src =
id_def.register_intermediate(instr_type.clone(), instruction_space);
let mut dst = operand;
let result = Ok::<_, TranslateError>(src);
if !is_dst {
mem::swap(&mut src, &mut dst);
mem::swap(&mut from_type, &mut to_type);
mem::swap(&mut from_space, &mut to_space);
}
conv_output.push(Statement::Conversion(ImplicitConversion {
src,
dst,
from_type,
from_space,
to_type,
to_space,
kind: conv_kind,
}));
result
}
None => Ok(operand),
}
},
)?;
func.push(statement);
func.append(&mut post_conv);
Ok(())
}
fn default_implicit_conversion(
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
if !state_is_compatible(instruction_space, operand_space) {
default_implicit_conversion_space(
(operand_space, operand_type),
(instruction_space, instruction_type),
)
} else if instruction_type != operand_type {
default_implicit_conversion_type(instruction_space, operand_type, instruction_type)
} else {
Ok(None)
}
}
// Space is different
fn default_implicit_conversion_space(
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space))
|| (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space))
{
Ok(Some(ConversionKind::PtrToPtr))
} else if state_is_compatible(operand_space, ast::StateSpace::Reg) {
match operand_type {
ast::Type::Pointer(operand_ptr_type, operand_ptr_space)
if *operand_ptr_space == instruction_space =>
{
if instruction_type != &ast::Type::Scalar(*operand_ptr_type) {
Ok(Some(ConversionKind::PtrToPtr))
} else {
Ok(None)
}
}
// TODO: 32 bit
ast::Type::Scalar(ast::ScalarType::B64)
| ast::Type::Scalar(ast::ScalarType::U64)
| ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space {
ast::StateSpace::Global
| ast::StateSpace::Generic
| ast::StateSpace::Const
| ast::StateSpace::Local
| ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
_ => Err(TranslateError::MismatchedType),
},
ast::Type::Scalar(ast::ScalarType::B32)
| ast::Type::Scalar(ast::ScalarType::U32)
| ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space {
ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
Ok(Some(ConversionKind::BitToPtr))
}
_ => Err(TranslateError::MismatchedType),
},
_ => Err(TranslateError::MismatchedType),
}
} else if state_is_compatible(instruction_space, ast::StateSpace::Reg) {
match instruction_type {
ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space)
if operand_space == *instruction_ptr_space =>
{
if operand_type != &ast::Type::Scalar(*instruction_ptr_type) {
Ok(Some(ConversionKind::PtrToPtr))
} else {
Ok(None)
}
}
_ => Err(TranslateError::MismatchedType),
}
} else {
Err(TranslateError::MismatchedType)
}
}
// Space is same, but type is different
fn default_implicit_conversion_type(
space: ast::StateSpace,
operand_type: &ast::Type,
instruction_type: &ast::Type,
) -> Result<Option<ConversionKind>, TranslateError> {
if state_is_compatible(space, ast::StateSpace::Reg) {
if should_bitcast(instruction_type, operand_type) {
Ok(Some(ConversionKind::Default))
} else {
Err(TranslateError::MismatchedType)
}
} else {
Ok(Some(ConversionKind::PtrToPtr))
}
}
fn coerces_to_generic(this: ast::StateSpace) -> bool {
match this {
ast::StateSpace::Global
| ast::StateSpace::Const
| ast::StateSpace::Local
| ptx_parser::StateSpace::SharedCta
| ast::StateSpace::SharedCluster
| ast::StateSpace::Shared => true,
ast::StateSpace::Reg
| ast::StateSpace::Param
| ast::StateSpace::ParamEntry
| ast::StateSpace::ParamFunc
| ast::StateSpace::Generic
| ast::StateSpace::Sreg => false,
}
}
fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
match (instr, operand) {
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
if inst.size_of() != operand.size_of() {
return false;
}
match inst.kind() {
ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit,
ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit,
ast::ScalarKind::Signed => {
operand.kind() == ast::ScalarKind::Bit
|| operand.kind() == ast::ScalarKind::Unsigned
}
ast::ScalarKind::Unsigned => {
operand.kind() == ast::ScalarKind::Bit
|| operand.kind() == ast::ScalarKind::Signed
}
ast::ScalarKind::Pred => false,
}
}
(ast::Type::Vector(inst, _), ast::Type::Vector(operand, _))
| (ast::Type::Array(inst, _), ast::Type::Array(operand, _)) => {
should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand))
}
_ => false,
}
}
fn should_convert_relaxed_dst_wrapper(
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
if !state_is_compatible(operand_space, instruction_space) {
return Err(TranslateError::MismatchedType);
}
if operand_type == instruction_type {
return Ok(None);
}
match should_convert_relaxed_dst(operand_type, instruction_type) {
conv @ Some(_) => Ok(conv),
None => Err(TranslateError::MismatchedType),
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
fn should_convert_relaxed_dst(
dst_type: &ast::Type,
instr_type: &ast::Type,
) -> Option<ConversionKind> {
if dst_type == instr_type {
return None;
}
match (dst_type, instr_type) {
(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
ast::ScalarKind::Bit => {
if instr_type.size_of() <= dst_type.size_of() {
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Signed => {
if dst_type.kind() != ast::ScalarKind::Float {
if instr_type.size_of() == dst_type.size_of() {
Some(ConversionKind::Default)
} else if instr_type.size_of() < dst_type.size_of() {
Some(ConversionKind::SignExtend)
} else {
None
}
} else {
None
}
}
ast::ScalarKind::Unsigned => {
if instr_type.size_of() <= dst_type.size_of()
&& dst_type.kind() != ast::ScalarKind::Float
{
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Float => {
if instr_type.size_of() <= dst_type.size_of()
&& dst_type.kind() == ast::ScalarKind::Bit
{
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Pred => None,
},
(ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
| (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {
should_convert_relaxed_dst(
&ast::Type::Scalar(*dst_type),
&ast::Type::Scalar(*instr_type),
)
}
_ => None,
}
}
fn should_convert_relaxed_src_wrapper(
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
if !state_is_compatible(operand_space, instruction_space) {
return Err(TranslateError::MismatchedType);
}
if operand_type == instruction_type {
return Ok(None);
}
match should_convert_relaxed_src(operand_type, instruction_type) {
conv @ Some(_) => Ok(conv),
None => Err(TranslateError::MismatchedType),
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
fn should_convert_relaxed_src(
src_type: &ast::Type,
instr_type: &ast::Type,
) -> Option<ConversionKind> {
if src_type == instr_type {
return None;
}
match (src_type, instr_type) {
(ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
ast::ScalarKind::Bit => {
if instr_type.size_of() <= src_type.size_of() {
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => {
if instr_type.size_of() <= src_type.size_of()
&& src_type.kind() != ast::ScalarKind::Float
{
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Float => {
if instr_type.size_of() <= src_type.size_of()
&& src_type.kind() == ast::ScalarKind::Bit
{
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Pred => None,
},
(ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
| (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {
should_convert_relaxed_src(
&ast::Type::Scalar(*dst_type),
&ast::Type::Scalar(*instr_type),
)
}
_ => None,
}
}

View file

@ -16,6 +16,9 @@ mod fix_special_registers;
mod insert_mem_ssa_statements;
mod normalize_identifiers;
mod normalize_predicates;
mod insert_implicit_conversions;
mod normalize_labels;
mod extract_globals;
static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv");
static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
@ -184,14 +187,12 @@ fn to_ssa<'input, 'b>(
)?;
let mut numeric_id_defs = numeric_id_defs.finish();
let expanded_statements = expand_arguments::run(ssa_statements, &mut numeric_id_defs)?;
todo!()
/*
let expanded_statements =
insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?;
insert_implicit_conversions::run(expanded_statements, &mut numeric_id_defs)?;
let mut numeric_id_defs = numeric_id_defs.unmut();
let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs);
let labeled_statements = normalize_labels::run(expanded_statements, &mut numeric_id_defs);
let (f_body, globals) =
extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?;
extract_globals::run(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?;
Ok(Function {
func_decl: func_decl,
globals: globals,
@ -200,7 +201,6 @@ fn to_ssa<'input, 'b>(
tuning,
linkage,
})
*/
}
pub struct Module {
@ -1220,3 +1220,56 @@ fn state_is_compatible(this: ast::StateSpace, other: ast::StateSpace) -> bool {
|| this == ast::StateSpace::Reg && other == ast::StateSpace::Sreg
|| this == ast::StateSpace::Sreg && other == ast::StateSpace::Reg
}
fn register_external_fn_call<'a>(
id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>,
name: String,
return_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
input_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
) -> Result<SpirvWord, TranslateError> {
match ptx_impl_imports.entry(name) {
hash_map::Entry::Vacant(entry) => {
let fn_id = id_defs.register_intermediate(None);
let return_arguments = fn_arguments_to_variables(id_defs, return_arguments);
let input_arguments = fn_arguments_to_variables(id_defs, input_arguments);
let func_decl = ast::MethodDeclaration::<SpirvWord> {
return_arguments,
name: ast::MethodName::Func(fn_id),
input_arguments,
shared_mem: None,
};
let func = Function {
func_decl: Rc::new(RefCell::new(func_decl)),
globals: Vec::new(),
body: None,
import_as: Some(entry.key().clone()),
tuning: Vec::new(),
linkage: ast::LinkingDirective::EXTERN,
};
entry.insert(Directive::Method(func));
Ok(fn_id)
}
hash_map::Entry::Occupied(entry) => match entry.get() {
Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
ast::MethodName::Func(fn_id) => Ok(fn_id),
ast::MethodName::Kernel(_) => Err(error_unreachable()),
},
_ => Err(error_unreachable()),
},
}
}
fn fn_arguments_to_variables<'a>(
id_defs: &mut NumericIdResolver,
args: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
) -> Vec<ast::Variable<SpirvWord>> {
args.map(|(typ, space)| ast::Variable {
align: None,
v_type: typ.clone(),
state_space: space,
name: id_defs.register_intermediate(None),
array_init: Vec::new(),
})
.collect::<Vec<_>>()
}

View file

@ -0,0 +1,48 @@
use std::{collections::HashSet, iter};
use super::*;
pub(super) fn run(
func: Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver,
) -> Vec<ExpandedStatement> {
let mut labels_in_use = HashSet::new();
for s in func.iter() {
match s {
Statement::Instruction(i) => {
if let Some(target) = jump_target(i) {
labels_in_use.insert(target);
}
}
Statement::Conditional(cond) => {
labels_in_use.insert(cond.if_true);
labels_in_use.insert(cond.if_false);
}
Statement::Variable(..)
| Statement::LoadVar(..)
| Statement::StoreVar(..)
| Statement::RetValue(..)
| Statement::Conversion(..)
| Statement::Constant(..)
| Statement::Label(..)
| Statement::PtrAccess { .. }
| Statement::RepackVector(..)
| Statement::FunctionPointer(..) => {}
}
}
iter::once(Statement::Label(id_def.register_intermediate(None)))
.chain(func.into_iter().filter(|s| match s {
Statement::Label(i) => labels_in_use.contains(i),
_ => true,
}))
.collect::<Vec<_>>()
}
fn jump_target<T: ast::Operand<Ident = SpirvWord>>(
this: &ast::Instruction<T>,
) -> Option<SpirvWord> {
match this {
ast::Instruction::Bra { arguments } => Some(arguments.src),
_ => None,
}
}