mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-04 15:19:49 +00:00
Port first pass
This commit is contained in:
parent
1ec1ca0c30
commit
12ef8dbc90
5 changed files with 421 additions and 108 deletions
|
@ -24,11 +24,10 @@ lalrpop_mod!(
|
||||||
);
|
);
|
||||||
|
|
||||||
pub mod ast;
|
pub mod ast;
|
||||||
mod pass;
|
pub(crate) mod pass;
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test;
|
mod test;
|
||||||
mod translate;
|
mod translate;
|
||||||
mod translate2;
|
|
||||||
|
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,347 @@
|
||||||
use ptx_parser as ast;
|
use ptx_parser as ast;
|
||||||
|
use rspirv::{binary::Assemble, dr};
|
||||||
use std::{
|
use std::{
|
||||||
borrow::Cow,
|
borrow::Cow,
|
||||||
cell::RefCell,
|
cell::RefCell,
|
||||||
collections::{hash_map, HashMap},
|
collections::{hash_map, HashMap},
|
||||||
|
ffi::CString,
|
||||||
rc::Rc,
|
rc::Rc,
|
||||||
};
|
};
|
||||||
|
|
||||||
mod normalize;
|
pub(crate) mod normalize;
|
||||||
|
|
||||||
|
static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv");
|
||||||
|
static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
|
||||||
|
const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl__";
|
||||||
|
|
||||||
|
pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result<Module, TranslateError> {
|
||||||
|
let mut id_defs = GlobalStringIdResolver::<'input>::new(SpirvWord(1));
|
||||||
|
let mut ptx_impl_imports = HashMap::new();
|
||||||
|
let directives = ast
|
||||||
|
.directives
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|directive| {
|
||||||
|
translate_directive(&mut id_defs, &mut ptx_impl_imports, directive).transpose()
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
/*
|
||||||
|
let directives = hoist_function_globals(directives);
|
||||||
|
let must_link_ptx_impl = ptx_impl_imports.len() > 0;
|
||||||
|
let mut directives = ptx_impl_imports
|
||||||
|
.into_iter()
|
||||||
|
.map(|(_, v)| v)
|
||||||
|
.chain(directives.into_iter())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let mut builder = dr::Builder::new();
|
||||||
|
builder.reserve_ids(id_defs.current_id());
|
||||||
|
let call_map = MethodsCallMap::new(&directives);
|
||||||
|
let mut directives =
|
||||||
|
convert_dynamic_shared_memory_usage(directives, &call_map, &mut || builder.id());
|
||||||
|
normalize_variable_decls(&mut directives);
|
||||||
|
let denorm_information = compute_denorm_information(&directives);
|
||||||
|
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
|
||||||
|
builder.set_version(1, 3);
|
||||||
|
emit_capabilities(&mut builder);
|
||||||
|
emit_extensions(&mut builder);
|
||||||
|
let opencl_id = emit_opencl_import(&mut builder);
|
||||||
|
emit_memory_model(&mut builder);
|
||||||
|
let mut map = TypeWordMap::new(&mut builder);
|
||||||
|
//emit_builtins(&mut builder, &mut map, &id_defs);
|
||||||
|
let mut kernel_info = HashMap::new();
|
||||||
|
let (build_options, should_flush_denorms) =
|
||||||
|
emit_denorm_build_string(&call_map, &denorm_information);
|
||||||
|
let (directives, globals_use_map) = get_globals_use_map(directives);
|
||||||
|
emit_directives(
|
||||||
|
&mut builder,
|
||||||
|
&mut map,
|
||||||
|
&id_defs,
|
||||||
|
opencl_id,
|
||||||
|
should_flush_denorms,
|
||||||
|
&call_map,
|
||||||
|
globals_use_map,
|
||||||
|
directives,
|
||||||
|
&mut kernel_info,
|
||||||
|
)?;
|
||||||
|
let spirv = builder.module();
|
||||||
|
Ok(Module {
|
||||||
|
spirv,
|
||||||
|
kernel_info,
|
||||||
|
should_link_ptx_impl: if must_link_ptx_impl {
|
||||||
|
Some((ZLUDA_PTX_IMPL_INTEL, ZLUDA_PTX_IMPL_AMD))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
},
|
||||||
|
build_options,
|
||||||
|
})
|
||||||
|
*/
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn translate_directive<'input, 'a>(
|
||||||
|
id_defs: &'a mut GlobalStringIdResolver<'input>,
|
||||||
|
ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
|
||||||
|
d: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
|
||||||
|
) -> Result<Option<Directive<'input>>, TranslateError> {
|
||||||
|
Ok(match d {
|
||||||
|
ast::Directive::Variable(linking, var) => Some(Directive::Variable(
|
||||||
|
linking,
|
||||||
|
ast::Variable {
|
||||||
|
align: var.align,
|
||||||
|
v_type: var.v_type.clone(),
|
||||||
|
state_space: var.state_space,
|
||||||
|
name: id_defs.get_or_add_def_typed(var.name, var.v_type, var.state_space, true),
|
||||||
|
array_init: var.array_init,
|
||||||
|
},
|
||||||
|
)),
|
||||||
|
ast::Directive::Method(linkage, f) => {
|
||||||
|
translate_function(id_defs, ptx_impl_imports, linkage, f)?.map(Directive::Method)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type ParsedFunction<'a> = ast::Function<'a, &'a str, ast::Statement<ast::ParsedOperand<&'a str>>>;
|
||||||
|
|
||||||
|
fn translate_function<'input, 'a>(
|
||||||
|
id_defs: &'a mut GlobalStringIdResolver<'input>,
|
||||||
|
ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
|
||||||
|
linkage: ast::LinkingDirective,
|
||||||
|
f: ParsedFunction<'input>,
|
||||||
|
) -> Result<Option<Function<'input>>, TranslateError> {
|
||||||
|
let import_as = match &f.func_directive {
|
||||||
|
ast::MethodDeclaration {
|
||||||
|
name: ast::MethodName::Func(func_name),
|
||||||
|
..
|
||||||
|
} if *func_name == "__assertfail" || *func_name == "vprintf" => {
|
||||||
|
Some([ZLUDA_PTX_PREFIX, func_name].concat())
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive)?;
|
||||||
|
let mut func = to_ssa(
|
||||||
|
ptx_impl_imports,
|
||||||
|
str_resolver,
|
||||||
|
fn_resolver,
|
||||||
|
fn_decl,
|
||||||
|
f.body,
|
||||||
|
f.tuning,
|
||||||
|
linkage,
|
||||||
|
)?;
|
||||||
|
func.import_as = import_as;
|
||||||
|
if func.import_as.is_some() {
|
||||||
|
ptx_impl_imports.insert(
|
||||||
|
func.import_as.as_ref().unwrap().clone(),
|
||||||
|
Directive::Method(func),
|
||||||
|
);
|
||||||
|
Ok(None)
|
||||||
|
} else {
|
||||||
|
Ok(Some(func))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_ssa<'input, 'b>(
|
||||||
|
ptx_impl_imports: &'b mut HashMap<String, Directive<'input>>,
|
||||||
|
mut id_defs: FnStringIdResolver<'input, 'b>,
|
||||||
|
fn_defs: GlobalFnDeclResolver<'input, 'b>,
|
||||||
|
func_decl: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
|
||||||
|
f_body: Option<Vec<ast::Statement<ast::ParsedOperand<&'input str>>>>,
|
||||||
|
tuning: Vec<ast::TuningDirective>,
|
||||||
|
linkage: ast::LinkingDirective,
|
||||||
|
) -> Result<Function<'input>, TranslateError> {
|
||||||
|
//deparamize_function_decl(&func_decl)?;
|
||||||
|
let f_body = match f_body {
|
||||||
|
Some(vec) => vec,
|
||||||
|
None => {
|
||||||
|
return Ok(Function {
|
||||||
|
func_decl: func_decl,
|
||||||
|
body: None,
|
||||||
|
globals: Vec::new(),
|
||||||
|
import_as: None,
|
||||||
|
tuning,
|
||||||
|
linkage,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let normalized_ids = normalize::run(&mut id_defs, &fn_defs, f_body)?;
|
||||||
|
todo!()
|
||||||
|
/*
|
||||||
|
let mut numeric_id_defs = id_defs.finish();
|
||||||
|
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
|
||||||
|
let typed_statements =
|
||||||
|
convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
|
||||||
|
let typed_statements =
|
||||||
|
fix_special_registers2(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?;
|
||||||
|
let (func_decl, typed_statements) =
|
||||||
|
convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?;
|
||||||
|
let ssa_statements = insert_mem_ssa_statements(
|
||||||
|
typed_statements,
|
||||||
|
&mut numeric_id_defs,
|
||||||
|
&mut (*func_decl).borrow_mut(),
|
||||||
|
)?;
|
||||||
|
let mut numeric_id_defs = numeric_id_defs.finish();
|
||||||
|
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
|
||||||
|
let expanded_statements =
|
||||||
|
insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?;
|
||||||
|
let mut numeric_id_defs = numeric_id_defs.unmut();
|
||||||
|
let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs);
|
||||||
|
let (f_body, globals) =
|
||||||
|
extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?;
|
||||||
|
Ok(Function {
|
||||||
|
func_decl: func_decl,
|
||||||
|
globals: globals,
|
||||||
|
body: Some(f_body),
|
||||||
|
import_as: None,
|
||||||
|
tuning,
|
||||||
|
linkage,
|
||||||
|
})
|
||||||
|
*/
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Module {
|
||||||
|
pub spirv: dr::Module,
|
||||||
|
pub kernel_info: HashMap<String, KernelInfo>,
|
||||||
|
pub should_link_ptx_impl: Option<(&'static [u8], &'static [u8])>,
|
||||||
|
pub build_options: CString,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module {
|
||||||
|
pub fn assemble(&self) -> Vec<u32> {
|
||||||
|
self.spirv.assemble()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct GlobalStringIdResolver<'input> {
|
||||||
|
current_id: SpirvWord,
|
||||||
|
variables: HashMap<Cow<'input, str>, SpirvWord>,
|
||||||
|
reverse_variables: HashMap<SpirvWord, &'input str>,
|
||||||
|
variables_type_check: HashMap<SpirvWord, Option<(ast::Type, ast::StateSpace, bool)>>,
|
||||||
|
special_registers: SpecialRegistersMap,
|
||||||
|
fns: HashMap<SpirvWord, FnSigMapper<'input>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'input> GlobalStringIdResolver<'input> {
|
||||||
|
fn new(start_id: SpirvWord) -> Self {
|
||||||
|
Self {
|
||||||
|
current_id: start_id,
|
||||||
|
variables: HashMap::new(),
|
||||||
|
reverse_variables: HashMap::new(),
|
||||||
|
variables_type_check: HashMap::new(),
|
||||||
|
special_registers: SpecialRegistersMap::new(),
|
||||||
|
fns: HashMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_or_add_def(&mut self, id: &'input str) -> SpirvWord {
|
||||||
|
self.get_or_add_impl(id, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_or_add_def_typed(
|
||||||
|
&mut self,
|
||||||
|
id: &'input str,
|
||||||
|
typ: ast::Type,
|
||||||
|
state_space: ast::StateSpace,
|
||||||
|
is_variable: bool,
|
||||||
|
) -> SpirvWord {
|
||||||
|
self.get_or_add_impl(id, Some((typ, state_space, is_variable)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_or_add_impl(
|
||||||
|
&mut self,
|
||||||
|
id: &'input str,
|
||||||
|
typ: Option<(ast::Type, ast::StateSpace, bool)>,
|
||||||
|
) -> SpirvWord {
|
||||||
|
let id = match self.variables.entry(Cow::Borrowed(id)) {
|
||||||
|
hash_map::Entry::Occupied(e) => *(e.get()),
|
||||||
|
hash_map::Entry::Vacant(e) => {
|
||||||
|
let numeric_id = self.current_id;
|
||||||
|
e.insert(numeric_id);
|
||||||
|
self.reverse_variables.insert(numeric_id, id);
|
||||||
|
self.current_id.0 += 1;
|
||||||
|
numeric_id
|
||||||
|
}
|
||||||
|
};
|
||||||
|
self.variables_type_check.insert(id, typ);
|
||||||
|
id
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_id(&self, id: &str) -> Result<SpirvWord, TranslateError> {
|
||||||
|
self.variables
|
||||||
|
.get(id)
|
||||||
|
.copied()
|
||||||
|
.ok_or_else(error_unknown_symbol)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn current_id(&self) -> SpirvWord {
|
||||||
|
self.current_id
|
||||||
|
}
|
||||||
|
|
||||||
|
fn start_fn<'b>(
|
||||||
|
&'b mut self,
|
||||||
|
header: &'b ast::MethodDeclaration<'input, &'input str>,
|
||||||
|
) -> Result<
|
||||||
|
(
|
||||||
|
FnStringIdResolver<'input, 'b>,
|
||||||
|
GlobalFnDeclResolver<'input, 'b>,
|
||||||
|
Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
|
||||||
|
),
|
||||||
|
TranslateError,
|
||||||
|
> {
|
||||||
|
// In case a function decl was inserted earlier we want to use its id
|
||||||
|
let name_id = self.get_or_add_def(header.name());
|
||||||
|
let mut fn_resolver = FnStringIdResolver {
|
||||||
|
current_id: &mut self.current_id,
|
||||||
|
global_variables: &self.variables,
|
||||||
|
global_type_check: &self.variables_type_check,
|
||||||
|
special_registers: &mut self.special_registers,
|
||||||
|
variables: vec![HashMap::new(); 1],
|
||||||
|
type_check: HashMap::new(),
|
||||||
|
};
|
||||||
|
let return_arguments = rename_fn_params(&mut fn_resolver, &header.return_arguments);
|
||||||
|
let input_arguments = rename_fn_params(&mut fn_resolver, &header.input_arguments);
|
||||||
|
let name = match header.name {
|
||||||
|
ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
|
||||||
|
ast::MethodName::Func(_) => ast::MethodName::Func(name_id),
|
||||||
|
};
|
||||||
|
let fn_decl = ast::MethodDeclaration {
|
||||||
|
return_arguments,
|
||||||
|
name,
|
||||||
|
input_arguments,
|
||||||
|
shared_mem: None,
|
||||||
|
};
|
||||||
|
let new_fn_decl = if !matches!(fn_decl.name, ast::MethodName::Kernel(_)) {
|
||||||
|
let resolver = FnSigMapper::remap_to_spirv_repr(fn_decl);
|
||||||
|
let new_fn_decl = resolver.func_decl.clone();
|
||||||
|
self.fns.insert(name_id, resolver);
|
||||||
|
new_fn_decl
|
||||||
|
} else {
|
||||||
|
Rc::new(RefCell::new(fn_decl))
|
||||||
|
};
|
||||||
|
Ok((
|
||||||
|
fn_resolver,
|
||||||
|
GlobalFnDeclResolver { fns: &self.fns },
|
||||||
|
new_fn_decl,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rename_fn_params<'a, 'b>(
|
||||||
|
fn_resolver: &mut FnStringIdResolver<'a, 'b>,
|
||||||
|
args: &'b [ast::Variable<&'a str>],
|
||||||
|
) -> Vec<ast::Variable<SpirvWord>> {
|
||||||
|
args.iter()
|
||||||
|
.map(|a| ast::Variable {
|
||||||
|
name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), true),
|
||||||
|
v_type: a.v_type.clone(),
|
||||||
|
state_space: a.state_space,
|
||||||
|
align: a.align,
|
||||||
|
array_init: a.array_init.clone(),
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct KernelInfo {
|
||||||
|
pub arguments_sizes: Vec<(usize, bool)>,
|
||||||
|
pub uses_shared_mem: bool,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
|
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
|
||||||
enum PtxSpecialRegister {
|
enum PtxSpecialRegister {
|
||||||
|
@ -108,10 +443,10 @@ impl SpecialRegistersMap {
|
||||||
struct FnStringIdResolver<'input, 'b> {
|
struct FnStringIdResolver<'input, 'b> {
|
||||||
current_id: &'b mut SpirvWord,
|
current_id: &'b mut SpirvWord,
|
||||||
global_variables: &'b HashMap<Cow<'input, str>, SpirvWord>,
|
global_variables: &'b HashMap<Cow<'input, str>, SpirvWord>,
|
||||||
global_type_check: &'b HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
|
global_type_check: &'b HashMap<SpirvWord, Option<(ast::Type, ast::StateSpace, bool)>>,
|
||||||
special_registers: &'b mut SpecialRegistersMap,
|
special_registers: &'b mut SpecialRegistersMap,
|
||||||
variables: Vec<HashMap<Cow<'input, str>, SpirvWord>>,
|
variables: Vec<HashMap<Cow<'input, str>, SpirvWord>>,
|
||||||
type_check: HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
|
type_check: HashMap<SpirvWord, Option<(ast::Type, ast::StateSpace, bool)>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'b> FnStringIdResolver<'a, 'b> {
|
impl<'a, 'b> FnStringIdResolver<'a, 'b> {
|
||||||
|
@ -160,7 +495,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.insert(Cow::Borrowed(id), numeric_id);
|
.insert(Cow::Borrowed(id), numeric_id);
|
||||||
self.type_check.insert(
|
self.type_check.insert(
|
||||||
numeric_id.0,
|
numeric_id,
|
||||||
typ.map(|(typ, space)| (typ, space, is_variable)),
|
typ.map(|(typ, space)| (typ, space, is_variable)),
|
||||||
);
|
);
|
||||||
self.current_id.0 += 1;
|
self.current_id.0 += 1;
|
||||||
|
@ -183,7 +518,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
|
||||||
SpirvWord(numeric_id.0 + i),
|
SpirvWord(numeric_id.0 + i),
|
||||||
);
|
);
|
||||||
self.type_check.insert(
|
self.type_check.insert(
|
||||||
numeric_id.0 + i,
|
SpirvWord(numeric_id.0 + i),
|
||||||
Some((typ.clone(), state_space, is_variable)),
|
Some((typ.clone(), state_space, is_variable)),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -196,8 +531,8 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
|
||||||
|
|
||||||
struct NumericIdResolver<'b> {
|
struct NumericIdResolver<'b> {
|
||||||
current_id: &'b mut SpirvWord,
|
current_id: &'b mut SpirvWord,
|
||||||
global_type_check: &'b HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
|
global_type_check: &'b HashMap<SpirvWord, Option<(ast::Type, ast::StateSpace, bool)>>,
|
||||||
type_check: HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
|
type_check: HashMap<SpirvWord, Option<(ast::Type, ast::StateSpace, bool)>>,
|
||||||
special_registers: &'b mut SpecialRegistersMap,
|
special_registers: &'b mut SpecialRegistersMap,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -210,12 +545,12 @@ impl<'b> NumericIdResolver<'b> {
|
||||||
&self,
|
&self,
|
||||||
id: SpirvWord,
|
id: SpirvWord,
|
||||||
) -> Result<(ast::Type, ast::StateSpace, bool), TranslateError> {
|
) -> Result<(ast::Type, ast::StateSpace, bool), TranslateError> {
|
||||||
match self.type_check.get(&id.0) {
|
match self.type_check.get(&id) {
|
||||||
Some(Some(x)) => Ok(x.clone()),
|
Some(Some(x)) => Ok(x.clone()),
|
||||||
Some(None) => Err(TranslateError::UntypedSymbol),
|
Some(None) => Err(TranslateError::UntypedSymbol),
|
||||||
None => match self.special_registers.get(id) {
|
None => match self.special_registers.get(id) {
|
||||||
Some(x) => Ok((x.get_type(), ast::StateSpace::Sreg, true)),
|
Some(x) => Ok((x.get_type(), ast::StateSpace::Sreg, true)),
|
||||||
None => match self.global_type_check.get(&id.0) {
|
None => match self.global_type_check.get(&id) {
|
||||||
Some(Some(result)) => Ok(result.clone()),
|
Some(Some(result)) => Ok(result.clone()),
|
||||||
Some(None) | None => Err(TranslateError::UntypedSymbol),
|
Some(None) | None => Err(TranslateError::UntypedSymbol),
|
||||||
},
|
},
|
||||||
|
@ -228,7 +563,7 @@ impl<'b> NumericIdResolver<'b> {
|
||||||
fn register_variable(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> SpirvWord {
|
fn register_variable(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> SpirvWord {
|
||||||
let new_id = *self.current_id;
|
let new_id = *self.current_id;
|
||||||
self.type_check
|
self.type_check
|
||||||
.insert(new_id.0, Some((typ, state_space, true)));
|
.insert(new_id, Some((typ, state_space, true)));
|
||||||
self.current_id.0 += 1;
|
self.current_id.0 += 1;
|
||||||
new_id
|
new_id
|
||||||
}
|
}
|
||||||
|
@ -236,7 +571,7 @@ impl<'b> NumericIdResolver<'b> {
|
||||||
fn register_intermediate(&mut self, typ: Option<(ast::Type, ast::StateSpace)>) -> SpirvWord {
|
fn register_intermediate(&mut self, typ: Option<(ast::Type, ast::StateSpace)>) -> SpirvWord {
|
||||||
let new_id = *self.current_id;
|
let new_id = *self.current_id;
|
||||||
self.type_check
|
self.type_check
|
||||||
.insert(new_id.0, typ.map(|(t, space)| (t, space, false)));
|
.insert(new_id, typ.map(|(t, space)| (t, space, false)));
|
||||||
self.current_id.0 += 1;
|
self.current_id.0 += 1;
|
||||||
new_id
|
new_id
|
||||||
}
|
}
|
||||||
|
@ -490,6 +825,10 @@ impl From<SpirvWord> for spirv::Word {
|
||||||
|
|
||||||
impl ast::Operand for SpirvWord {
|
impl ast::Operand for SpirvWord {
|
||||||
type Ident = Self;
|
type Ident = Self;
|
||||||
|
|
||||||
|
fn from_ident(ident: Self::Ident) -> Self {
|
||||||
|
ident
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn pred_map_variable<U, T, F: FnMut(T) -> Result<U, TranslateError>>(
|
fn pred_map_variable<U, T, F: FnMut(T) -> Result<U, TranslateError>>(
|
||||||
|
@ -503,29 +842,18 @@ fn pred_map_variable<U, T, F: FnMut(T) -> Result<U, TranslateError>>(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: ast::Operand, U: ast::Operand, X: FnMut(&str) -> Result<SpirvWord, Err>, Err> ast::VisitorMap<T, U, Err> for X {
|
pub(crate) enum Directive<'input> {
|
||||||
fn visit(
|
Variable(ast::LinkingDirective, ast::Variable<SpirvWord>),
|
||||||
&mut self,
|
Method(Function<'input>),
|
||||||
args: T,
|
|
||||||
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
|
||||||
is_dst: bool,
|
|
||||||
) -> U {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn visit_ident(
|
|
||||||
&mut self,
|
|
||||||
args: <T as ptx_parser::Operand>::Ident,
|
|
||||||
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
|
||||||
is_dst: bool,
|
|
||||||
) -> <U as ptx_parser::Operand>::Ident {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn op_map_variable<'a, F: FnMut(&str) -> Result<SpirvWord, TranslateError>>(
|
pub(crate) struct Function<'input> {
|
||||||
this: ast::Instruction<ast::ParsedOperand<&'a str>>,
|
pub func_decl: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
|
||||||
f: &mut F,
|
pub globals: Vec<ast::Variable<SpirvWord>>,
|
||||||
) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
|
pub body: Option<Vec<ExpandedStatement>>,
|
||||||
ast::visit_map(this , f)
|
import_as: Option<String>,
|
||||||
|
tuning: Vec<ast::TuningDirective>,
|
||||||
|
linkage: ast::LinkingDirective,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ExpandedStatement = Statement<ast::Instruction<SpirvWord>, SpirvWord>;
|
|
@ -9,7 +9,7 @@ type NormalizedStatement = Statement<
|
||||||
ast::ParsedOperand<SpirvWord>,
|
ast::ParsedOperand<SpirvWord>,
|
||||||
>;
|
>;
|
||||||
|
|
||||||
fn run<'input, 'b>(
|
pub(crate) fn run<'input, 'b>(
|
||||||
id_defs: &mut FnStringIdResolver<'input, 'b>,
|
id_defs: &mut FnStringIdResolver<'input, 'b>,
|
||||||
fn_defs: &GlobalFnDeclResolver<'input, 'b>,
|
fn_defs: &GlobalFnDeclResolver<'input, 'b>,
|
||||||
func: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
|
func: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
|
||||||
|
@ -47,7 +47,11 @@ fn expand_map_variables<'a, 'b>(
|
||||||
ast::Statement::Instruction(p, i) => result.push(Statement::Instruction((
|
ast::Statement::Instruction(p, i) => result.push(Statement::Instruction((
|
||||||
p.map(|p| pred_map_variable(p, &mut |id| id_defs.get_id(id)))
|
p.map(|p| pred_map_variable(p, &mut |id| id_defs.get_id(id)))
|
||||||
.transpose()?,
|
.transpose()?,
|
||||||
op_map_variable(i, &mut |id| id_defs.get_id(id))?,
|
ast::visit_map(i, &mut |id,
|
||||||
|
_: Option<(&ast::Type, ast::StateSpace)>,
|
||||||
|
_: bool| {
|
||||||
|
id_defs.get_id(id)
|
||||||
|
})?,
|
||||||
))),
|
))),
|
||||||
ast::Statement::Variable(var) => {
|
ast::Statement::Variable(var) => {
|
||||||
let var_type = var.var.v_type.clone();
|
let var_type = var.var.v_type.clone();
|
||||||
|
|
|
@ -1,60 +0,0 @@
|
||||||
use std::collections::HashMap;
|
|
||||||
use half::f16;
|
|
||||||
use ptx_parser as ast;
|
|
||||||
|
|
||||||
fn to_ssa<'input, 'b>(
|
|
||||||
ptx_impl_imports: &'b mut HashMap<String, Directive<'input>>,
|
|
||||||
mut id_defs: FnStringIdResolver<'input, 'b>,
|
|
||||||
fn_defs: GlobalFnDeclResolver<'input, 'b>,
|
|
||||||
func_decl: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
|
|
||||||
f_body: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>>,
|
|
||||||
tuning: Vec<ast::TuningDirective>,
|
|
||||||
linkage: ast::LinkingDirective,
|
|
||||||
) -> Result<Function<'input>, TranslateError> {
|
|
||||||
//deparamize_function_decl(&func_decl)?;
|
|
||||||
let f_body = match f_body {
|
|
||||||
Some(vec) => vec,
|
|
||||||
None => {
|
|
||||||
return Ok(Function {
|
|
||||||
func_decl: func_decl,
|
|
||||||
body: None,
|
|
||||||
globals: Vec::new(),
|
|
||||||
import_as: None,
|
|
||||||
tuning,
|
|
||||||
linkage,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body)?;
|
|
||||||
/*
|
|
||||||
let mut numeric_id_defs = id_defs.finish();
|
|
||||||
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
|
|
||||||
let typed_statements =
|
|
||||||
convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
|
|
||||||
let typed_statements =
|
|
||||||
fix_special_registers2(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?;
|
|
||||||
let (func_decl, typed_statements) =
|
|
||||||
convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?;
|
|
||||||
let ssa_statements = insert_mem_ssa_statements(
|
|
||||||
typed_statements,
|
|
||||||
&mut numeric_id_defs,
|
|
||||||
&mut (*func_decl).borrow_mut(),
|
|
||||||
)?;
|
|
||||||
let mut numeric_id_defs = numeric_id_defs.finish();
|
|
||||||
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
|
|
||||||
let expanded_statements =
|
|
||||||
insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?;
|
|
||||||
let mut numeric_id_defs = numeric_id_defs.unmut();
|
|
||||||
let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs);
|
|
||||||
let (f_body, globals) =
|
|
||||||
extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?;
|
|
||||||
Ok(Function {
|
|
||||||
func_decl: func_decl,
|
|
||||||
globals: globals,
|
|
||||||
body: Some(f_body),
|
|
||||||
import_as: None,
|
|
||||||
tuning,
|
|
||||||
linkage,
|
|
||||||
})
|
|
||||||
*/
|
|
||||||
}
|
|
|
@ -555,12 +555,46 @@ pub trait VisitorMap<From: Operand, To: Operand, Err> {
|
||||||
) -> Result<To::Ident, Err>;
|
) -> Result<To::Ident, Err>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<
|
impl<T: Copy, U: Copy, Err, Fn> VisitorMap<ParsedOperand<T>, ParsedOperand<U>, Err> for Fn
|
||||||
T: Operand,
|
where
|
||||||
U: Operand,
|
Fn: FnMut(T, Option<(&Type, StateSpace)>, bool) -> Result<U, Err>,
|
||||||
Err,
|
{
|
||||||
Fn: FnMut(T, Option<(&Type, StateSpace)>, bool) -> Result<U, Err>,
|
fn visit(
|
||||||
> VisitorMap<T, U, Err> for Fn
|
&mut self,
|
||||||
|
args: ParsedOperand<T>,
|
||||||
|
type_space: Option<(&Type, StateSpace)>,
|
||||||
|
is_dst: bool,
|
||||||
|
) -> Result<ParsedOperand<U>, Err> {
|
||||||
|
Ok(match args {
|
||||||
|
ParsedOperand::Reg(ident) => ParsedOperand::Reg((self)(ident, type_space, is_dst)?),
|
||||||
|
ParsedOperand::RegOffset(ident, imm) => {
|
||||||
|
ParsedOperand::RegOffset((self)(ident, type_space, is_dst)?, imm)
|
||||||
|
}
|
||||||
|
ParsedOperand::Imm(imm) => ParsedOperand::Imm(imm),
|
||||||
|
ParsedOperand::VecMember(ident, index) => {
|
||||||
|
ParsedOperand::VecMember((self)(ident, type_space, is_dst)?, index)
|
||||||
|
}
|
||||||
|
ParsedOperand::VecPack(vec) => ParsedOperand::VecPack(
|
||||||
|
vec.into_iter()
|
||||||
|
.map(|ident| (self)(ident, type_space, is_dst))
|
||||||
|
.collect::<Result<Vec<_>, _>>()?,
|
||||||
|
),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_ident(
|
||||||
|
&mut self,
|
||||||
|
args: T,
|
||||||
|
type_space: Option<(&Type, StateSpace)>,
|
||||||
|
is_dst: bool,
|
||||||
|
) -> Result<U, Err> {
|
||||||
|
(self)(args, type_space, is_dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Operand<Ident = T>, U: Operand<Ident = U>, Err, Fn> VisitorMap<T, U, Err> for Fn
|
||||||
|
where
|
||||||
|
Fn: FnMut(T, Option<(&Type, StateSpace)>, bool) -> Result<U, Err>,
|
||||||
{
|
{
|
||||||
fn visit(
|
fn visit(
|
||||||
&mut self,
|
&mut self,
|
||||||
|
@ -573,12 +607,11 @@ impl<
|
||||||
|
|
||||||
fn visit_ident(
|
fn visit_ident(
|
||||||
&mut self,
|
&mut self,
|
||||||
args: T::Ident,
|
args: T,
|
||||||
type_space: Option<(&Type, StateSpace)>,
|
type_space: Option<(&Type, StateSpace)>,
|
||||||
is_dst: bool,
|
is_dst: bool,
|
||||||
) -> Result<U::Ident, Err> {
|
) -> Result<U, Err> {
|
||||||
let value: U = (self)(T::from_ident(args), type_space, is_dst)?;
|
(self)(args, type_space, is_dst)
|
||||||
Ok(value)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -925,6 +958,15 @@ pub struct MethodDeclaration<'input, ID> {
|
||||||
pub shared_mem: Option<ID>,
|
pub shared_mem: Option<ID>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'input> MethodDeclaration<'input, &'input str> {
|
||||||
|
pub fn name(&self) -> &'input str {
|
||||||
|
match self.name {
|
||||||
|
MethodName::Kernel(n) => n,
|
||||||
|
MethodName::Func(n) => n,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Hash, PartialEq, Eq, Copy, Clone)]
|
#[derive(Hash, PartialEq, Eq, Copy, Clone)]
|
||||||
pub enum MethodName<'input, ID> {
|
pub enum MethodName<'input, ID> {
|
||||||
Kernel(&'input str),
|
Kernel(&'input str),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue