Refactor implicit conversions, explicit ld/st and global hoisting

This commit is contained in:
Andrzej Janik 2024-09-23 06:02:28 +02:00
parent 7bd4179d1d
commit 78a9f22cf7
6 changed files with 627 additions and 46 deletions

View file

@ -164,17 +164,16 @@ impl Deref for MemoryBuffer {
}
pub(super) fn run<'input>(
id_defs: &GlobalStringIdResolver<'input>,
call_map: MethodsCallMap<'input>,
directives: Vec<Directive<'input>>,
id_defs: GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<MemoryBuffer, TranslateError> {
let context = Context::new();
let module = Module::new(&context, LLVM_UNNAMED);
let mut emit_ctx = ModuleEmitContext::new(&context, &module, id_defs);
let mut emit_ctx = ModuleEmitContext::new(&context, &module, &id_defs);
for directive in directives {
match directive {
Directive::Variable(..) => todo!(),
Directive::Method(method) => emit_ctx.emit_method(method)?,
Directive2::Variable(..) => todo!(),
Directive2::Method(method) => emit_ctx.emit_method(method)?,
}
}
module.write_to_stderr();
@ -188,7 +187,7 @@ struct ModuleEmitContext<'a, 'input> {
context: LLVMContextRef,
module: LLVMModuleRef,
builder: Builder,
id_defs: &'a GlobalStringIdResolver<'input>,
id_defs: &'a GlobalStringIdentResolver2<'input>,
resolver: ResolveIdent,
}
@ -196,7 +195,7 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
fn new(
context: &Context,
module: &Module,
id_defs: &'a GlobalStringIdResolver<'input>,
id_defs: &'a GlobalStringIdentResolver2<'input>,
) -> Self {
ModuleEmitContext {
context: context.get(),
@ -215,20 +214,27 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
LLVMCallConv::LLVMCCallConv as u32
}
fn emit_method(&mut self, method: Function<'input>) -> Result<(), TranslateError> {
let func_decl = method.func_decl.borrow();
fn emit_method(
&mut self,
method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<(), TranslateError> {
let func_decl = method.func_decl;
let name = method
.import_as
.as_deref()
.unwrap_or_else(|| match func_decl.name {
ast::MethodName::Kernel(name) => name,
ast::MethodName::Func(id) => self.id_defs.reverse_variables[&id],
});
.or_else(|| match func_decl.name {
ast::MethodName::Kernel(name) => Some(name),
ast::MethodName::Func(id) => self.id_defs.ident_map[&id].name.as_deref(),
})
.ok_or_else(|| error_unreachable())?;
let name = CString::new(name).map_err(|_| error_unreachable())?;
let fn_type = get_function_type(
self.context,
func_decl.return_arguments.iter().map(|v| &v.v_type),
func_decl.input_arguments.iter().map(|v| &v.v_type),
func_decl
.input_arguments
.iter()
.map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)),
)?;
let fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
if let ast::MethodName::Func(name) = func_decl.name {
@ -239,6 +245,19 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
let name = self.resolver.get_or_add(param.name);
unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) };
self.resolver.register(param.name, value);
if func_decl.name.is_kernel() {
let attr_kind = unsafe {
LLVMGetEnumAttributeKindForName(b"byref".as_ptr().cast(), b"byref".len())
};
let attr = unsafe {
LLVMCreateTypeAttribute(
self.context,
attr_kind,
get_type(self.context, &param.v_type)?,
)
};
unsafe { LLVMAddAttributeAtIndex(fn_, i as u32 + 1, attr) };
}
}
let call_conv = if func_decl.name.is_kernel() {
Self::kernel_call_convention()
@ -264,12 +283,26 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
}
}
fn get_input_argument_type(
context: LLVMContextRef,
v_type: &ptx_parser::Type,
state_space: ptx_parser::StateSpace,
) -> Result<LLVMTypeRef, TranslateError> {
match state_space {
ptx_parser::StateSpace::ParamEntry => {
Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(state_space)?) })
}
ptx_parser::StateSpace::Reg => get_type(context, v_type),
_ => return Err(error_unreachable()),
}
}
struct MethodEmitContext<'a, 'input> {
context: LLVMContextRef,
module: LLVMModuleRef,
method: LLVMValueRef,
builder: LLVMBuilderRef,
id_defs: &'a GlobalStringIdResolver<'input>,
id_defs: &'a GlobalStringIdentResolver2<'input>,
variables_builder: Builder,
resolver: &'a mut ResolveIdent,
}
@ -533,7 +566,9 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
let type_ = get_function_type(
self.context,
data.return_arguments.iter().map(|(type_, space)| type_),
data.input_arguments.iter().map(|(type_, space)| type_),
data.input_arguments
.iter()
.map(|(type_, space)| get_input_argument_type(self.context, &type_, *space)),
)?;
let mut input_arguments = arguments
.input_arguments
@ -633,11 +668,10 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR
fn get_function_type<'a>(
context: LLVMContextRef,
mut return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
input_args: impl ExactSizeIterator<Item = &'a ast::Type>,
input_args: impl ExactSizeIterator<Item = Result<LLVMTypeRef, TranslateError>>,
) -> Result<LLVMTypeRef, TranslateError> {
let mut input_args: Vec<*mut llvm_zluda::LLVMType> = input_args
.map(|type_| get_type(context, type_))
.collect::<Result<Vec<_>, _>>()?;
let mut input_args: Vec<*mut llvm_zluda::LLVMType> =
input_args.collect::<Result<Vec<_>, _>>()?;
let return_type = match return_args.len() {
0 => unsafe { LLVMVoidTypeInContext(context) },
1 => get_type(context, return_args.next().unwrap())?,
@ -658,7 +692,7 @@ fn get_state_space(space: ast::StateSpace) -> Result<u32, TranslateError> {
ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE),
ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE),
ast::StateSpace::Param => Err(TranslateError::Todo),
ast::StateSpace::ParamEntry => Err(TranslateError::Todo),
ast::StateSpace::ParamEntry => Ok(CONSTANT_ADDRESS_SPACE),
ast::StateSpace::ParamFunc => Err(TranslateError::Todo),
ast::StateSpace::Local => Ok(PRIVATE_ADDRESS_SPACE),
ast::StateSpace::Global => Ok(GLOBAL_ADDRESS_SPACE),
@ -675,7 +709,7 @@ struct ResolveIdent {
}
impl ResolveIdent {
fn new<'input>(_id_defs: &GlobalStringIdResolver<'input>) -> Self {
fn new<'input>(_id_defs: &GlobalStringIdentResolver2<'input>) -> Self {
ResolveIdent {
words: HashMap::new(),
values: HashMap::new(),

View file

@ -0,0 +1,45 @@
use super::*;
pub(super) fn run<'input>(
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
let mut result = Vec::with_capacity(directives.len());
for mut directive in directives.into_iter() {
run_directive(&mut result, &mut directive);
result.push(directive);
}
Ok(result)
}
fn run_directive<'input>(
result: &mut Vec<Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>>,
directive: &mut Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>,
) -> Result<(), TranslateError> {
match directive {
Directive2::Variable(..) => {}
Directive2::Method(function2) => run_function(result, function2),
}
Ok(())
}
fn run_function<'input>(
result: &mut Vec<Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>>,
function: &mut Function2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>,
) {
function.body = function.body.take().map(|statements| {
statements
.into_iter()
.filter_map(|statement| match statement {
Statement::Variable(var @ ast::Variable {
state_space:
ast::StateSpace::Global | ast::StateSpace::Const | ast::StateSpace::Shared,
..
}) => {
result.push(Directive2::Variable(ast::LinkingDirective::NONE, var));
None
}
s => Some(s),
})
.collect()
});
}

