Be more correct when emitting brev, refactor inst->func call pass

This commit is contained in:
Andrzej Janik 2021-07-02 22:45:09 +02:00
parent 7d4fbedfcf
commit e328ecc550
3 changed files with 226 additions and 334 deletions

View file

@ -253,6 +253,14 @@ ulong FUNC(bfi_b64)(ulong insert, ulong base, uint offset, uint count) {
return intel_bfi(base, insert, offset, count);
}
uint FUNC(brev_b32)(uint base) {
return intel_bfrev(base);
}
ulong FUNC(brev_b64)(ulong base) {
return intel_bfrev(base);
}
void FUNC(__assertfail)(
__private ulong* message,
__private ulong* file,

View file

@ -7,17 +7,22 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
%21 = OpExtInstImport "OpenCL.std"
%24 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "brev"
OpDecorate %20 LinkageAttributes "__zluda_ptx_impl__brev_b32" Import
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%24 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%uint = OpTypeInt 32 0
%27 = OpTypeFunction %uint %uint
%ulong = OpTypeInt 64 0
%29 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
%1 = OpFunction %void None %24
%20 = OpFunction %uint None %27
%22 = OpFunctionParameter %uint
OpFunctionEnd
%1 = OpFunction %void None %29
%7 = OpFunctionParameter %ulong
%8 = OpFunctionParameter %ulong
%19 = OpLabel
@ -37,7 +42,7 @@
%11 = OpLoad %uint %17 Aligned 4
OpStore %6 %11
%14 = OpLoad %uint %6
%13 = OpBitReverse %uint %14
%13 = OpFunctionCall %uint %20 %14
OpStore %6 %13
%15 = OpLoad %ulong %5
%16 = OpLoad %uint %6

View file

@ -8,6 +8,7 @@ use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, me
use rspirv::binary::Assemble;
static ZLUDA_PTX_IMPL: &'static [u8] = include_bytes!("../lib/zluda_ptx_impl.spv");
static ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl__";
quick_error! {
#[derive(Debug)]
@ -1248,7 +1249,7 @@ fn to_ssa<'input, 'b>(
let mut numeric_id_defs = numeric_id_defs.unmut();
let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs);
let (f_body, globals) =
extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs);
extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?;
Ok(Function {
func_decl: func_decl,
globals: globals,
@ -1344,7 +1345,7 @@ fn extract_globals<'input, 'b>(
sorted_statements: Vec<ExpandedStatement>,
ptx_impl_imports: &mut HashMap<String, Directive>,
id_def: &mut NumericIdResolver,
) -> (Vec<ExpandedStatement>, Vec<ast::Variable<spirv::Word>>) {
) -> Result<(Vec<ExpandedStatement>, Vec<ast::Variable<spirv::Word>>), TranslateError> {
let mut local = Vec::with_capacity(sorted_statements.len());
let mut global = Vec::new();
for statement in sorted_statements {
@ -1366,13 +1367,34 @@ fn extract_globals<'input, 'b>(
},
) => global.push(var),
Statement::Instruction(ast::Instruction::Bfe { typ, arg }) => {
local.push(to_ptx_impl_bfe_call(id_def, ptx_impl_imports, typ, arg));
let fn_name = [ZLUDA_PTX_PREFIX, "bfe_", typ.to_ptx_name()].concat();
local.push(instruction_to_fn_call(
id_def,
ptx_impl_imports,
ast::Instruction::Bfe { typ, arg },
fn_name,
)?);
}
Statement::Instruction(ast::Instruction::Bfi { typ, arg }) => {
local.push(to_ptx_impl_bfi_call(id_def, ptx_impl_imports, typ, arg));
let fn_name = [ZLUDA_PTX_PREFIX, "bfi_", typ.to_ptx_name()].concat();
local.push(instruction_to_fn_call(
id_def,
ptx_impl_imports,
ast::Instruction::Bfi { typ, arg },
fn_name,
)?);
}
Statement::Instruction(ast::Instruction::Brev { typ, arg }) => {
let fn_name = [ZLUDA_PTX_PREFIX, "brev_", typ.to_ptx_name()].concat();
local.push(instruction_to_fn_call(
id_def,
ptx_impl_imports,
ast::Instruction::Brev { typ, arg },
fn_name,
)?);
}
Statement::Instruction(ast::Instruction::Atom(
d
details
@
ast::AtomDetails {
inner:
@ -1382,19 +1404,28 @@ fn extract_globals<'input, 'b>(
},
..
},
a,
args,
)) => {
local.push(to_ptx_impl_atomic_call(
let fn_name = [
ZLUDA_PTX_PREFIX,
"atom_",
details.semantics.to_ptx_name(),
"_",
details.scope.to_ptx_name(),
"_",
details.space.to_ptx_name(),
"_inc",
]
.concat();
local.push(instruction_to_fn_call(
id_def,
ptx_impl_imports,
d,
a,
"inc",
ast::ScalarType::U32,
));
ast::Instruction::Atom(details, args),
fn_name,
)?);
}
Statement::Instruction(ast::Instruction::Atom(
d
details
@
ast::AtomDetails {
inner:
@ -1404,57 +1435,122 @@ fn extract_globals<'input, 'b>(
},
..
},
a,
args,
)) => {
local.push(to_ptx_impl_atomic_call(
let fn_name = [
ZLUDA_PTX_PREFIX,
"atom_",
details.semantics.to_ptx_name(),
"_",
details.scope.to_ptx_name(),
"_",
details.space.to_ptx_name(),
"_dec",
]
.concat();
local.push(instruction_to_fn_call(
id_def,
ptx_impl_imports,
d,
a,
"dec",
ast::ScalarType::U32,
));
ast::Instruction::Atom(details, args),
fn_name,
)?);
}
Statement::Instruction(ast::Instruction::Atom(
details
@
ast::AtomDetails {
inner:
ast::AtomInnerDetails::Float {
op: ast::AtomFloatOp::Add,
typ,
..
},
semantics,
scope,
space,
..
},
a,
args,
)) => {
let details = ast::AtomDetails {
inner: ast::AtomInnerDetails::Float {
op: ast::AtomFloatOp::Add,
typ,
},
semantics,
scope,
space,
};
let (op, typ) = match typ {
ast::ScalarType::F32 => ("add_f32", ast::ScalarType::F32),
ast::ScalarType::F64 => ("add_f64", ast::ScalarType::F64),
_ => unreachable!(),
};
local.push(to_ptx_impl_atomic_call(
let fn_name = [
ZLUDA_PTX_PREFIX,
"atom_",
details.semantics.to_ptx_name(),
"_",
details.scope.to_ptx_name(),
"_",
details.space.to_ptx_name(),
"_add_",
details.inner.get_type().to_ptx_name(),
]
.concat();
local.push(instruction_to_fn_call(
id_def,
ptx_impl_imports,
details,
a,
op,
typ,
));
ast::Instruction::Atom(details, args),
fn_name,
)?);
}
s => local.push(s),
}
}
(local, global)
Ok((local, global))
}
impl ast::ScalarType {
fn to_ptx_name(self) -> &'static str {
match self {
ast::ScalarType::B8 => "b8",
ast::ScalarType::B16 => "b16",
ast::ScalarType::B32 => "b32",
ast::ScalarType::B64 => "b64",
ast::ScalarType::U8 => "u8",
ast::ScalarType::U16 => "u16",
ast::ScalarType::U32 => "u32",
ast::ScalarType::U64 => "u64",
ast::ScalarType::S8 => "s8",
ast::ScalarType::S16 => "s16",
ast::ScalarType::S32 => "s32",
ast::ScalarType::S64 => "s64",
ast::ScalarType::F16 => "f16",
ast::ScalarType::F32 => "f32",
ast::ScalarType::F64 => "f64",
ast::ScalarType::F16x2 => "f16x2",
ast::ScalarType::Pred => "pred",
}
}
}
impl ast::AtomSemantics {
fn to_ptx_name(self) -> &'static str {
match self {
ast::AtomSemantics::Relaxed => "relaxed",
ast::AtomSemantics::Acquire => "acquire",
ast::AtomSemantics::Release => "release",
ast::AtomSemantics::AcquireRelease => "acq_rel",
}
}
}
impl ast::MemScope {
fn to_ptx_name(self) -> &'static str {
match self {
ast::MemScope::Cta => "cta",
ast::MemScope::Gpu => "gpu",
ast::MemScope::Sys => "sys",
}
}
}
impl ast::StateSpace {
fn to_ptx_name(self) -> &'static str {
match self {
ast::StateSpace::Generic => "generic",
ast::StateSpace::Global => "global",
ast::StateSpace::Shared => "shared",
ast::StateSpace::Reg => "reg",
ast::StateSpace::Const => "const",
ast::StateSpace::Local => "local",
ast::StateSpace::Param => "param",
ast::StateSpace::Sreg => "sreg",
}
}
}
fn normalize_variable_decls(directives: &mut Vec<Directive>) {
@ -1591,142 +1687,58 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, TypedArgParams>
}
}
//TODO: share common code between this and to_ptx_impl_bfe_call
fn to_ptx_impl_atomic_call(
fn instruction_to_fn_call(
id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>,
details: ast::AtomDetails,
arg: ast::Arg3<ExpandedArgParams>,
op: &'static str,
typ: ast::ScalarType,
) -> ExpandedStatement {
let semantics = ptx_semantics_name(details.semantics);
let scope = ptx_scope_name(details.scope);
let space = ptx_space_name(details.space);
let fn_name = format!(
"__zluda_ptx_impl__atom_{}_{}_{}_{}",
semantics, scope, space, op
);
// TODO: extract to a function
let ptr_space = details.space;
let scalar_typ = ast::ScalarType::from(typ);
let fn_id = match ptx_impl_imports.entry(fn_name) {
hash_map::Entry::Vacant(entry) => {
let fn_id = id_defs.register_intermediate(None);
let func_decl = ast::MethodDeclaration::<spirv::Word> {
return_arguments: vec![ast::Variable {
align: None,
v_type: ast::Type::Scalar(scalar_typ),
state_space: ast::StateSpace::Reg,
name: id_defs.register_intermediate(None),
array_init: Vec::new(),
}],
name: ast::MethodName::Func(fn_id),
input_arguments: vec![
ast::Variable {
align: None,
v_type: ast::Type::Pointer(typ, ptr_space),
state_space: ast::StateSpace::Reg,
name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
ast::Variable {
align: None,
v_type: ast::Type::Scalar(scalar_typ),
state_space: ast::StateSpace::Reg,
name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
],
shared_mem: None,
};
let func = Function {
func_decl: Rc::new(RefCell::new(func_decl)),
globals: Vec::new(),
body: None,
import_as: Some(entry.key().clone()),
tuning: Vec::new(),
};
entry.insert(Directive::Method(func));
fn_id
}
hash_map::Entry::Occupied(entry) => match entry.get() {
Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
ast::MethodName::Func(fn_id) => fn_id,
ast::MethodName::Kernel(_) => unreachable!(),
},
_ => unreachable!(),
},
};
Statement::Call(ResolvedCall {
inst: ast::Instruction<ExpandedArgParams>,
fn_name: String,
) -> Result<ExpandedStatement, TranslateError> {
let mut arguments = Vec::new();
inst.visit(&mut |desc: ArgumentDescriptor<spirv::Word>,
typ: Option<(&ast::Type, ast::StateSpace)>| {
let (typ, space) = match typ {
Some((typ, space)) => (typ.clone(), space),
None => return Err(error_unreachable()),
};
arguments.push((desc, typ, space));
Ok(0)
})?;
let return_arguments_count = arguments
.iter()
.position(|(desc, _, _)| !desc.is_dst)
.unwrap_or(0);
let (return_arguments, input_arguments) = arguments.split_at(return_arguments_count);
let fn_id = register_external_fn_call(
id_defs,
ptx_impl_imports,
fn_name,
return_arguments,
input_arguments,
)?;
Ok(Statement::Call(ResolvedCall {
uniform: false,
name: fn_id,
return_arguments: vec![(arg.dst, ast::Type::Scalar(scalar_typ), ast::StateSpace::Reg)],
input_arguments: vec![
(
arg.src1,
ast::Type::Pointer(typ, ptr_space),
ast::StateSpace::Reg,
),
(
arg.src2,
ast::Type::Scalar(scalar_typ),
ast::StateSpace::Reg,
),
],
})
return_arguments: arguments_to_resolved_arguments(return_arguments),
input_arguments: arguments_to_resolved_arguments(input_arguments),
}))
}
fn to_ptx_impl_bfe_call(
fn register_external_fn_call(
id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>,
typ: ast::ScalarType,
arg: ast::Arg4<ExpandedArgParams>,
) -> ExpandedStatement {
let prefix = "__zluda_ptx_impl__";
let suffix = match typ {
ast::ScalarType::U32 => "bfe_u32",
ast::ScalarType::U64 => "bfe_u64",
ast::ScalarType::S32 => "bfe_s32",
ast::ScalarType::S64 => "bfe_s64",
_ => unreachable!(),
};
let fn_name = format!("{}{}", prefix, suffix);
let fn_id = match ptx_impl_imports.entry(fn_name) {
name: String,
return_arguments: &[(ArgumentDescriptor<spirv::Word>, ast::Type, ast::StateSpace)],
input_arguments: &[(ArgumentDescriptor<spirv::Word>, ast::Type, ast::StateSpace)],
) -> Result<spirv::Word, TranslateError> {
match ptx_impl_imports.entry(name) {
hash_map::Entry::Vacant(entry) => {
let fn_id = id_defs.register_intermediate(None);
let return_arguments = fn_arguments_to_variables(id_defs, return_arguments);
let input_arguments = fn_arguments_to_variables(id_defs, input_arguments);
let func_decl = ast::MethodDeclaration::<spirv::Word> {
return_arguments: vec![ast::Variable {
align: None,
v_type: ast::Type::Scalar(typ.into()),
state_space: ast::StateSpace::Reg,
name: id_defs.register_intermediate(None),
array_init: Vec::new(),
}],
return_arguments,
name: ast::MethodName::Func(fn_id),
input_arguments: vec![
ast::Variable {
align: None,
v_type: ast::Type::Scalar(typ.into()),
state_space: ast::StateSpace::Reg,
name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
ast::Variable {
align: None,
v_type: ast::Type::Scalar(ast::ScalarType::U32),
state_space: ast::StateSpace::Reg,
name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
ast::Variable {
align: None,
v_type: ast::Type::Scalar(ast::ScalarType::U32),
state_space: ast::StateSpace::Reg,
name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
],
input_arguments,
shared_mem: None,
};
let func = Function {
@ -1737,142 +1749,39 @@ fn to_ptx_impl_bfe_call(
tuning: Vec::new(),
};
entry.insert(Directive::Method(func));
fn_id
Ok(fn_id)
}
hash_map::Entry::Occupied(entry) => match entry.get() {
Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
ast::MethodName::Func(fn_id) => fn_id,
ast::MethodName::Kernel(_) => unreachable!(),
ast::MethodName::Func(fn_id) => Ok(fn_id),
ast::MethodName::Kernel(_) => Err(error_unreachable()),
},
_ => unreachable!(),
_ => Err(error_unreachable()),
},
};
Statement::Call(ResolvedCall {
uniform: false,
name: fn_id,
return_arguments: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)],
input_arguments: vec![
(
arg.src1,
ast::Type::Scalar(typ.into()),
ast::StateSpace::Reg,
),
(
arg.src2,
ast::Type::Scalar(ast::ScalarType::U32),
ast::StateSpace::Reg,
),
(
arg.src3,
ast::Type::Scalar(ast::ScalarType::U32),
ast::StateSpace::Reg,
),
],
})
}
}
fn to_ptx_impl_bfi_call(
fn fn_arguments_to_variables(
id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>,
typ: ast::ScalarType,
arg: ast::Arg5<ExpandedArgParams>,
) -> ExpandedStatement {
let prefix = "__zluda_ptx_impl__";
let suffix = match typ {
ast::ScalarType::B32 => "bfi_b32",
ast::ScalarType::B64 => "bfi_b64",
_ => unreachable!(),
};
let fn_name = format!("{}{}", prefix, suffix);
let fn_id = match ptx_impl_imports.entry(fn_name) {
hash_map::Entry::Vacant(entry) => {
let fn_id = id_defs.register_intermediate(None);
let func_decl = ast::MethodDeclaration::<spirv::Word> {
return_arguments: vec![ast::Variable {
align: None,
v_type: ast::Type::Scalar(typ.into()),
state_space: ast::StateSpace::Reg,
name: id_defs.register_intermediate(None),
array_init: Vec::new(),
}],
name: ast::MethodName::Func(fn_id),
input_arguments: vec![
ast::Variable {
align: None,
v_type: ast::Type::Scalar(typ.into()),
state_space: ast::StateSpace::Reg,
name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
ast::Variable {
align: None,
v_type: ast::Type::Scalar(typ.into()),
state_space: ast::StateSpace::Reg,
name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
ast::Variable {
align: None,
v_type: ast::Type::Scalar(ast::ScalarType::U32),
state_space: ast::StateSpace::Reg,
name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
ast::Variable {
align: None,
v_type: ast::Type::Scalar(ast::ScalarType::U32),
state_space: ast::StateSpace::Reg,
name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
],
shared_mem: None,
};
let func = Function {
func_decl: Rc::new(RefCell::new(func_decl)),
globals: Vec::new(),
body: None,
import_as: Some(entry.key().clone()),
tuning: Vec::new(),
};
entry.insert(Directive::Method(func));
fn_id
}
hash_map::Entry::Occupied(entry) => match entry.get() {
Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
ast::MethodName::Func(fn_id) => fn_id,
ast::MethodName::Kernel(_) => unreachable!(),
},
_ => unreachable!(),
},
};
Statement::Call(ResolvedCall {
uniform: false,
name: fn_id,
return_arguments: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)],
input_arguments: vec![
(
arg.src1,
ast::Type::Scalar(typ.into()),
ast::StateSpace::Reg,
),
(
arg.src2,
ast::Type::Scalar(typ.into()),
ast::StateSpace::Reg,
),
(
arg.src3,
ast::Type::Scalar(ast::ScalarType::U32),
ast::StateSpace::Reg,
),
(
arg.src4,
ast::Type::Scalar(ast::ScalarType::U32),
ast::StateSpace::Reg,
),
],
})
args: &[(ArgumentDescriptor<spirv::Word>, ast::Type, ast::StateSpace)],
) -> Vec<ast::Variable<spirv::Word>> {
args.iter()
.map(|(_, typ, space)| ast::Variable {
align: None,
v_type: typ.clone(),
state_space: *space,
name: id_defs.register_intermediate(None),
array_init: Vec::new(),
})
.collect::<Vec<_>>()
}
fn arguments_to_resolved_arguments(
args: &[(ArgumentDescriptor<spirv::Word>, ast::Type, ast::StateSpace)],
) -> Vec<(spirv::Word, ast::Type, ast::StateSpace)> {
args.iter()
.map(|(desc, typ, space)| (desc.op, typ.clone(), *space))
.collect::<Vec<_>>()
}
fn normalize_labels(
@ -3305,36 +3214,6 @@ struct PtxImplImport {
in_args: Vec<ast::Type>,
}
fn ptx_semantics_name(sema: ast::AtomSemantics) -> &'static str {
match sema {
ast::AtomSemantics::Relaxed => "relaxed",
ast::AtomSemantics::Acquire => "acquire",
ast::AtomSemantics::Release => "release",
ast::AtomSemantics::AcquireRelease => "acq_rel",
}
}
fn ptx_scope_name(scope: ast::MemScope) -> &'static str {
match scope {
ast::MemScope::Cta => "cta",
ast::MemScope::Gpu => "gpu",
ast::MemScope::Sys => "sys",
}
}
fn ptx_space_name(space: ast::StateSpace) -> &'static str {
match space {
ast::StateSpace::Generic => "generic",
ast::StateSpace::Global => "global",
ast::StateSpace::Shared => "shared",
ast::StateSpace::Reg => "reg",
ast::StateSpace::Const => "const",
ast::StateSpace::Local => "local",
ast::StateSpace::Param => "param",
ast::StateSpace::Sreg => "sreg",
}
}
fn emit_mul_float(
builder: &mut dr::Builder,
map: &mut TypeWordMap,