diff --git a/ptx/src/pass/extract_globals.rs b/ptx/src/pass/extract_globals.rs new file mode 100644 index 0000000..680a5ee --- /dev/null +++ b/ptx/src/pass/extract_globals.rs @@ -0,0 +1,282 @@ +use super::*; + +pub(super) fn run<'input, 'b>( + sorted_statements: Vec, + ptx_impl_imports: &mut HashMap, + id_def: &mut NumericIdResolver, +) -> Result<(Vec, Vec>), TranslateError> { + let mut local = Vec::with_capacity(sorted_statements.len()); + let mut global = Vec::new(); + for statement in sorted_statements { + match statement { + Statement::Variable( + var @ ast::Variable { + state_space: ast::StateSpace::Shared, + .. + }, + ) + | Statement::Variable( + var @ ast::Variable { + state_space: ast::StateSpace::Global, + .. + }, + ) => global.push(var), + Statement::Instruction(ast::Instruction::Bfe { data, arguments }) => { + let fn_name = [ZLUDA_PTX_PREFIX, "bfe_", scalar_to_ptx_name(data)].concat(); + local.push(instruction_to_fn_call( + id_def, + ptx_impl_imports, + ast::Instruction::Bfe { data, arguments }, + fn_name, + )?); + } + Statement::Instruction(ast::Instruction::Bfi { data, arguments }) => { + let fn_name = [ZLUDA_PTX_PREFIX, "bfi_", scalar_to_ptx_name(data)].concat(); + local.push(instruction_to_fn_call( + id_def, + ptx_impl_imports, + ast::Instruction::Bfi { data, arguments }, + fn_name, + )?); + } + Statement::Instruction(ast::Instruction::Brev { data, arguments }) => { + let fn_name: String = + [ZLUDA_PTX_PREFIX, "brev_", scalar_to_ptx_name(data)].concat(); + local.push(instruction_to_fn_call( + id_def, + ptx_impl_imports, + ast::Instruction::Brev { data, arguments }, + fn_name, + )?); + } + Statement::Instruction(ast::Instruction::Activemask { arguments }) => { + let fn_name = [ZLUDA_PTX_PREFIX, "activemask"].concat(); + local.push(instruction_to_fn_call( + id_def, + ptx_impl_imports, + ast::Instruction::Activemask { arguments }, + fn_name, + )?); + } + Statement::Instruction(ast::Instruction::Atom { + data: + data @ ast::AtomDetails { + op: ast::AtomicOp::IncrementWrap, + semantics, + scope, + space, + .. + }, + arguments, + }) => { + let fn_name = [ + ZLUDA_PTX_PREFIX, + "atom_", + semantics_to_ptx_name(semantics), + "_", + scope_to_ptx_name(scope), + "_", + space_to_ptx_name(space), + "_inc", + ] + .concat(); + local.push(instruction_to_fn_call( + id_def, + ptx_impl_imports, + ast::Instruction::Atom { data, arguments }, + fn_name, + )?); + } + Statement::Instruction(ast::Instruction::Atom { + data: + data @ ast::AtomDetails { + op: ast::AtomicOp::DecrementWrap, + semantics, + scope, + space, + .. + }, + arguments, + }) => { + let fn_name = [ + ZLUDA_PTX_PREFIX, + "atom_", + semantics_to_ptx_name(semantics), + "_", + scope_to_ptx_name(scope), + "_", + space_to_ptx_name(space), + "_dec", + ] + .concat(); + local.push(instruction_to_fn_call( + id_def, + ptx_impl_imports, + ast::Instruction::Atom { data, arguments }, + fn_name, + )?); + } + Statement::Instruction(ast::Instruction::Atom { + data: + data @ ast::AtomDetails { + op: ast::AtomicOp::FloatAdd, + semantics, + scope, + space, + .. + }, + arguments, + }) => { + let scalar_type = match data.type_ { + ptx_parser::Type::Scalar(scalar) => scalar, + _ => return Err(error_unreachable()), + }; + let fn_name = [ + ZLUDA_PTX_PREFIX, + "atom_", + semantics_to_ptx_name(semantics), + "_", + scope_to_ptx_name(scope), + "_", + space_to_ptx_name(space), + "_add_", + scalar_to_ptx_name(scalar_type), + ] + .concat(); + local.push(instruction_to_fn_call( + id_def, + ptx_impl_imports, + ast::Instruction::Atom { data, arguments }, + fn_name, + )?); + } + s => local.push(s), + } + } + Ok((local, global)) +} + +fn instruction_to_fn_call( + id_defs: &mut NumericIdResolver, + ptx_impl_imports: &mut HashMap, + inst: ast::Instruction, + fn_name: String, +) -> Result { + let mut arguments = Vec::new(); + ast::visit_map(inst, &mut |operand, + type_space: Option<( + &ast::Type, + ast::StateSpace, + )>, + is_dst, + _| { + let (typ, space) = match type_space { + Some((typ, space)) => (typ.clone(), space), + None => return Err(error_unreachable()), + }; + arguments.push((operand, is_dst, typ, space)); + Ok(SpirvWord(0)) + })?; + let return_arguments_count = arguments + .iter() + .position(|(desc, is_dst, _, _)| !is_dst) + .unwrap_or(arguments.len()); + let (return_arguments, input_arguments) = arguments.split_at(return_arguments_count); + let fn_id = register_external_fn_call( + id_defs, + ptx_impl_imports, + fn_name, + return_arguments + .iter() + .map(|(_, _, typ, state)| (typ, *state)), + input_arguments + .iter() + .map(|(_, _, typ, state)| (typ, *state)), + )?; + Ok(Statement::Instruction(ast::Instruction::Call { + data: ast::CallDetails { + uniform: false, + return_arguments: return_arguments + .iter() + .map(|(_, _, typ, state)| (typ.clone(), *state)) + .collect::>(), + input_arguments: input_arguments + .iter() + .map(|(_, _, typ, state)| (typ.clone(), *state)) + .collect::>(), + }, + arguments: ast::CallArgs { + return_arguments: return_arguments + .iter() + .map(|(name, _, _, _)| *name) + .collect::>(), + func: fn_id, + input_arguments: input_arguments + .iter() + .map(|(name, _, _, _)| *name) + .collect::>(), + }, + })) +} + +fn scalar_to_ptx_name(this: ast::ScalarType) -> &'static str { + match this { + ast::ScalarType::B8 => "b8", + ast::ScalarType::B16 => "b16", + ast::ScalarType::B32 => "b32", + ast::ScalarType::B64 => "b64", + ast::ScalarType::B128 => "b128", + ast::ScalarType::U8 => "u8", + ast::ScalarType::U16 => "u16", + ast::ScalarType::U16x2 => "u16x2", + ast::ScalarType::U32 => "u32", + ast::ScalarType::U64 => "u64", + ast::ScalarType::S8 => "s8", + ast::ScalarType::S16 => "s16", + ast::ScalarType::S16x2 => "s16x2", + ast::ScalarType::S32 => "s32", + ast::ScalarType::S64 => "s64", + ast::ScalarType::F16 => "f16", + ast::ScalarType::F16x2 => "f16x2", + ast::ScalarType::F32 => "f32", + ast::ScalarType::F64 => "f64", + ast::ScalarType::BF16 => "bf16", + ast::ScalarType::BF16x2 => "bf16x2", + ast::ScalarType::Pred => "pred", + } +} + +fn semantics_to_ptx_name(this: ast::AtomSemantics) -> &'static str { + match this { + ast::AtomSemantics::Relaxed => "relaxed", + ast::AtomSemantics::Acquire => "acquire", + ast::AtomSemantics::Release => "release", + ast::AtomSemantics::AcqRel => "acq_rel", + } +} + +fn scope_to_ptx_name(this: ast::MemScope) -> &'static str { + match this { + ast::MemScope::Cta => "cta", + ast::MemScope::Gpu => "gpu", + ast::MemScope::Sys => "sys", + ast::MemScope::Cluster => "cluster", + } +} + +fn space_to_ptx_name(this: ast::StateSpace) -> &'static str { + match this { + ast::StateSpace::Generic => "generic", + ast::StateSpace::Global => "global", + ast::StateSpace::Shared => "shared", + ast::StateSpace::Reg => "reg", + ast::StateSpace::Const => "const", + ast::StateSpace::Local => "local", + ast::StateSpace::Param => "param", + ast::StateSpace::Sreg => "sreg", + ast::StateSpace::SharedCluster => "shared_cluster", + ast::StateSpace::ParamEntry => "param_entry", + ast::StateSpace::SharedCta => "shared_cta", + ast::StateSpace::ParamFunc => "param_func", + } +} diff --git a/ptx/src/pass/fix_special_registers.rs b/ptx/src/pass/fix_special_registers.rs index 871537d..304bc61 100644 --- a/ptx/src/pass/fix_special_registers.rs +++ b/ptx/src/pass/fix_special_registers.rs @@ -128,56 +128,3 @@ impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> { } } } - -fn register_external_fn_call<'a>( - id_defs: &mut NumericIdResolver, - ptx_impl_imports: &mut HashMap, - name: String, - return_arguments: impl Iterator, - input_arguments: impl Iterator, -) -> Result { - match ptx_impl_imports.entry(name) { - hash_map::Entry::Vacant(entry) => { - let fn_id = id_defs.register_intermediate(None); - let return_arguments = fn_arguments_to_variables(id_defs, return_arguments); - let input_arguments = fn_arguments_to_variables(id_defs, input_arguments); - let func_decl = ast::MethodDeclaration:: { - return_arguments, - name: ast::MethodName::Func(fn_id), - input_arguments, - shared_mem: None, - }; - let func = Function { - func_decl: Rc::new(RefCell::new(func_decl)), - globals: Vec::new(), - body: None, - import_as: Some(entry.key().clone()), - tuning: Vec::new(), - linkage: ast::LinkingDirective::EXTERN, - }; - entry.insert(Directive::Method(func)); - Ok(fn_id) - } - hash_map::Entry::Occupied(entry) => match entry.get() { - Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name { - ast::MethodName::Func(fn_id) => Ok(fn_id), - ast::MethodName::Kernel(_) => Err(error_unreachable()), - }, - _ => Err(error_unreachable()), - }, - } -} - -fn fn_arguments_to_variables<'a>( - id_defs: &mut NumericIdResolver, - args: impl Iterator, -) -> Vec> { - args.map(|(typ, space)| ast::Variable { - align: None, - v_type: typ.clone(), - state_space: space, - name: id_defs.register_intermediate(None), - array_init: Vec::new(), - }) - .collect::>() -} \ No newline at end of file diff --git a/ptx/src/pass/insert_implicit_conversions.rs b/ptx/src/pass/insert_implicit_conversions.rs new file mode 100644 index 0000000..4a0dc8e --- /dev/null +++ b/ptx/src/pass/insert_implicit_conversions.rs @@ -0,0 +1,402 @@ +use std::mem; + +use super::*; +use ptx_parser as ast; + +/* + There are several kinds of implicit conversions in PTX: + * auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands + * special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size + - ld.param: not documented, but for instruction `ld.param. x, [y]`, + semantics are to first zext/chop/bitcast `y` as needed and then do + documented special ld/st/cvt conversion rules for destination operands + - st.param [x] y (used as function return arguments) same rule as above applies + - generic/global ld: for instruction `ld x, [y]`, y must be of type + b64/u64/s64, which is bitcast to a pointer, dereferenced and then + documented special ld/st/cvt conversion rules are applied to dst + - generic/global st: for instruction `st [x], y`, x must be of type + b64/u64/s64, which is bitcast to a pointer +*/ +pub(super) fn run( + func: Vec, + id_def: &mut MutableNumericIdResolver, +) -> Result, TranslateError> { + let mut result = Vec::with_capacity(func.len()); + for s in func.into_iter() { + match s { + Statement::Instruction(inst) => { + insert_implicit_conversions_impl( + &mut result, + id_def, + Statement::Instruction(inst), + )?; + } + Statement::PtrAccess(access) => { + insert_implicit_conversions_impl( + &mut result, + id_def, + Statement::PtrAccess(access), + )?; + } + Statement::RepackVector(repack) => { + insert_implicit_conversions_impl( + &mut result, + id_def, + Statement::RepackVector(repack), + )?; + } + s @ Statement::Conditional(_) + | s @ Statement::Conversion(_) + | s @ Statement::Label(_) + | s @ Statement::Constant(_) + | s @ Statement::Variable(_) + | s @ Statement::LoadVar(..) + | s @ Statement::StoreVar(..) + | s @ Statement::RetValue(..) + | s @ Statement::FunctionPointer(..) => result.push(s), + } + } + Ok(result) +} + +fn insert_implicit_conversions_impl( + func: &mut Vec, + id_def: &mut MutableNumericIdResolver, + stmt: ExpandedStatement, +) -> Result<(), TranslateError> { + let mut post_conv = Vec::new(); + let statement = stmt.visit_map::( + &mut |operand, + type_state: Option<(&ast::Type, ast::StateSpace)>, + is_dst, + relaxed_type_check| { + let (instr_type, instruction_space) = match type_state { + None => return Ok(operand), + Some(t) => t, + }; + let (operand_type, operand_space) = id_def.get_typed(operand)?; + let conversion_fn = if relaxed_type_check { + if is_dst { + should_convert_relaxed_dst_wrapper + } else { + should_convert_relaxed_src_wrapper + } + } else { + default_implicit_conversion + }; + match conversion_fn( + (operand_space, &operand_type), + (instruction_space, instr_type), + )? { + Some(conv_kind) => { + let conv_output = if is_dst { &mut post_conv } else { &mut *func }; + let mut from_type = instr_type.clone(); + let mut from_space = instruction_space; + let mut to_type = operand_type; + let mut to_space = operand_space; + let mut src = + id_def.register_intermediate(instr_type.clone(), instruction_space); + let mut dst = operand; + let result = Ok::<_, TranslateError>(src); + if !is_dst { + mem::swap(&mut src, &mut dst); + mem::swap(&mut from_type, &mut to_type); + mem::swap(&mut from_space, &mut to_space); + } + conv_output.push(Statement::Conversion(ImplicitConversion { + src, + dst, + from_type, + from_space, + to_type, + to_space, + kind: conv_kind, + })); + result + } + None => Ok(operand), + } + }, + )?; + func.push(statement); + func.append(&mut post_conv); + Ok(()) +} + +fn default_implicit_conversion( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), +) -> Result, TranslateError> { + if !state_is_compatible(instruction_space, operand_space) { + default_implicit_conversion_space( + (operand_space, operand_type), + (instruction_space, instruction_type), + ) + } else if instruction_type != operand_type { + default_implicit_conversion_type(instruction_space, operand_type, instruction_type) + } else { + Ok(None) + } +} + +// Space is different +fn default_implicit_conversion_space( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), +) -> Result, TranslateError> { + if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space)) + || (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space)) + { + Ok(Some(ConversionKind::PtrToPtr)) + } else if state_is_compatible(operand_space, ast::StateSpace::Reg) { + match operand_type { + ast::Type::Pointer(operand_ptr_type, operand_ptr_space) + if *operand_ptr_space == instruction_space => + { + if instruction_type != &ast::Type::Scalar(*operand_ptr_type) { + Ok(Some(ConversionKind::PtrToPtr)) + } else { + Ok(None) + } + } + // TODO: 32 bit + ast::Type::Scalar(ast::ScalarType::B64) + | ast::Type::Scalar(ast::ScalarType::U64) + | ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space { + ast::StateSpace::Global + | ast::StateSpace::Generic + | ast::StateSpace::Const + | ast::StateSpace::Local + | ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)), + _ => Err(TranslateError::MismatchedType), + }, + ast::Type::Scalar(ast::ScalarType::B32) + | ast::Type::Scalar(ast::ScalarType::U32) + | ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space { + ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => { + Ok(Some(ConversionKind::BitToPtr)) + } + _ => Err(TranslateError::MismatchedType), + }, + _ => Err(TranslateError::MismatchedType), + } + } else if state_is_compatible(instruction_space, ast::StateSpace::Reg) { + match instruction_type { + ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space) + if operand_space == *instruction_ptr_space => + { + if operand_type != &ast::Type::Scalar(*instruction_ptr_type) { + Ok(Some(ConversionKind::PtrToPtr)) + } else { + Ok(None) + } + } + _ => Err(TranslateError::MismatchedType), + } + } else { + Err(TranslateError::MismatchedType) + } +} + +// Space is same, but type is different +fn default_implicit_conversion_type( + space: ast::StateSpace, + operand_type: &ast::Type, + instruction_type: &ast::Type, +) -> Result, TranslateError> { + if state_is_compatible(space, ast::StateSpace::Reg) { + if should_bitcast(instruction_type, operand_type) { + Ok(Some(ConversionKind::Default)) + } else { + Err(TranslateError::MismatchedType) + } + } else { + Ok(Some(ConversionKind::PtrToPtr)) + } +} + +fn coerces_to_generic(this: ast::StateSpace) -> bool { + match this { + ast::StateSpace::Global + | ast::StateSpace::Const + | ast::StateSpace::Local + | ptx_parser::StateSpace::SharedCta + | ast::StateSpace::SharedCluster + | ast::StateSpace::Shared => true, + ast::StateSpace::Reg + | ast::StateSpace::Param + | ast::StateSpace::ParamEntry + | ast::StateSpace::ParamFunc + | ast::StateSpace::Generic + | ast::StateSpace::Sreg => false, + } +} + +fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool { + match (instr, operand) { + (ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => { + if inst.size_of() != operand.size_of() { + return false; + } + match inst.kind() { + ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit, + ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit, + ast::ScalarKind::Signed => { + operand.kind() == ast::ScalarKind::Bit + || operand.kind() == ast::ScalarKind::Unsigned + } + ast::ScalarKind::Unsigned => { + operand.kind() == ast::ScalarKind::Bit + || operand.kind() == ast::ScalarKind::Signed + } + ast::ScalarKind::Pred => false, + } + } + (ast::Type::Vector(inst, _), ast::Type::Vector(operand, _)) + | (ast::Type::Array(inst, _), ast::Type::Array(operand, _)) => { + should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand)) + } + _ => false, + } +} + +fn should_convert_relaxed_dst_wrapper( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), +) -> Result, TranslateError> { + if !state_is_compatible(operand_space, instruction_space) { + return Err(TranslateError::MismatchedType); + } + if operand_type == instruction_type { + return Ok(None); + } + match should_convert_relaxed_dst(operand_type, instruction_type) { + conv @ Some(_) => Ok(conv), + None => Err(TranslateError::MismatchedType), + } +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands +fn should_convert_relaxed_dst( + dst_type: &ast::Type, + instr_type: &ast::Type, +) -> Option { + if dst_type == instr_type { + return None; + } + match (dst_type, instr_type) { + (ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { + ast::ScalarKind::Bit => { + if instr_type.size_of() <= dst_type.size_of() { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Signed => { + if dst_type.kind() != ast::ScalarKind::Float { + if instr_type.size_of() == dst_type.size_of() { + Some(ConversionKind::Default) + } else if instr_type.size_of() < dst_type.size_of() { + Some(ConversionKind::SignExtend) + } else { + None + } + } else { + None + } + } + ast::ScalarKind::Unsigned => { + if instr_type.size_of() <= dst_type.size_of() + && dst_type.kind() != ast::ScalarKind::Float + { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Float => { + if instr_type.size_of() <= dst_type.size_of() + && dst_type.kind() == ast::ScalarKind::Bit + { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Pred => None, + }, + (ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _)) + | (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => { + should_convert_relaxed_dst( + &ast::Type::Scalar(*dst_type), + &ast::Type::Scalar(*instr_type), + ) + } + _ => None, + } +} + +fn should_convert_relaxed_src_wrapper( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), +) -> Result, TranslateError> { + if !state_is_compatible(operand_space, instruction_space) { + return Err(TranslateError::MismatchedType); + } + if operand_type == instruction_type { + return Ok(None); + } + match should_convert_relaxed_src(operand_type, instruction_type) { + conv @ Some(_) => Ok(conv), + None => Err(TranslateError::MismatchedType), + } +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands +fn should_convert_relaxed_src( + src_type: &ast::Type, + instr_type: &ast::Type, +) -> Option { + if src_type == instr_type { + return None; + } + match (src_type, instr_type) { + (ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { + ast::ScalarKind::Bit => { + if instr_type.size_of() <= src_type.size_of() { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => { + if instr_type.size_of() <= src_type.size_of() + && src_type.kind() != ast::ScalarKind::Float + { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Float => { + if instr_type.size_of() <= src_type.size_of() + && src_type.kind() == ast::ScalarKind::Bit + { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Pred => None, + }, + (ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _)) + | (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => { + should_convert_relaxed_src( + &ast::Type::Scalar(*dst_type), + &ast::Type::Scalar(*instr_type), + ) + } + _ => None, + } +} diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 896a34a..1fdf3a6 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -16,6 +16,9 @@ mod fix_special_registers; mod insert_mem_ssa_statements; mod normalize_identifiers; mod normalize_predicates; +mod insert_implicit_conversions; +mod normalize_labels; +mod extract_globals; static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv"); static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc"); @@ -184,14 +187,12 @@ fn to_ssa<'input, 'b>( )?; let mut numeric_id_defs = numeric_id_defs.finish(); let expanded_statements = expand_arguments::run(ssa_statements, &mut numeric_id_defs)?; - todo!() - /* let expanded_statements = - insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?; + insert_implicit_conversions::run(expanded_statements, &mut numeric_id_defs)?; let mut numeric_id_defs = numeric_id_defs.unmut(); - let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs); + let labeled_statements = normalize_labels::run(expanded_statements, &mut numeric_id_defs); let (f_body, globals) = - extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?; + extract_globals::run(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?; Ok(Function { func_decl: func_decl, globals: globals, @@ -200,7 +201,6 @@ fn to_ssa<'input, 'b>( tuning, linkage, }) - */ } pub struct Module { @@ -1220,3 +1220,56 @@ fn state_is_compatible(this: ast::StateSpace, other: ast::StateSpace) -> bool { || this == ast::StateSpace::Reg && other == ast::StateSpace::Sreg || this == ast::StateSpace::Sreg && other == ast::StateSpace::Reg } + +fn register_external_fn_call<'a>( + id_defs: &mut NumericIdResolver, + ptx_impl_imports: &mut HashMap, + name: String, + return_arguments: impl Iterator, + input_arguments: impl Iterator, +) -> Result { + match ptx_impl_imports.entry(name) { + hash_map::Entry::Vacant(entry) => { + let fn_id = id_defs.register_intermediate(None); + let return_arguments = fn_arguments_to_variables(id_defs, return_arguments); + let input_arguments = fn_arguments_to_variables(id_defs, input_arguments); + let func_decl = ast::MethodDeclaration:: { + return_arguments, + name: ast::MethodName::Func(fn_id), + input_arguments, + shared_mem: None, + }; + let func = Function { + func_decl: Rc::new(RefCell::new(func_decl)), + globals: Vec::new(), + body: None, + import_as: Some(entry.key().clone()), + tuning: Vec::new(), + linkage: ast::LinkingDirective::EXTERN, + }; + entry.insert(Directive::Method(func)); + Ok(fn_id) + } + hash_map::Entry::Occupied(entry) => match entry.get() { + Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name { + ast::MethodName::Func(fn_id) => Ok(fn_id), + ast::MethodName::Kernel(_) => Err(error_unreachable()), + }, + _ => Err(error_unreachable()), + }, + } +} + +fn fn_arguments_to_variables<'a>( + id_defs: &mut NumericIdResolver, + args: impl Iterator, +) -> Vec> { + args.map(|(typ, space)| ast::Variable { + align: None, + v_type: typ.clone(), + state_space: space, + name: id_defs.register_intermediate(None), + array_init: Vec::new(), + }) + .collect::>() +} diff --git a/ptx/src/pass/normalize_labels.rs b/ptx/src/pass/normalize_labels.rs new file mode 100644 index 0000000..097d87c --- /dev/null +++ b/ptx/src/pass/normalize_labels.rs @@ -0,0 +1,48 @@ +use std::{collections::HashSet, iter}; + +use super::*; + +pub(super) fn run( + func: Vec, + id_def: &mut NumericIdResolver, +) -> Vec { + let mut labels_in_use = HashSet::new(); + for s in func.iter() { + match s { + Statement::Instruction(i) => { + if let Some(target) = jump_target(i) { + labels_in_use.insert(target); + } + } + Statement::Conditional(cond) => { + labels_in_use.insert(cond.if_true); + labels_in_use.insert(cond.if_false); + } + Statement::Variable(..) + | Statement::LoadVar(..) + | Statement::StoreVar(..) + | Statement::RetValue(..) + | Statement::Conversion(..) + | Statement::Constant(..) + | Statement::Label(..) + | Statement::PtrAccess { .. } + | Statement::RepackVector(..) + | Statement::FunctionPointer(..) => {} + } + } + iter::once(Statement::Label(id_def.register_intermediate(None))) + .chain(func.into_iter().filter(|s| match s { + Statement::Label(i) => labels_in_use.contains(i), + _ => true, + })) + .collect::>() +} + +fn jump_target>( + this: &ast::Instruction, +) -> Option { + match this { + ast::Instruction::Bra { arguments } => Some(arguments.src), + _ => None, + } +}