mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 08:24:44 +00:00
Finish up cleanup for PTX function support
This commit is contained in:
parent
de734305cf
commit
bbb3a6c5cb
3 changed files with 83 additions and 104 deletions
|
@ -8,7 +8,7 @@ use super::{context, device, module, Decuda, Encuda};
|
|||
use std::mem;
|
||||
use std::os::raw::{c_uint, c_ulong, c_ushort};
|
||||
use std::{
|
||||
ffi::{c_void, CStr, CString},
|
||||
ffi::{c_void, CStr},
|
||||
ptr, slice,
|
||||
};
|
||||
|
||||
|
|
|
@ -190,14 +190,17 @@ fn test_spvtxt_assert<'a>(
|
|||
ptr::null_mut()
|
||||
)
|
||||
};
|
||||
assert_eq!(result, spv_result_t::SPV_SUCCESS);
|
||||
let raw_text = unsafe {
|
||||
std::slice::from_raw_parts((*spv_text).str_ as *const u8, (*spv_text).length)
|
||||
};
|
||||
let spv_from_ptx_text = unsafe { str::from_utf8_unchecked(raw_text) };
|
||||
// TODO: stop leaking kernel text
|
||||
unsafe { spirv_tools::spvContextDestroy(spv_context) };
|
||||
panic!(spv_from_ptx_text);
|
||||
if result == spv_result_t::SPV_SUCCESS {
|
||||
let raw_text = unsafe {
|
||||
std::slice::from_raw_parts((*spv_text).str_ as *const u8, (*spv_text).length)
|
||||
};
|
||||
let spv_from_ptx_text = unsafe { str::from_utf8_unchecked(raw_text) };
|
||||
// TODO: stop leaking kernel text
|
||||
panic!(spv_from_ptx_text);
|
||||
} else {
|
||||
panic!(ptx_mod.disassemble());
|
||||
}
|
||||
}
|
||||
unsafe { spirv_tools::spvContextDestroy(spv_context) };
|
||||
Ok(())
|
||||
|
|
|
@ -155,7 +155,14 @@ impl TypeWordMap {
|
|||
}
|
||||
|
||||
pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<dr::Module, dr::Error> {
|
||||
let mut id_defs = GlobalStringIdResolver::new(1);
|
||||
let ssa_functions = ast
|
||||
.functions
|
||||
.into_iter()
|
||||
.map(|f| to_ssa_function(&mut id_defs, f))
|
||||
.collect::<Vec<_>>();
|
||||
let mut builder = dr::Builder::new();
|
||||
builder.reserve_ids(id_defs.current_id());
|
||||
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
|
||||
builder.set_version(1, 3);
|
||||
emit_capabilities(&mut builder);
|
||||
|
@ -163,13 +170,8 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<dr::Module, dr::Error
|
|||
let opencl_id = emit_opencl_import(&mut builder);
|
||||
emit_memory_model(&mut builder);
|
||||
let mut map = TypeWordMap::new(&mut builder);
|
||||
let mut id_defs = GlobalStringIdResolver::new(builder.id());
|
||||
let ssa_functions = ast
|
||||
.functions
|
||||
.into_iter()
|
||||
.map(|f| to_ssa_function(&mut id_defs, opencl_id, f))
|
||||
.collect::<Vec<_>>();
|
||||
for f in ssa_functions {
|
||||
emit_function_header(&mut builder, &mut map, &id_defs, f.func_directive, &*f.args)?;
|
||||
emit_function_args(&mut builder, &mut map, &*f.args);
|
||||
emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.body)?;
|
||||
builder.end_function()?;
|
||||
|
@ -177,6 +179,31 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<dr::Module, dr::Error
|
|||
Ok(builder.module())
|
||||
}
|
||||
|
||||
fn emit_function_header(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
global: &GlobalStringIdResolver,
|
||||
func_directive: ast::FunctionHeader<ExpandedArgParams>,
|
||||
params: &[ast::Argument<ExpandedArgParams>],
|
||||
) -> Result<(), dr::Error> {
|
||||
let func_type = get_function_type(builder, map, params);
|
||||
let (fn_id, ret_type) = match func_directive {
|
||||
ast::FunctionHeader::Kernel(name) => {
|
||||
let fn_id = global.get_id(name);
|
||||
builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, &[]);
|
||||
(fn_id, map.void())
|
||||
}
|
||||
ast::FunctionHeader::Func(params, name) => todo!(),
|
||||
};
|
||||
builder.begin_function(
|
||||
ret_type,
|
||||
Some(fn_id),
|
||||
spirv::FunctionControl::NONE,
|
||||
func_type,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn to_spirv<'a>(ast: ast::Module<'a>) -> Result<Vec<u32>, dr::Error> {
|
||||
let module = to_spirv_module(ast)?;
|
||||
Ok(module.assemble())
|
||||
|
@ -206,21 +233,19 @@ fn emit_memory_model(builder: &mut dr::Builder) {
|
|||
|
||||
fn to_ssa_function<'a>(
|
||||
id_defs: &mut GlobalStringIdResolver<'a>,
|
||||
opencl_id: spirv::Word,
|
||||
f: ast::ParsedFunction<'a>,
|
||||
) -> ExpandedFunction<'a> {
|
||||
let ids_start = id_defs.current_id();
|
||||
let fn_resolver = FnStringIdResolver::new(id_defs);
|
||||
let mut fn_resolver = FnStringIdResolver::new(id_defs, f.func_directive.name());
|
||||
let f_header = match f.func_directive {
|
||||
ast::FunctionHeader::Kernel(name) => todo!(),
|
||||
ast::FunctionHeader::Func(ret_params, name) => todo!(),
|
||||
ast::FunctionHeader::Kernel(name) => ast::FunctionHeader::Kernel(name),
|
||||
ast::FunctionHeader::Func(ret_params, name) => {
|
||||
let name_id = fn_resolver.add_global_def(name);
|
||||
let ret_ids = expand_fn_params(&mut fn_resolver, ret_params);
|
||||
ast::FunctionHeader::Func(ret_ids, name_id)
|
||||
}
|
||||
};
|
||||
let f_args = todo!();
|
||||
let f_body = Some(to_ssa(
|
||||
fn_resolver,
|
||||
&f.args,
|
||||
f.body.unwrap_or_else(|| todo!()),
|
||||
));
|
||||
let f_args = expand_fn_params(&mut fn_resolver, f.args);
|
||||
let f_body = Some(to_ssa(fn_resolver, f.body.unwrap_or_else(|| Vec::new())));
|
||||
ExpandedFunction {
|
||||
func_directive: f_header,
|
||||
args: f_args,
|
||||
|
@ -228,19 +253,24 @@ fn to_ssa_function<'a>(
|
|||
}
|
||||
}
|
||||
|
||||
fn apply_id_offset(func_body: Vec<ExpandedStatement>, id_offset: u32) -> Vec<ExpandedStatement> {
|
||||
func_body
|
||||
.into_iter()
|
||||
.map(|s| s.visit_variable(&mut |id| id + id_offset))
|
||||
fn expand_fn_params<'a, 'b>(
|
||||
fn_resolver: &mut FnStringIdResolver<'a, 'b>,
|
||||
args: Vec<ast::Argument<ast::ParsedArgParams<'a>>>,
|
||||
) -> Vec<ast::Argument<ExpandedArgParams>> {
|
||||
args.into_iter()
|
||||
.map(|a| ast::Argument {
|
||||
name: fn_resolver.add_def(a.name, Some(ast::Type::Scalar(a.a_type))),
|
||||
a_type: a.a_type,
|
||||
length: a.length,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn to_ssa<'a, 'b>(
|
||||
mut id_defs: FnStringIdResolver<'a, 'b>,
|
||||
f_args: &'b [ast::Argument<ast::ParsedArgParams<'a>>],
|
||||
f_body: Vec<ast::Statement<ast::ParsedArgParams<'a>>>,
|
||||
) -> Vec<ExpandedStatement> {
|
||||
let normalized_ids = normalize_identifiers(&mut id_defs, &f_args, f_body);
|
||||
let normalized_ids = normalize_identifiers(&mut id_defs, f_body);
|
||||
let mut numeric_id_defs = id_defs.finish();
|
||||
let normalized_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs);
|
||||
let ssa_statements = insert_mem_ssa_statements(normalized_statements, &mut numeric_id_defs);
|
||||
|
@ -593,7 +623,7 @@ fn insert_implicit_conversions(
|
|||
fn get_function_type(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
args: &[ast::Argument<ast::ParsedArgParams>],
|
||||
args: &[ast::Argument<ExpandedArgParams>],
|
||||
) -> spirv::Word {
|
||||
map.get_or_add_fn(builder, args.iter().map(|arg| SpirvType::from(arg.a_type)))
|
||||
}
|
||||
|
@ -603,17 +633,15 @@ fn emit_function_args(
|
|||
map: &mut TypeWordMap,
|
||||
args: &[ast::Argument<ExpandedArgParams>],
|
||||
) {
|
||||
let mut id = todo!();
|
||||
for arg in args {
|
||||
let result_type = map.get_or_add_scalar(builder, arg.a_type);
|
||||
let inst = dr::Instruction::new(
|
||||
spirv::Op::FunctionParameter,
|
||||
Some(result_type),
|
||||
Some(id),
|
||||
Some(arg.name),
|
||||
Vec::new(),
|
||||
);
|
||||
builder.function.as_mut().unwrap().parameters.push(inst);
|
||||
id += 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1095,12 +1123,8 @@ fn emit_implicit_conversion(
|
|||
// TODO: support scopes
|
||||
fn normalize_identifiers<'a, 'b>(
|
||||
id_defs: &mut FnStringIdResolver<'a, 'b>,
|
||||
args: &[ast::Argument<ast::ParsedArgParams<'a>>],
|
||||
func: Vec<ast::Statement<ast::ParsedArgParams<'a>>>,
|
||||
) -> Vec<ast::Statement<NormalizedArgParams>> {
|
||||
for arg in args {
|
||||
id_defs.add_def(arg.name, Some(ast::Type::Scalar(arg.a_type)));
|
||||
}
|
||||
for s in func.iter() {
|
||||
match s {
|
||||
ast::Statement::Label(id) => {
|
||||
|
@ -1180,8 +1204,8 @@ impl<'a> GlobalStringIdResolver<'a> {
|
|||
numeric_id
|
||||
}
|
||||
|
||||
fn reserve_id(&mut self) {
|
||||
self.current_id += 1;
|
||||
fn get_id(&self, id: &str) -> spirv::Word {
|
||||
self.variables[id]
|
||||
}
|
||||
|
||||
fn current_id(&self) -> spirv::Word {
|
||||
|
@ -1196,7 +1220,8 @@ struct FnStringIdResolver<'a, 'b> {
|
|||
}
|
||||
|
||||
impl<'a, 'b> FnStringIdResolver<'a, 'b> {
|
||||
fn new(global: &'b mut GlobalStringIdResolver<'a>) -> Self {
|
||||
fn new(global: &'b mut GlobalStringIdResolver<'a>, f_name: &'a str) -> Self {
|
||||
global.add_def(f_name);
|
||||
Self {
|
||||
global: global,
|
||||
variables: vec![HashMap::new(); 1],
|
||||
|
@ -1229,6 +1254,10 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
|
|||
self.global.variables[id]
|
||||
}
|
||||
|
||||
fn add_global_def(&mut self, id: &'a str) -> spirv::Word {
|
||||
self.global.add_def(id)
|
||||
}
|
||||
|
||||
fn add_def(&mut self, id: &'a str, typ: Option<ast::Type>) -> spirv::Word {
|
||||
let numeric_id = self.global.current_id;
|
||||
self.variables
|
||||
|
@ -1294,25 +1323,6 @@ enum Statement<I> {
|
|||
Constant(ConstantDefinition),
|
||||
}
|
||||
|
||||
impl Statement<ast::Instruction<ExpandedArgParams>> {
|
||||
fn visit_variable<F: FnMut(spirv::Word) -> spirv::Word>(self, f: &mut F) -> Self {
|
||||
match self {
|
||||
Statement::Variable(id, t, ss, align) => Statement::Variable(f(id), t, ss, align),
|
||||
Statement::LoadVar(a, t) => {
|
||||
Statement::LoadVar(a.map(&mut reduced_visitor(f), Some(t)), t)
|
||||
}
|
||||
Statement::StoreVar(a, t) => {
|
||||
Statement::StoreVar(a.map(&mut reduced_visitor(f), Some(t)), t)
|
||||
}
|
||||
Statement::Label(id) => Statement::Label(f(id)),
|
||||
Statement::Instruction(inst) => Statement::Instruction(inst.visit_variable(f)),
|
||||
Statement::Conditional(bra) => Statement::Conditional(bra.map(f)),
|
||||
Statement::Conversion(conv) => Statement::Conversion(conv.map(f)),
|
||||
Statement::Constant(cons) => Statement::Constant(cons.map(f)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum NormalizedArgParams {}
|
||||
type NormalizedStatement = Statement<ast::Instruction<NormalizedArgParams>>;
|
||||
|
||||
|
@ -1513,18 +1523,7 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
fn reduced_visitor<'a>(
|
||||
f: &'a mut impl FnMut(spirv::Word) -> spirv::Word,
|
||||
) -> impl FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word + 'a {
|
||||
move |desc| f(desc.op)
|
||||
}
|
||||
|
||||
impl ast::Instruction<ExpandedArgParams> {
|
||||
fn visit_variable<F: FnMut(spirv::Word) -> spirv::Word>(self, f: &mut F) -> Self {
|
||||
let mut visitor = reduced_visitor(f);
|
||||
self.map(&mut visitor)
|
||||
}
|
||||
|
||||
fn visit_variable_extended<F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
|
||||
self,
|
||||
f: &mut F,
|
||||
|
@ -1562,32 +1561,12 @@ struct ConstantDefinition {
|
|||
pub value: i128,
|
||||
}
|
||||
|
||||
impl ConstantDefinition {
|
||||
fn map<F: FnMut(spirv::Word) -> spirv::Word>(self, f: &mut F) -> Self {
|
||||
Self {
|
||||
dst: f(self.dst),
|
||||
typ: self.typ,
|
||||
value: self.value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct BrachCondition {
|
||||
predicate: spirv::Word,
|
||||
if_true: spirv::Word,
|
||||
if_false: spirv::Word,
|
||||
}
|
||||
|
||||
impl BrachCondition {
|
||||
fn map<F: FnMut(spirv::Word) -> spirv::Word>(self, f: &mut F) -> Self {
|
||||
Self {
|
||||
predicate: f(self.predicate),
|
||||
if_true: f(self.if_true),
|
||||
if_false: f(self.if_false),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ImplicitConversion {
|
||||
src: spirv::Word,
|
||||
dst: spirv::Word,
|
||||
|
@ -1604,18 +1583,6 @@ enum ConversionKind {
|
|||
Ptr(ast::LdStateSpace),
|
||||
}
|
||||
|
||||
impl ImplicitConversion {
|
||||
fn map<F: FnMut(spirv::Word) -> spirv::Word>(self, f: &mut F) -> Self {
|
||||
Self {
|
||||
src: f(self.src),
|
||||
dst: f(self.dst),
|
||||
from: self.from,
|
||||
to: self.to,
|
||||
kind: self.kind,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ast::PredAt<T> {
|
||||
fn map_variable<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::PredAt<U> {
|
||||
ast::PredAt {
|
||||
|
@ -2354,6 +2321,15 @@ fn insert_implicit_bitcasts(
|
|||
}
|
||||
}
|
||||
|
||||
impl<'a> ast::FunctionHeader<'a, ast::ParsedArgParams<'a>> {
|
||||
fn name(&self) -> &'a str {
|
||||
match self {
|
||||
ast::FunctionHeader::Kernel(name) => name,
|
||||
ast::FunctionHeader::Func(_, name) => name,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CFGs below taken from "Modern Compiler Implementation in Java"
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
|
Loading…
Add table
Reference in a new issue