Port first pass

This commit is contained in:
Andrzej Janik 2024-08-23 03:03:57 +02:00
parent 1ec1ca0c30
commit 12ef8dbc90
5 changed files with 421 additions and 108 deletions

View file

@ -24,11 +24,10 @@ lalrpop_mod!(
);
pub mod ast;
mod pass;
pub(crate) mod pass;
#[cfg(test)]
mod test;
mod translate;
mod translate2;
use std::fmt;

View file

@ -1,12 +1,347 @@
use ptx_parser as ast;
use rspirv::{binary::Assemble, dr};
use std::{
borrow::Cow,
cell::RefCell,
collections::{hash_map, HashMap},
ffi::CString,
rc::Rc,
};
mod normalize;
pub(crate) mod normalize;
static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv");
static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl__";
pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result<Module, TranslateError> {
let mut id_defs = GlobalStringIdResolver::<'input>::new(SpirvWord(1));
let mut ptx_impl_imports = HashMap::new();
let directives = ast
.directives
.into_iter()
.filter_map(|directive| {
translate_directive(&mut id_defs, &mut ptx_impl_imports, directive).transpose()
})
.collect::<Result<Vec<_>, _>>()?;
/*
let directives = hoist_function_globals(directives);
let must_link_ptx_impl = ptx_impl_imports.len() > 0;
let mut directives = ptx_impl_imports
.into_iter()
.map(|(_, v)| v)
.chain(directives.into_iter())
.collect::<Vec<_>>();
let mut builder = dr::Builder::new();
builder.reserve_ids(id_defs.current_id());
let call_map = MethodsCallMap::new(&directives);
let mut directives =
convert_dynamic_shared_memory_usage(directives, &call_map, &mut || builder.id());
normalize_variable_decls(&mut directives);
let denorm_information = compute_denorm_information(&directives);
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
builder.set_version(1, 3);
emit_capabilities(&mut builder);
emit_extensions(&mut builder);
let opencl_id = emit_opencl_import(&mut builder);
emit_memory_model(&mut builder);
let mut map = TypeWordMap::new(&mut builder);
//emit_builtins(&mut builder, &mut map, &id_defs);
let mut kernel_info = HashMap::new();
let (build_options, should_flush_denorms) =
emit_denorm_build_string(&call_map, &denorm_information);
let (directives, globals_use_map) = get_globals_use_map(directives);
emit_directives(
&mut builder,
&mut map,
&id_defs,
opencl_id,
should_flush_denorms,
&call_map,
globals_use_map,
directives,
&mut kernel_info,
)?;
let spirv = builder.module();
Ok(Module {
spirv,
kernel_info,
should_link_ptx_impl: if must_link_ptx_impl {
Some((ZLUDA_PTX_IMPL_INTEL, ZLUDA_PTX_IMPL_AMD))
} else {
None
},
build_options,
})
*/
todo!()
}
fn translate_directive<'input, 'a>(
id_defs: &'a mut GlobalStringIdResolver<'input>,
ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
d: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
) -> Result<Option<Directive<'input>>, TranslateError> {
Ok(match d {
ast::Directive::Variable(linking, var) => Some(Directive::Variable(
linking,
ast::Variable {
align: var.align,
v_type: var.v_type.clone(),
state_space: var.state_space,
name: id_defs.get_or_add_def_typed(var.name, var.v_type, var.state_space, true),
array_init: var.array_init,
},
)),
ast::Directive::Method(linkage, f) => {
translate_function(id_defs, ptx_impl_imports, linkage, f)?.map(Directive::Method)
}
})
}
type ParsedFunction<'a> = ast::Function<'a, &'a str, ast::Statement<ast::ParsedOperand<&'a str>>>;
fn translate_function<'input, 'a>(
id_defs: &'a mut GlobalStringIdResolver<'input>,
ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
linkage: ast::LinkingDirective,
f: ParsedFunction<'input>,
) -> Result<Option<Function<'input>>, TranslateError> {
let import_as = match &f.func_directive {
ast::MethodDeclaration {
name: ast::MethodName::Func(func_name),
..
} if *func_name == "__assertfail" || *func_name == "vprintf" => {
Some([ZLUDA_PTX_PREFIX, func_name].concat())
}
_ => None,
};
let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive)?;
let mut func = to_ssa(
ptx_impl_imports,
str_resolver,
fn_resolver,
fn_decl,
f.body,
f.tuning,
linkage,
)?;
func.import_as = import_as;
if func.import_as.is_some() {
ptx_impl_imports.insert(
func.import_as.as_ref().unwrap().clone(),
Directive::Method(func),
);
Ok(None)
} else {
Ok(Some(func))
}
}
fn to_ssa<'input, 'b>(
ptx_impl_imports: &'b mut HashMap<String, Directive<'input>>,
mut id_defs: FnStringIdResolver<'input, 'b>,
fn_defs: GlobalFnDeclResolver<'input, 'b>,
func_decl: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
f_body: Option<Vec<ast::Statement<ast::ParsedOperand<&'input str>>>>,
tuning: Vec<ast::TuningDirective>,
linkage: ast::LinkingDirective,
) -> Result<Function<'input>, TranslateError> {
//deparamize_function_decl(&func_decl)?;
let f_body = match f_body {
Some(vec) => vec,
None => {
return Ok(Function {
func_decl: func_decl,
body: None,
globals: Vec::new(),
import_as: None,
tuning,
linkage,
})
}
};
let normalized_ids = normalize::run(&mut id_defs, &fn_defs, f_body)?;
todo!()
/*
let mut numeric_id_defs = id_defs.finish();
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
let typed_statements =
convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
let typed_statements =
fix_special_registers2(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?;
let (func_decl, typed_statements) =
convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?;
let ssa_statements = insert_mem_ssa_statements(
typed_statements,
&mut numeric_id_defs,
&mut (*func_decl).borrow_mut(),
)?;
let mut numeric_id_defs = numeric_id_defs.finish();
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
let expanded_statements =
insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?;
let mut numeric_id_defs = numeric_id_defs.unmut();
let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs);
let (f_body, globals) =
extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?;
Ok(Function {
func_decl: func_decl,
globals: globals,
body: Some(f_body),
import_as: None,
tuning,
linkage,
})
*/
}
pub struct Module {
pub spirv: dr::Module,
pub kernel_info: HashMap<String, KernelInfo>,
pub should_link_ptx_impl: Option<(&'static [u8], &'static [u8])>,
pub build_options: CString,
}
impl Module {
pub fn assemble(&self) -> Vec<u32> {
self.spirv.assemble()
}
}
struct GlobalStringIdResolver<'input> {
current_id: SpirvWord,
variables: HashMap<Cow<'input, str>, SpirvWord>,
reverse_variables: HashMap<SpirvWord, &'input str>,
variables_type_check: HashMap<SpirvWord, Option<(ast::Type, ast::StateSpace, bool)>>,
special_registers: SpecialRegistersMap,
fns: HashMap<SpirvWord, FnSigMapper<'input>>,
}
impl<'input> GlobalStringIdResolver<'input> {
fn new(start_id: SpirvWord) -> Self {
Self {
current_id: start_id,
variables: HashMap::new(),
reverse_variables: HashMap::new(),
variables_type_check: HashMap::new(),
special_registers: SpecialRegistersMap::new(),
fns: HashMap::new(),
}
}
fn get_or_add_def(&mut self, id: &'input str) -> SpirvWord {
self.get_or_add_impl(id, None)
}
fn get_or_add_def_typed(
&mut self,
id: &'input str,
typ: ast::Type,
state_space: ast::StateSpace,
is_variable: bool,
) -> SpirvWord {
self.get_or_add_impl(id, Some((typ, state_space, is_variable)))
}
fn get_or_add_impl(
&mut self,
id: &'input str,
typ: Option<(ast::Type, ast::StateSpace, bool)>,
) -> SpirvWord {
let id = match self.variables.entry(Cow::Borrowed(id)) {
hash_map::Entry::Occupied(e) => *(e.get()),
hash_map::Entry::Vacant(e) => {
let numeric_id = self.current_id;
e.insert(numeric_id);
self.reverse_variables.insert(numeric_id, id);
self.current_id.0 += 1;
numeric_id
}
};
self.variables_type_check.insert(id, typ);
id
}
fn get_id(&self, id: &str) -> Result<SpirvWord, TranslateError> {
self.variables
.get(id)
.copied()
.ok_or_else(error_unknown_symbol)
}
fn current_id(&self) -> SpirvWord {
self.current_id
}
fn start_fn<'b>(
&'b mut self,
header: &'b ast::MethodDeclaration<'input, &'input str>,
) -> Result<
(
FnStringIdResolver<'input, 'b>,
GlobalFnDeclResolver<'input, 'b>,
Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
),
TranslateError,
> {
// In case a function decl was inserted earlier we want to use its id
let name_id = self.get_or_add_def(header.name());
let mut fn_resolver = FnStringIdResolver {
current_id: &mut self.current_id,
global_variables: &self.variables,
global_type_check: &self.variables_type_check,
special_registers: &mut self.special_registers,
variables: vec![HashMap::new(); 1],
type_check: HashMap::new(),
};
let return_arguments = rename_fn_params(&mut fn_resolver, &header.return_arguments);
let input_arguments = rename_fn_params(&mut fn_resolver, &header.input_arguments);
let name = match header.name {
ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
ast::MethodName::Func(_) => ast::MethodName::Func(name_id),
};
let fn_decl = ast::MethodDeclaration {
return_arguments,
name,
input_arguments,
shared_mem: None,
};
let new_fn_decl = if !matches!(fn_decl.name, ast::MethodName::Kernel(_)) {
let resolver = FnSigMapper::remap_to_spirv_repr(fn_decl);
let new_fn_decl = resolver.func_decl.clone();
self.fns.insert(name_id, resolver);
new_fn_decl
} else {
Rc::new(RefCell::new(fn_decl))
};
Ok((
fn_resolver,
GlobalFnDeclResolver { fns: &self.fns },
new_fn_decl,
))
}
}
fn rename_fn_params<'a, 'b>(
fn_resolver: &mut FnStringIdResolver<'a, 'b>,
args: &'b [ast::Variable<&'a str>],
) -> Vec<ast::Variable<SpirvWord>> {
args.iter()
.map(|a| ast::Variable {
name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), true),
v_type: a.v_type.clone(),
state_space: a.state_space,
align: a.align,
array_init: a.array_init.clone(),
})
.collect()
}
pub struct KernelInfo {
pub arguments_sizes: Vec<(usize, bool)>,
pub uses_shared_mem: bool,
}
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
enum PtxSpecialRegister {
@ -108,10 +443,10 @@ impl SpecialRegistersMap {
struct FnStringIdResolver<'input, 'b> {
current_id: &'b mut SpirvWord,
global_variables: &'b HashMap<Cow<'input, str>, SpirvWord>,
global_type_check: &'b HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
global_type_check: &'b HashMap<SpirvWord, Option<(ast::Type, ast::StateSpace, bool)>>,
special_registers: &'b mut SpecialRegistersMap,
variables: Vec<HashMap<Cow<'input, str>, SpirvWord>>,
type_check: HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
type_check: HashMap<SpirvWord, Option<(ast::Type, ast::StateSpace, bool)>>,
}
impl<'a, 'b> FnStringIdResolver<'a, 'b> {
@ -160,7 +495,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
.unwrap()
.insert(Cow::Borrowed(id), numeric_id);
self.type_check.insert(
numeric_id.0,
numeric_id,
typ.map(|(typ, space)| (typ, space, is_variable)),
);
self.current_id.0 += 1;
@ -183,7 +518,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
SpirvWord(numeric_id.0 + i),
);
self.type_check.insert(
numeric_id.0 + i,
SpirvWord(numeric_id.0 + i),
Some((typ.clone(), state_space, is_variable)),
);
}
@ -196,8 +531,8 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
struct NumericIdResolver<'b> {
current_id: &'b mut SpirvWord,
global_type_check: &'b HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
type_check: HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
global_type_check: &'b HashMap<SpirvWord, Option<(ast::Type, ast::StateSpace, bool)>>,
type_check: HashMap<SpirvWord, Option<(ast::Type, ast::StateSpace, bool)>>,
special_registers: &'b mut SpecialRegistersMap,
}
@ -210,12 +545,12 @@ impl<'b> NumericIdResolver<'b> {
&self,
id: SpirvWord,
) -> Result<(ast::Type, ast::StateSpace, bool), TranslateError> {
match self.type_check.get(&id.0) {
match self.type_check.get(&id) {
Some(Some(x)) => Ok(x.clone()),
Some(None) => Err(TranslateError::UntypedSymbol),
None => match self.special_registers.get(id) {
Some(x) => Ok((x.get_type(), ast::StateSpace::Sreg, true)),
None => match self.global_type_check.get(&id.0) {
None => match self.global_type_check.get(&id) {
Some(Some(result)) => Ok(result.clone()),
Some(None) | None => Err(TranslateError::UntypedSymbol),
},
@ -228,7 +563,7 @@ impl<'b> NumericIdResolver<'b> {
fn register_variable(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> SpirvWord {
let new_id = *self.current_id;
self.type_check
.insert(new_id.0, Some((typ, state_space, true)));
.insert(new_id, Some((typ, state_space, true)));
self.current_id.0 += 1;
new_id
}
@ -236,7 +571,7 @@ impl<'b> NumericIdResolver<'b> {
fn register_intermediate(&mut self, typ: Option<(ast::Type, ast::StateSpace)>) -> SpirvWord {
let new_id = *self.current_id;
self.type_check
.insert(new_id.0, typ.map(|(t, space)| (t, space, false)));
.insert(new_id, typ.map(|(t, space)| (t, space, false)));
self.current_id.0 += 1;
new_id
}
@ -490,6 +825,10 @@ impl From<SpirvWord> for spirv::Word {
impl ast::Operand for SpirvWord {
type Ident = Self;
fn from_ident(ident: Self::Ident) -> Self {
ident
}
}
fn pred_map_variable<U, T, F: FnMut(T) -> Result<U, TranslateError>>(
@ -503,29 +842,18 @@ fn pred_map_variable<U, T, F: FnMut(T) -> Result<U, TranslateError>>(
})
}
impl<T: ast::Operand, U: ast::Operand, X: FnMut(&str) -> Result<SpirvWord, Err>, Err> ast::VisitorMap<T, U, Err> for X {
fn visit(
&mut self,
args: T,
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
is_dst: bool,
) -> U {
todo!()
}
fn visit_ident(
&mut self,
args: <T as ptx_parser::Operand>::Ident,
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
is_dst: bool,
) -> <U as ptx_parser::Operand>::Ident {
todo!()
}
pub(crate) enum Directive<'input> {
Variable(ast::LinkingDirective, ast::Variable<SpirvWord>),
Method(Function<'input>),
}
fn op_map_variable<'a, F: FnMut(&str) -> Result<SpirvWord, TranslateError>>(
this: ast::Instruction<ast::ParsedOperand<&'a str>>,
f: &mut F,
) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
ast::visit_map(this , f)
pub(crate) struct Function<'input> {
pub func_decl: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
pub globals: Vec<ast::Variable<SpirvWord>>,
pub body: Option<Vec<ExpandedStatement>>,
import_as: Option<String>,
tuning: Vec<ast::TuningDirective>,
linkage: ast::LinkingDirective,
}
type ExpandedStatement = Statement<ast::Instruction<SpirvWord>, SpirvWord>;

View file

@ -9,7 +9,7 @@ type NormalizedStatement = Statement<
ast::ParsedOperand<SpirvWord>,
>;
fn run<'input, 'b>(
pub(crate) fn run<'input, 'b>(
id_defs: &mut FnStringIdResolver<'input, 'b>,
fn_defs: &GlobalFnDeclResolver<'input, 'b>,
func: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
@ -47,7 +47,11 @@ fn expand_map_variables<'a, 'b>(
ast::Statement::Instruction(p, i) => result.push(Statement::Instruction((
p.map(|p| pred_map_variable(p, &mut |id| id_defs.get_id(id)))
.transpose()?,
op_map_variable(i, &mut |id| id_defs.get_id(id))?,
ast::visit_map(i, &mut |id,
_: Option<(&ast::Type, ast::StateSpace)>,
_: bool| {
id_defs.get_id(id)
})?,
))),
ast::Statement::Variable(var) => {
let var_type = var.var.v_type.clone();

View file

@ -1,60 +0,0 @@
use std::collections::HashMap;
use half::f16;
use ptx_parser as ast;
fn to_ssa<'input, 'b>(
ptx_impl_imports: &'b mut HashMap<String, Directive<'input>>,
mut id_defs: FnStringIdResolver<'input, 'b>,
fn_defs: GlobalFnDeclResolver<'input, 'b>,
func_decl: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
f_body: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>>,
tuning: Vec<ast::TuningDirective>,
linkage: ast::LinkingDirective,
) -> Result<Function<'input>, TranslateError> {
//deparamize_function_decl(&func_decl)?;
let f_body = match f_body {
Some(vec) => vec,
None => {
return Ok(Function {
func_decl: func_decl,
body: None,
globals: Vec::new(),
import_as: None,
tuning,
linkage,
})
}
};
let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body)?;
/*
let mut numeric_id_defs = id_defs.finish();
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
let typed_statements =
convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
let typed_statements =
fix_special_registers2(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?;
let (func_decl, typed_statements) =
convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?;
let ssa_statements = insert_mem_ssa_statements(
typed_statements,
&mut numeric_id_defs,
&mut (*func_decl).borrow_mut(),
)?;
let mut numeric_id_defs = numeric_id_defs.finish();
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
let expanded_statements =
insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?;
let mut numeric_id_defs = numeric_id_defs.unmut();
let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs);
let (f_body, globals) =
extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?;
Ok(Function {
func_decl: func_decl,
globals: globals,
body: Some(f_body),
import_as: None,
tuning,
linkage,
})
*/
}

View file

@ -555,12 +555,46 @@ pub trait VisitorMap<From: Operand, To: Operand, Err> {
) -> Result<To::Ident, Err>;
}
impl<
T: Operand,
U: Operand,
Err,
Fn: FnMut(T, Option<(&Type, StateSpace)>, bool) -> Result<U, Err>,
> VisitorMap<T, U, Err> for Fn
impl<T: Copy, U: Copy, Err, Fn> VisitorMap<ParsedOperand<T>, ParsedOperand<U>, Err> for Fn
where
Fn: FnMut(T, Option<(&Type, StateSpace)>, bool) -> Result<U, Err>,
{
fn visit(
&mut self,
args: ParsedOperand<T>,
type_space: Option<(&Type, StateSpace)>,
is_dst: bool,
) -> Result<ParsedOperand<U>, Err> {
Ok(match args {
ParsedOperand::Reg(ident) => ParsedOperand::Reg((self)(ident, type_space, is_dst)?),
ParsedOperand::RegOffset(ident, imm) => {
ParsedOperand::RegOffset((self)(ident, type_space, is_dst)?, imm)
}
ParsedOperand::Imm(imm) => ParsedOperand::Imm(imm),
ParsedOperand::VecMember(ident, index) => {
ParsedOperand::VecMember((self)(ident, type_space, is_dst)?, index)
}
ParsedOperand::VecPack(vec) => ParsedOperand::VecPack(
vec.into_iter()
.map(|ident| (self)(ident, type_space, is_dst))
.collect::<Result<Vec<_>, _>>()?,
),
})
}
fn visit_ident(
&mut self,
args: T,
type_space: Option<(&Type, StateSpace)>,
is_dst: bool,
) -> Result<U, Err> {
(self)(args, type_space, is_dst)
}
}
impl<T: Operand<Ident = T>, U: Operand<Ident = U>, Err, Fn> VisitorMap<T, U, Err> for Fn
where
Fn: FnMut(T, Option<(&Type, StateSpace)>, bool) -> Result<U, Err>,
{
fn visit(
&mut self,
@ -573,12 +607,11 @@ impl<
fn visit_ident(
&mut self,
args: T::Ident,
args: T,
type_space: Option<(&Type, StateSpace)>,
is_dst: bool,
) -> Result<U::Ident, Err> {
let value: U = (self)(T::from_ident(args), type_space, is_dst)?;
Ok(value)
) -> Result<U, Err> {
(self)(args, type_space, is_dst)
}
}
@ -925,6 +958,15 @@ pub struct MethodDeclaration<'input, ID> {
pub shared_mem: Option<ID>,
}
impl<'input> MethodDeclaration<'input, &'input str> {
pub fn name(&self) -> &'input str {
match self.name {
MethodName::Kernel(n) => n,
MethodName::Func(n) => n,
}
}
}
#[derive(Hash, PartialEq, Eq, Copy, Clone)]
pub enum MethodName<'input, ID> {
Kernel(&'input str),