mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Port remaining two passes
This commit is contained in:
parent
c088cc2171
commit
144f8bd5ed
5 changed files with 791 additions and 59 deletions
282
ptx/src/pass/extract_globals.rs
Normal file
282
ptx/src/pass/extract_globals.rs
Normal file
|
@ -0,0 +1,282 @@
|
|||
use super::*;
|
||||
|
||||
pub(super) fn run<'input, 'b>(
|
||||
sorted_statements: Vec<ExpandedStatement>,
|
||||
ptx_impl_imports: &mut HashMap<String, Directive>,
|
||||
id_def: &mut NumericIdResolver,
|
||||
) -> Result<(Vec<ExpandedStatement>, Vec<ast::Variable<SpirvWord>>), TranslateError> {
|
||||
let mut local = Vec::with_capacity(sorted_statements.len());
|
||||
let mut global = Vec::new();
|
||||
for statement in sorted_statements {
|
||||
match statement {
|
||||
Statement::Variable(
|
||||
var @ ast::Variable {
|
||||
state_space: ast::StateSpace::Shared,
|
||||
..
|
||||
},
|
||||
)
|
||||
| Statement::Variable(
|
||||
var @ ast::Variable {
|
||||
state_space: ast::StateSpace::Global,
|
||||
..
|
||||
},
|
||||
) => global.push(var),
|
||||
Statement::Instruction(ast::Instruction::Bfe { data, arguments }) => {
|
||||
let fn_name = [ZLUDA_PTX_PREFIX, "bfe_", scalar_to_ptx_name(data)].concat();
|
||||
local.push(instruction_to_fn_call(
|
||||
id_def,
|
||||
ptx_impl_imports,
|
||||
ast::Instruction::Bfe { data, arguments },
|
||||
fn_name,
|
||||
)?);
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Bfi { data, arguments }) => {
|
||||
let fn_name = [ZLUDA_PTX_PREFIX, "bfi_", scalar_to_ptx_name(data)].concat();
|
||||
local.push(instruction_to_fn_call(
|
||||
id_def,
|
||||
ptx_impl_imports,
|
||||
ast::Instruction::Bfi { data, arguments },
|
||||
fn_name,
|
||||
)?);
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Brev { data, arguments }) => {
|
||||
let fn_name: String =
|
||||
[ZLUDA_PTX_PREFIX, "brev_", scalar_to_ptx_name(data)].concat();
|
||||
local.push(instruction_to_fn_call(
|
||||
id_def,
|
||||
ptx_impl_imports,
|
||||
ast::Instruction::Brev { data, arguments },
|
||||
fn_name,
|
||||
)?);
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Activemask { arguments }) => {
|
||||
let fn_name = [ZLUDA_PTX_PREFIX, "activemask"].concat();
|
||||
local.push(instruction_to_fn_call(
|
||||
id_def,
|
||||
ptx_impl_imports,
|
||||
ast::Instruction::Activemask { arguments },
|
||||
fn_name,
|
||||
)?);
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Atom {
|
||||
data:
|
||||
data @ ast::AtomDetails {
|
||||
op: ast::AtomicOp::IncrementWrap,
|
||||
semantics,
|
||||
scope,
|
||||
space,
|
||||
..
|
||||
},
|
||||
arguments,
|
||||
}) => {
|
||||
let fn_name = [
|
||||
ZLUDA_PTX_PREFIX,
|
||||
"atom_",
|
||||
semantics_to_ptx_name(semantics),
|
||||
"_",
|
||||
scope_to_ptx_name(scope),
|
||||
"_",
|
||||
space_to_ptx_name(space),
|
||||
"_inc",
|
||||
]
|
||||
.concat();
|
||||
local.push(instruction_to_fn_call(
|
||||
id_def,
|
||||
ptx_impl_imports,
|
||||
ast::Instruction::Atom { data, arguments },
|
||||
fn_name,
|
||||
)?);
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Atom {
|
||||
data:
|
||||
data @ ast::AtomDetails {
|
||||
op: ast::AtomicOp::DecrementWrap,
|
||||
semantics,
|
||||
scope,
|
||||
space,
|
||||
..
|
||||
},
|
||||
arguments,
|
||||
}) => {
|
||||
let fn_name = [
|
||||
ZLUDA_PTX_PREFIX,
|
||||
"atom_",
|
||||
semantics_to_ptx_name(semantics),
|
||||
"_",
|
||||
scope_to_ptx_name(scope),
|
||||
"_",
|
||||
space_to_ptx_name(space),
|
||||
"_dec",
|
||||
]
|
||||
.concat();
|
||||
local.push(instruction_to_fn_call(
|
||||
id_def,
|
||||
ptx_impl_imports,
|
||||
ast::Instruction::Atom { data, arguments },
|
||||
fn_name,
|
||||
)?);
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Atom {
|
||||
data:
|
||||
data @ ast::AtomDetails {
|
||||
op: ast::AtomicOp::FloatAdd,
|
||||
semantics,
|
||||
scope,
|
||||
space,
|
||||
..
|
||||
},
|
||||
arguments,
|
||||
}) => {
|
||||
let scalar_type = match data.type_ {
|
||||
ptx_parser::Type::Scalar(scalar) => scalar,
|
||||
_ => return Err(error_unreachable()),
|
||||
};
|
||||
let fn_name = [
|
||||
ZLUDA_PTX_PREFIX,
|
||||
"atom_",
|
||||
semantics_to_ptx_name(semantics),
|
||||
"_",
|
||||
scope_to_ptx_name(scope),
|
||||
"_",
|
||||
space_to_ptx_name(space),
|
||||
"_add_",
|
||||
scalar_to_ptx_name(scalar_type),
|
||||
]
|
||||
.concat();
|
||||
local.push(instruction_to_fn_call(
|
||||
id_def,
|
||||
ptx_impl_imports,
|
||||
ast::Instruction::Atom { data, arguments },
|
||||
fn_name,
|
||||
)?);
|
||||
}
|
||||
s => local.push(s),
|
||||
}
|
||||
}
|
||||
Ok((local, global))
|
||||
}
|
||||
|
||||
fn instruction_to_fn_call(
|
||||
id_defs: &mut NumericIdResolver,
|
||||
ptx_impl_imports: &mut HashMap<String, Directive>,
|
||||
inst: ast::Instruction<SpirvWord>,
|
||||
fn_name: String,
|
||||
) -> Result<ExpandedStatement, TranslateError> {
|
||||
let mut arguments = Vec::new();
|
||||
ast::visit_map(inst, &mut |operand,
|
||||
type_space: Option<(
|
||||
&ast::Type,
|
||||
ast::StateSpace,
|
||||
)>,
|
||||
is_dst,
|
||||
_| {
|
||||
let (typ, space) = match type_space {
|
||||
Some((typ, space)) => (typ.clone(), space),
|
||||
None => return Err(error_unreachable()),
|
||||
};
|
||||
arguments.push((operand, is_dst, typ, space));
|
||||
Ok(SpirvWord(0))
|
||||
})?;
|
||||
let return_arguments_count = arguments
|
||||
.iter()
|
||||
.position(|(desc, is_dst, _, _)| !is_dst)
|
||||
.unwrap_or(arguments.len());
|
||||
let (return_arguments, input_arguments) = arguments.split_at(return_arguments_count);
|
||||
let fn_id = register_external_fn_call(
|
||||
id_defs,
|
||||
ptx_impl_imports,
|
||||
fn_name,
|
||||
return_arguments
|
||||
.iter()
|
||||
.map(|(_, _, typ, state)| (typ, *state)),
|
||||
input_arguments
|
||||
.iter()
|
||||
.map(|(_, _, typ, state)| (typ, *state)),
|
||||
)?;
|
||||
Ok(Statement::Instruction(ast::Instruction::Call {
|
||||
data: ast::CallDetails {
|
||||
uniform: false,
|
||||
return_arguments: return_arguments
|
||||
.iter()
|
||||
.map(|(_, _, typ, state)| (typ.clone(), *state))
|
||||
.collect::<Vec<_>>(),
|
||||
input_arguments: input_arguments
|
||||
.iter()
|
||||
.map(|(_, _, typ, state)| (typ.clone(), *state))
|
||||
.collect::<Vec<_>>(),
|
||||
},
|
||||
arguments: ast::CallArgs {
|
||||
return_arguments: return_arguments
|
||||
.iter()
|
||||
.map(|(name, _, _, _)| *name)
|
||||
.collect::<Vec<_>>(),
|
||||
func: fn_id,
|
||||
input_arguments: input_arguments
|
||||
.iter()
|
||||
.map(|(name, _, _, _)| *name)
|
||||
.collect::<Vec<_>>(),
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
fn scalar_to_ptx_name(this: ast::ScalarType) -> &'static str {
|
||||
match this {
|
||||
ast::ScalarType::B8 => "b8",
|
||||
ast::ScalarType::B16 => "b16",
|
||||
ast::ScalarType::B32 => "b32",
|
||||
ast::ScalarType::B64 => "b64",
|
||||
ast::ScalarType::B128 => "b128",
|
||||
ast::ScalarType::U8 => "u8",
|
||||
ast::ScalarType::U16 => "u16",
|
||||
ast::ScalarType::U16x2 => "u16x2",
|
||||
ast::ScalarType::U32 => "u32",
|
||||
ast::ScalarType::U64 => "u64",
|
||||
ast::ScalarType::S8 => "s8",
|
||||
ast::ScalarType::S16 => "s16",
|
||||
ast::ScalarType::S16x2 => "s16x2",
|
||||
ast::ScalarType::S32 => "s32",
|
||||
ast::ScalarType::S64 => "s64",
|
||||
ast::ScalarType::F16 => "f16",
|
||||
ast::ScalarType::F16x2 => "f16x2",
|
||||
ast::ScalarType::F32 => "f32",
|
||||
ast::ScalarType::F64 => "f64",
|
||||
ast::ScalarType::BF16 => "bf16",
|
||||
ast::ScalarType::BF16x2 => "bf16x2",
|
||||
ast::ScalarType::Pred => "pred",
|
||||
}
|
||||
}
|
||||
|
||||
fn semantics_to_ptx_name(this: ast::AtomSemantics) -> &'static str {
|
||||
match this {
|
||||
ast::AtomSemantics::Relaxed => "relaxed",
|
||||
ast::AtomSemantics::Acquire => "acquire",
|
||||
ast::AtomSemantics::Release => "release",
|
||||
ast::AtomSemantics::AcqRel => "acq_rel",
|
||||
}
|
||||
}
|
||||
|
||||
fn scope_to_ptx_name(this: ast::MemScope) -> &'static str {
|
||||
match this {
|
||||
ast::MemScope::Cta => "cta",
|
||||
ast::MemScope::Gpu => "gpu",
|
||||
ast::MemScope::Sys => "sys",
|
||||
ast::MemScope::Cluster => "cluster",
|
||||
}
|
||||
}
|
||||
|
||||
fn space_to_ptx_name(this: ast::StateSpace) -> &'static str {
|
||||
match this {
|
||||
ast::StateSpace::Generic => "generic",
|
||||
ast::StateSpace::Global => "global",
|
||||
ast::StateSpace::Shared => "shared",
|
||||
ast::StateSpace::Reg => "reg",
|
||||
ast::StateSpace::Const => "const",
|
||||
ast::StateSpace::Local => "local",
|
||||
ast::StateSpace::Param => "param",
|
||||
ast::StateSpace::Sreg => "sreg",
|
||||
ast::StateSpace::SharedCluster => "shared_cluster",
|
||||
ast::StateSpace::ParamEntry => "param_entry",
|
||||
ast::StateSpace::SharedCta => "shared_cta",
|
||||
ast::StateSpace::ParamFunc => "param_func",
|
||||
}
|
||||
}
|
|
@ -128,56 +128,3 @@ impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn register_external_fn_call<'a>(
|
||||
id_defs: &mut NumericIdResolver,
|
||||
ptx_impl_imports: &mut HashMap<String, Directive>,
|
||||
name: String,
|
||||
return_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
|
||||
input_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
match ptx_impl_imports.entry(name) {
|
||||
hash_map::Entry::Vacant(entry) => {
|
||||
let fn_id = id_defs.register_intermediate(None);
|
||||
let return_arguments = fn_arguments_to_variables(id_defs, return_arguments);
|
||||
let input_arguments = fn_arguments_to_variables(id_defs, input_arguments);
|
||||
let func_decl = ast::MethodDeclaration::<SpirvWord> {
|
||||
return_arguments,
|
||||
name: ast::MethodName::Func(fn_id),
|
||||
input_arguments,
|
||||
shared_mem: None,
|
||||
};
|
||||
let func = Function {
|
||||
func_decl: Rc::new(RefCell::new(func_decl)),
|
||||
globals: Vec::new(),
|
||||
body: None,
|
||||
import_as: Some(entry.key().clone()),
|
||||
tuning: Vec::new(),
|
||||
linkage: ast::LinkingDirective::EXTERN,
|
||||
};
|
||||
entry.insert(Directive::Method(func));
|
||||
Ok(fn_id)
|
||||
}
|
||||
hash_map::Entry::Occupied(entry) => match entry.get() {
|
||||
Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
|
||||
ast::MethodName::Func(fn_id) => Ok(fn_id),
|
||||
ast::MethodName::Kernel(_) => Err(error_unreachable()),
|
||||
},
|
||||
_ => Err(error_unreachable()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn fn_arguments_to_variables<'a>(
|
||||
id_defs: &mut NumericIdResolver,
|
||||
args: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
|
||||
) -> Vec<ast::Variable<SpirvWord>> {
|
||||
args.map(|(typ, space)| ast::Variable {
|
||||
align: None,
|
||||
v_type: typ.clone(),
|
||||
state_space: space,
|
||||
name: id_defs.register_intermediate(None),
|
||||
array_init: Vec::new(),
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
402
ptx/src/pass/insert_implicit_conversions.rs
Normal file
402
ptx/src/pass/insert_implicit_conversions.rs
Normal file
|
@ -0,0 +1,402 @@
|
|||
use std::mem;
|
||||
|
||||
use super::*;
|
||||
use ptx_parser as ast;
|
||||
|
||||
/*
|
||||
There are several kinds of implicit conversions in PTX:
|
||||
* auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands
|
||||
* special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size
|
||||
- ld.param: not documented, but for instruction `ld.param.<type> x, [y]`,
|
||||
semantics are to first zext/chop/bitcast `y` as needed and then do
|
||||
documented special ld/st/cvt conversion rules for destination operands
|
||||
- st.param [x] y (used as function return arguments) same rule as above applies
|
||||
- generic/global ld: for instruction `ld x, [y]`, y must be of type
|
||||
b64/u64/s64, which is bitcast to a pointer, dereferenced and then
|
||||
documented special ld/st/cvt conversion rules are applied to dst
|
||||
- generic/global st: for instruction `st [x], y`, x must be of type
|
||||
b64/u64/s64, which is bitcast to a pointer
|
||||
*/
|
||||
pub(super) fn run(
|
||||
func: Vec<ExpandedStatement>,
|
||||
id_def: &mut MutableNumericIdResolver,
|
||||
) -> Result<Vec<ExpandedStatement>, TranslateError> {
|
||||
let mut result = Vec::with_capacity(func.len());
|
||||
for s in func.into_iter() {
|
||||
match s {
|
||||
Statement::Instruction(inst) => {
|
||||
insert_implicit_conversions_impl(
|
||||
&mut result,
|
||||
id_def,
|
||||
Statement::Instruction(inst),
|
||||
)?;
|
||||
}
|
||||
Statement::PtrAccess(access) => {
|
||||
insert_implicit_conversions_impl(
|
||||
&mut result,
|
||||
id_def,
|
||||
Statement::PtrAccess(access),
|
||||
)?;
|
||||
}
|
||||
Statement::RepackVector(repack) => {
|
||||
insert_implicit_conversions_impl(
|
||||
&mut result,
|
||||
id_def,
|
||||
Statement::RepackVector(repack),
|
||||
)?;
|
||||
}
|
||||
s @ Statement::Conditional(_)
|
||||
| s @ Statement::Conversion(_)
|
||||
| s @ Statement::Label(_)
|
||||
| s @ Statement::Constant(_)
|
||||
| s @ Statement::Variable(_)
|
||||
| s @ Statement::LoadVar(..)
|
||||
| s @ Statement::StoreVar(..)
|
||||
| s @ Statement::RetValue(..)
|
||||
| s @ Statement::FunctionPointer(..) => result.push(s),
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn insert_implicit_conversions_impl(
|
||||
func: &mut Vec<ExpandedStatement>,
|
||||
id_def: &mut MutableNumericIdResolver,
|
||||
stmt: ExpandedStatement,
|
||||
) -> Result<(), TranslateError> {
|
||||
let mut post_conv = Vec::new();
|
||||
let statement = stmt.visit_map::<SpirvWord, TranslateError>(
|
||||
&mut |operand,
|
||||
type_state: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst,
|
||||
relaxed_type_check| {
|
||||
let (instr_type, instruction_space) = match type_state {
|
||||
None => return Ok(operand),
|
||||
Some(t) => t,
|
||||
};
|
||||
let (operand_type, operand_space) = id_def.get_typed(operand)?;
|
||||
let conversion_fn = if relaxed_type_check {
|
||||
if is_dst {
|
||||
should_convert_relaxed_dst_wrapper
|
||||
} else {
|
||||
should_convert_relaxed_src_wrapper
|
||||
}
|
||||
} else {
|
||||
default_implicit_conversion
|
||||
};
|
||||
match conversion_fn(
|
||||
(operand_space, &operand_type),
|
||||
(instruction_space, instr_type),
|
||||
)? {
|
||||
Some(conv_kind) => {
|
||||
let conv_output = if is_dst { &mut post_conv } else { &mut *func };
|
||||
let mut from_type = instr_type.clone();
|
||||
let mut from_space = instruction_space;
|
||||
let mut to_type = operand_type;
|
||||
let mut to_space = operand_space;
|
||||
let mut src =
|
||||
id_def.register_intermediate(instr_type.clone(), instruction_space);
|
||||
let mut dst = operand;
|
||||
let result = Ok::<_, TranslateError>(src);
|
||||
if !is_dst {
|
||||
mem::swap(&mut src, &mut dst);
|
||||
mem::swap(&mut from_type, &mut to_type);
|
||||
mem::swap(&mut from_space, &mut to_space);
|
||||
}
|
||||
conv_output.push(Statement::Conversion(ImplicitConversion {
|
||||
src,
|
||||
dst,
|
||||
from_type,
|
||||
from_space,
|
||||
to_type,
|
||||
to_space,
|
||||
kind: conv_kind,
|
||||
}));
|
||||
result
|
||||
}
|
||||
None => Ok(operand),
|
||||
}
|
||||
},
|
||||
)?;
|
||||
func.push(statement);
|
||||
func.append(&mut post_conv);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn default_implicit_conversion(
|
||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if !state_is_compatible(instruction_space, operand_space) {
|
||||
default_implicit_conversion_space(
|
||||
(operand_space, operand_type),
|
||||
(instruction_space, instruction_type),
|
||||
)
|
||||
} else if instruction_type != operand_type {
|
||||
default_implicit_conversion_type(instruction_space, operand_type, instruction_type)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
// Space is different
|
||||
fn default_implicit_conversion_space(
|
||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space))
|
||||
|| (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space))
|
||||
{
|
||||
Ok(Some(ConversionKind::PtrToPtr))
|
||||
} else if state_is_compatible(operand_space, ast::StateSpace::Reg) {
|
||||
match operand_type {
|
||||
ast::Type::Pointer(operand_ptr_type, operand_ptr_space)
|
||||
if *operand_ptr_space == instruction_space =>
|
||||
{
|
||||
if instruction_type != &ast::Type::Scalar(*operand_ptr_type) {
|
||||
Ok(Some(ConversionKind::PtrToPtr))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
// TODO: 32 bit
|
||||
ast::Type::Scalar(ast::ScalarType::B64)
|
||||
| ast::Type::Scalar(ast::ScalarType::U64)
|
||||
| ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space {
|
||||
ast::StateSpace::Global
|
||||
| ast::StateSpace::Generic
|
||||
| ast::StateSpace::Const
|
||||
| ast::StateSpace::Local
|
||||
| ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
|
||||
_ => Err(TranslateError::MismatchedType),
|
||||
},
|
||||
ast::Type::Scalar(ast::ScalarType::B32)
|
||||
| ast::Type::Scalar(ast::ScalarType::U32)
|
||||
| ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space {
|
||||
ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
|
||||
Ok(Some(ConversionKind::BitToPtr))
|
||||
}
|
||||
_ => Err(TranslateError::MismatchedType),
|
||||
},
|
||||
_ => Err(TranslateError::MismatchedType),
|
||||
}
|
||||
} else if state_is_compatible(instruction_space, ast::StateSpace::Reg) {
|
||||
match instruction_type {
|
||||
ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space)
|
||||
if operand_space == *instruction_ptr_space =>
|
||||
{
|
||||
if operand_type != &ast::Type::Scalar(*instruction_ptr_type) {
|
||||
Ok(Some(ConversionKind::PtrToPtr))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
_ => Err(TranslateError::MismatchedType),
|
||||
}
|
||||
} else {
|
||||
Err(TranslateError::MismatchedType)
|
||||
}
|
||||
}
|
||||
|
||||
// Space is same, but type is different
|
||||
fn default_implicit_conversion_type(
|
||||
space: ast::StateSpace,
|
||||
operand_type: &ast::Type,
|
||||
instruction_type: &ast::Type,
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if state_is_compatible(space, ast::StateSpace::Reg) {
|
||||
if should_bitcast(instruction_type, operand_type) {
|
||||
Ok(Some(ConversionKind::Default))
|
||||
} else {
|
||||
Err(TranslateError::MismatchedType)
|
||||
}
|
||||
} else {
|
||||
Ok(Some(ConversionKind::PtrToPtr))
|
||||
}
|
||||
}
|
||||
|
||||
fn coerces_to_generic(this: ast::StateSpace) -> bool {
|
||||
match this {
|
||||
ast::StateSpace::Global
|
||||
| ast::StateSpace::Const
|
||||
| ast::StateSpace::Local
|
||||
| ptx_parser::StateSpace::SharedCta
|
||||
| ast::StateSpace::SharedCluster
|
||||
| ast::StateSpace::Shared => true,
|
||||
ast::StateSpace::Reg
|
||||
| ast::StateSpace::Param
|
||||
| ast::StateSpace::ParamEntry
|
||||
| ast::StateSpace::ParamFunc
|
||||
| ast::StateSpace::Generic
|
||||
| ast::StateSpace::Sreg => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
|
||||
match (instr, operand) {
|
||||
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
|
||||
if inst.size_of() != operand.size_of() {
|
||||
return false;
|
||||
}
|
||||
match inst.kind() {
|
||||
ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit,
|
||||
ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit,
|
||||
ast::ScalarKind::Signed => {
|
||||
operand.kind() == ast::ScalarKind::Bit
|
||||
|| operand.kind() == ast::ScalarKind::Unsigned
|
||||
}
|
||||
ast::ScalarKind::Unsigned => {
|
||||
operand.kind() == ast::ScalarKind::Bit
|
||||
|| operand.kind() == ast::ScalarKind::Signed
|
||||
}
|
||||
ast::ScalarKind::Pred => false,
|
||||
}
|
||||
}
|
||||
(ast::Type::Vector(inst, _), ast::Type::Vector(operand, _))
|
||||
| (ast::Type::Array(inst, _), ast::Type::Array(operand, _)) => {
|
||||
should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand))
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn should_convert_relaxed_dst_wrapper(
|
||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if !state_is_compatible(operand_space, instruction_space) {
|
||||
return Err(TranslateError::MismatchedType);
|
||||
}
|
||||
if operand_type == instruction_type {
|
||||
return Ok(None);
|
||||
}
|
||||
match should_convert_relaxed_dst(operand_type, instruction_type) {
|
||||
conv @ Some(_) => Ok(conv),
|
||||
None => Err(TranslateError::MismatchedType),
|
||||
}
|
||||
}
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
|
||||
fn should_convert_relaxed_dst(
|
||||
dst_type: &ast::Type,
|
||||
instr_type: &ast::Type,
|
||||
) -> Option<ConversionKind> {
|
||||
if dst_type == instr_type {
|
||||
return None;
|
||||
}
|
||||
match (dst_type, instr_type) {
|
||||
(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
|
||||
ast::ScalarKind::Bit => {
|
||||
if instr_type.size_of() <= dst_type.size_of() {
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Signed => {
|
||||
if dst_type.kind() != ast::ScalarKind::Float {
|
||||
if instr_type.size_of() == dst_type.size_of() {
|
||||
Some(ConversionKind::Default)
|
||||
} else if instr_type.size_of() < dst_type.size_of() {
|
||||
Some(ConversionKind::SignExtend)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Unsigned => {
|
||||
if instr_type.size_of() <= dst_type.size_of()
|
||||
&& dst_type.kind() != ast::ScalarKind::Float
|
||||
{
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Float => {
|
||||
if instr_type.size_of() <= dst_type.size_of()
|
||||
&& dst_type.kind() == ast::ScalarKind::Bit
|
||||
{
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Pred => None,
|
||||
},
|
||||
(ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
|
||||
| (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {
|
||||
should_convert_relaxed_dst(
|
||||
&ast::Type::Scalar(*dst_type),
|
||||
&ast::Type::Scalar(*instr_type),
|
||||
)
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn should_convert_relaxed_src_wrapper(
|
||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if !state_is_compatible(operand_space, instruction_space) {
|
||||
return Err(TranslateError::MismatchedType);
|
||||
}
|
||||
if operand_type == instruction_type {
|
||||
return Ok(None);
|
||||
}
|
||||
match should_convert_relaxed_src(operand_type, instruction_type) {
|
||||
conv @ Some(_) => Ok(conv),
|
||||
None => Err(TranslateError::MismatchedType),
|
||||
}
|
||||
}
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
|
||||
fn should_convert_relaxed_src(
|
||||
src_type: &ast::Type,
|
||||
instr_type: &ast::Type,
|
||||
) -> Option<ConversionKind> {
|
||||
if src_type == instr_type {
|
||||
return None;
|
||||
}
|
||||
match (src_type, instr_type) {
|
||||
(ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
|
||||
ast::ScalarKind::Bit => {
|
||||
if instr_type.size_of() <= src_type.size_of() {
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => {
|
||||
if instr_type.size_of() <= src_type.size_of()
|
||||
&& src_type.kind() != ast::ScalarKind::Float
|
||||
{
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Float => {
|
||||
if instr_type.size_of() <= src_type.size_of()
|
||||
&& src_type.kind() == ast::ScalarKind::Bit
|
||||
{
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Pred => None,
|
||||
},
|
||||
(ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
|
||||
| (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {
|
||||
should_convert_relaxed_src(
|
||||
&ast::Type::Scalar(*dst_type),
|
||||
&ast::Type::Scalar(*instr_type),
|
||||
)
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
|
@ -16,6 +16,9 @@ mod fix_special_registers;
|
|||
mod insert_mem_ssa_statements;
|
||||
mod normalize_identifiers;
|
||||
mod normalize_predicates;
|
||||
mod insert_implicit_conversions;
|
||||
mod normalize_labels;
|
||||
mod extract_globals;
|
||||
|
||||
static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv");
|
||||
static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
|
||||
|
@ -184,14 +187,12 @@ fn to_ssa<'input, 'b>(
|
|||
)?;
|
||||
let mut numeric_id_defs = numeric_id_defs.finish();
|
||||
let expanded_statements = expand_arguments::run(ssa_statements, &mut numeric_id_defs)?;
|
||||
todo!()
|
||||
/*
|
||||
let expanded_statements =
|
||||
insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?;
|
||||
insert_implicit_conversions::run(expanded_statements, &mut numeric_id_defs)?;
|
||||
let mut numeric_id_defs = numeric_id_defs.unmut();
|
||||
let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs);
|
||||
let labeled_statements = normalize_labels::run(expanded_statements, &mut numeric_id_defs);
|
||||
let (f_body, globals) =
|
||||
extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?;
|
||||
extract_globals::run(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?;
|
||||
Ok(Function {
|
||||
func_decl: func_decl,
|
||||
globals: globals,
|
||||
|
@ -200,7 +201,6 @@ fn to_ssa<'input, 'b>(
|
|||
tuning,
|
||||
linkage,
|
||||
})
|
||||
*/
|
||||
}
|
||||
|
||||
pub struct Module {
|
||||
|
@ -1220,3 +1220,56 @@ fn state_is_compatible(this: ast::StateSpace, other: ast::StateSpace) -> bool {
|
|||
|| this == ast::StateSpace::Reg && other == ast::StateSpace::Sreg
|
||||
|| this == ast::StateSpace::Sreg && other == ast::StateSpace::Reg
|
||||
}
|
||||
|
||||
fn register_external_fn_call<'a>(
|
||||
id_defs: &mut NumericIdResolver,
|
||||
ptx_impl_imports: &mut HashMap<String, Directive>,
|
||||
name: String,
|
||||
return_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
|
||||
input_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
match ptx_impl_imports.entry(name) {
|
||||
hash_map::Entry::Vacant(entry) => {
|
||||
let fn_id = id_defs.register_intermediate(None);
|
||||
let return_arguments = fn_arguments_to_variables(id_defs, return_arguments);
|
||||
let input_arguments = fn_arguments_to_variables(id_defs, input_arguments);
|
||||
let func_decl = ast::MethodDeclaration::<SpirvWord> {
|
||||
return_arguments,
|
||||
name: ast::MethodName::Func(fn_id),
|
||||
input_arguments,
|
||||
shared_mem: None,
|
||||
};
|
||||
let func = Function {
|
||||
func_decl: Rc::new(RefCell::new(func_decl)),
|
||||
globals: Vec::new(),
|
||||
body: None,
|
||||
import_as: Some(entry.key().clone()),
|
||||
tuning: Vec::new(),
|
||||
linkage: ast::LinkingDirective::EXTERN,
|
||||
};
|
||||
entry.insert(Directive::Method(func));
|
||||
Ok(fn_id)
|
||||
}
|
||||
hash_map::Entry::Occupied(entry) => match entry.get() {
|
||||
Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
|
||||
ast::MethodName::Func(fn_id) => Ok(fn_id),
|
||||
ast::MethodName::Kernel(_) => Err(error_unreachable()),
|
||||
},
|
||||
_ => Err(error_unreachable()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn fn_arguments_to_variables<'a>(
|
||||
id_defs: &mut NumericIdResolver,
|
||||
args: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
|
||||
) -> Vec<ast::Variable<SpirvWord>> {
|
||||
args.map(|(typ, space)| ast::Variable {
|
||||
align: None,
|
||||
v_type: typ.clone(),
|
||||
state_space: space,
|
||||
name: id_defs.register_intermediate(None),
|
||||
array_init: Vec::new(),
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
|
48
ptx/src/pass/normalize_labels.rs
Normal file
48
ptx/src/pass/normalize_labels.rs
Normal file
|
@ -0,0 +1,48 @@
|
|||
use std::{collections::HashSet, iter};
|
||||
|
||||
use super::*;
|
||||
|
||||
pub(super) fn run(
|
||||
func: Vec<ExpandedStatement>,
|
||||
id_def: &mut NumericIdResolver,
|
||||
) -> Vec<ExpandedStatement> {
|
||||
let mut labels_in_use = HashSet::new();
|
||||
for s in func.iter() {
|
||||
match s {
|
||||
Statement::Instruction(i) => {
|
||||
if let Some(target) = jump_target(i) {
|
||||
labels_in_use.insert(target);
|
||||
}
|
||||
}
|
||||
Statement::Conditional(cond) => {
|
||||
labels_in_use.insert(cond.if_true);
|
||||
labels_in_use.insert(cond.if_false);
|
||||
}
|
||||
Statement::Variable(..)
|
||||
| Statement::LoadVar(..)
|
||||
| Statement::StoreVar(..)
|
||||
| Statement::RetValue(..)
|
||||
| Statement::Conversion(..)
|
||||
| Statement::Constant(..)
|
||||
| Statement::Label(..)
|
||||
| Statement::PtrAccess { .. }
|
||||
| Statement::RepackVector(..)
|
||||
| Statement::FunctionPointer(..) => {}
|
||||
}
|
||||
}
|
||||
iter::once(Statement::Label(id_def.register_intermediate(None)))
|
||||
.chain(func.into_iter().filter(|s| match s {
|
||||
Statement::Label(i) => labels_in_use.contains(i),
|
||||
_ => true,
|
||||
}))
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn jump_target<T: ast::Operand<Ident = SpirvWord>>(
|
||||
this: &ast::Instruction<T>,
|
||||
) -> Option<SpirvWord> {
|
||||
match this {
|
||||
ast::Instruction::Bra { arguments } => Some(arguments.src),
|
||||
_ => None,
|
||||
}
|
||||
}
|
Loading…
Add table
Reference in a new issue