mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Be more correct when emitting brev, refactor inst->func call pass
This commit is contained in:
parent
7d4fbedfcf
commit
e328ecc550
3 changed files with 226 additions and 334 deletions
|
@ -253,6 +253,14 @@ ulong FUNC(bfi_b64)(ulong insert, ulong base, uint offset, uint count) {
|
|||
return intel_bfi(base, insert, offset, count);
|
||||
}
|
||||
|
||||
uint FUNC(brev_b32)(uint base) {
|
||||
return intel_bfrev(base);
|
||||
}
|
||||
|
||||
ulong FUNC(brev_b64)(ulong base) {
|
||||
return intel_bfrev(base);
|
||||
}
|
||||
|
||||
void FUNC(__assertfail)(
|
||||
__private ulong* message,
|
||||
__private ulong* file,
|
||||
|
|
|
@ -7,17 +7,22 @@
|
|||
OpCapability Int64
|
||||
OpCapability Float16
|
||||
OpCapability Float64
|
||||
%21 = OpExtInstImport "OpenCL.std"
|
||||
%24 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %1 "brev"
|
||||
OpDecorate %20 LinkageAttributes "__zluda_ptx_impl__brev_b32" Import
|
||||
%void = OpTypeVoid
|
||||
%ulong = OpTypeInt 64 0
|
||||
%24 = OpTypeFunction %void %ulong %ulong
|
||||
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||
%uint = OpTypeInt 32 0
|
||||
%27 = OpTypeFunction %uint %uint
|
||||
%ulong = OpTypeInt 64 0
|
||||
%29 = OpTypeFunction %void %ulong %ulong
|
||||
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||
%_ptr_Function_uint = OpTypePointer Function %uint
|
||||
%_ptr_Generic_uint = OpTypePointer Generic %uint
|
||||
%1 = OpFunction %void None %24
|
||||
%20 = OpFunction %uint None %27
|
||||
%22 = OpFunctionParameter %uint
|
||||
OpFunctionEnd
|
||||
%1 = OpFunction %void None %29
|
||||
%7 = OpFunctionParameter %ulong
|
||||
%8 = OpFunctionParameter %ulong
|
||||
%19 = OpLabel
|
||||
|
@ -37,7 +42,7 @@
|
|||
%11 = OpLoad %uint %17 Aligned 4
|
||||
OpStore %6 %11
|
||||
%14 = OpLoad %uint %6
|
||||
%13 = OpBitReverse %uint %14
|
||||
%13 = OpFunctionCall %uint %20 %14
|
||||
OpStore %6 %13
|
||||
%15 = OpLoad %ulong %5
|
||||
%16 = OpLoad %uint %6
|
||||
|
|
|
@ -8,6 +8,7 @@ use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, me
|
|||
use rspirv::binary::Assemble;
|
||||
|
||||
static ZLUDA_PTX_IMPL: &'static [u8] = include_bytes!("../lib/zluda_ptx_impl.spv");
|
||||
static ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl__";
|
||||
|
||||
quick_error! {
|
||||
#[derive(Debug)]
|
||||
|
@ -1248,7 +1249,7 @@ fn to_ssa<'input, 'b>(
|
|||
let mut numeric_id_defs = numeric_id_defs.unmut();
|
||||
let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs);
|
||||
let (f_body, globals) =
|
||||
extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs);
|
||||
extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?;
|
||||
Ok(Function {
|
||||
func_decl: func_decl,
|
||||
globals: globals,
|
||||
|
@ -1344,7 +1345,7 @@ fn extract_globals<'input, 'b>(
|
|||
sorted_statements: Vec<ExpandedStatement>,
|
||||
ptx_impl_imports: &mut HashMap<String, Directive>,
|
||||
id_def: &mut NumericIdResolver,
|
||||
) -> (Vec<ExpandedStatement>, Vec<ast::Variable<spirv::Word>>) {
|
||||
) -> Result<(Vec<ExpandedStatement>, Vec<ast::Variable<spirv::Word>>), TranslateError> {
|
||||
let mut local = Vec::with_capacity(sorted_statements.len());
|
||||
let mut global = Vec::new();
|
||||
for statement in sorted_statements {
|
||||
|
@ -1366,13 +1367,34 @@ fn extract_globals<'input, 'b>(
|
|||
},
|
||||
) => global.push(var),
|
||||
Statement::Instruction(ast::Instruction::Bfe { typ, arg }) => {
|
||||
local.push(to_ptx_impl_bfe_call(id_def, ptx_impl_imports, typ, arg));
|
||||
let fn_name = [ZLUDA_PTX_PREFIX, "bfe_", typ.to_ptx_name()].concat();
|
||||
local.push(instruction_to_fn_call(
|
||||
id_def,
|
||||
ptx_impl_imports,
|
||||
ast::Instruction::Bfe { typ, arg },
|
||||
fn_name,
|
||||
)?);
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Bfi { typ, arg }) => {
|
||||
local.push(to_ptx_impl_bfi_call(id_def, ptx_impl_imports, typ, arg));
|
||||
let fn_name = [ZLUDA_PTX_PREFIX, "bfi_", typ.to_ptx_name()].concat();
|
||||
local.push(instruction_to_fn_call(
|
||||
id_def,
|
||||
ptx_impl_imports,
|
||||
ast::Instruction::Bfi { typ, arg },
|
||||
fn_name,
|
||||
)?);
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Brev { typ, arg }) => {
|
||||
let fn_name = [ZLUDA_PTX_PREFIX, "brev_", typ.to_ptx_name()].concat();
|
||||
local.push(instruction_to_fn_call(
|
||||
id_def,
|
||||
ptx_impl_imports,
|
||||
ast::Instruction::Brev { typ, arg },
|
||||
fn_name,
|
||||
)?);
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Atom(
|
||||
d
|
||||
details
|
||||
@
|
||||
ast::AtomDetails {
|
||||
inner:
|
||||
|
@ -1382,19 +1404,28 @@ fn extract_globals<'input, 'b>(
|
|||
},
|
||||
..
|
||||
},
|
||||
a,
|
||||
args,
|
||||
)) => {
|
||||
local.push(to_ptx_impl_atomic_call(
|
||||
let fn_name = [
|
||||
ZLUDA_PTX_PREFIX,
|
||||
"atom_",
|
||||
details.semantics.to_ptx_name(),
|
||||
"_",
|
||||
details.scope.to_ptx_name(),
|
||||
"_",
|
||||
details.space.to_ptx_name(),
|
||||
"_inc",
|
||||
]
|
||||
.concat();
|
||||
local.push(instruction_to_fn_call(
|
||||
id_def,
|
||||
ptx_impl_imports,
|
||||
d,
|
||||
a,
|
||||
"inc",
|
||||
ast::ScalarType::U32,
|
||||
));
|
||||
ast::Instruction::Atom(details, args),
|
||||
fn_name,
|
||||
)?);
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Atom(
|
||||
d
|
||||
details
|
||||
@
|
||||
ast::AtomDetails {
|
||||
inner:
|
||||
|
@ -1404,57 +1435,122 @@ fn extract_globals<'input, 'b>(
|
|||
},
|
||||
..
|
||||
},
|
||||
a,
|
||||
args,
|
||||
)) => {
|
||||
local.push(to_ptx_impl_atomic_call(
|
||||
let fn_name = [
|
||||
ZLUDA_PTX_PREFIX,
|
||||
"atom_",
|
||||
details.semantics.to_ptx_name(),
|
||||
"_",
|
||||
details.scope.to_ptx_name(),
|
||||
"_",
|
||||
details.space.to_ptx_name(),
|
||||
"_dec",
|
||||
]
|
||||
.concat();
|
||||
local.push(instruction_to_fn_call(
|
||||
id_def,
|
||||
ptx_impl_imports,
|
||||
d,
|
||||
a,
|
||||
"dec",
|
||||
ast::ScalarType::U32,
|
||||
));
|
||||
ast::Instruction::Atom(details, args),
|
||||
fn_name,
|
||||
)?);
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Atom(
|
||||
details
|
||||
@
|
||||
ast::AtomDetails {
|
||||
inner:
|
||||
ast::AtomInnerDetails::Float {
|
||||
op: ast::AtomFloatOp::Add,
|
||||
typ,
|
||||
..
|
||||
},
|
||||
semantics,
|
||||
scope,
|
||||
space,
|
||||
..
|
||||
},
|
||||
a,
|
||||
args,
|
||||
)) => {
|
||||
let details = ast::AtomDetails {
|
||||
inner: ast::AtomInnerDetails::Float {
|
||||
op: ast::AtomFloatOp::Add,
|
||||
typ,
|
||||
},
|
||||
semantics,
|
||||
scope,
|
||||
space,
|
||||
};
|
||||
let (op, typ) = match typ {
|
||||
ast::ScalarType::F32 => ("add_f32", ast::ScalarType::F32),
|
||||
ast::ScalarType::F64 => ("add_f64", ast::ScalarType::F64),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
local.push(to_ptx_impl_atomic_call(
|
||||
let fn_name = [
|
||||
ZLUDA_PTX_PREFIX,
|
||||
"atom_",
|
||||
details.semantics.to_ptx_name(),
|
||||
"_",
|
||||
details.scope.to_ptx_name(),
|
||||
"_",
|
||||
details.space.to_ptx_name(),
|
||||
"_add_",
|
||||
details.inner.get_type().to_ptx_name(),
|
||||
]
|
||||
.concat();
|
||||
local.push(instruction_to_fn_call(
|
||||
id_def,
|
||||
ptx_impl_imports,
|
||||
details,
|
||||
a,
|
||||
op,
|
||||
typ,
|
||||
));
|
||||
ast::Instruction::Atom(details, args),
|
||||
fn_name,
|
||||
)?);
|
||||
}
|
||||
s => local.push(s),
|
||||
}
|
||||
}
|
||||
(local, global)
|
||||
Ok((local, global))
|
||||
}
|
||||
|
||||
impl ast::ScalarType {
|
||||
fn to_ptx_name(self) -> &'static str {
|
||||
match self {
|
||||
ast::ScalarType::B8 => "b8",
|
||||
ast::ScalarType::B16 => "b16",
|
||||
ast::ScalarType::B32 => "b32",
|
||||
ast::ScalarType::B64 => "b64",
|
||||
ast::ScalarType::U8 => "u8",
|
||||
ast::ScalarType::U16 => "u16",
|
||||
ast::ScalarType::U32 => "u32",
|
||||
ast::ScalarType::U64 => "u64",
|
||||
ast::ScalarType::S8 => "s8",
|
||||
ast::ScalarType::S16 => "s16",
|
||||
ast::ScalarType::S32 => "s32",
|
||||
ast::ScalarType::S64 => "s64",
|
||||
ast::ScalarType::F16 => "f16",
|
||||
ast::ScalarType::F32 => "f32",
|
||||
ast::ScalarType::F64 => "f64",
|
||||
ast::ScalarType::F16x2 => "f16x2",
|
||||
ast::ScalarType::Pred => "pred",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ast::AtomSemantics {
|
||||
fn to_ptx_name(self) -> &'static str {
|
||||
match self {
|
||||
ast::AtomSemantics::Relaxed => "relaxed",
|
||||
ast::AtomSemantics::Acquire => "acquire",
|
||||
ast::AtomSemantics::Release => "release",
|
||||
ast::AtomSemantics::AcquireRelease => "acq_rel",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ast::MemScope {
|
||||
fn to_ptx_name(self) -> &'static str {
|
||||
match self {
|
||||
ast::MemScope::Cta => "cta",
|
||||
ast::MemScope::Gpu => "gpu",
|
||||
ast::MemScope::Sys => "sys",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ast::StateSpace {
|
||||
fn to_ptx_name(self) -> &'static str {
|
||||
match self {
|
||||
ast::StateSpace::Generic => "generic",
|
||||
ast::StateSpace::Global => "global",
|
||||
ast::StateSpace::Shared => "shared",
|
||||
ast::StateSpace::Reg => "reg",
|
||||
ast::StateSpace::Const => "const",
|
||||
ast::StateSpace::Local => "local",
|
||||
ast::StateSpace::Param => "param",
|
||||
ast::StateSpace::Sreg => "sreg",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_variable_decls(directives: &mut Vec<Directive>) {
|
||||
|
@ -1591,142 +1687,58 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, TypedArgParams>
|
|||
}
|
||||
}
|
||||
|
||||
//TODO: share common code between this and to_ptx_impl_bfe_call
|
||||
fn to_ptx_impl_atomic_call(
|
||||
fn instruction_to_fn_call(
|
||||
id_defs: &mut NumericIdResolver,
|
||||
ptx_impl_imports: &mut HashMap<String, Directive>,
|
||||
details: ast::AtomDetails,
|
||||
arg: ast::Arg3<ExpandedArgParams>,
|
||||
op: &'static str,
|
||||
typ: ast::ScalarType,
|
||||
) -> ExpandedStatement {
|
||||
let semantics = ptx_semantics_name(details.semantics);
|
||||
let scope = ptx_scope_name(details.scope);
|
||||
let space = ptx_space_name(details.space);
|
||||
let fn_name = format!(
|
||||
"__zluda_ptx_impl__atom_{}_{}_{}_{}",
|
||||
semantics, scope, space, op
|
||||
);
|
||||
// TODO: extract to a function
|
||||
let ptr_space = details.space;
|
||||
let scalar_typ = ast::ScalarType::from(typ);
|
||||
let fn_id = match ptx_impl_imports.entry(fn_name) {
|
||||
hash_map::Entry::Vacant(entry) => {
|
||||
let fn_id = id_defs.register_intermediate(None);
|
||||
let func_decl = ast::MethodDeclaration::<spirv::Word> {
|
||||
return_arguments: vec![ast::Variable {
|
||||
align: None,
|
||||
v_type: ast::Type::Scalar(scalar_typ),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: id_defs.register_intermediate(None),
|
||||
array_init: Vec::new(),
|
||||
}],
|
||||
name: ast::MethodName::Func(fn_id),
|
||||
input_arguments: vec![
|
||||
ast::Variable {
|
||||
align: None,
|
||||
v_type: ast::Type::Pointer(typ, ptr_space),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: id_defs.register_intermediate(None),
|
||||
array_init: Vec::new(),
|
||||
},
|
||||
ast::Variable {
|
||||
align: None,
|
||||
v_type: ast::Type::Scalar(scalar_typ),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: id_defs.register_intermediate(None),
|
||||
array_init: Vec::new(),
|
||||
},
|
||||
],
|
||||
shared_mem: None,
|
||||
};
|
||||
let func = Function {
|
||||
func_decl: Rc::new(RefCell::new(func_decl)),
|
||||
globals: Vec::new(),
|
||||
body: None,
|
||||
import_as: Some(entry.key().clone()),
|
||||
tuning: Vec::new(),
|
||||
};
|
||||
entry.insert(Directive::Method(func));
|
||||
fn_id
|
||||
}
|
||||
hash_map::Entry::Occupied(entry) => match entry.get() {
|
||||
Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
|
||||
ast::MethodName::Func(fn_id) => fn_id,
|
||||
ast::MethodName::Kernel(_) => unreachable!(),
|
||||
},
|
||||
_ => unreachable!(),
|
||||
},
|
||||
};
|
||||
Statement::Call(ResolvedCall {
|
||||
inst: ast::Instruction<ExpandedArgParams>,
|
||||
fn_name: String,
|
||||
) -> Result<ExpandedStatement, TranslateError> {
|
||||
let mut arguments = Vec::new();
|
||||
inst.visit(&mut |desc: ArgumentDescriptor<spirv::Word>,
|
||||
typ: Option<(&ast::Type, ast::StateSpace)>| {
|
||||
let (typ, space) = match typ {
|
||||
Some((typ, space)) => (typ.clone(), space),
|
||||
None => return Err(error_unreachable()),
|
||||
};
|
||||
arguments.push((desc, typ, space));
|
||||
Ok(0)
|
||||
})?;
|
||||
let return_arguments_count = arguments
|
||||
.iter()
|
||||
.position(|(desc, _, _)| !desc.is_dst)
|
||||
.unwrap_or(0);
|
||||
let (return_arguments, input_arguments) = arguments.split_at(return_arguments_count);
|
||||
let fn_id = register_external_fn_call(
|
||||
id_defs,
|
||||
ptx_impl_imports,
|
||||
fn_name,
|
||||
return_arguments,
|
||||
input_arguments,
|
||||
)?;
|
||||
Ok(Statement::Call(ResolvedCall {
|
||||
uniform: false,
|
||||
name: fn_id,
|
||||
return_arguments: vec![(arg.dst, ast::Type::Scalar(scalar_typ), ast::StateSpace::Reg)],
|
||||
input_arguments: vec![
|
||||
(
|
||||
arg.src1,
|
||||
ast::Type::Pointer(typ, ptr_space),
|
||||
ast::StateSpace::Reg,
|
||||
),
|
||||
(
|
||||
arg.src2,
|
||||
ast::Type::Scalar(scalar_typ),
|
||||
ast::StateSpace::Reg,
|
||||
),
|
||||
],
|
||||
})
|
||||
return_arguments: arguments_to_resolved_arguments(return_arguments),
|
||||
input_arguments: arguments_to_resolved_arguments(input_arguments),
|
||||
}))
|
||||
}
|
||||
|
||||
fn to_ptx_impl_bfe_call(
|
||||
fn register_external_fn_call(
|
||||
id_defs: &mut NumericIdResolver,
|
||||
ptx_impl_imports: &mut HashMap<String, Directive>,
|
||||
typ: ast::ScalarType,
|
||||
arg: ast::Arg4<ExpandedArgParams>,
|
||||
) -> ExpandedStatement {
|
||||
let prefix = "__zluda_ptx_impl__";
|
||||
let suffix = match typ {
|
||||
ast::ScalarType::U32 => "bfe_u32",
|
||||
ast::ScalarType::U64 => "bfe_u64",
|
||||
ast::ScalarType::S32 => "bfe_s32",
|
||||
ast::ScalarType::S64 => "bfe_s64",
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let fn_name = format!("{}{}", prefix, suffix);
|
||||
let fn_id = match ptx_impl_imports.entry(fn_name) {
|
||||
name: String,
|
||||
return_arguments: &[(ArgumentDescriptor<spirv::Word>, ast::Type, ast::StateSpace)],
|
||||
input_arguments: &[(ArgumentDescriptor<spirv::Word>, ast::Type, ast::StateSpace)],
|
||||
) -> Result<spirv::Word, TranslateError> {
|
||||
match ptx_impl_imports.entry(name) {
|
||||
hash_map::Entry::Vacant(entry) => {
|
||||
let fn_id = id_defs.register_intermediate(None);
|
||||
let return_arguments = fn_arguments_to_variables(id_defs, return_arguments);
|
||||
let input_arguments = fn_arguments_to_variables(id_defs, input_arguments);
|
||||
let func_decl = ast::MethodDeclaration::<spirv::Word> {
|
||||
return_arguments: vec![ast::Variable {
|
||||
align: None,
|
||||
v_type: ast::Type::Scalar(typ.into()),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: id_defs.register_intermediate(None),
|
||||
array_init: Vec::new(),
|
||||
}],
|
||||
return_arguments,
|
||||
name: ast::MethodName::Func(fn_id),
|
||||
input_arguments: vec![
|
||||
ast::Variable {
|
||||
align: None,
|
||||
v_type: ast::Type::Scalar(typ.into()),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: id_defs.register_intermediate(None),
|
||||
array_init: Vec::new(),
|
||||
},
|
||||
ast::Variable {
|
||||
align: None,
|
||||
v_type: ast::Type::Scalar(ast::ScalarType::U32),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: id_defs.register_intermediate(None),
|
||||
array_init: Vec::new(),
|
||||
},
|
||||
ast::Variable {
|
||||
align: None,
|
||||
v_type: ast::Type::Scalar(ast::ScalarType::U32),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: id_defs.register_intermediate(None),
|
||||
array_init: Vec::new(),
|
||||
},
|
||||
],
|
||||
input_arguments,
|
||||
shared_mem: None,
|
||||
};
|
||||
let func = Function {
|
||||
|
@ -1737,142 +1749,39 @@ fn to_ptx_impl_bfe_call(
|
|||
tuning: Vec::new(),
|
||||
};
|
||||
entry.insert(Directive::Method(func));
|
||||
fn_id
|
||||
Ok(fn_id)
|
||||
}
|
||||
hash_map::Entry::Occupied(entry) => match entry.get() {
|
||||
Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
|
||||
ast::MethodName::Func(fn_id) => fn_id,
|
||||
ast::MethodName::Kernel(_) => unreachable!(),
|
||||
ast::MethodName::Func(fn_id) => Ok(fn_id),
|
||||
ast::MethodName::Kernel(_) => Err(error_unreachable()),
|
||||
},
|
||||
_ => unreachable!(),
|
||||
_ => Err(error_unreachable()),
|
||||
},
|
||||
};
|
||||
Statement::Call(ResolvedCall {
|
||||
uniform: false,
|
||||
name: fn_id,
|
||||
return_arguments: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)],
|
||||
input_arguments: vec![
|
||||
(
|
||||
arg.src1,
|
||||
ast::Type::Scalar(typ.into()),
|
||||
ast::StateSpace::Reg,
|
||||
),
|
||||
(
|
||||
arg.src2,
|
||||
ast::Type::Scalar(ast::ScalarType::U32),
|
||||
ast::StateSpace::Reg,
|
||||
),
|
||||
(
|
||||
arg.src3,
|
||||
ast::Type::Scalar(ast::ScalarType::U32),
|
||||
ast::StateSpace::Reg,
|
||||
),
|
||||
],
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn to_ptx_impl_bfi_call(
|
||||
fn fn_arguments_to_variables(
|
||||
id_defs: &mut NumericIdResolver,
|
||||
ptx_impl_imports: &mut HashMap<String, Directive>,
|
||||
typ: ast::ScalarType,
|
||||
arg: ast::Arg5<ExpandedArgParams>,
|
||||
) -> ExpandedStatement {
|
||||
let prefix = "__zluda_ptx_impl__";
|
||||
let suffix = match typ {
|
||||
ast::ScalarType::B32 => "bfi_b32",
|
||||
ast::ScalarType::B64 => "bfi_b64",
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let fn_name = format!("{}{}", prefix, suffix);
|
||||
let fn_id = match ptx_impl_imports.entry(fn_name) {
|
||||
hash_map::Entry::Vacant(entry) => {
|
||||
let fn_id = id_defs.register_intermediate(None);
|
||||
let func_decl = ast::MethodDeclaration::<spirv::Word> {
|
||||
return_arguments: vec![ast::Variable {
|
||||
align: None,
|
||||
v_type: ast::Type::Scalar(typ.into()),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: id_defs.register_intermediate(None),
|
||||
array_init: Vec::new(),
|
||||
}],
|
||||
name: ast::MethodName::Func(fn_id),
|
||||
input_arguments: vec![
|
||||
ast::Variable {
|
||||
align: None,
|
||||
v_type: ast::Type::Scalar(typ.into()),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: id_defs.register_intermediate(None),
|
||||
array_init: Vec::new(),
|
||||
},
|
||||
ast::Variable {
|
||||
align: None,
|
||||
v_type: ast::Type::Scalar(typ.into()),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: id_defs.register_intermediate(None),
|
||||
array_init: Vec::new(),
|
||||
},
|
||||
ast::Variable {
|
||||
align: None,
|
||||
v_type: ast::Type::Scalar(ast::ScalarType::U32),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: id_defs.register_intermediate(None),
|
||||
array_init: Vec::new(),
|
||||
},
|
||||
ast::Variable {
|
||||
align: None,
|
||||
v_type: ast::Type::Scalar(ast::ScalarType::U32),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: id_defs.register_intermediate(None),
|
||||
array_init: Vec::new(),
|
||||
},
|
||||
],
|
||||
shared_mem: None,
|
||||
};
|
||||
let func = Function {
|
||||
func_decl: Rc::new(RefCell::new(func_decl)),
|
||||
globals: Vec::new(),
|
||||
body: None,
|
||||
import_as: Some(entry.key().clone()),
|
||||
tuning: Vec::new(),
|
||||
};
|
||||
entry.insert(Directive::Method(func));
|
||||
fn_id
|
||||
}
|
||||
hash_map::Entry::Occupied(entry) => match entry.get() {
|
||||
Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
|
||||
ast::MethodName::Func(fn_id) => fn_id,
|
||||
ast::MethodName::Kernel(_) => unreachable!(),
|
||||
},
|
||||
_ => unreachable!(),
|
||||
},
|
||||
};
|
||||
Statement::Call(ResolvedCall {
|
||||
uniform: false,
|
||||
name: fn_id,
|
||||
return_arguments: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)],
|
||||
input_arguments: vec![
|
||||
(
|
||||
arg.src1,
|
||||
ast::Type::Scalar(typ.into()),
|
||||
ast::StateSpace::Reg,
|
||||
),
|
||||
(
|
||||
arg.src2,
|
||||
ast::Type::Scalar(typ.into()),
|
||||
ast::StateSpace::Reg,
|
||||
),
|
||||
(
|
||||
arg.src3,
|
||||
ast::Type::Scalar(ast::ScalarType::U32),
|
||||
ast::StateSpace::Reg,
|
||||
),
|
||||
(
|
||||
arg.src4,
|
||||
ast::Type::Scalar(ast::ScalarType::U32),
|
||||
ast::StateSpace::Reg,
|
||||
),
|
||||
],
|
||||
})
|
||||
args: &[(ArgumentDescriptor<spirv::Word>, ast::Type, ast::StateSpace)],
|
||||
) -> Vec<ast::Variable<spirv::Word>> {
|
||||
args.iter()
|
||||
.map(|(_, typ, space)| ast::Variable {
|
||||
align: None,
|
||||
v_type: typ.clone(),
|
||||
state_space: *space,
|
||||
name: id_defs.register_intermediate(None),
|
||||
array_init: Vec::new(),
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn arguments_to_resolved_arguments(
|
||||
args: &[(ArgumentDescriptor<spirv::Word>, ast::Type, ast::StateSpace)],
|
||||
) -> Vec<(spirv::Word, ast::Type, ast::StateSpace)> {
|
||||
args.iter()
|
||||
.map(|(desc, typ, space)| (desc.op, typ.clone(), *space))
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn normalize_labels(
|
||||
|
@ -3305,36 +3214,6 @@ struct PtxImplImport {
|
|||
in_args: Vec<ast::Type>,
|
||||
}
|
||||
|
||||
fn ptx_semantics_name(sema: ast::AtomSemantics) -> &'static str {
|
||||
match sema {
|
||||
ast::AtomSemantics::Relaxed => "relaxed",
|
||||
ast::AtomSemantics::Acquire => "acquire",
|
||||
ast::AtomSemantics::Release => "release",
|
||||
ast::AtomSemantics::AcquireRelease => "acq_rel",
|
||||
}
|
||||
}
|
||||
|
||||
fn ptx_scope_name(scope: ast::MemScope) -> &'static str {
|
||||
match scope {
|
||||
ast::MemScope::Cta => "cta",
|
||||
ast::MemScope::Gpu => "gpu",
|
||||
ast::MemScope::Sys => "sys",
|
||||
}
|
||||
}
|
||||
|
||||
fn ptx_space_name(space: ast::StateSpace) -> &'static str {
|
||||
match space {
|
||||
ast::StateSpace::Generic => "generic",
|
||||
ast::StateSpace::Global => "global",
|
||||
ast::StateSpace::Shared => "shared",
|
||||
ast::StateSpace::Reg => "reg",
|
||||
ast::StateSpace::Const => "const",
|
||||
ast::StateSpace::Local => "local",
|
||||
ast::StateSpace::Param => "param",
|
||||
ast::StateSpace::Sreg => "sreg",
|
||||
}
|
||||
}
|
||||
|
||||
fn emit_mul_float(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
|
|
Loading…
Add table
Reference in a new issue