diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index ce9a596..5d43a26 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -200,7 +200,7 @@ pub struct LdData { pub typ: ScalarType, } -#[derive(PartialEq, Eq)] +#[derive(Copy, Clone, PartialEq, Eq)] pub enum LdStQualifier { Weak, Volatile, @@ -208,14 +208,14 @@ pub enum LdStQualifier { Acquire(LdScope), } -#[derive(PartialEq, Eq)] +#[derive(Copy, Clone, PartialEq, Eq)] pub enum LdScope { Cta, Gpu, Sys, } -#[derive(PartialEq, Eq)] +#[derive(Copy, Clone, PartialEq, Eq)] pub enum LdStateSpace { Generic, Const, @@ -225,7 +225,7 @@ pub enum LdStateSpace { Shared, } -#[derive(PartialEq, Eq)] +#[derive(Copy, Clone, PartialEq, Eq)] pub enum LdCacheOperator { Cached, L2Only, diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index a186772..ad87af8 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -5,7 +5,7 @@ use std::cell::RefCell; use std::collections::{BTreeMap, HashMap, HashSet}; use std::{borrow::Cow, fmt}; -use rspirv::binary::Assemble; +use rspirv::binary::{Assemble, Disassemble}; #[derive(PartialEq, Eq, Hash, Clone, Copy)] enum SpirvType { @@ -86,6 +86,7 @@ pub fn to_spirv(ast: ast::Module) -> Result, dr::Error> { emit_function(&mut builder, &mut map, f)?; } let module = builder.module(); + dbg!(print!("{}", module.disassemble())); Ok(module.assemble()) } @@ -122,19 +123,44 @@ fn emit_function<'a>( if f.kernel { builder.entry_point(spirv::ExecutionModel::Kernel, func_id, f.name, &[]); } + let (mut func_body, bbs, _, unique_ids) = to_ssa(&f.args, f.body); + let id_offset = builder.reserve_ids(unique_ids); + emit_function_args(builder, id_offset, map, &f.args); + apply_id_offset(&mut func_body, id_offset); + emit_function_body_ops(builder, map, &func_body, &bbs)?; + builder.end_function()?; + Ok(func_id) +} + +fn apply_id_offset(func_body: &mut Vec, id_offset: u32) { + for s in func_body { + s.visit_id_mut(&mut |_, id| *id += id_offset); + } +} + +fn to_ssa<'a>( + f_args: &[ast::Argument], + f_body: Vec>, +) -> ( + Vec, + Vec, + Vec>, + spirv::Word, +) { let mut contant_ids = HashMap::new(); - collect_arg_ids(&mut contant_ids, &f.args); - collect_label_ids(&mut contant_ids, &f.body); - let registers = collect_registers(&f.body); - let (normalized_ids, unique_ids, type_check) = - normalize_identifiers(f.body, &contant_ids, registers); + let mut type_check = HashMap::new(); + collect_arg_ids(&mut contant_ids, &mut type_check, &f_args); + collect_label_ids(&mut contant_ids, &f_body); + let registers = collect_var_definitions(&f_args, &f_body); + let (normalized_ids, unique_ids) = + normalize_identifiers(f_body, &contant_ids, &mut type_check, registers); let (mut func_body, unique_ids) = - insert_implicit_conversion(normalized_ids, unique_ids, &|x| type_check[&x]); + insert_implicit_conversions(normalized_ids, unique_ids, &|x| type_check[&x]); let bbs = get_basic_blocks(&func_body); let rpostorder = to_reverse_postorder(&bbs); let doms = immediate_dominators(&bbs, &rpostorder); let dom_fronts = dominance_frontiers(&bbs, &doms); - let (_, unique_ids) = ssa_legalize( + let (phis, unique_ids) = ssa_legalize( &mut func_body, contant_ids.len() as u32, unique_ids, @@ -142,15 +168,17 @@ fn emit_function<'a>( &doms, &dom_fronts, ); - let id_offset = builder.reserve_ids(unique_ids); - emit_function_args(builder, id_offset, map, &f.args); - emit_function_body_ops(builder, id_offset, map, &func_body, &bbs)?; - builder.end_function()?; - Ok(func_id) + (func_body, bbs, phis, unique_ids) } -fn collect_registers<'a>(body: &[ast::Statement<&'a str>]) -> HashMap, ast::Type> { +fn collect_var_definitions<'a>( + args: &[ast::Argument<'a>], + body: &[ast::Statement<&'a str>], +) -> HashMap, ast::Type> { let mut result = HashMap::new(); + for param in args { + result.insert(Cow::Borrowed(param.name), ast::Type::Scalar(param.a_type)); + } for s in body { match s { ast::Statement::Variable(var) => match var.count { @@ -170,12 +198,19 @@ fn collect_registers<'a>(body: &[ast::Statement<&'a str>]) -> HashMap. x, [y]` semantics are x = *(*)y + - ld.param: not documented, but for instruction `ld.param. x, [y]`, + semantics are to first zext/chop/bitcast `y` as needed and then do + documented special ld/st/cvt conversion rules for destination operands + - generic ld: for instruction `ld x, [y]`, y must be of type b64/u64/s64, + which is bitcast to a pointer, dereferenced and then documented special + ld/st/cvt conversion rules are applied + - generic ld: for instruction `ld [x], y`, x must be of type b64/u64/s64, + which is bitcast to a pointer */ -fn insert_implicit_conversion ast::Type>( +fn insert_implicit_conversions ast::Type>( normalized_ids: Vec, unique_ids: spirv::Word, type_check: &TypeCheck, @@ -190,16 +225,42 @@ fn insert_implicit_conversion ast::Type>( for s in normalized_ids.into_iter() { match s { Statement::Instruction(inst) => match inst { - ast::Instruction::Add(add, arg) => { - arg.insert_implicit_conversions( + ast::Instruction::Ld(ld, mut arg) => { + let new_arg_src = arg.src.map_id(&mut |arg_src| { + insert_implicit_conversions_ld_src( + &mut result, + ast::Type::Scalar(ld.typ), + type_check, + new_id, + |instr, op| ld.state_space.should_convert(instr, op), + arg_src, + ) + }); + arg.src = new_arg_src; + insert_implicit_bitcasts( + false, + true, &mut result, - ast::Type::Scalar(add.typ), type_check, new_id, - |arg| Statement::Instruction(ast::Instruction::Add(add, arg)), + ast::Instruction::Ld(ld, arg), ); } - _ => todo!(), + ast::Instruction::St(st, mut arg) => { + let arg_dst_type = type_check(arg.dst); + let new_dst = new_id(); + result.push(Statement::Converison(ImplicitConversion{ + src: arg.dst, + dst: new_dst, + from: arg_dst_type, + to: ast::Type::Scalar(st.typ), + kind: ConversionKind::Ptr + })); + arg.dst = new_dst; + } + inst @ _ => { + insert_implicit_bitcasts(true, true, &mut result, type_check, new_id, inst) + } }, s @ Statement::Conditional(_) | s @ Statement::Label(_) => result.push(s), Statement::Converison(_) => unreachable!(), @@ -236,10 +297,15 @@ fn emit_function_args( } } -fn collect_arg_ids<'a>(result: &mut HashMap<&'a str, spirv::Word>, args: &'a [ast::Argument<'a>]) { +fn collect_arg_ids<'a>( + result: &mut HashMap<&'a str, spirv::Word>, + type_check: &mut HashMap, + args: &'a [ast::Argument<'a>], +) { let mut id = result.len() as u32; for arg in args { result.insert(arg.name, id); + type_check.insert(id, ast::Type::Scalar(arg.a_type)); id += 1; } } @@ -263,7 +329,6 @@ fn collect_label_ids<'a>( fn emit_function_body_ops( builder: &mut dr::Builder, - id_offset: spirv::Word, map: &mut TypeWordMap, func: &[Statement], cfg: &[BasicBlock], @@ -276,56 +341,40 @@ fn emit_function_body_ops( continue; } let header_id = if let Statement::Label(id) = body[0] { - Some(id_offset + id) + Some(id) } else { None }; builder.begin_block(header_id)?; for s in body { match s { - // If block startd with a label it has already been emitted, + // If block starts with a label it has already been emitted, // all other labels in the block are unused Statement::Label(_) => (), - Statement::Converison(_) => todo!(), + Statement::Converison(cv) => emit_implicit_conversion(builder, map, cv)?, Statement::Conditional(bra) => { builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?; } Statement::Instruction(inst) => match inst { // SPIR-V does not support marking jumps as guaranteed-converged ast::Instruction::Bra(_, arg) => { - builder.branch(arg.src + id_offset)?; + builder.branch(arg.src)?; } ast::Instruction::Ld(data, arg) => { if data.qualifier != ast::LdStQualifier::Weak || data.vector.is_some() { todo!() } let src = match arg.src { - ast::Operand::Reg(id) => id + id_offset, + ast::Operand::Reg(id) => id, _ => todo!(), }; let result_type = map.get_or_add_scalar(builder, data.typ); match data.state_space { ast::LdStateSpace::Generic => { - // TODO: make the cast optional - let ptr_result_type = map.get_or_add( - builder, - SpirvType::Pointer( - data.typ, - spirv::StorageClass::CrossWorkgroup, - ), - ); - let bitcast = - builder.convert_u_to_ptr(ptr_result_type, None, src)?; - builder.load( - result_type, - Some(arg.dst + id_offset), - bitcast, - None, - [], - )?; + builder.load(result_type, Some(arg.dst), src, None, [])?; } ast::LdStateSpace::Param => { - builder.copy_object(result_type, Some(arg.dst + id_offset), src)?; + builder.copy_object(result_type, Some(arg.dst), src)?; } _ => todo!(), } @@ -338,17 +387,10 @@ fn emit_function_body_ops( todo!() } let src = match arg.src { - ast::Operand::Reg(id) => id + id_offset, + ast::Operand::Reg(id) => id, _ => todo!(), }; - // TODO make cast optional - let ptr_result_type = map.get_or_add( - builder, - SpirvType::Pointer(data.typ, spirv::StorageClass::CrossWorkgroup), - ); - let bitcast = - builder.convert_u_to_ptr(ptr_result_type, None, arg.dst + id_offset)?; - builder.store(bitcast, src, None, &[])?; + builder.store(arg.dst, src, None, &[])?; } // SPIR-V does not support ret as guaranteed-converged ast::Instruction::Ret(_) => builder.ret()?, @@ -360,12 +402,76 @@ fn emit_function_body_ops( Ok(()) } +fn emit_implicit_conversion( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + cv: &ImplicitConversion, +) -> Result<(), dr::Error> { + let (from_type, to_type) = match (cv.from, cv.to) { + (ast::Type::Scalar(from), ast::Type::Scalar(to)) => (from, to), + _ => todo!(), + }; + match cv.kind { + ConversionKind::Ptr => { + let dst_type = map.get_or_add( + builder, + SpirvType::Pointer(to_type, spirv_headers::StorageClass::Generic), + ); + builder.bitcast(dst_type, Some(cv.dst), cv.src)?; + } + ConversionKind::Default => { + if from_type.width() == to_type.width() { + if from_type.kind() == ScalarKind::Unsigned && to_type.kind() == ScalarKind::Byte + || from_type.kind() == ScalarKind::Byte + && to_type.kind() == ScalarKind::Unsigned + { + return Ok(()); + } + let dst_type = map.get_or_add_scalar(builder, to_type); + builder.bitcast(dst_type, Some(cv.dst), cv.src)?; + } else { + let as_unsigned_type = map.get_or_add_scalar( + builder, + ast::ScalarType::from_parts(from_type.width(), ScalarKind::Unsigned), + ); + let as_unsigned = builder.bitcast(as_unsigned_type, None, cv.src)?; + let as_unsigned_wide_type = + ast::ScalarType::from_parts(to_type.width(), ScalarKind::Unsigned); + let as_unsigned_wide_spirv = map.get_or_add_scalar( + builder, + ast::ScalarType::from_parts(to_type.width(), ScalarKind::Unsigned), + ); + if to_type.kind() == ScalarKind::Unsigned || to_type.kind() == ScalarKind::Byte { + builder.u_convert(as_unsigned_wide_spirv, Some(cv.dst), as_unsigned)?; + } else { + let as_unsigned_wide = + builder.u_convert(as_unsigned_wide_spirv, None, as_unsigned)?; + emit_implicit_conversion( + builder, + map, + &ImplicitConversion { + src: as_unsigned_wide, + dst: cv.dst, + from: ast::Type::Scalar(as_unsigned_wide_type), + to: cv.to, + kind: ConversionKind::Default, + }, + )?; + } + } + } + ConversionKind::SignExtend => todo!(), + } + Ok(()) +} + // TODO: support scopes fn normalize_identifiers<'a>( func: Vec>, constant_identifiers: &HashMap<&'a str, spirv::Word>, // arguments and labels can't be redefined + type_map: &mut HashMap, types: HashMap, ast::Type>, -) -> (Vec, spirv::Word, HashMap) { +) -> (Vec, spirv::Word) { let mut result = Vec::with_capacity(func.len()); let mut id: u32 = constant_identifiers.len() as u32; let mut remapped_ids = HashMap::new(); @@ -389,11 +495,12 @@ fn normalize_identifiers<'a>( for s in func { Statement::from_ast(s, &mut result, &mut get_or_add); } - let mut type_map = HashMap::with_capacity(types.len()); - for (old_id, new_id) in remapped_ids { - type_map.insert(new_id, types[old_id]); - } - (result, id, type_map) + type_map.extend( + remapped_ids + .into_iter() + .map(|(old_id, new_id)| (new_id, types[old_id])), + ); + (result, id) } fn ssa_legalize( @@ -911,10 +1018,17 @@ impl BrachCondition { } struct ImplicitConversion { - dst: spirv::Word, src: spirv::Word, + dst: spirv::Word, from: ast::Type, to: ast::Type, + kind: ConversionKind, +} + +enum ConversionKind { + Default, // zero-extend/chop/bitcast depending on types + SignExtend, + Ptr, } impl ImplicitConversion { @@ -1050,6 +1164,16 @@ impl ast::Instruction { ast::Instruction::Ret(_) => (), } } + + fn get_type(&self) -> Option { + match self { + ast::Instruction::Add(add, _) => Some(ast::Type::Scalar(add.typ)), + ast::Instruction::Ret(_) => None, + ast::Instruction::Ld(ld, _) => Some(ast::Type::Scalar(ld.typ)), + ast::Instruction::St(st, _) => Some(ast::Type::Scalar(st.typ)), + _ => todo!(), + } + } } impl ast::Instruction { @@ -1162,31 +1286,6 @@ impl ast::Arg3 { } } -impl ast::Arg3 { - fn insert_implicit_conversions< - TypeCheck: Fn(spirv::Word) -> ast::Type, - NewId: FnMut() -> spirv::Word, - NewStatement: FnOnce(Self) -> Statement, - >( - self, - func: &mut Vec, - op_type: ast::Type, - type_check: &TypeCheck, - new_id: &mut NewId, - new_statement: NewStatement, - ) { - let src1 = self - .src1 - .insert_implicit_conversion(func, op_type, type_check, new_id); - let src2 = self - .src2 - .insert_implicit_conversion(func, op_type, type_check, new_id); - insert_implicit_conversion_dst(func, op_type, type_check, new_id, self.dst, |dst| { - new_statement(Self { dst, src1, src2 }) - }); - } -} - impl ast::Arg4 { fn map_id U>(self, f: &mut F) -> ast::Arg4 { ast::Arg4 { @@ -1266,37 +1365,6 @@ impl ast::Operand { } } -impl ast::Operand { - fn insert_implicit_conversion< - TypeCheck: Fn(spirv::Word) -> ast::Type, - NewId: FnMut() -> spirv::Word, - >( - self, - func: &mut Vec, - op_type: ast::Type, - type_check: &TypeCheck, - new_id: &mut NewId, - ) -> Self { - match self { - ast::Operand::Reg(src) => { - if type_check(src) == op_type { - return self; - } - let new_src = new_id(); - func.push(Statement::Converison(ImplicitConversion { - src: src, - dst: new_src, - from: type_check(src), - to: op_type, - })); - ast::Operand::Reg(new_src) - } - o @ ast::Operand::Imm(_) => o, - ast::Operand::RegOffset(_, _) => todo!(), - } - } -} - impl ast::MovOperand { fn map_id U>(self, f: &mut F) -> ast::MovOperand { match self { @@ -1320,29 +1388,266 @@ impl ast::MovOperand { } } -fn insert_implicit_conversion_dst< +#[derive(Clone, Copy, PartialEq)] +enum ScalarKind { + Byte, + Unsigned, + Signed, + Float, +} + +impl ast::ScalarType { + fn width(self) -> u8 { + match self { + ast::ScalarType::U8 => 1, + ast::ScalarType::S8 => 1, + ast::ScalarType::B8 => 1, + ast::ScalarType::U16 => 2, + ast::ScalarType::S16 => 2, + ast::ScalarType::B16 => 2, + ast::ScalarType::F16 => 2, + ast::ScalarType::U32 => 4, + ast::ScalarType::S32 => 4, + ast::ScalarType::B32 => 4, + ast::ScalarType::F32 => 4, + ast::ScalarType::U64 => 8, + ast::ScalarType::S64 => 8, + ast::ScalarType::B64 => 8, + ast::ScalarType::F64 => 8, + } + } + + fn kind(self) -> ScalarKind { + match self { + ast::ScalarType::U8 => ScalarKind::Unsigned, + ast::ScalarType::U16 => ScalarKind::Unsigned, + ast::ScalarType::U32 => ScalarKind::Unsigned, + ast::ScalarType::U64 => ScalarKind::Unsigned, + ast::ScalarType::S8 => ScalarKind::Signed, + ast::ScalarType::S16 => ScalarKind::Signed, + ast::ScalarType::S32 => ScalarKind::Signed, + ast::ScalarType::S64 => ScalarKind::Signed, + ast::ScalarType::B8 => ScalarKind::Byte, + ast::ScalarType::B16 => ScalarKind::Byte, + ast::ScalarType::B32 => ScalarKind::Byte, + ast::ScalarType::B64 => ScalarKind::Byte, + ast::ScalarType::F16 => ScalarKind::Float, + ast::ScalarType::F32 => ScalarKind::Float, + ast::ScalarType::F64 => ScalarKind::Float, + } + } + + fn from_parts(width: u8, kind: ScalarKind) -> Self { + match kind { + ScalarKind::Float => match width { + 2 => ast::ScalarType::F16, + 4 => ast::ScalarType::F32, + 8 => ast::ScalarType::F64, + _ => unreachable!(), + }, + ScalarKind::Byte => match width { + 1 => ast::ScalarType::B8, + 2 => ast::ScalarType::B16, + 4 => ast::ScalarType::B32, + 8 => ast::ScalarType::B64, + _ => unreachable!(), + }, + ScalarKind::Signed => match width { + 1 => ast::ScalarType::S8, + 2 => ast::ScalarType::S16, + 4 => ast::ScalarType::S32, + 8 => ast::ScalarType::S64, + _ => unreachable!(), + }, + ScalarKind::Unsigned => match width { + 1 => ast::ScalarType::U8, + 2 => ast::ScalarType::U16, + 4 => ast::ScalarType::U32, + 8 => ast::ScalarType::U64, + _ => unreachable!(), + }, + } + } +} + +fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool { + match (instr, operand) { + (ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => { + if inst.width() != operand.width() { + return false; + } + match inst.kind() { + ScalarKind::Byte => operand.kind() != ScalarKind::Byte, + ScalarKind::Float => operand.kind() == ScalarKind::Byte, + ScalarKind::Signed => { + operand.kind() == ScalarKind::Byte || operand.kind() == ScalarKind::Unsigned + } + ScalarKind::Unsigned => { + operand.kind() == ScalarKind::Byte || operand.kind() == ScalarKind::Signed + } + } + } + _ => false, + } +} + +impl ast::LdStateSpace { + fn should_convert(self, instr_type: ast::Type, op_type: ast::Type) -> Option { + match self { + ast::LdStateSpace::Param => { + if instr_type != op_type { + Some(ConversionKind::Default) + } else { + None + } + } + ast::LdStateSpace::Generic => Some(ConversionKind::Ptr), + _ => todo!(), + } + } +} + +fn insert_forced_bitcast_src< TypeCheck: Fn(spirv::Word) -> ast::Type, NewId: FnMut() -> spirv::Word, - NewStatement: FnOnce(spirv::Word) -> Statement, >( func: &mut Vec, op_type: ast::Type, type_check: &TypeCheck, new_id: &mut NewId, - dst: spirv::Word, - new_statement: NewStatement, -) { - if type_check(dst) == op_type { - func.push(new_statement(dst)); - } else { - let new_dst = new_id(); - func.push(new_statement(new_dst)); + src: spirv::Word, +) -> spirv::Word { + let src_type = type_check(src); + if src_type == op_type { + return src; + } + let new_src = new_id(); + func.push(Statement::Converison(ImplicitConversion { + src: src, + dst: new_src, + from: src_type, + to: op_type, + kind: ConversionKind::Default, + })); + new_src +} + +fn insert_implicit_conversions_ld_src< + TypeCheck: Fn(spirv::Word) -> ast::Type, + NewId: FnMut() -> spirv::Word, + ShouldConvert: Fn(ast::Type, ast::Type) -> Option, +>( + func: &mut Vec, + instr_type: ast::Type, + type_check: &TypeCheck, + new_id: &mut NewId, + should_convert: ShouldConvert, + src: spirv::Word, +) -> spirv::Word { + let src_type = type_check(src); + if let Some(conv_kind) = should_convert(src_type, instr_type) { + let new_src = new_id(); func.push(Statement::Converison(ImplicitConversion { - src: new_dst, - dst: dst, - from: type_check(new_dst), - to: op_type, + src: src, + dst: new_src, + from: src_type, + to: instr_type, + kind: conv_kind, })); + new_src + } else { + src + } +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands +fn should_convert_relaxed_dst( + dst_type: ast::Type, + instr_type: ast::ScalarType, +) -> Option { + if dst_type == ast::Type::Scalar(instr_type) { + return None; + } + match dst_type { + ast::Type::Scalar(dst_type) => match instr_type.kind() { + ScalarKind::Byte => { + if instr_type.width() <= dst_type.width() { + Some(ConversionKind::Default) + } else { + None + } + } + ScalarKind::Signed => { + if instr_type.width() <= dst_type.width() && dst_type.kind() != ScalarKind::Float { + Some(ConversionKind::SignExtend) + } else { + None + } + } + ScalarKind::Unsigned => { + if instr_type.width() <= dst_type.width() && dst_type.kind() != ScalarKind::Float { + Some(ConversionKind::Default) + } else { + None + } + } + ScalarKind::Float => { + if instr_type.width() <= dst_type.width() && dst_type.kind() == ScalarKind::Float { + Some(ConversionKind::Default) + } else { + None + } + } + }, + _ => None, + } +} + +fn insert_implicit_bitcasts< + TypeCheck: Fn(spirv::Word) -> ast::Type, + NewId: FnMut() -> spirv::Word, +>( + do_src_bitcast: bool, + do_dst_bitcast: bool, + func: &mut Vec, + type_check: &TypeCheck, + new_id: &mut NewId, + mut instr: ast::Instruction, +) { + let mut dst_coercion = None; + if let Some(instr_type) = instr.get_type() { + instr.visit_id_mut(&mut |is_dst, id| { + if (is_dst && !do_dst_bitcast) || (!is_dst && !do_src_bitcast) { + return; + } + let id_type = type_check(*id); + if should_bitcast(instr_type, type_check(*id)) { + let replacement_id = new_id(); + if is_dst { + dst_coercion = Some(ImplicitConversion { + src: replacement_id, + dst: *id, + from: instr_type, + to: id_type, + kind: ConversionKind::Default, + }); + *id = replacement_id; + } else { + func.push(Statement::Converison(ImplicitConversion { + src: *id, + dst: replacement_id, + from: id_type, + to: instr_type, + kind: ConversionKind::Default, + })); + *id = replacement_id; + } + } + }); + } + func.push(Statement::Instruction(instr)); + if let Some(cond) = dst_coercion { + func.push(Statement::Converison(cond)); } } @@ -1678,6 +1983,12 @@ mod tests { // page 403 const FIG_19_4: &'static str = "{ + .reg.u32 i; + .reg.u32 j; + .reg.u32 k; + .reg.pred p; + .reg.pred q; + mov.u32 i, 1; mov.u32 j, 1; mov.u32 k, 0; @@ -1710,7 +2021,9 @@ mod tests { assert_eq!(errors.len(), 0); let mut constant_ids = HashMap::new(); collect_label_ids(&mut constant_ids, &ast); - let (normalized_ids, _) = normalize_identifiers(ast, &constant_ids); + let registers = collect_var_definitions(&[], &ast); + let (normalized_ids, _) = + normalize_identifiers(ast, &constant_ids, &mut HashMap::new(), registers); let mut bbs = get_basic_blocks(&normalized_ids); bbs.iter_mut().for_each(sort_pred_succ); assert_eq!( @@ -1857,7 +2170,9 @@ mod tests { let mut constant_ids = HashMap::new(); collect_label_ids(&mut constant_ids, &fn_ast); assert_eq!(constant_ids.len(), 4); - let (normalized_ids, max_id) = normalize_identifiers(fn_ast, &constant_ids); + let registers = collect_var_definitions(&[], &fn_ast); + let (normalized_ids, max_id) = + normalize_identifiers(fn_ast, &constant_ids, &mut HashMap::new(), registers); let bbs = get_basic_blocks(&normalized_ids); let rpostorder = to_reverse_postorder(&bbs); let doms = immediate_dominators(&bbs, &rpostorder); @@ -1895,21 +2210,7 @@ mod tests { .parse(&mut errors, func) .unwrap(); assert_eq!(errors.len(), 0); - let mut constant_ids = HashMap::new(); - collect_label_ids(&mut constant_ids, &fn_ast); - let (mut func, unique_ids) = normalize_identifiers(fn_ast, &constant_ids); - let bbs = get_basic_blocks(&func); - let rpostorder = to_reverse_postorder(&bbs); - let doms = immediate_dominators(&bbs, &rpostorder); - let dom_fronts = dominance_frontiers(&bbs, &doms); - let (mut ssa_phis, _) = ssa_legalize( - &mut func, - constant_ids.len() as u32, - unique_ids, - &bbs, - &doms, - &dom_fronts, - ); + let (func, _, mut ssa_phis, unique_ids) = to_ssa(&[], fn_ast); assert_phi_dst_id(unique_ids, &ssa_phis); assert_dst_unique(&func, &ssa_phis); sort_phi(&mut ssa_phis);