mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Start fixing builtins
This commit is contained in:
parent
770a379452
commit
5976e55b2b
1 changed files with 178 additions and 31 deletions
|
@ -1030,7 +1030,7 @@ fn emit_builtins(
|
|||
map: &mut TypeWordMap,
|
||||
id_defs: &GlobalStringIdResolver,
|
||||
) {
|
||||
for (reg, id) in id_defs.special_registers.iter() {
|
||||
for (reg, id) in id_defs.special_registers.builtins() {
|
||||
let result_type = map.get_or_add(
|
||||
builder,
|
||||
SpirvType::Pointer(
|
||||
|
@ -1038,9 +1038,9 @@ fn emit_builtins(
|
|||
spirv::StorageClass::Input,
|
||||
),
|
||||
);
|
||||
builder.variable(result_type, Some(*id), spirv::StorageClass::Input, None);
|
||||
builder.variable(result_type, Some(id), spirv::StorageClass::Input, None);
|
||||
builder.decorate(
|
||||
*id,
|
||||
id,
|
||||
spirv::Decoration::BuiltIn,
|
||||
&[dr::Operand::BuiltIn(reg.get_builtin())],
|
||||
);
|
||||
|
@ -1086,11 +1086,7 @@ fn emit_function_header<'a>(
|
|||
.iter()
|
||||
.filter_map(|(k, t)| t.as_ref().map(|_| *k))
|
||||
.collect::<Vec<_>>();
|
||||
let mut interface = defined_globals
|
||||
.special_registers
|
||||
.iter()
|
||||
.map(|(_, id)| *id)
|
||||
.collect::<Vec<_>>();
|
||||
let mut interface = defined_globals.special_registers.interface();
|
||||
for ast::Variable { name, .. } in synthetic_globals {
|
||||
interface.push(*name);
|
||||
}
|
||||
|
@ -1320,6 +1316,7 @@ fn to_ssa<'input, 'b>(
|
|||
convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
|
||||
let typed_statements =
|
||||
convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?;
|
||||
let typed_statements = fix_builtins(typed_statements, &mut numeric_id_defs)?;
|
||||
let ssa_statements = insert_mem_ssa_statements(
|
||||
typed_statements,
|
||||
&mut numeric_id_defs,
|
||||
|
@ -1343,6 +1340,75 @@ fn to_ssa<'input, 'b>(
|
|||
})
|
||||
}
|
||||
|
||||
fn fix_builtins(
|
||||
typed_statements: Vec<TypedStatement>,
|
||||
numeric_id_defs: &mut NumericIdResolver,
|
||||
) -> Result<Vec<TypedStatement>, TranslateError> {
|
||||
let mut result = Vec::with_capacity(typed_statements.len());
|
||||
let mut visitor = FixBuiltinsVisitor {};
|
||||
for s in typed_statements {
|
||||
match s {
|
||||
Statement::Instruction(inst) => {
|
||||
result.push(Statement::Instruction(inst.map(&mut visitor)?))
|
||||
}
|
||||
s => result.push(s),
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
struct FixBuiltinsVisitor {}
|
||||
|
||||
impl ArgumentMapVisitor<TypedArgParams, TypedArgParams> for FixBuiltinsVisitor {
|
||||
fn id(
|
||||
&mut self,
|
||||
desc: ArgumentDescriptor<spirv::Word>,
|
||||
typ: Option<&ast::Type>,
|
||||
) -> Result<spirv::Word, TranslateError> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn operand(
|
||||
&mut self,
|
||||
desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
|
||||
typ: &ast::Type,
|
||||
) -> Result<ast::Operand<spirv::Word>, TranslateError> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn id_or_vector(
|
||||
&mut self,
|
||||
desc: ArgumentDescriptor<ast::IdOrVector<spirv::Word>>,
|
||||
typ: &ast::Type,
|
||||
) -> Result<ast::IdOrVector<spirv::Word>, TranslateError> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn operand_or_vector(
|
||||
&mut self,
|
||||
desc: ArgumentDescriptor<ast::OperandOrVector<spirv::Word>>,
|
||||
typ: &ast::Type,
|
||||
) -> Result<ast::OperandOrVector<spirv::Word>, TranslateError> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn src_call_operand(
|
||||
&mut self,
|
||||
desc: ArgumentDescriptor<ast::CallOperand<spirv::Word>>,
|
||||
typ: &ast::Type,
|
||||
) -> Result<ast::CallOperand<spirv::Word>, TranslateError> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn src_member_operand(
|
||||
&mut self,
|
||||
desc: ArgumentDescriptor<(spirv::Word, u8)>,
|
||||
typ: (ast::ScalarType, u8),
|
||||
) -> Result<(spirv::Word, u8), TranslateError> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_globals<'input, 'b>(
|
||||
sorted_statements: Vec<ExpandedStatement>,
|
||||
ptx_impl_imports: &mut HashMap<String, Directive>,
|
||||
|
@ -4599,9 +4665,13 @@ fn is_64_bit_integer(id_defs: &NumericIdResolver, id: spirv::Word) -> bool {
|
|||
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
|
||||
enum PtxSpecialRegister {
|
||||
Tid,
|
||||
Tid64,
|
||||
Ntid,
|
||||
Ntid64,
|
||||
Ctaid,
|
||||
Ctaid64,
|
||||
Nctaid,
|
||||
Nctaid64,
|
||||
}
|
||||
|
||||
impl PtxSpecialRegister {
|
||||
|
@ -4618,27 +4688,116 @@ impl PtxSpecialRegister {
|
|||
fn get_type(self) -> ast::Type {
|
||||
match self {
|
||||
PtxSpecialRegister::Tid => ast::Type::Vector(ast::ScalarType::U32, 4),
|
||||
PtxSpecialRegister::Tid64 => ast::Type::Vector(ast::ScalarType::U64, 3),
|
||||
PtxSpecialRegister::Ntid => ast::Type::Vector(ast::ScalarType::U32, 4),
|
||||
PtxSpecialRegister::Ntid64 => ast::Type::Vector(ast::ScalarType::U64, 3),
|
||||
PtxSpecialRegister::Ctaid => ast::Type::Vector(ast::ScalarType::U32, 4),
|
||||
PtxSpecialRegister::Ctaid64 => ast::Type::Vector(ast::ScalarType::U64, 3),
|
||||
PtxSpecialRegister::Nctaid => ast::Type::Vector(ast::ScalarType::U32, 4),
|
||||
PtxSpecialRegister::Nctaid64 => ast::Type::Vector(ast::ScalarType::U64, 3),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_builtin(self) -> spirv::BuiltIn {
|
||||
match self {
|
||||
PtxSpecialRegister::Tid => spirv::BuiltIn::LocalInvocationId,
|
||||
PtxSpecialRegister::Ntid => spirv::BuiltIn::WorkgroupSize,
|
||||
PtxSpecialRegister::Ctaid => spirv::BuiltIn::WorkgroupId,
|
||||
PtxSpecialRegister::Nctaid => spirv::BuiltIn::NumWorkgroups,
|
||||
PtxSpecialRegister::Tid | PtxSpecialRegister::Tid64 => {
|
||||
spirv::BuiltIn::LocalInvocationId
|
||||
}
|
||||
PtxSpecialRegister::Ntid | PtxSpecialRegister::Ntid64 => spirv::BuiltIn::WorkgroupSize,
|
||||
PtxSpecialRegister::Ctaid | PtxSpecialRegister::Ctaid64 => spirv::BuiltIn::WorkgroupId,
|
||||
PtxSpecialRegister::Nctaid | PtxSpecialRegister::Nctaid64 => {
|
||||
spirv::BuiltIn::NumWorkgroups
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn push_conversion(
|
||||
self,
|
||||
id_def: &mut NumericIdResolver,
|
||||
func: &mut Vec<TypedStatement>,
|
||||
mut composite_read: CompositeRead,
|
||||
) {
|
||||
todo!()
|
||||
/*
|
||||
match self {
|
||||
PtxSpecialRegister::Tid
|
||||
| PtxSpecialRegister::Ntid
|
||||
| PtxSpecialRegister::Ctaid
|
||||
| PtxSpecialRegister::Nctaid => {
|
||||
if composite_read.src_index == 3 {
|
||||
func.push(Statement::Constant(ConstantDefinition {
|
||||
dst: composite_read.dst,
|
||||
typ: ast::ScalarType::U32,
|
||||
value: ast::ImmediateValue::U64(0),
|
||||
}));
|
||||
} else {
|
||||
let dst = composite_read.dst;
|
||||
let temp_dst =
|
||||
id_def.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::U64)));
|
||||
composite_read.dst = temp_dst;
|
||||
func.push(Statement::Composite(composite_read));
|
||||
func.push(Statement::Conversion(ImplicitConversion {
|
||||
src: temp_dst,
|
||||
dst: dst,
|
||||
from: ast::Type::Scalar(ast::ScalarType::U64),
|
||||
to: ast::Type::Scalar(ast::ScalarType::U32),
|
||||
kind: ConversionKind::Default,
|
||||
src_sema: ArgumentSemantics::Default,
|
||||
dst_sema: ArgumentSemantics::Default,
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
||||
}
|
||||
|
||||
struct SpecialRegistersMap {
|
||||
reg_to_id: HashMap<PtxSpecialRegister, spirv::Word>,
|
||||
id_to_reg: HashMap<spirv::Word, PtxSpecialRegister>,
|
||||
}
|
||||
|
||||
impl SpecialRegistersMap {
|
||||
fn new() -> Self {
|
||||
SpecialRegistersMap {
|
||||
reg_to_id: HashMap::new(),
|
||||
id_to_reg: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn builtins<'a>(&'a self) -> impl Iterator<Item = (PtxSpecialRegister, spirv::Word)> + 'a {
|
||||
self.reg_to_id.iter().map(|(reg, id)| (*reg, *id))
|
||||
}
|
||||
|
||||
fn interface(&self) -> Vec<spirv::Word> {
|
||||
self.id_to_reg.iter().map(|(id, _)| *id).collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn get(&self, id: spirv::Word) -> Option<PtxSpecialRegister> {
|
||||
self.id_to_reg.get(&id).copied()
|
||||
}
|
||||
|
||||
fn get_or_add(&mut self, current_id: &mut spirv::Word, reg: PtxSpecialRegister) -> spirv::Word {
|
||||
match self.reg_to_id.entry(reg) {
|
||||
hash_map::Entry::Occupied(e) => *e.get(),
|
||||
hash_map::Entry::Vacant(e) => {
|
||||
let numeric_id = *current_id;
|
||||
*current_id += 1;
|
||||
e.insert(numeric_id);
|
||||
self.id_to_reg.insert(numeric_id, reg);
|
||||
numeric_id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn replace(&mut self) {}
|
||||
}
|
||||
|
||||
struct GlobalStringIdResolver<'input> {
|
||||
current_id: spirv::Word,
|
||||
variables: HashMap<Cow<'input, str>, spirv::Word>,
|
||||
variables_type_check: HashMap<u32, Option<(ast::Type, bool)>>,
|
||||
special_registers: HashMap<PtxSpecialRegister, spirv::Word>,
|
||||
special_registers: SpecialRegistersMap,
|
||||
fns: HashMap<spirv::Word, FnDecl>,
|
||||
}
|
||||
|
||||
|
@ -4653,7 +4812,7 @@ impl<'a> GlobalStringIdResolver<'a> {
|
|||
current_id: start_id,
|
||||
variables: HashMap::new(),
|
||||
variables_type_check: HashMap::new(),
|
||||
special_registers: HashMap::new(),
|
||||
special_registers: SpecialRegistersMap::new(),
|
||||
fns: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
@ -4768,7 +4927,7 @@ struct FnStringIdResolver<'input, 'b> {
|
|||
current_id: &'b mut spirv::Word,
|
||||
global_variables: &'b HashMap<Cow<'input, str>, spirv::Word>,
|
||||
global_type_check: &'b HashMap<u32, Option<(ast::Type, bool)>>,
|
||||
special_registers: &'b mut HashMap<PtxSpecialRegister, spirv::Word>,
|
||||
special_registers: &'b mut SpecialRegistersMap,
|
||||
variables: Vec<HashMap<Cow<'input, str>, spirv::Word>>,
|
||||
type_check: HashMap<u32, Option<(ast::Type, bool)>>,
|
||||
}
|
||||
|
@ -4779,11 +4938,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
|
|||
current_id: self.current_id,
|
||||
global_type_check: self.global_type_check,
|
||||
type_check: self.type_check,
|
||||
special_registers: self
|
||||
.special_registers
|
||||
.iter()
|
||||
.map(|(reg, id)| (*id, *reg))
|
||||
.collect(),
|
||||
special_registers: self.special_registers,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -4807,15 +4962,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
|
|||
None => {
|
||||
let sreg =
|
||||
PtxSpecialRegister::try_parse(id).ok_or(TranslateError::UnknownSymbol)?;
|
||||
match self.special_registers.entry(sreg) {
|
||||
hash_map::Entry::Occupied(e) => Ok(*e.get()),
|
||||
hash_map::Entry::Vacant(e) => {
|
||||
let numeric_id = *self.current_id;
|
||||
*self.current_id += 1;
|
||||
e.insert(numeric_id);
|
||||
Ok(numeric_id)
|
||||
}
|
||||
}
|
||||
Ok(self.special_registers.get_or_add(self.current_id, sreg))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4858,7 +5005,7 @@ struct NumericIdResolver<'b> {
|
|||
current_id: &'b mut spirv::Word,
|
||||
global_type_check: &'b HashMap<u32, Option<(ast::Type, bool)>>,
|
||||
type_check: HashMap<u32, Option<(ast::Type, bool)>>,
|
||||
special_registers: HashMap<spirv::Word, PtxSpecialRegister>,
|
||||
special_registers: &'b mut SpecialRegistersMap,
|
||||
}
|
||||
|
||||
impl<'b> NumericIdResolver<'b> {
|
||||
|
@ -4870,7 +5017,7 @@ impl<'b> NumericIdResolver<'b> {
|
|||
match self.type_check.get(&id) {
|
||||
Some(Some(x)) => Ok(x.clone()),
|
||||
Some(None) => Err(TranslateError::UntypedSymbol),
|
||||
None => match self.special_registers.get(&id) {
|
||||
None => match self.special_registers.get(id) {
|
||||
Some(x) => Ok((x.get_type(), true)),
|
||||
None => match self.global_type_check.get(&id) {
|
||||
Some(Some(result)) => Ok(result.clone()),
|
||||
|
|
Loading…
Add table
Reference in a new issue