diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index dac31a5..2b4f593 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1364,12 +1364,13 @@ fn fix_builtins( value: ast::ImmediateValue::U64(0), })); } else { - let src_type = match numeric_id_defs.special_registers.get(details.arg.src) { + let sreg_and_type = match numeric_id_defs.special_registers.get(details.arg.src) + { Some(reg) => get_sreg_id_scalar_type(numeric_id_defs, reg), None => None, }; - let (sreg_src, scalar_typ, vector_width) = match src_type { - Some(x) => x, + let (sreg_src, scalar_typ, vector_width) = match sreg_and_type { + Some(sreg_and_type) => sreg_and_type, None => { result.push(Statement::LoadVar(details)); continue; @@ -1409,11 +1410,9 @@ fn get_sreg_id_scalar_type( ) -> Option<(spirv::Word, ast::ScalarType, u8)> { match sreg.normalized_sreg_and_type() { Some((normalized_sreg, typ, vec_width)) => Some(( - numeric_id_defs.special_registers.replace( - numeric_id_defs.current_id, - sreg, - normalized_sreg, - ), + numeric_id_defs + .special_registers + .get_or_add(numeric_id_defs.current_id, normalized_sreg), typ, vec_width, )), @@ -4755,11 +4754,26 @@ impl SpecialRegistersMap { } fn builtins<'a>(&'a self) -> impl Iterator + 'a { - self.reg_to_id.iter().map(|(reg, id)| (*reg, *id)) + self.reg_to_id.iter().filter_map(|(sreg, id)| { + if sreg.normalized_sreg_and_type().is_none() { + Some((*sreg, *id)) + } else { + None + } + }) } fn interface(&self) -> Vec { - self.id_to_reg.iter().map(|(id, _)| *id).collect::>() + self.reg_to_id + .iter() + .filter_map(|(sreg, id)| { + if sreg.normalized_sreg_and_type().is_none() { + Some(*id) + } else { + None + } + }) + .collect::>() } fn get(&self, id: spirv::Word) -> Option { @@ -4778,34 +4792,6 @@ impl SpecialRegistersMap { } } } - - fn replace( - &mut self, - current_id: &mut spirv::Word, - old: PtxSpecialRegister, - new: PtxSpecialRegister, - ) -> spirv::Word { - match self.reg_to_id.entry(old) { - hash_map::Entry::Occupied(e) => { - let id = e.remove(); - self.reg_to_id.insert(new, id); - id - } - hash_map::Entry::Vacant(e) => { - drop(e); - match self.reg_to_id.entry(new) { - hash_map::Entry::Occupied(e) => *e.get(), - hash_map::Entry::Vacant(e) => { - let numeric_id = *current_id; - *current_id += 1; - e.insert(numeric_id); - self.id_to_reg.insert(numeric_id, new); - numeric_id - } - } - } - } - } } struct GlobalStringIdResolver<'input> {