mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Port first pass
This commit is contained in:
parent
1ec1ca0c30
commit
12ef8dbc90
5 changed files with 421 additions and 108 deletions
|
@ -24,11 +24,10 @@ lalrpop_mod!(
|
|||
);
|
||||
|
||||
pub mod ast;
|
||||
mod pass;
|
||||
pub(crate) mod pass;
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
mod translate;
|
||||
mod translate2;
|
||||
|
||||
use std::fmt;
|
||||
|
||||
|
|
|
@ -1,12 +1,347 @@
|
|||
use ptx_parser as ast;
|
||||
use rspirv::{binary::Assemble, dr};
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
cell::RefCell,
|
||||
collections::{hash_map, HashMap},
|
||||
ffi::CString,
|
||||
rc::Rc,
|
||||
};
|
||||
|
||||
mod normalize;
|
||||
pub(crate) mod normalize;
|
||||
|
||||
static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv");
|
||||
static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
|
||||
const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl__";
|
||||
|
||||
pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result<Module, TranslateError> {
|
||||
let mut id_defs = GlobalStringIdResolver::<'input>::new(SpirvWord(1));
|
||||
let mut ptx_impl_imports = HashMap::new();
|
||||
let directives = ast
|
||||
.directives
|
||||
.into_iter()
|
||||
.filter_map(|directive| {
|
||||
translate_directive(&mut id_defs, &mut ptx_impl_imports, directive).transpose()
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
/*
|
||||
let directives = hoist_function_globals(directives);
|
||||
let must_link_ptx_impl = ptx_impl_imports.len() > 0;
|
||||
let mut directives = ptx_impl_imports
|
||||
.into_iter()
|
||||
.map(|(_, v)| v)
|
||||
.chain(directives.into_iter())
|
||||
.collect::<Vec<_>>();
|
||||
let mut builder = dr::Builder::new();
|
||||
builder.reserve_ids(id_defs.current_id());
|
||||
let call_map = MethodsCallMap::new(&directives);
|
||||
let mut directives =
|
||||
convert_dynamic_shared_memory_usage(directives, &call_map, &mut || builder.id());
|
||||
normalize_variable_decls(&mut directives);
|
||||
let denorm_information = compute_denorm_information(&directives);
|
||||
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
|
||||
builder.set_version(1, 3);
|
||||
emit_capabilities(&mut builder);
|
||||
emit_extensions(&mut builder);
|
||||
let opencl_id = emit_opencl_import(&mut builder);
|
||||
emit_memory_model(&mut builder);
|
||||
let mut map = TypeWordMap::new(&mut builder);
|
||||
//emit_builtins(&mut builder, &mut map, &id_defs);
|
||||
let mut kernel_info = HashMap::new();
|
||||
let (build_options, should_flush_denorms) =
|
||||
emit_denorm_build_string(&call_map, &denorm_information);
|
||||
let (directives, globals_use_map) = get_globals_use_map(directives);
|
||||
emit_directives(
|
||||
&mut builder,
|
||||
&mut map,
|
||||
&id_defs,
|
||||
opencl_id,
|
||||
should_flush_denorms,
|
||||
&call_map,
|
||||
globals_use_map,
|
||||
directives,
|
||||
&mut kernel_info,
|
||||
)?;
|
||||
let spirv = builder.module();
|
||||
Ok(Module {
|
||||
spirv,
|
||||
kernel_info,
|
||||
should_link_ptx_impl: if must_link_ptx_impl {
|
||||
Some((ZLUDA_PTX_IMPL_INTEL, ZLUDA_PTX_IMPL_AMD))
|
||||
} else {
|
||||
None
|
||||
},
|
||||
build_options,
|
||||
})
|
||||
*/
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn translate_directive<'input, 'a>(
|
||||
id_defs: &'a mut GlobalStringIdResolver<'input>,
|
||||
ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
|
||||
d: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
|
||||
) -> Result<Option<Directive<'input>>, TranslateError> {
|
||||
Ok(match d {
|
||||
ast::Directive::Variable(linking, var) => Some(Directive::Variable(
|
||||
linking,
|
||||
ast::Variable {
|
||||
align: var.align,
|
||||
v_type: var.v_type.clone(),
|
||||
state_space: var.state_space,
|
||||
name: id_defs.get_or_add_def_typed(var.name, var.v_type, var.state_space, true),
|
||||
array_init: var.array_init,
|
||||
},
|
||||
)),
|
||||
ast::Directive::Method(linkage, f) => {
|
||||
translate_function(id_defs, ptx_impl_imports, linkage, f)?.map(Directive::Method)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type ParsedFunction<'a> = ast::Function<'a, &'a str, ast::Statement<ast::ParsedOperand<&'a str>>>;
|
||||
|
||||
fn translate_function<'input, 'a>(
|
||||
id_defs: &'a mut GlobalStringIdResolver<'input>,
|
||||
ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
|
||||
linkage: ast::LinkingDirective,
|
||||
f: ParsedFunction<'input>,
|
||||
) -> Result<Option<Function<'input>>, TranslateError> {
|
||||
let import_as = match &f.func_directive {
|
||||
ast::MethodDeclaration {
|
||||
name: ast::MethodName::Func(func_name),
|
||||
..
|
||||
} if *func_name == "__assertfail" || *func_name == "vprintf" => {
|
||||
Some([ZLUDA_PTX_PREFIX, func_name].concat())
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive)?;
|
||||
let mut func = to_ssa(
|
||||
ptx_impl_imports,
|
||||
str_resolver,
|
||||
fn_resolver,
|
||||
fn_decl,
|
||||
f.body,
|
||||
f.tuning,
|
||||
linkage,
|
||||
)?;
|
||||
func.import_as = import_as;
|
||||
if func.import_as.is_some() {
|
||||
ptx_impl_imports.insert(
|
||||
func.import_as.as_ref().unwrap().clone(),
|
||||
Directive::Method(func),
|
||||
);
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(func))
|
||||
}
|
||||
}
|
||||
|
||||
fn to_ssa<'input, 'b>(
|
||||
ptx_impl_imports: &'b mut HashMap<String, Directive<'input>>,
|
||||
mut id_defs: FnStringIdResolver<'input, 'b>,
|
||||
fn_defs: GlobalFnDeclResolver<'input, 'b>,
|
||||
func_decl: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
|
||||
f_body: Option<Vec<ast::Statement<ast::ParsedOperand<&'input str>>>>,
|
||||
tuning: Vec<ast::TuningDirective>,
|
||||
linkage: ast::LinkingDirective,
|
||||
) -> Result<Function<'input>, TranslateError> {
|
||||
//deparamize_function_decl(&func_decl)?;
|
||||
let f_body = match f_body {
|
||||
Some(vec) => vec,
|
||||
None => {
|
||||
return Ok(Function {
|
||||
func_decl: func_decl,
|
||||
body: None,
|
||||
globals: Vec::new(),
|
||||
import_as: None,
|
||||
tuning,
|
||||
linkage,
|
||||
})
|
||||
}
|
||||
};
|
||||
let normalized_ids = normalize::run(&mut id_defs, &fn_defs, f_body)?;
|
||||
todo!()
|
||||
/*
|
||||
let mut numeric_id_defs = id_defs.finish();
|
||||
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
|
||||
let typed_statements =
|
||||
convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
|
||||
let typed_statements =
|
||||
fix_special_registers2(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?;
|
||||
let (func_decl, typed_statements) =
|
||||
convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?;
|
||||
let ssa_statements = insert_mem_ssa_statements(
|
||||
typed_statements,
|
||||
&mut numeric_id_defs,
|
||||
&mut (*func_decl).borrow_mut(),
|
||||
)?;
|
||||
let mut numeric_id_defs = numeric_id_defs.finish();
|
||||
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
|
||||
let expanded_statements =
|
||||
insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?;
|
||||
let mut numeric_id_defs = numeric_id_defs.unmut();
|
||||
let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs);
|
||||
let (f_body, globals) =
|
||||
extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?;
|
||||
Ok(Function {
|
||||
func_decl: func_decl,
|
||||
globals: globals,
|
||||
body: Some(f_body),
|
||||
import_as: None,
|
||||
tuning,
|
||||
linkage,
|
||||
})
|
||||
*/
|
||||
}
|
||||
|
||||
pub struct Module {
|
||||
pub spirv: dr::Module,
|
||||
pub kernel_info: HashMap<String, KernelInfo>,
|
||||
pub should_link_ptx_impl: Option<(&'static [u8], &'static [u8])>,
|
||||
pub build_options: CString,
|
||||
}
|
||||
|
||||
impl Module {
|
||||
pub fn assemble(&self) -> Vec<u32> {
|
||||
self.spirv.assemble()
|
||||
}
|
||||
}
|
||||
|
||||
struct GlobalStringIdResolver<'input> {
|
||||
current_id: SpirvWord,
|
||||
variables: HashMap<Cow<'input, str>, SpirvWord>,
|
||||
reverse_variables: HashMap<SpirvWord, &'input str>,
|
||||
variables_type_check: HashMap<SpirvWord, Option<(ast::Type, ast::StateSpace, bool)>>,
|
||||
special_registers: SpecialRegistersMap,
|
||||
fns: HashMap<SpirvWord, FnSigMapper<'input>>,
|
||||
}
|
||||
|
||||
impl<'input> GlobalStringIdResolver<'input> {
|
||||
fn new(start_id: SpirvWord) -> Self {
|
||||
Self {
|
||||
current_id: start_id,
|
||||
variables: HashMap::new(),
|
||||
reverse_variables: HashMap::new(),
|
||||
variables_type_check: HashMap::new(),
|
||||
special_registers: SpecialRegistersMap::new(),
|
||||
fns: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_or_add_def(&mut self, id: &'input str) -> SpirvWord {
|
||||
self.get_or_add_impl(id, None)
|
||||
}
|
||||
|
||||
fn get_or_add_def_typed(
|
||||
&mut self,
|
||||
id: &'input str,
|
||||
typ: ast::Type,
|
||||
state_space: ast::StateSpace,
|
||||
is_variable: bool,
|
||||
) -> SpirvWord {
|
||||
self.get_or_add_impl(id, Some((typ, state_space, is_variable)))
|
||||
}
|
||||
|
||||
fn get_or_add_impl(
|
||||
&mut self,
|
||||
id: &'input str,
|
||||
typ: Option<(ast::Type, ast::StateSpace, bool)>,
|
||||
) -> SpirvWord {
|
||||
let id = match self.variables.entry(Cow::Borrowed(id)) {
|
||||
hash_map::Entry::Occupied(e) => *(e.get()),
|
||||
hash_map::Entry::Vacant(e) => {
|
||||
let numeric_id = self.current_id;
|
||||
e.insert(numeric_id);
|
||||
self.reverse_variables.insert(numeric_id, id);
|
||||
self.current_id.0 += 1;
|
||||
numeric_id
|
||||
}
|
||||
};
|
||||
self.variables_type_check.insert(id, typ);
|
||||
id
|
||||
}
|
||||
|
||||
fn get_id(&self, id: &str) -> Result<SpirvWord, TranslateError> {
|
||||
self.variables
|
||||
.get(id)
|
||||
.copied()
|
||||
.ok_or_else(error_unknown_symbol)
|
||||
}
|
||||
|
||||
fn current_id(&self) -> SpirvWord {
|
||||
self.current_id
|
||||
}
|
||||
|
||||
fn start_fn<'b>(
|
||||
&'b mut self,
|
||||
header: &'b ast::MethodDeclaration<'input, &'input str>,
|
||||
) -> Result<
|
||||
(
|
||||
FnStringIdResolver<'input, 'b>,
|
||||
GlobalFnDeclResolver<'input, 'b>,
|
||||
Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
|
||||
),
|
||||
TranslateError,
|
||||
> {
|
||||
// In case a function decl was inserted earlier we want to use its id
|
||||
let name_id = self.get_or_add_def(header.name());
|
||||
let mut fn_resolver = FnStringIdResolver {
|
||||
current_id: &mut self.current_id,
|
||||
global_variables: &self.variables,
|
||||
global_type_check: &self.variables_type_check,
|
||||
special_registers: &mut self.special_registers,
|
||||
variables: vec![HashMap::new(); 1],
|
||||
type_check: HashMap::new(),
|
||||
};
|
||||
let return_arguments = rename_fn_params(&mut fn_resolver, &header.return_arguments);
|
||||
let input_arguments = rename_fn_params(&mut fn_resolver, &header.input_arguments);
|
||||
let name = match header.name {
|
||||
ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
|
||||
ast::MethodName::Func(_) => ast::MethodName::Func(name_id),
|
||||
};
|
||||
let fn_decl = ast::MethodDeclaration {
|
||||
return_arguments,
|
||||
name,
|
||||
input_arguments,
|
||||
shared_mem: None,
|
||||
};
|
||||
let new_fn_decl = if !matches!(fn_decl.name, ast::MethodName::Kernel(_)) {
|
||||
let resolver = FnSigMapper::remap_to_spirv_repr(fn_decl);
|
||||
let new_fn_decl = resolver.func_decl.clone();
|
||||
self.fns.insert(name_id, resolver);
|
||||
new_fn_decl
|
||||
} else {
|
||||
Rc::new(RefCell::new(fn_decl))
|
||||
};
|
||||
Ok((
|
||||
fn_resolver,
|
||||
GlobalFnDeclResolver { fns: &self.fns },
|
||||
new_fn_decl,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn rename_fn_params<'a, 'b>(
|
||||
fn_resolver: &mut FnStringIdResolver<'a, 'b>,
|
||||
args: &'b [ast::Variable<&'a str>],
|
||||
) -> Vec<ast::Variable<SpirvWord>> {
|
||||
args.iter()
|
||||
.map(|a| ast::Variable {
|
||||
name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), true),
|
||||
v_type: a.v_type.clone(),
|
||||
state_space: a.state_space,
|
||||
align: a.align,
|
||||
array_init: a.array_init.clone(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub struct KernelInfo {
|
||||
pub arguments_sizes: Vec<(usize, bool)>,
|
||||
pub uses_shared_mem: bool,
|
||||
}
|
||||
|
||||
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
|
||||
enum PtxSpecialRegister {
|
||||
|
@ -108,10 +443,10 @@ impl SpecialRegistersMap {
|
|||
struct FnStringIdResolver<'input, 'b> {
|
||||
current_id: &'b mut SpirvWord,
|
||||
global_variables: &'b HashMap<Cow<'input, str>, SpirvWord>,
|
||||
global_type_check: &'b HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
|
||||
global_type_check: &'b HashMap<SpirvWord, Option<(ast::Type, ast::StateSpace, bool)>>,
|
||||
special_registers: &'b mut SpecialRegistersMap,
|
||||
variables: Vec<HashMap<Cow<'input, str>, SpirvWord>>,
|
||||
type_check: HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
|
||||
type_check: HashMap<SpirvWord, Option<(ast::Type, ast::StateSpace, bool)>>,
|
||||
}
|
||||
|
||||
impl<'a, 'b> FnStringIdResolver<'a, 'b> {
|
||||
|
@ -160,7 +495,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
|
|||
.unwrap()
|
||||
.insert(Cow::Borrowed(id), numeric_id);
|
||||
self.type_check.insert(
|
||||
numeric_id.0,
|
||||
numeric_id,
|
||||
typ.map(|(typ, space)| (typ, space, is_variable)),
|
||||
);
|
||||
self.current_id.0 += 1;
|
||||
|
@ -183,7 +518,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
|
|||
SpirvWord(numeric_id.0 + i),
|
||||
);
|
||||
self.type_check.insert(
|
||||
numeric_id.0 + i,
|
||||
SpirvWord(numeric_id.0 + i),
|
||||
Some((typ.clone(), state_space, is_variable)),
|
||||
);
|
||||
}
|
||||
|
@ -196,8 +531,8 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
|
|||
|
||||
struct NumericIdResolver<'b> {
|
||||
current_id: &'b mut SpirvWord,
|
||||
global_type_check: &'b HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
|
||||
type_check: HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
|
||||
global_type_check: &'b HashMap<SpirvWord, Option<(ast::Type, ast::StateSpace, bool)>>,
|
||||
type_check: HashMap<SpirvWord, Option<(ast::Type, ast::StateSpace, bool)>>,
|
||||
special_registers: &'b mut SpecialRegistersMap,
|
||||
}
|
||||
|
||||
|
@ -210,12 +545,12 @@ impl<'b> NumericIdResolver<'b> {
|
|||
&self,
|
||||
id: SpirvWord,
|
||||
) -> Result<(ast::Type, ast::StateSpace, bool), TranslateError> {
|
||||
match self.type_check.get(&id.0) {
|
||||
match self.type_check.get(&id) {
|
||||
Some(Some(x)) => Ok(x.clone()),
|
||||
Some(None) => Err(TranslateError::UntypedSymbol),
|
||||
None => match self.special_registers.get(id) {
|
||||
Some(x) => Ok((x.get_type(), ast::StateSpace::Sreg, true)),
|
||||
None => match self.global_type_check.get(&id.0) {
|
||||
None => match self.global_type_check.get(&id) {
|
||||
Some(Some(result)) => Ok(result.clone()),
|
||||
Some(None) | None => Err(TranslateError::UntypedSymbol),
|
||||
},
|
||||
|
@ -228,7 +563,7 @@ impl<'b> NumericIdResolver<'b> {
|
|||
fn register_variable(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> SpirvWord {
|
||||
let new_id = *self.current_id;
|
||||
self.type_check
|
||||
.insert(new_id.0, Some((typ, state_space, true)));
|
||||
.insert(new_id, Some((typ, state_space, true)));
|
||||
self.current_id.0 += 1;
|
||||
new_id
|
||||
}
|
||||
|
@ -236,7 +571,7 @@ impl<'b> NumericIdResolver<'b> {
|
|||
fn register_intermediate(&mut self, typ: Option<(ast::Type, ast::StateSpace)>) -> SpirvWord {
|
||||
let new_id = *self.current_id;
|
||||
self.type_check
|
||||
.insert(new_id.0, typ.map(|(t, space)| (t, space, false)));
|
||||
.insert(new_id, typ.map(|(t, space)| (t, space, false)));
|
||||
self.current_id.0 += 1;
|
||||
new_id
|
||||
}
|
||||
|
@ -490,6 +825,10 @@ impl From<SpirvWord> for spirv::Word {
|
|||
|
||||
impl ast::Operand for SpirvWord {
|
||||
type Ident = Self;
|
||||
|
||||
fn from_ident(ident: Self::Ident) -> Self {
|
||||
ident
|
||||
}
|
||||
}
|
||||
|
||||
fn pred_map_variable<U, T, F: FnMut(T) -> Result<U, TranslateError>>(
|
||||
|
@ -503,29 +842,18 @@ fn pred_map_variable<U, T, F: FnMut(T) -> Result<U, TranslateError>>(
|
|||
})
|
||||
}
|
||||
|
||||
impl<T: ast::Operand, U: ast::Operand, X: FnMut(&str) -> Result<SpirvWord, Err>, Err> ast::VisitorMap<T, U, Err> for X {
|
||||
fn visit(
|
||||
&mut self,
|
||||
args: T,
|
||||
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
is_dst: bool,
|
||||
) -> U {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn visit_ident(
|
||||
&mut self,
|
||||
args: <T as ptx_parser::Operand>::Ident,
|
||||
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
is_dst: bool,
|
||||
) -> <U as ptx_parser::Operand>::Ident {
|
||||
todo!()
|
||||
}
|
||||
pub(crate) enum Directive<'input> {
|
||||
Variable(ast::LinkingDirective, ast::Variable<SpirvWord>),
|
||||
Method(Function<'input>),
|
||||
}
|
||||
|
||||
fn op_map_variable<'a, F: FnMut(&str) -> Result<SpirvWord, TranslateError>>(
|
||||
this: ast::Instruction<ast::ParsedOperand<&'a str>>,
|
||||
f: &mut F,
|
||||
) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
|
||||
ast::visit_map(this , f)
|
||||
pub(crate) struct Function<'input> {
|
||||
pub func_decl: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
|
||||
pub globals: Vec<ast::Variable<SpirvWord>>,
|
||||
pub body: Option<Vec<ExpandedStatement>>,
|
||||
import_as: Option<String>,
|
||||
tuning: Vec<ast::TuningDirective>,
|
||||
linkage: ast::LinkingDirective,
|
||||
}
|
||||
|
||||
type ExpandedStatement = Statement<ast::Instruction<SpirvWord>, SpirvWord>;
|
|
@ -9,7 +9,7 @@ type NormalizedStatement = Statement<
|
|||
ast::ParsedOperand<SpirvWord>,
|
||||
>;
|
||||
|
||||
fn run<'input, 'b>(
|
||||
pub(crate) fn run<'input, 'b>(
|
||||
id_defs: &mut FnStringIdResolver<'input, 'b>,
|
||||
fn_defs: &GlobalFnDeclResolver<'input, 'b>,
|
||||
func: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
|
||||
|
@ -47,7 +47,11 @@ fn expand_map_variables<'a, 'b>(
|
|||
ast::Statement::Instruction(p, i) => result.push(Statement::Instruction((
|
||||
p.map(|p| pred_map_variable(p, &mut |id| id_defs.get_id(id)))
|
||||
.transpose()?,
|
||||
op_map_variable(i, &mut |id| id_defs.get_id(id))?,
|
||||
ast::visit_map(i, &mut |id,
|
||||
_: Option<(&ast::Type, ast::StateSpace)>,
|
||||
_: bool| {
|
||||
id_defs.get_id(id)
|
||||
})?,
|
||||
))),
|
||||
ast::Statement::Variable(var) => {
|
||||
let var_type = var.var.v_type.clone();
|
||||
|
|
|
@ -1,60 +0,0 @@
|
|||
use std::collections::HashMap;
|
||||
use half::f16;
|
||||
use ptx_parser as ast;
|
||||
|
||||
fn to_ssa<'input, 'b>(
|
||||
ptx_impl_imports: &'b mut HashMap<String, Directive<'input>>,
|
||||
mut id_defs: FnStringIdResolver<'input, 'b>,
|
||||
fn_defs: GlobalFnDeclResolver<'input, 'b>,
|
||||
func_decl: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
|
||||
f_body: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>>,
|
||||
tuning: Vec<ast::TuningDirective>,
|
||||
linkage: ast::LinkingDirective,
|
||||
) -> Result<Function<'input>, TranslateError> {
|
||||
//deparamize_function_decl(&func_decl)?;
|
||||
let f_body = match f_body {
|
||||
Some(vec) => vec,
|
||||
None => {
|
||||
return Ok(Function {
|
||||
func_decl: func_decl,
|
||||
body: None,
|
||||
globals: Vec::new(),
|
||||
import_as: None,
|
||||
tuning,
|
||||
linkage,
|
||||
})
|
||||
}
|
||||
};
|
||||
let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body)?;
|
||||
/*
|
||||
let mut numeric_id_defs = id_defs.finish();
|
||||
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
|
||||
let typed_statements =
|
||||
convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
|
||||
let typed_statements =
|
||||
fix_special_registers2(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?;
|
||||
let (func_decl, typed_statements) =
|
||||
convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?;
|
||||
let ssa_statements = insert_mem_ssa_statements(
|
||||
typed_statements,
|
||||
&mut numeric_id_defs,
|
||||
&mut (*func_decl).borrow_mut(),
|
||||
)?;
|
||||
let mut numeric_id_defs = numeric_id_defs.finish();
|
||||
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
|
||||
let expanded_statements =
|
||||
insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?;
|
||||
let mut numeric_id_defs = numeric_id_defs.unmut();
|
||||
let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs);
|
||||
let (f_body, globals) =
|
||||
extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?;
|
||||
Ok(Function {
|
||||
func_decl: func_decl,
|
||||
globals: globals,
|
||||
body: Some(f_body),
|
||||
import_as: None,
|
||||
tuning,
|
||||
linkage,
|
||||
})
|
||||
*/
|
||||
}
|
|
@ -555,12 +555,46 @@ pub trait VisitorMap<From: Operand, To: Operand, Err> {
|
|||
) -> Result<To::Ident, Err>;
|
||||
}
|
||||
|
||||
impl<
|
||||
T: Operand,
|
||||
U: Operand,
|
||||
Err,
|
||||
Fn: FnMut(T, Option<(&Type, StateSpace)>, bool) -> Result<U, Err>,
|
||||
> VisitorMap<T, U, Err> for Fn
|
||||
impl<T: Copy, U: Copy, Err, Fn> VisitorMap<ParsedOperand<T>, ParsedOperand<U>, Err> for Fn
|
||||
where
|
||||
Fn: FnMut(T, Option<(&Type, StateSpace)>, bool) -> Result<U, Err>,
|
||||
{
|
||||
fn visit(
|
||||
&mut self,
|
||||
args: ParsedOperand<T>,
|
||||
type_space: Option<(&Type, StateSpace)>,
|
||||
is_dst: bool,
|
||||
) -> Result<ParsedOperand<U>, Err> {
|
||||
Ok(match args {
|
||||
ParsedOperand::Reg(ident) => ParsedOperand::Reg((self)(ident, type_space, is_dst)?),
|
||||
ParsedOperand::RegOffset(ident, imm) => {
|
||||
ParsedOperand::RegOffset((self)(ident, type_space, is_dst)?, imm)
|
||||
}
|
||||
ParsedOperand::Imm(imm) => ParsedOperand::Imm(imm),
|
||||
ParsedOperand::VecMember(ident, index) => {
|
||||
ParsedOperand::VecMember((self)(ident, type_space, is_dst)?, index)
|
||||
}
|
||||
ParsedOperand::VecPack(vec) => ParsedOperand::VecPack(
|
||||
vec.into_iter()
|
||||
.map(|ident| (self)(ident, type_space, is_dst))
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
fn visit_ident(
|
||||
&mut self,
|
||||
args: T,
|
||||
type_space: Option<(&Type, StateSpace)>,
|
||||
is_dst: bool,
|
||||
) -> Result<U, Err> {
|
||||
(self)(args, type_space, is_dst)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Operand<Ident = T>, U: Operand<Ident = U>, Err, Fn> VisitorMap<T, U, Err> for Fn
|
||||
where
|
||||
Fn: FnMut(T, Option<(&Type, StateSpace)>, bool) -> Result<U, Err>,
|
||||
{
|
||||
fn visit(
|
||||
&mut self,
|
||||
|
@ -573,12 +607,11 @@ impl<
|
|||
|
||||
fn visit_ident(
|
||||
&mut self,
|
||||
args: T::Ident,
|
||||
args: T,
|
||||
type_space: Option<(&Type, StateSpace)>,
|
||||
is_dst: bool,
|
||||
) -> Result<U::Ident, Err> {
|
||||
let value: U = (self)(T::from_ident(args), type_space, is_dst)?;
|
||||
Ok(value)
|
||||
) -> Result<U, Err> {
|
||||
(self)(args, type_space, is_dst)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -925,6 +958,15 @@ pub struct MethodDeclaration<'input, ID> {
|
|||
pub shared_mem: Option<ID>,
|
||||
}
|
||||
|
||||
impl<'input> MethodDeclaration<'input, &'input str> {
|
||||
pub fn name(&self) -> &'input str {
|
||||
match self.name {
|
||||
MethodName::Kernel(n) => n,
|
||||
MethodName::Func(n) => n,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Hash, PartialEq, Eq, Copy, Clone)]
|
||||
pub enum MethodName<'input, ID> {
|
||||
Kernel(&'input str),
|
||||
|
|
Loading…
Add table
Reference in a new issue