View file

@ -41,10 +41,9 @@ fn run_method<'a, 'input>(
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
let mut func_decl = method.func_decl;
for arg in func_decl.return_arguments.iter_mut() {
visitor.visit_variable(arg);
visitor.visit_variable(arg)?;
}
let is_kernel = func_decl.name.is_kernel();
// let mut prelude = Vec::with_capacity(method.body.as_ref().map(Vec::len).unwrap_or(0));
if is_kernel {
for arg in func_decl.input_arguments.iter_mut() {
let old_name = arg.name;
@ -85,23 +84,29 @@ fn run_statement<'a, 'input>(
) -> Result<(), TranslateError> {
match statement {
Statement::Variable(mut var) => {
visitor.visit_variable(&mut var);
visitor.visit_variable(&mut var)?;
result.push(Statement::Variable(var));
}
Statement::Instruction(ast::Instruction::Ld { data, arguments }) => {
let instruction = visitor.visit_ld(data, arguments)?;
let instruction = ast::visit_map(instruction, visitor)?;
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
result.push(Statement::Instruction(instruction));
result.extend(visitor.post.drain(..).map(Statement::Instruction));
}
Statement::Instruction(ast::Instruction::St {
data,
mut arguments,
}) => {
Statement::Instruction(ast::Instruction::St { data, arguments }) => {
let instruction = visitor.visit_st(data, arguments)?;
let instruction = ast::visit_map(instruction, visitor)?;
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
result.push(Statement::Instruction(instruction));
result.extend(visitor.post.drain(..).map(Statement::Instruction));
}
s => {
let new_statement = s.visit_map(visitor)?;
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
result.push(new_statement);
result.extend(visitor.post.drain(..).map(Statement::Instruction));
}
s => result.push(s.visit_map(visitor)?),
}
Ok(())
}
@ -109,6 +114,8 @@ fn run_statement<'a, 'input>(
struct InsertMemSSAVisitor<'a, 'input> {
resolver: &'a mut GlobalStringIdentResolver2<'input>,
variables: FxHashMap<SpirvWord, RemapAction>,
pre: Vec<ast::Instruction<SpirvWord>>,
post: Vec<ast::Instruction<SpirvWord>>,
}
impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
@ -116,6 +123,8 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
Self {
resolver,
variables: FxHashMap::default(),
pre: Vec::new(),
post: Vec::new(),
}
}
@ -141,14 +150,20 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
fn variable(
&mut self,
type_: &ast::Type,
old_name: SpirvWord,
new_name: SpirvWord,
old_space: ast::StateSpace,
) -> Result<(), TranslateError> {
Ok(match old_space {
ast::StateSpace::Reg => {
self.variables
.insert(old_name, RemapAction::PreLdPostSt(new_name));
self.variables.insert(
old_name,
RemapAction::PreLdPostSt {
name: new_name,
type_: type_.clone(),
},
);
}
ast::StateSpace::Param => {
self.variables.insert(
@ -182,7 +197,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
if let Some(remap) = self.variables.get(&arguments.src1) {
match remap {
RemapAction::PreLdPostSt(_) => return Err(error_mismatched_type()),
RemapAction::PreLdPostSt { .. } => {}
RemapAction::LDStSpaceChange {
old_space,
new_space,
@ -206,7 +221,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
if let Some(remap) = self.variables.get(&arguments.src) {
match remap {
RemapAction::PreLdPostSt(_) => return Err(error_mismatched_type()),
RemapAction::PreLdPostSt { .. } => {}
RemapAction::LDStSpaceChange {
old_space,
new_space,
@ -223,7 +238,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
Ok(ast::Instruction::Ld { data, arguments })
}
fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) {
fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
if var.state_space != ast::StateSpace::Local {
let old_name = var.name;
let old_space = var.state_space;
@ -231,10 +246,11 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
let new_name = self
.resolver
.register_unnamed(Some((var.v_type.clone(), new_space)));
self.variable(old_name, new_name, old_space);
self.variable(&var.v_type, old_name, new_name, old_space)?;
var.name = new_name;
var.state_space = new_space;
}
Ok(())
}
}
@ -243,12 +259,58 @@ impl<'a, 'input> ast::VisitorMap<SpirvWord, SpirvWord, TranslateError>
{
fn visit(
&mut self,
args: SpirvWord,
ident: SpirvWord,
type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
todo!()
if let Some(remap) = self.variables.get(&ident) {
match remap {
RemapAction::PreLdPostSt { name, type_ } => {
if is_dst {
let temp = self
.resolver
.register_unnamed(Some((type_.clone(), ast::StateSpace::Reg)));
self.post.push(ast::Instruction::St {
data: ast::StData {
state_space: ast::StateSpace::Local,
qualifier: ast::LdStQualifier::Weak,
caching: ast::StCacheOperator::Writethrough,
typ: type_.clone(),
},
arguments: ast::StArgs {
src1: *name,
src2: temp,
},
});
Ok(temp)
} else {
let temp = self
.resolver
.register_unnamed(Some((type_.clone(), ast::StateSpace::Reg)));
self.pre.push(ast::Instruction::Ld {
data: ast::LdDetails {
state_space: ast::StateSpace::Local,
qualifier: ast::LdStQualifier::Weak,
caching: ast::LdCacheOperator::Cached,
typ: type_.clone(),
non_coherent: false,
},
arguments: ast::LdArgs {
dst: temp,
src: *name,
},
});
Ok(temp)
}
}
RemapAction::LDStSpaceChange { .. } => {
return Err(error_mismatched_type());
}
}
} else {
Ok(ident)
}
}
fn visit_ident(
@ -262,9 +324,12 @@ impl<'a, 'input> ast::VisitorMap<SpirvWord, SpirvWord, TranslateError>
}
}
#[derive(Clone, Copy)]
#[derive(Clone)]
enum RemapAction {
PreLdPostSt(SpirvWord),
PreLdPostSt {
name: SpirvWord,
type_: ast::Type,
},
LDStSpaceChange {
old_space: ast::StateSpace,
new_space: ast::StateSpace,

View file

@ -0,0 +1,426 @@
use std::mem;
use super::*;
use ptx_parser as ast;
/*
There are several kinds of implicit conversions in PTX:
* auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands
* special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size
- ld.param: not documented, but for instruction `ld.param.<type> x, [y]`,
semantics are to first zext/chop/bitcast `y` as needed and then do
documented special ld/st/cvt conversion rules for destination operands
- st.param [x] y (used as function return arguments) same rule as above applies
- generic/global ld: for instruction `ld x, [y]`, y must be of type
b64/u64/s64, which is bitcast to a pointer, dereferenced and then
documented special ld/st/cvt conversion rules are applied to dst
- generic/global st: for instruction `st [x], y`, x must be of type
b64/u64/s64, which is bitcast to a pointer
*/
pub(super) fn run<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives
.into_iter()
.map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>()
}
fn run_directive<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive {
var @ Directive2::Variable(..) => var,
Directive2::Method(mut method) => {
method.body = method
.body
.map(|statements| run_statements(resolver, statements))
.transpose()?;
Directive2::Method(method)
}
})
}
fn run_statements<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
func: Vec<ExpandedStatement>,
) -> Result<Vec<ExpandedStatement>, TranslateError> {
let mut result = Vec::with_capacity(func.len());
for s in func.into_iter() {
insert_implicit_conversions_impl(resolver, &mut result, s)?;
}
Ok(result)
}
fn insert_implicit_conversions_impl<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
func: &mut Vec<ExpandedStatement>,
stmt: ExpandedStatement,
) -> Result<(), TranslateError> {
let mut post_conv = Vec::new();
let statement = stmt.visit_map::<SpirvWord, TranslateError>(
&mut |operand,
type_state: Option<(&ast::Type, ast::StateSpace)>,
is_dst,
relaxed_type_check| {
let (instr_type, instruction_space) = match type_state {
None => return Ok(operand),
Some(t) => t,
};
let (operand_type, operand_space) = resolver.get_typed(operand)?;
let conversion_fn = if relaxed_type_check {
if is_dst {
should_convert_relaxed_dst_wrapper
} else {
should_convert_relaxed_src_wrapper
}
} else {
default_implicit_conversion
};
match conversion_fn(
(*operand_space, &operand_type),
(instruction_space, instr_type),
)? {
Some(conv_kind) => {
let conv_output = if is_dst { &mut post_conv } else { &mut *func };
let mut from_type = instr_type.clone();
let mut from_space = instruction_space;
let mut to_type = operand_type.clone();
let mut to_space = *operand_space;
let mut src =
resolver.register_unnamed(Some((instr_type.clone(), instruction_space)));
let mut dst = operand;
let result = Ok::<_, TranslateError>(src);
if !is_dst {
mem::swap(&mut src, &mut dst);
mem::swap(&mut from_type, &mut to_type);
mem::swap(&mut from_space, &mut to_space);
}
conv_output.push(Statement::Conversion(ImplicitConversion {
src,
dst,
from_type,
from_space,
to_type,
to_space,
kind: conv_kind,
}));
result
}
None => Ok(operand),
}
},
)?;
func.push(statement);
func.append(&mut post_conv);
Ok(())
}
pub(crate) fn default_implicit_conversion(
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
if instruction_space == ast::StateSpace::Reg {
if operand_space == ast::StateSpace::Reg {
if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) =
(operand_type, instruction_type)
{
if scalar.kind() == ast::ScalarKind::Bit
&& scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
{
return Ok(Some(ConversionKind::Default));
}
}
} else if is_addressable(operand_space) {
return Ok(Some(ConversionKind::AddressOf));
}
}
if instruction_space != operand_space {
default_implicit_conversion_space(
(operand_space, operand_type),
(instruction_space, instruction_type),
)
} else if instruction_type != operand_type {
default_implicit_conversion_type(instruction_space, operand_type, instruction_type)
} else {
Ok(None)
}
}
fn is_addressable(this: ast::StateSpace) -> bool {
match this {
ast::StateSpace::Const
| ast::StateSpace::Generic
| ast::StateSpace::Global
| ast::StateSpace::Local
| ast::StateSpace::Shared => true,
ast::StateSpace::Param | ast::StateSpace::Reg => false,
ast::StateSpace::SharedCluster
| ast::StateSpace::SharedCta
| ast::StateSpace::ParamEntry
| ast::StateSpace::ParamFunc => todo!(),
}
}
// Space is different
fn default_implicit_conversion_space(
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space))
|| (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space))
{
Ok(Some(ConversionKind::PtrToPtr))
} else if operand_space == ast::StateSpace::Reg {
match operand_type {
ast::Type::Pointer(operand_ptr_type, operand_ptr_space)
if *operand_ptr_space == instruction_space =>
{
if instruction_type != &ast::Type::Scalar(*operand_ptr_type) {
Ok(Some(ConversionKind::PtrToPtr))
} else {
Ok(None)
}
}
// TODO: 32 bit
ast::Type::Scalar(ast::ScalarType::B64)
| ast::Type::Scalar(ast::ScalarType::U64)
| ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space {
ast::StateSpace::Global
| ast::StateSpace::Generic
| ast::StateSpace::Const
| ast::StateSpace::Local
| ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
_ => Err(error_mismatched_type()),
},
ast::Type::Scalar(ast::ScalarType::B32)
| ast::Type::Scalar(ast::ScalarType::U32)
| ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space {
ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
Ok(Some(ConversionKind::BitToPtr))
}
_ => Err(error_mismatched_type()),
},
_ => Err(error_mismatched_type()),
}
} else if instruction_space == ast::StateSpace::Reg {
match instruction_type {
ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space)
if operand_space == *instruction_ptr_space =>
{
if operand_type != &ast::Type::Scalar(*instruction_ptr_type) {
Ok(Some(ConversionKind::PtrToPtr))
} else {
Ok(None)
}
}
_ => Err(error_mismatched_type()),
}
} else {
Err(error_mismatched_type())
}
}
// Space is same, but type is different
fn default_implicit_conversion_type(
space: ast::StateSpace,
operand_type: &ast::Type,
instruction_type: &ast::Type,
) -> Result<Option<ConversionKind>, TranslateError> {
if space == ast::StateSpace::Reg {
if should_bitcast(instruction_type, operand_type) {
Ok(Some(ConversionKind::Default))
} else {
Err(TranslateError::MismatchedType)
}
} else {
Ok(Some(ConversionKind::PtrToPtr))
}
}
fn coerces_to_generic(this: ast::StateSpace) -> bool {
match this {
ast::StateSpace::Global
| ast::StateSpace::Const
| ast::StateSpace::Local
| ptx_parser::StateSpace::SharedCta
| ast::StateSpace::SharedCluster
| ast::StateSpace::Shared => true,
ast::StateSpace::Reg
| ast::StateSpace::Param
| ast::StateSpace::ParamEntry
| ast::StateSpace::ParamFunc
| ast::StateSpace::Generic => false,
}
}
fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
match (instr, operand) {
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
if inst.size_of() != operand.size_of() {
return false;
}
match inst.kind() {
ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit,
ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit,
ast::ScalarKind::Signed => {
operand.kind() == ast::ScalarKind::Bit
|| operand.kind() == ast::ScalarKind::Unsigned
}
ast::ScalarKind::Unsigned => {
operand.kind() == ast::ScalarKind::Bit
|| operand.kind() == ast::ScalarKind::Signed
}
ast::ScalarKind::Pred => false,
}
}
(ast::Type::Vector(_, inst), ast::Type::Vector(_, operand))
| (ast::Type::Array(_, inst, _), ast::Type::Array(_, operand, _)) => {
should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand))
}
_ => false,
}
}
pub(crate) fn should_convert_relaxed_dst_wrapper(
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
if operand_space != instruction_space {
return Err(TranslateError::MismatchedType);
}
if operand_type == instruction_type {
return Ok(None);
}
match should_convert_relaxed_dst(operand_type, instruction_type) {
conv @ Some(_) => Ok(conv),
None => Err(TranslateError::MismatchedType),
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
fn should_convert_relaxed_dst(
dst_type: &ast::Type,
instr_type: &ast::Type,
) -> Option<ConversionKind> {
if dst_type == instr_type {
return None;
}
match (dst_type, instr_type) {
(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
ast::ScalarKind::Bit => {
if instr_type.size_of() <= dst_type.size_of() {
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Signed => {
if dst_type.kind() != ast::ScalarKind::Float {
if instr_type.size_of() == dst_type.size_of() {
Some(ConversionKind::Default)
} else if instr_type.size_of() < dst_type.size_of() {
Some(ConversionKind::SignExtend)
} else {
None
}
} else {
None
}
}
ast::ScalarKind::Unsigned => {
if instr_type.size_of() <= dst_type.size_of()
&& dst_type.kind() != ast::ScalarKind::Float
{
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Float => {
if instr_type.size_of() <= dst_type.size_of()
&& dst_type.kind() == ast::ScalarKind::Bit
{
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Pred => None,
},
(ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
| (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
should_convert_relaxed_dst(
&ast::Type::Scalar(*dst_type),
&ast::Type::Scalar(*instr_type),
)
}
_ => None,
}
}
pub(crate) fn should_convert_relaxed_src_wrapper(
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
if operand_space != instruction_space {
return Err(error_mismatched_type());
}
if operand_type == instruction_type {
return Ok(None);
}
match should_convert_relaxed_src(operand_type, instruction_type) {
conv @ Some(_) => Ok(conv),
None => Err(error_mismatched_type()),
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
fn should_convert_relaxed_src(
src_type: &ast::Type,
instr_type: &ast::Type,
) -> Option<ConversionKind> {
if src_type == instr_type {
return None;
}
match (src_type, instr_type) {
(ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
ast::ScalarKind::Bit => {
if instr_type.size_of() <= src_type.size_of() {
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => {
if instr_type.size_of() <= src_type.size_of()
&& src_type.kind() != ast::ScalarKind::Float
{
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Float => {
if instr_type.size_of() <= src_type.size_of()
&& src_type.kind() == ast::ScalarKind::Bit
{
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Pred => None,
},
(ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
| (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
should_convert_relaxed_src(
&ast::Type::Scalar(*dst_type),
&ast::Type::Scalar(*instr_type),
)
}
_ => None,
}
}

View file

@ -27,8 +27,10 @@ mod expand_operands;
mod extract_globals;
mod fix_special_registers;
mod fix_special_registers2;
mod hoist_globals;
mod insert_explicit_load_store;
mod insert_implicit_conversions;
mod insert_implicit_conversions2;
mod insert_mem_ssa_statements;
mod normalize_identifiers;
mod normalize_identifiers2;
@ -67,11 +69,13 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl
})?;
normalize_variable_decls(&mut directives);
let denorm_information = compute_denorm_information(&directives);
let llvm_ir = emit_llvm::run(&id_defs, call_map, directives)?;
todo!()
/*
let llvm_ir: emit_llvm::MemoryBuffer = emit_llvm::run(&id_defs, call_map, directives)?;
Ok(Module {
llvm_ir,
kernel_info: HashMap::new(),
})
}) */
}
pub fn to_llvm_module2<'input>(ast: ast::Module<'input>) -> Result<Module, TranslateError> {
@ -82,10 +86,17 @@ pub fn to_llvm_module2<'input>(ast: ast::Module<'input>) -> Result<Module, Trans
let directives = normalize_predicates2::run(&mut flat_resolver, directives)?;
let directives = resolve_function_pointers::run(directives)?;
let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?;
let directives = expand_operands::run(&mut flat_resolver, directives)?;
let directives: Vec<Directive2<'_, ptx_parser::Instruction<SpirvWord>, SpirvWord>> =
expand_operands::run(&mut flat_resolver, directives)?;
let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
todo!()
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
let directives = hoist_globals::run(directives)?;
let llvm_ir = emit_llvm::run(flat_resolver, directives)?;
Ok(Module {
llvm_ir,
kernel_info: HashMap::new(),
})
}
fn translate_directive<'input, 'a>(

View file

@ -236,7 +236,7 @@ fn test_hip_assert<
output: &mut [Output],
) -> Result<(), Box<dyn error::Error + 'a>> {
let ast = ptx_parser::parse_module_checked(ptx_text).unwrap();
let llvm_ir = pass::to_llvm_module(ast).unwrap();
let llvm_ir = pass::to_llvm_module2(ast).unwrap();
let name = CString::new(name)?;
let result =
run_hip(name.as_c_str(), llvm_ir, input, output).map_err(|err| DisplayError { err })?;