mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-06 00:00:13 +00:00
Be more correct when emitting brev, refactor inst->func call pass
This commit is contained in:
parent
7d4fbedfcf
commit
e328ecc550
3 changed files with 226 additions and 334 deletions
|
@ -253,6 +253,14 @@ ulong FUNC(bfi_b64)(ulong insert, ulong base, uint offset, uint count) {
|
||||||
return intel_bfi(base, insert, offset, count);
|
return intel_bfi(base, insert, offset, count);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint FUNC(brev_b32)(uint base) {
|
||||||
|
return intel_bfrev(base);
|
||||||
|
}
|
||||||
|
|
||||||
|
ulong FUNC(brev_b64)(ulong base) {
|
||||||
|
return intel_bfrev(base);
|
||||||
|
}
|
||||||
|
|
||||||
void FUNC(__assertfail)(
|
void FUNC(__assertfail)(
|
||||||
__private ulong* message,
|
__private ulong* message,
|
||||||
__private ulong* file,
|
__private ulong* file,
|
||||||
|
|
|
@ -7,17 +7,22 @@
|
||||||
OpCapability Int64
|
OpCapability Int64
|
||||||
OpCapability Float16
|
OpCapability Float16
|
||||||
OpCapability Float64
|
OpCapability Float64
|
||||||
%21 = OpExtInstImport "OpenCL.std"
|
%24 = OpExtInstImport "OpenCL.std"
|
||||||
OpMemoryModel Physical64 OpenCL
|
OpMemoryModel Physical64 OpenCL
|
||||||
OpEntryPoint Kernel %1 "brev"
|
OpEntryPoint Kernel %1 "brev"
|
||||||
|
OpDecorate %20 LinkageAttributes "__zluda_ptx_impl__brev_b32" Import
|
||||||
%void = OpTypeVoid
|
%void = OpTypeVoid
|
||||||
%ulong = OpTypeInt 64 0
|
|
||||||
%24 = OpTypeFunction %void %ulong %ulong
|
|
||||||
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
|
||||||
%uint = OpTypeInt 32 0
|
%uint = OpTypeInt 32 0
|
||||||
|
%27 = OpTypeFunction %uint %uint
|
||||||
|
%ulong = OpTypeInt 64 0
|
||||||
|
%29 = OpTypeFunction %void %ulong %ulong
|
||||||
|
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||||
%_ptr_Function_uint = OpTypePointer Function %uint
|
%_ptr_Function_uint = OpTypePointer Function %uint
|
||||||
%_ptr_Generic_uint = OpTypePointer Generic %uint
|
%_ptr_Generic_uint = OpTypePointer Generic %uint
|
||||||
%1 = OpFunction %void None %24
|
%20 = OpFunction %uint None %27
|
||||||
|
%22 = OpFunctionParameter %uint
|
||||||
|
OpFunctionEnd
|
||||||
|
%1 = OpFunction %void None %29
|
||||||
%7 = OpFunctionParameter %ulong
|
%7 = OpFunctionParameter %ulong
|
||||||
%8 = OpFunctionParameter %ulong
|
%8 = OpFunctionParameter %ulong
|
||||||
%19 = OpLabel
|
%19 = OpLabel
|
||||||
|
@ -37,7 +42,7 @@
|
||||||
%11 = OpLoad %uint %17 Aligned 4
|
%11 = OpLoad %uint %17 Aligned 4
|
||||||
OpStore %6 %11
|
OpStore %6 %11
|
||||||
%14 = OpLoad %uint %6
|
%14 = OpLoad %uint %6
|
||||||
%13 = OpBitReverse %uint %14
|
%13 = OpFunctionCall %uint %20 %14
|
||||||
OpStore %6 %13
|
OpStore %6 %13
|
||||||
%15 = OpLoad %ulong %5
|
%15 = OpLoad %ulong %5
|
||||||
%16 = OpLoad %uint %6
|
%16 = OpLoad %uint %6
|
||||||
|
|
|
@ -8,6 +8,7 @@ use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, me
|
||||||
use rspirv::binary::Assemble;
|
use rspirv::binary::Assemble;
|
||||||
|
|
||||||
static ZLUDA_PTX_IMPL: &'static [u8] = include_bytes!("../lib/zluda_ptx_impl.spv");
|
static ZLUDA_PTX_IMPL: &'static [u8] = include_bytes!("../lib/zluda_ptx_impl.spv");
|
||||||
|
static ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl__";
|
||||||
|
|
||||||
quick_error! {
|
quick_error! {
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
@ -1248,7 +1249,7 @@ fn to_ssa<'input, 'b>(
|
||||||
let mut numeric_id_defs = numeric_id_defs.unmut();
|
let mut numeric_id_defs = numeric_id_defs.unmut();
|
||||||
let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs);
|
let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs);
|
||||||
let (f_body, globals) =
|
let (f_body, globals) =
|
||||||
extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs);
|
extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?;
|
||||||
Ok(Function {
|
Ok(Function {
|
||||||
func_decl: func_decl,
|
func_decl: func_decl,
|
||||||
globals: globals,
|
globals: globals,
|
||||||
|
@ -1344,7 +1345,7 @@ fn extract_globals<'input, 'b>(
|
||||||
sorted_statements: Vec<ExpandedStatement>,
|
sorted_statements: Vec<ExpandedStatement>,
|
||||||
ptx_impl_imports: &mut HashMap<String, Directive>,
|
ptx_impl_imports: &mut HashMap<String, Directive>,
|
||||||
id_def: &mut NumericIdResolver,
|
id_def: &mut NumericIdResolver,
|
||||||
) -> (Vec<ExpandedStatement>, Vec<ast::Variable<spirv::Word>>) {
|
) -> Result<(Vec<ExpandedStatement>, Vec<ast::Variable<spirv::Word>>), TranslateError> {
|
||||||
let mut local = Vec::with_capacity(sorted_statements.len());
|
let mut local = Vec::with_capacity(sorted_statements.len());
|
||||||
let mut global = Vec::new();
|
let mut global = Vec::new();
|
||||||
for statement in sorted_statements {
|
for statement in sorted_statements {
|
||||||
|
@ -1366,13 +1367,34 @@ fn extract_globals<'input, 'b>(
|
||||||
},
|
},
|
||||||
) => global.push(var),
|
) => global.push(var),
|
||||||
Statement::Instruction(ast::Instruction::Bfe { typ, arg }) => {
|
Statement::Instruction(ast::Instruction::Bfe { typ, arg }) => {
|
||||||
local.push(to_ptx_impl_bfe_call(id_def, ptx_impl_imports, typ, arg));
|
let fn_name = [ZLUDA_PTX_PREFIX, "bfe_", typ.to_ptx_name()].concat();
|
||||||
|
local.push(instruction_to_fn_call(
|
||||||
|
id_def,
|
||||||
|
ptx_impl_imports,
|
||||||
|
ast::Instruction::Bfe { typ, arg },
|
||||||
|
fn_name,
|
||||||
|
)?);
|
||||||
}
|
}
|
||||||
Statement::Instruction(ast::Instruction::Bfi { typ, arg }) => {
|
Statement::Instruction(ast::Instruction::Bfi { typ, arg }) => {
|
||||||
local.push(to_ptx_impl_bfi_call(id_def, ptx_impl_imports, typ, arg));
|
let fn_name = [ZLUDA_PTX_PREFIX, "bfi_", typ.to_ptx_name()].concat();
|
||||||
|
local.push(instruction_to_fn_call(
|
||||||
|
id_def,
|
||||||
|
ptx_impl_imports,
|
||||||
|
ast::Instruction::Bfi { typ, arg },
|
||||||
|
fn_name,
|
||||||
|
)?);
|
||||||
|
}
|
||||||
|
Statement::Instruction(ast::Instruction::Brev { typ, arg }) => {
|
||||||
|
let fn_name = [ZLUDA_PTX_PREFIX, "brev_", typ.to_ptx_name()].concat();
|
||||||
|
local.push(instruction_to_fn_call(
|
||||||
|
id_def,
|
||||||
|
ptx_impl_imports,
|
||||||
|
ast::Instruction::Brev { typ, arg },
|
||||||
|
fn_name,
|
||||||
|
)?);
|
||||||
}
|
}
|
||||||
Statement::Instruction(ast::Instruction::Atom(
|
Statement::Instruction(ast::Instruction::Atom(
|
||||||
d
|
details
|
||||||
@
|
@
|
||||||
ast::AtomDetails {
|
ast::AtomDetails {
|
||||||
inner:
|
inner:
|
||||||
|
@ -1382,19 +1404,28 @@ fn extract_globals<'input, 'b>(
|
||||||
},
|
},
|
||||||
..
|
..
|
||||||
},
|
},
|
||||||
a,
|
args,
|
||||||
)) => {
|
)) => {
|
||||||
local.push(to_ptx_impl_atomic_call(
|
let fn_name = [
|
||||||
|
ZLUDA_PTX_PREFIX,
|
||||||
|
"atom_",
|
||||||
|
details.semantics.to_ptx_name(),
|
||||||
|
"_",
|
||||||
|
details.scope.to_ptx_name(),
|
||||||
|
"_",
|
||||||
|
details.space.to_ptx_name(),
|
||||||
|
"_inc",
|
||||||
|
]
|
||||||
|
.concat();
|
||||||
|
local.push(instruction_to_fn_call(
|
||||||
id_def,
|
id_def,
|
||||||
ptx_impl_imports,
|
ptx_impl_imports,
|
||||||
d,
|
ast::Instruction::Atom(details, args),
|
||||||
a,
|
fn_name,
|
||||||
"inc",
|
)?);
|
||||||
ast::ScalarType::U32,
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
Statement::Instruction(ast::Instruction::Atom(
|
Statement::Instruction(ast::Instruction::Atom(
|
||||||
d
|
details
|
||||||
@
|
@
|
||||||
ast::AtomDetails {
|
ast::AtomDetails {
|
||||||
inner:
|
inner:
|
||||||
|
@ -1404,57 +1435,122 @@ fn extract_globals<'input, 'b>(
|
||||||
},
|
},
|
||||||
..
|
..
|
||||||
},
|
},
|
||||||
a,
|
args,
|
||||||
)) => {
|
)) => {
|
||||||
local.push(to_ptx_impl_atomic_call(
|
let fn_name = [
|
||||||
|
ZLUDA_PTX_PREFIX,
|
||||||
|
"atom_",
|
||||||
|
details.semantics.to_ptx_name(),
|
||||||
|
"_",
|
||||||
|
details.scope.to_ptx_name(),
|
||||||
|
"_",
|
||||||
|
details.space.to_ptx_name(),
|
||||||
|
"_dec",
|
||||||
|
]
|
||||||
|
.concat();
|
||||||
|
local.push(instruction_to_fn_call(
|
||||||
id_def,
|
id_def,
|
||||||
ptx_impl_imports,
|
ptx_impl_imports,
|
||||||
d,
|
ast::Instruction::Atom(details, args),
|
||||||
a,
|
fn_name,
|
||||||
"dec",
|
)?);
|
||||||
ast::ScalarType::U32,
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
Statement::Instruction(ast::Instruction::Atom(
|
Statement::Instruction(ast::Instruction::Atom(
|
||||||
|
details
|
||||||
|
@
|
||||||
ast::AtomDetails {
|
ast::AtomDetails {
|
||||||
inner:
|
inner:
|
||||||
ast::AtomInnerDetails::Float {
|
ast::AtomInnerDetails::Float {
|
||||||
op: ast::AtomFloatOp::Add,
|
op: ast::AtomFloatOp::Add,
|
||||||
typ,
|
..
|
||||||
},
|
},
|
||||||
semantics,
|
..
|
||||||
scope,
|
|
||||||
space,
|
|
||||||
},
|
},
|
||||||
a,
|
args,
|
||||||
)) => {
|
)) => {
|
||||||
let details = ast::AtomDetails {
|
let fn_name = [
|
||||||
inner: ast::AtomInnerDetails::Float {
|
ZLUDA_PTX_PREFIX,
|
||||||
op: ast::AtomFloatOp::Add,
|
"atom_",
|
||||||
typ,
|
details.semantics.to_ptx_name(),
|
||||||
},
|
"_",
|
||||||
semantics,
|
details.scope.to_ptx_name(),
|
||||||
scope,
|
"_",
|
||||||
space,
|
details.space.to_ptx_name(),
|
||||||
};
|
"_add_",
|
||||||
let (op, typ) = match typ {
|
details.inner.get_type().to_ptx_name(),
|
||||||
ast::ScalarType::F32 => ("add_f32", ast::ScalarType::F32),
|
]
|
||||||
ast::ScalarType::F64 => ("add_f64", ast::ScalarType::F64),
|
.concat();
|
||||||
_ => unreachable!(),
|
local.push(instruction_to_fn_call(
|
||||||
};
|
|
||||||
local.push(to_ptx_impl_atomic_call(
|
|
||||||
id_def,
|
id_def,
|
||||||
ptx_impl_imports,
|
ptx_impl_imports,
|
||||||
details,
|
ast::Instruction::Atom(details, args),
|
||||||
a,
|
fn_name,
|
||||||
op,
|
)?);
|
||||||
typ,
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
s => local.push(s),
|
s => local.push(s),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
(local, global)
|
Ok((local, global))
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ast::ScalarType {
|
||||||
|
fn to_ptx_name(self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
ast::ScalarType::B8 => "b8",
|
||||||
|
ast::ScalarType::B16 => "b16",
|
||||||
|
ast::ScalarType::B32 => "b32",
|
||||||
|
ast::ScalarType::B64 => "b64",
|
||||||
|
ast::ScalarType::U8 => "u8",
|
||||||
|
ast::ScalarType::U16 => "u16",
|
||||||
|
ast::ScalarType::U32 => "u32",
|
||||||
|
ast::ScalarType::U64 => "u64",
|
||||||
|
ast::ScalarType::S8 => "s8",
|
||||||
|
ast::ScalarType::S16 => "s16",
|
||||||
|
ast::ScalarType::S32 => "s32",
|
||||||
|
ast::ScalarType::S64 => "s64",
|
||||||
|
ast::ScalarType::F16 => "f16",
|
||||||
|
ast::ScalarType::F32 => "f32",
|
||||||
|
ast::ScalarType::F64 => "f64",
|
||||||
|
ast::ScalarType::F16x2 => "f16x2",
|
||||||
|
ast::ScalarType::Pred => "pred",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ast::AtomSemantics {
|
||||||
|
fn to_ptx_name(self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
ast::AtomSemantics::Relaxed => "relaxed",
|
||||||
|
ast::AtomSemantics::Acquire => "acquire",
|
||||||
|
ast::AtomSemantics::Release => "release",
|
||||||
|
ast::AtomSemantics::AcquireRelease => "acq_rel",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ast::MemScope {
|
||||||
|
fn to_ptx_name(self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
ast::MemScope::Cta => "cta",
|
||||||
|
ast::MemScope::Gpu => "gpu",
|
||||||
|
ast::MemScope::Sys => "sys",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ast::StateSpace {
|
||||||
|
fn to_ptx_name(self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
ast::StateSpace::Generic => "generic",
|
||||||
|
ast::StateSpace::Global => "global",
|
||||||
|
ast::StateSpace::Shared => "shared",
|
||||||
|
ast::StateSpace::Reg => "reg",
|
||||||
|
ast::StateSpace::Const => "const",
|
||||||
|
ast::StateSpace::Local => "local",
|
||||||
|
ast::StateSpace::Param => "param",
|
||||||
|
ast::StateSpace::Sreg => "sreg",
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn normalize_variable_decls(directives: &mut Vec<Directive>) {
|
fn normalize_variable_decls(directives: &mut Vec<Directive>) {
|
||||||
|
@ -1591,142 +1687,58 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, TypedArgParams>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO: share common code between this and to_ptx_impl_bfe_call
|
fn instruction_to_fn_call(
|
||||||
fn to_ptx_impl_atomic_call(
|
|
||||||
id_defs: &mut NumericIdResolver,
|
id_defs: &mut NumericIdResolver,
|
||||||
ptx_impl_imports: &mut HashMap<String, Directive>,
|
ptx_impl_imports: &mut HashMap<String, Directive>,
|
||||||
details: ast::AtomDetails,
|
inst: ast::Instruction<ExpandedArgParams>,
|
||||||
arg: ast::Arg3<ExpandedArgParams>,
|
fn_name: String,
|
||||||
op: &'static str,
|
) -> Result<ExpandedStatement, TranslateError> {
|
||||||
typ: ast::ScalarType,
|
let mut arguments = Vec::new();
|
||||||
) -> ExpandedStatement {
|
inst.visit(&mut |desc: ArgumentDescriptor<spirv::Word>,
|
||||||
let semantics = ptx_semantics_name(details.semantics);
|
typ: Option<(&ast::Type, ast::StateSpace)>| {
|
||||||
let scope = ptx_scope_name(details.scope);
|
let (typ, space) = match typ {
|
||||||
let space = ptx_space_name(details.space);
|
Some((typ, space)) => (typ.clone(), space),
|
||||||
let fn_name = format!(
|
None => return Err(error_unreachable()),
|
||||||
"__zluda_ptx_impl__atom_{}_{}_{}_{}",
|
|
||||||
semantics, scope, space, op
|
|
||||||
);
|
|
||||||
// TODO: extract to a function
|
|
||||||
let ptr_space = details.space;
|
|
||||||
let scalar_typ = ast::ScalarType::from(typ);
|
|
||||||
let fn_id = match ptx_impl_imports.entry(fn_name) {
|
|
||||||
hash_map::Entry::Vacant(entry) => {
|
|
||||||
let fn_id = id_defs.register_intermediate(None);
|
|
||||||
let func_decl = ast::MethodDeclaration::<spirv::Word> {
|
|
||||||
return_arguments: vec![ast::Variable {
|
|
||||||
align: None,
|
|
||||||
v_type: ast::Type::Scalar(scalar_typ),
|
|
||||||
state_space: ast::StateSpace::Reg,
|
|
||||||
name: id_defs.register_intermediate(None),
|
|
||||||
array_init: Vec::new(),
|
|
||||||
}],
|
|
||||||
name: ast::MethodName::Func(fn_id),
|
|
||||||
input_arguments: vec![
|
|
||||||
ast::Variable {
|
|
||||||
align: None,
|
|
||||||
v_type: ast::Type::Pointer(typ, ptr_space),
|
|
||||||
state_space: ast::StateSpace::Reg,
|
|
||||||
name: id_defs.register_intermediate(None),
|
|
||||||
array_init: Vec::new(),
|
|
||||||
},
|
|
||||||
ast::Variable {
|
|
||||||
align: None,
|
|
||||||
v_type: ast::Type::Scalar(scalar_typ),
|
|
||||||
state_space: ast::StateSpace::Reg,
|
|
||||||
name: id_defs.register_intermediate(None),
|
|
||||||
array_init: Vec::new(),
|
|
||||||
},
|
|
||||||
],
|
|
||||||
shared_mem: None,
|
|
||||||
};
|
};
|
||||||
let func = Function {
|
arguments.push((desc, typ, space));
|
||||||
func_decl: Rc::new(RefCell::new(func_decl)),
|
Ok(0)
|
||||||
globals: Vec::new(),
|
})?;
|
||||||
body: None,
|
let return_arguments_count = arguments
|
||||||
import_as: Some(entry.key().clone()),
|
.iter()
|
||||||
tuning: Vec::new(),
|
.position(|(desc, _, _)| !desc.is_dst)
|
||||||
};
|
.unwrap_or(0);
|
||||||
entry.insert(Directive::Method(func));
|
let (return_arguments, input_arguments) = arguments.split_at(return_arguments_count);
|
||||||
fn_id
|
let fn_id = register_external_fn_call(
|
||||||
}
|
id_defs,
|
||||||
hash_map::Entry::Occupied(entry) => match entry.get() {
|
ptx_impl_imports,
|
||||||
Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
|
fn_name,
|
||||||
ast::MethodName::Func(fn_id) => fn_id,
|
return_arguments,
|
||||||
ast::MethodName::Kernel(_) => unreachable!(),
|
input_arguments,
|
||||||
},
|
)?;
|
||||||
_ => unreachable!(),
|
Ok(Statement::Call(ResolvedCall {
|
||||||
},
|
|
||||||
};
|
|
||||||
Statement::Call(ResolvedCall {
|
|
||||||
uniform: false,
|
uniform: false,
|
||||||
name: fn_id,
|
name: fn_id,
|
||||||
return_arguments: vec![(arg.dst, ast::Type::Scalar(scalar_typ), ast::StateSpace::Reg)],
|
return_arguments: arguments_to_resolved_arguments(return_arguments),
|
||||||
input_arguments: vec![
|
input_arguments: arguments_to_resolved_arguments(input_arguments),
|
||||||
(
|
}))
|
||||||
arg.src1,
|
|
||||||
ast::Type::Pointer(typ, ptr_space),
|
|
||||||
ast::StateSpace::Reg,
|
|
||||||
),
|
|
||||||
(
|
|
||||||
arg.src2,
|
|
||||||
ast::Type::Scalar(scalar_typ),
|
|
||||||
ast::StateSpace::Reg,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn to_ptx_impl_bfe_call(
|
fn register_external_fn_call(
|
||||||
id_defs: &mut NumericIdResolver,
|
id_defs: &mut NumericIdResolver,
|
||||||
ptx_impl_imports: &mut HashMap<String, Directive>,
|
ptx_impl_imports: &mut HashMap<String, Directive>,
|
||||||
typ: ast::ScalarType,
|
name: String,
|
||||||
arg: ast::Arg4<ExpandedArgParams>,
|
return_arguments: &[(ArgumentDescriptor<spirv::Word>, ast::Type, ast::StateSpace)],
|
||||||
) -> ExpandedStatement {
|
input_arguments: &[(ArgumentDescriptor<spirv::Word>, ast::Type, ast::StateSpace)],
|
||||||
let prefix = "__zluda_ptx_impl__";
|
) -> Result<spirv::Word, TranslateError> {
|
||||||
let suffix = match typ {
|
match ptx_impl_imports.entry(name) {
|
||||||
ast::ScalarType::U32 => "bfe_u32",
|
|
||||||
ast::ScalarType::U64 => "bfe_u64",
|
|
||||||
ast::ScalarType::S32 => "bfe_s32",
|
|
||||||
ast::ScalarType::S64 => "bfe_s64",
|
|
||||||
_ => unreachable!(),
|
|
||||||
};
|
|
||||||
let fn_name = format!("{}{}", prefix, suffix);
|
|
||||||
let fn_id = match ptx_impl_imports.entry(fn_name) {
|
|
||||||
hash_map::Entry::Vacant(entry) => {
|
hash_map::Entry::Vacant(entry) => {
|
||||||
let fn_id = id_defs.register_intermediate(None);
|
let fn_id = id_defs.register_intermediate(None);
|
||||||
|
let return_arguments = fn_arguments_to_variables(id_defs, return_arguments);
|
||||||
|
let input_arguments = fn_arguments_to_variables(id_defs, input_arguments);
|
||||||
let func_decl = ast::MethodDeclaration::<spirv::Word> {
|
let func_decl = ast::MethodDeclaration::<spirv::Word> {
|
||||||
return_arguments: vec![ast::Variable {
|
return_arguments,
|
||||||
align: None,
|
|
||||||
v_type: ast::Type::Scalar(typ.into()),
|
|
||||||
state_space: ast::StateSpace::Reg,
|
|
||||||
name: id_defs.register_intermediate(None),
|
|
||||||
array_init: Vec::new(),
|
|
||||||
}],
|
|
||||||
name: ast::MethodName::Func(fn_id),
|
name: ast::MethodName::Func(fn_id),
|
||||||
input_arguments: vec![
|
input_arguments,
|
||||||
ast::Variable {
|
|
||||||
align: None,
|
|
||||||
v_type: ast::Type::Scalar(typ.into()),
|
|
||||||
state_space: ast::StateSpace::Reg,
|
|
||||||
name: id_defs.register_intermediate(None),
|
|
||||||
array_init: Vec::new(),
|
|
||||||
},
|
|
||||||
ast::Variable {
|
|
||||||
align: None,
|
|
||||||
v_type: ast::Type::Scalar(ast::ScalarType::U32),
|
|
||||||
state_space: ast::StateSpace::Reg,
|
|
||||||
name: id_defs.register_intermediate(None),
|
|
||||||
array_init: Vec::new(),
|
|
||||||
},
|
|
||||||
ast::Variable {
|
|
||||||
align: None,
|
|
||||||
v_type: ast::Type::Scalar(ast::ScalarType::U32),
|
|
||||||
state_space: ast::StateSpace::Reg,
|
|
||||||
name: id_defs.register_intermediate(None),
|
|
||||||
array_init: Vec::new(),
|
|
||||||
},
|
|
||||||
],
|
|
||||||
shared_mem: None,
|
shared_mem: None,
|
||||||
};
|
};
|
||||||
let func = Function {
|
let func = Function {
|
||||||
|
@ -1737,142 +1749,39 @@ fn to_ptx_impl_bfe_call(
|
||||||
tuning: Vec::new(),
|
tuning: Vec::new(),
|
||||||
};
|
};
|
||||||
entry.insert(Directive::Method(func));
|
entry.insert(Directive::Method(func));
|
||||||
fn_id
|
Ok(fn_id)
|
||||||
}
|
}
|
||||||
hash_map::Entry::Occupied(entry) => match entry.get() {
|
hash_map::Entry::Occupied(entry) => match entry.get() {
|
||||||
Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
|
Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
|
||||||
ast::MethodName::Func(fn_id) => fn_id,
|
ast::MethodName::Func(fn_id) => Ok(fn_id),
|
||||||
ast::MethodName::Kernel(_) => unreachable!(),
|
ast::MethodName::Kernel(_) => Err(error_unreachable()),
|
||||||
},
|
},
|
||||||
_ => unreachable!(),
|
_ => Err(error_unreachable()),
|
||||||
},
|
},
|
||||||
};
|
}
|
||||||
Statement::Call(ResolvedCall {
|
|
||||||
uniform: false,
|
|
||||||
name: fn_id,
|
|
||||||
return_arguments: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)],
|
|
||||||
input_arguments: vec![
|
|
||||||
(
|
|
||||||
arg.src1,
|
|
||||||
ast::Type::Scalar(typ.into()),
|
|
||||||
ast::StateSpace::Reg,
|
|
||||||
),
|
|
||||||
(
|
|
||||||
arg.src2,
|
|
||||||
ast::Type::Scalar(ast::ScalarType::U32),
|
|
||||||
ast::StateSpace::Reg,
|
|
||||||
),
|
|
||||||
(
|
|
||||||
arg.src3,
|
|
||||||
ast::Type::Scalar(ast::ScalarType::U32),
|
|
||||||
ast::StateSpace::Reg,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn to_ptx_impl_bfi_call(
|
fn fn_arguments_to_variables(
|
||||||
id_defs: &mut NumericIdResolver,
|
id_defs: &mut NumericIdResolver,
|
||||||
ptx_impl_imports: &mut HashMap<String, Directive>,
|
args: &[(ArgumentDescriptor<spirv::Word>, ast::Type, ast::StateSpace)],
|
||||||
typ: ast::ScalarType,
|
) -> Vec<ast::Variable<spirv::Word>> {
|
||||||
arg: ast::Arg5<ExpandedArgParams>,
|
args.iter()
|
||||||
) -> ExpandedStatement {
|
.map(|(_, typ, space)| ast::Variable {
|
||||||
let prefix = "__zluda_ptx_impl__";
|
|
||||||
let suffix = match typ {
|
|
||||||
ast::ScalarType::B32 => "bfi_b32",
|
|
||||||
ast::ScalarType::B64 => "bfi_b64",
|
|
||||||
_ => unreachable!(),
|
|
||||||
};
|
|
||||||
let fn_name = format!("{}{}", prefix, suffix);
|
|
||||||
let fn_id = match ptx_impl_imports.entry(fn_name) {
|
|
||||||
hash_map::Entry::Vacant(entry) => {
|
|
||||||
let fn_id = id_defs.register_intermediate(None);
|
|
||||||
let func_decl = ast::MethodDeclaration::<spirv::Word> {
|
|
||||||
return_arguments: vec![ast::Variable {
|
|
||||||
align: None,
|
align: None,
|
||||||
v_type: ast::Type::Scalar(typ.into()),
|
v_type: typ.clone(),
|
||||||
state_space: ast::StateSpace::Reg,
|
state_space: *space,
|
||||||
name: id_defs.register_intermediate(None),
|
name: id_defs.register_intermediate(None),
|
||||||
array_init: Vec::new(),
|
array_init: Vec::new(),
|
||||||
}],
|
|
||||||
name: ast::MethodName::Func(fn_id),
|
|
||||||
input_arguments: vec![
|
|
||||||
ast::Variable {
|
|
||||||
align: None,
|
|
||||||
v_type: ast::Type::Scalar(typ.into()),
|
|
||||||
state_space: ast::StateSpace::Reg,
|
|
||||||
name: id_defs.register_intermediate(None),
|
|
||||||
array_init: Vec::new(),
|
|
||||||
},
|
|
||||||
ast::Variable {
|
|
||||||
align: None,
|
|
||||||
v_type: ast::Type::Scalar(typ.into()),
|
|
||||||
state_space: ast::StateSpace::Reg,
|
|
||||||
name: id_defs.register_intermediate(None),
|
|
||||||
array_init: Vec::new(),
|
|
||||||
},
|
|
||||||
ast::Variable {
|
|
||||||
align: None,
|
|
||||||
v_type: ast::Type::Scalar(ast::ScalarType::U32),
|
|
||||||
state_space: ast::StateSpace::Reg,
|
|
||||||
name: id_defs.register_intermediate(None),
|
|
||||||
array_init: Vec::new(),
|
|
||||||
},
|
|
||||||
ast::Variable {
|
|
||||||
align: None,
|
|
||||||
v_type: ast::Type::Scalar(ast::ScalarType::U32),
|
|
||||||
state_space: ast::StateSpace::Reg,
|
|
||||||
name: id_defs.register_intermediate(None),
|
|
||||||
array_init: Vec::new(),
|
|
||||||
},
|
|
||||||
],
|
|
||||||
shared_mem: None,
|
|
||||||
};
|
|
||||||
let func = Function {
|
|
||||||
func_decl: Rc::new(RefCell::new(func_decl)),
|
|
||||||
globals: Vec::new(),
|
|
||||||
body: None,
|
|
||||||
import_as: Some(entry.key().clone()),
|
|
||||||
tuning: Vec::new(),
|
|
||||||
};
|
|
||||||
entry.insert(Directive::Method(func));
|
|
||||||
fn_id
|
|
||||||
}
|
|
||||||
hash_map::Entry::Occupied(entry) => match entry.get() {
|
|
||||||
Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
|
|
||||||
ast::MethodName::Func(fn_id) => fn_id,
|
|
||||||
ast::MethodName::Kernel(_) => unreachable!(),
|
|
||||||
},
|
|
||||||
_ => unreachable!(),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
Statement::Call(ResolvedCall {
|
|
||||||
uniform: false,
|
|
||||||
name: fn_id,
|
|
||||||
return_arguments: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)],
|
|
||||||
input_arguments: vec![
|
|
||||||
(
|
|
||||||
arg.src1,
|
|
||||||
ast::Type::Scalar(typ.into()),
|
|
||||||
ast::StateSpace::Reg,
|
|
||||||
),
|
|
||||||
(
|
|
||||||
arg.src2,
|
|
||||||
ast::Type::Scalar(typ.into()),
|
|
||||||
ast::StateSpace::Reg,
|
|
||||||
),
|
|
||||||
(
|
|
||||||
arg.src3,
|
|
||||||
ast::Type::Scalar(ast::ScalarType::U32),
|
|
||||||
ast::StateSpace::Reg,
|
|
||||||
),
|
|
||||||
(
|
|
||||||
arg.src4,
|
|
||||||
ast::Type::Scalar(ast::ScalarType::U32),
|
|
||||||
ast::StateSpace::Reg,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
})
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn arguments_to_resolved_arguments(
|
||||||
|
args: &[(ArgumentDescriptor<spirv::Word>, ast::Type, ast::StateSpace)],
|
||||||
|
) -> Vec<(spirv::Word, ast::Type, ast::StateSpace)> {
|
||||||
|
args.iter()
|
||||||
|
.map(|(desc, typ, space)| (desc.op, typ.clone(), *space))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn normalize_labels(
|
fn normalize_labels(
|
||||||
|
@ -3305,36 +3214,6 @@ struct PtxImplImport {
|
||||||
in_args: Vec<ast::Type>,
|
in_args: Vec<ast::Type>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ptx_semantics_name(sema: ast::AtomSemantics) -> &'static str {
|
|
||||||
match sema {
|
|
||||||
ast::AtomSemantics::Relaxed => "relaxed",
|
|
||||||
ast::AtomSemantics::Acquire => "acquire",
|
|
||||||
ast::AtomSemantics::Release => "release",
|
|
||||||
ast::AtomSemantics::AcquireRelease => "acq_rel",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ptx_scope_name(scope: ast::MemScope) -> &'static str {
|
|
||||||
match scope {
|
|
||||||
ast::MemScope::Cta => "cta",
|
|
||||||
ast::MemScope::Gpu => "gpu",
|
|
||||||
ast::MemScope::Sys => "sys",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ptx_space_name(space: ast::StateSpace) -> &'static str {
|
|
||||||
match space {
|
|
||||||
ast::StateSpace::Generic => "generic",
|
|
||||||
ast::StateSpace::Global => "global",
|
|
||||||
ast::StateSpace::Shared => "shared",
|
|
||||||
ast::StateSpace::Reg => "reg",
|
|
||||||
ast::StateSpace::Const => "const",
|
|
||||||
ast::StateSpace::Local => "local",
|
|
||||||
ast::StateSpace::Param => "param",
|
|
||||||
ast::StateSpace::Sreg => "sreg",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn emit_mul_float(
|
fn emit_mul_float(
|
||||||
builder: &mut dr::Builder,
|
builder: &mut dr::Builder,
|
||||||
map: &mut TypeWordMap,
|
map: &mut TypeWordMap,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue