diff --git a/ptx/lib/zluda_ptx_impl.cl b/ptx/lib/zluda_ptx_impl.cl index 85958d5..a6a5a37 100644 --- a/ptx/lib/zluda_ptx_impl.cl +++ b/ptx/lib/zluda_ptx_impl.cl @@ -253,6 +253,14 @@ ulong FUNC(bfi_b64)(ulong insert, ulong base, uint offset, uint count) { return intel_bfi(base, insert, offset, count); } +uint FUNC(brev_b32)(uint base) { + return intel_bfrev(base); +} + +ulong FUNC(brev_b64)(ulong base) { + return intel_bfrev(base); +} + void FUNC(__assertfail)( __private ulong* message, __private ulong* file, diff --git a/ptx/src/test/spirv_run/brev.spvtxt b/ptx/src/test/spirv_run/brev.spvtxt index 68faeca..7341adb 100644 --- a/ptx/src/test/spirv_run/brev.spvtxt +++ b/ptx/src/test/spirv_run/brev.spvtxt @@ -7,17 +7,22 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %21 = OpExtInstImport "OpenCL.std" + %24 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "brev" + OpDecorate %20 LinkageAttributes "__zluda_ptx_impl__brev_b32" Import %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %24 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong %uint = OpTypeInt 32 0 + %27 = OpTypeFunction %uint %uint + %ulong = OpTypeInt 64 0 + %29 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint - %1 = OpFunction %void None %24 + %20 = OpFunction %uint None %27 + %22 = OpFunctionParameter %uint + OpFunctionEnd + %1 = OpFunction %void None %29 %7 = OpFunctionParameter %ulong %8 = OpFunctionParameter %ulong %19 = OpLabel @@ -37,7 +42,7 @@ %11 = OpLoad %uint %17 Aligned 4 OpStore %6 %11 %14 = OpLoad %uint %6 - %13 = OpBitReverse %uint %14 + %13 = OpFunctionCall %uint %20 %14 OpStore %6 %13 %15 = OpLoad %ulong %5 %16 = OpLoad %uint %6 diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 7cefdd6..16a4cfb 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -8,6 +8,7 @@ use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, me use rspirv::binary::Assemble; static ZLUDA_PTX_IMPL: &'static [u8] = include_bytes!("../lib/zluda_ptx_impl.spv"); +static ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl__"; quick_error! { #[derive(Debug)] @@ -1248,7 +1249,7 @@ fn to_ssa<'input, 'b>( let mut numeric_id_defs = numeric_id_defs.unmut(); let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs); let (f_body, globals) = - extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs); + extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?; Ok(Function { func_decl: func_decl, globals: globals, @@ -1344,7 +1345,7 @@ fn extract_globals<'input, 'b>( sorted_statements: Vec, ptx_impl_imports: &mut HashMap, id_def: &mut NumericIdResolver, -) -> (Vec, Vec>) { +) -> Result<(Vec, Vec>), TranslateError> { let mut local = Vec::with_capacity(sorted_statements.len()); let mut global = Vec::new(); for statement in sorted_statements { @@ -1366,13 +1367,34 @@ fn extract_globals<'input, 'b>( }, ) => global.push(var), Statement::Instruction(ast::Instruction::Bfe { typ, arg }) => { - local.push(to_ptx_impl_bfe_call(id_def, ptx_impl_imports, typ, arg)); + let fn_name = [ZLUDA_PTX_PREFIX, "bfe_", typ.to_ptx_name()].concat(); + local.push(instruction_to_fn_call( + id_def, + ptx_impl_imports, + ast::Instruction::Bfe { typ, arg }, + fn_name, + )?); } Statement::Instruction(ast::Instruction::Bfi { typ, arg }) => { - local.push(to_ptx_impl_bfi_call(id_def, ptx_impl_imports, typ, arg)); + let fn_name = [ZLUDA_PTX_PREFIX, "bfi_", typ.to_ptx_name()].concat(); + local.push(instruction_to_fn_call( + id_def, + ptx_impl_imports, + ast::Instruction::Bfi { typ, arg }, + fn_name, + )?); + } + Statement::Instruction(ast::Instruction::Brev { typ, arg }) => { + let fn_name = [ZLUDA_PTX_PREFIX, "brev_", typ.to_ptx_name()].concat(); + local.push(instruction_to_fn_call( + id_def, + ptx_impl_imports, + ast::Instruction::Brev { typ, arg }, + fn_name, + )?); } Statement::Instruction(ast::Instruction::Atom( - d + details @ ast::AtomDetails { inner: @@ -1382,19 +1404,28 @@ fn extract_globals<'input, 'b>( }, .. }, - a, + args, )) => { - local.push(to_ptx_impl_atomic_call( + let fn_name = [ + ZLUDA_PTX_PREFIX, + "atom_", + details.semantics.to_ptx_name(), + "_", + details.scope.to_ptx_name(), + "_", + details.space.to_ptx_name(), + "_inc", + ] + .concat(); + local.push(instruction_to_fn_call( id_def, ptx_impl_imports, - d, - a, - "inc", - ast::ScalarType::U32, - )); + ast::Instruction::Atom(details, args), + fn_name, + )?); } Statement::Instruction(ast::Instruction::Atom( - d + details @ ast::AtomDetails { inner: @@ -1404,57 +1435,122 @@ fn extract_globals<'input, 'b>( }, .. }, - a, + args, )) => { - local.push(to_ptx_impl_atomic_call( + let fn_name = [ + ZLUDA_PTX_PREFIX, + "atom_", + details.semantics.to_ptx_name(), + "_", + details.scope.to_ptx_name(), + "_", + details.space.to_ptx_name(), + "_dec", + ] + .concat(); + local.push(instruction_to_fn_call( id_def, ptx_impl_imports, - d, - a, - "dec", - ast::ScalarType::U32, - )); + ast::Instruction::Atom(details, args), + fn_name, + )?); } Statement::Instruction(ast::Instruction::Atom( + details + @ ast::AtomDetails { inner: ast::AtomInnerDetails::Float { op: ast::AtomFloatOp::Add, - typ, + .. }, - semantics, - scope, - space, + .. }, - a, + args, )) => { - let details = ast::AtomDetails { - inner: ast::AtomInnerDetails::Float { - op: ast::AtomFloatOp::Add, - typ, - }, - semantics, - scope, - space, - }; - let (op, typ) = match typ { - ast::ScalarType::F32 => ("add_f32", ast::ScalarType::F32), - ast::ScalarType::F64 => ("add_f64", ast::ScalarType::F64), - _ => unreachable!(), - }; - local.push(to_ptx_impl_atomic_call( + let fn_name = [ + ZLUDA_PTX_PREFIX, + "atom_", + details.semantics.to_ptx_name(), + "_", + details.scope.to_ptx_name(), + "_", + details.space.to_ptx_name(), + "_add_", + details.inner.get_type().to_ptx_name(), + ] + .concat(); + local.push(instruction_to_fn_call( id_def, ptx_impl_imports, - details, - a, - op, - typ, - )); + ast::Instruction::Atom(details, args), + fn_name, + )?); } s => local.push(s), } } - (local, global) + Ok((local, global)) +} + +impl ast::ScalarType { + fn to_ptx_name(self) -> &'static str { + match self { + ast::ScalarType::B8 => "b8", + ast::ScalarType::B16 => "b16", + ast::ScalarType::B32 => "b32", + ast::ScalarType::B64 => "b64", + ast::ScalarType::U8 => "u8", + ast::ScalarType::U16 => "u16", + ast::ScalarType::U32 => "u32", + ast::ScalarType::U64 => "u64", + ast::ScalarType::S8 => "s8", + ast::ScalarType::S16 => "s16", + ast::ScalarType::S32 => "s32", + ast::ScalarType::S64 => "s64", + ast::ScalarType::F16 => "f16", + ast::ScalarType::F32 => "f32", + ast::ScalarType::F64 => "f64", + ast::ScalarType::F16x2 => "f16x2", + ast::ScalarType::Pred => "pred", + } + } +} + +impl ast::AtomSemantics { + fn to_ptx_name(self) -> &'static str { + match self { + ast::AtomSemantics::Relaxed => "relaxed", + ast::AtomSemantics::Acquire => "acquire", + ast::AtomSemantics::Release => "release", + ast::AtomSemantics::AcquireRelease => "acq_rel", + } + } +} + +impl ast::MemScope { + fn to_ptx_name(self) -> &'static str { + match self { + ast::MemScope::Cta => "cta", + ast::MemScope::Gpu => "gpu", + ast::MemScope::Sys => "sys", + } + } +} + +impl ast::StateSpace { + fn to_ptx_name(self) -> &'static str { + match self { + ast::StateSpace::Generic => "generic", + ast::StateSpace::Global => "global", + ast::StateSpace::Shared => "shared", + ast::StateSpace::Reg => "reg", + ast::StateSpace::Const => "const", + ast::StateSpace::Local => "local", + ast::StateSpace::Param => "param", + ast::StateSpace::Sreg => "sreg", + } + } } fn normalize_variable_decls(directives: &mut Vec) { @@ -1591,142 +1687,58 @@ impl<'a, 'b> ArgumentMapVisitor } } -//TODO: share common code between this and to_ptx_impl_bfe_call -fn to_ptx_impl_atomic_call( +fn instruction_to_fn_call( id_defs: &mut NumericIdResolver, ptx_impl_imports: &mut HashMap, - details: ast::AtomDetails, - arg: ast::Arg3, - op: &'static str, - typ: ast::ScalarType, -) -> ExpandedStatement { - let semantics = ptx_semantics_name(details.semantics); - let scope = ptx_scope_name(details.scope); - let space = ptx_space_name(details.space); - let fn_name = format!( - "__zluda_ptx_impl__atom_{}_{}_{}_{}", - semantics, scope, space, op - ); - // TODO: extract to a function - let ptr_space = details.space; - let scalar_typ = ast::ScalarType::from(typ); - let fn_id = match ptx_impl_imports.entry(fn_name) { - hash_map::Entry::Vacant(entry) => { - let fn_id = id_defs.register_intermediate(None); - let func_decl = ast::MethodDeclaration:: { - return_arguments: vec![ast::Variable { - align: None, - v_type: ast::Type::Scalar(scalar_typ), - state_space: ast::StateSpace::Reg, - name: id_defs.register_intermediate(None), - array_init: Vec::new(), - }], - name: ast::MethodName::Func(fn_id), - input_arguments: vec![ - ast::Variable { - align: None, - v_type: ast::Type::Pointer(typ, ptr_space), - state_space: ast::StateSpace::Reg, - name: id_defs.register_intermediate(None), - array_init: Vec::new(), - }, - ast::Variable { - align: None, - v_type: ast::Type::Scalar(scalar_typ), - state_space: ast::StateSpace::Reg, - name: id_defs.register_intermediate(None), - array_init: Vec::new(), - }, - ], - shared_mem: None, - }; - let func = Function { - func_decl: Rc::new(RefCell::new(func_decl)), - globals: Vec::new(), - body: None, - import_as: Some(entry.key().clone()), - tuning: Vec::new(), - }; - entry.insert(Directive::Method(func)); - fn_id - } - hash_map::Entry::Occupied(entry) => match entry.get() { - Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name { - ast::MethodName::Func(fn_id) => fn_id, - ast::MethodName::Kernel(_) => unreachable!(), - }, - _ => unreachable!(), - }, - }; - Statement::Call(ResolvedCall { + inst: ast::Instruction, + fn_name: String, +) -> Result { + let mut arguments = Vec::new(); + inst.visit(&mut |desc: ArgumentDescriptor, + typ: Option<(&ast::Type, ast::StateSpace)>| { + let (typ, space) = match typ { + Some((typ, space)) => (typ.clone(), space), + None => return Err(error_unreachable()), + }; + arguments.push((desc, typ, space)); + Ok(0) + })?; + let return_arguments_count = arguments + .iter() + .position(|(desc, _, _)| !desc.is_dst) + .unwrap_or(0); + let (return_arguments, input_arguments) = arguments.split_at(return_arguments_count); + let fn_id = register_external_fn_call( + id_defs, + ptx_impl_imports, + fn_name, + return_arguments, + input_arguments, + )?; + Ok(Statement::Call(ResolvedCall { uniform: false, name: fn_id, - return_arguments: vec![(arg.dst, ast::Type::Scalar(scalar_typ), ast::StateSpace::Reg)], - input_arguments: vec![ - ( - arg.src1, - ast::Type::Pointer(typ, ptr_space), - ast::StateSpace::Reg, - ), - ( - arg.src2, - ast::Type::Scalar(scalar_typ), - ast::StateSpace::Reg, - ), - ], - }) + return_arguments: arguments_to_resolved_arguments(return_arguments), + input_arguments: arguments_to_resolved_arguments(input_arguments), + })) } -fn to_ptx_impl_bfe_call( +fn register_external_fn_call( id_defs: &mut NumericIdResolver, ptx_impl_imports: &mut HashMap, - typ: ast::ScalarType, - arg: ast::Arg4, -) -> ExpandedStatement { - let prefix = "__zluda_ptx_impl__"; - let suffix = match typ { - ast::ScalarType::U32 => "bfe_u32", - ast::ScalarType::U64 => "bfe_u64", - ast::ScalarType::S32 => "bfe_s32", - ast::ScalarType::S64 => "bfe_s64", - _ => unreachable!(), - }; - let fn_name = format!("{}{}", prefix, suffix); - let fn_id = match ptx_impl_imports.entry(fn_name) { + name: String, + return_arguments: &[(ArgumentDescriptor, ast::Type, ast::StateSpace)], + input_arguments: &[(ArgumentDescriptor, ast::Type, ast::StateSpace)], +) -> Result { + match ptx_impl_imports.entry(name) { hash_map::Entry::Vacant(entry) => { let fn_id = id_defs.register_intermediate(None); + let return_arguments = fn_arguments_to_variables(id_defs, return_arguments); + let input_arguments = fn_arguments_to_variables(id_defs, input_arguments); let func_decl = ast::MethodDeclaration:: { - return_arguments: vec![ast::Variable { - align: None, - v_type: ast::Type::Scalar(typ.into()), - state_space: ast::StateSpace::Reg, - name: id_defs.register_intermediate(None), - array_init: Vec::new(), - }], + return_arguments, name: ast::MethodName::Func(fn_id), - input_arguments: vec![ - ast::Variable { - align: None, - v_type: ast::Type::Scalar(typ.into()), - state_space: ast::StateSpace::Reg, - name: id_defs.register_intermediate(None), - array_init: Vec::new(), - }, - ast::Variable { - align: None, - v_type: ast::Type::Scalar(ast::ScalarType::U32), - state_space: ast::StateSpace::Reg, - name: id_defs.register_intermediate(None), - array_init: Vec::new(), - }, - ast::Variable { - align: None, - v_type: ast::Type::Scalar(ast::ScalarType::U32), - state_space: ast::StateSpace::Reg, - name: id_defs.register_intermediate(None), - array_init: Vec::new(), - }, - ], + input_arguments, shared_mem: None, }; let func = Function { @@ -1737,142 +1749,39 @@ fn to_ptx_impl_bfe_call( tuning: Vec::new(), }; entry.insert(Directive::Method(func)); - fn_id + Ok(fn_id) } hash_map::Entry::Occupied(entry) => match entry.get() { Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name { - ast::MethodName::Func(fn_id) => fn_id, - ast::MethodName::Kernel(_) => unreachable!(), + ast::MethodName::Func(fn_id) => Ok(fn_id), + ast::MethodName::Kernel(_) => Err(error_unreachable()), }, - _ => unreachable!(), + _ => Err(error_unreachable()), }, - }; - Statement::Call(ResolvedCall { - uniform: false, - name: fn_id, - return_arguments: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)], - input_arguments: vec![ - ( - arg.src1, - ast::Type::Scalar(typ.into()), - ast::StateSpace::Reg, - ), - ( - arg.src2, - ast::Type::Scalar(ast::ScalarType::U32), - ast::StateSpace::Reg, - ), - ( - arg.src3, - ast::Type::Scalar(ast::ScalarType::U32), - ast::StateSpace::Reg, - ), - ], - }) + } } -fn to_ptx_impl_bfi_call( +fn fn_arguments_to_variables( id_defs: &mut NumericIdResolver, - ptx_impl_imports: &mut HashMap, - typ: ast::ScalarType, - arg: ast::Arg5, -) -> ExpandedStatement { - let prefix = "__zluda_ptx_impl__"; - let suffix = match typ { - ast::ScalarType::B32 => "bfi_b32", - ast::ScalarType::B64 => "bfi_b64", - _ => unreachable!(), - }; - let fn_name = format!("{}{}", prefix, suffix); - let fn_id = match ptx_impl_imports.entry(fn_name) { - hash_map::Entry::Vacant(entry) => { - let fn_id = id_defs.register_intermediate(None); - let func_decl = ast::MethodDeclaration:: { - return_arguments: vec![ast::Variable { - align: None, - v_type: ast::Type::Scalar(typ.into()), - state_space: ast::StateSpace::Reg, - name: id_defs.register_intermediate(None), - array_init: Vec::new(), - }], - name: ast::MethodName::Func(fn_id), - input_arguments: vec![ - ast::Variable { - align: None, - v_type: ast::Type::Scalar(typ.into()), - state_space: ast::StateSpace::Reg, - name: id_defs.register_intermediate(None), - array_init: Vec::new(), - }, - ast::Variable { - align: None, - v_type: ast::Type::Scalar(typ.into()), - state_space: ast::StateSpace::Reg, - name: id_defs.register_intermediate(None), - array_init: Vec::new(), - }, - ast::Variable { - align: None, - v_type: ast::Type::Scalar(ast::ScalarType::U32), - state_space: ast::StateSpace::Reg, - name: id_defs.register_intermediate(None), - array_init: Vec::new(), - }, - ast::Variable { - align: None, - v_type: ast::Type::Scalar(ast::ScalarType::U32), - state_space: ast::StateSpace::Reg, - name: id_defs.register_intermediate(None), - array_init: Vec::new(), - }, - ], - shared_mem: None, - }; - let func = Function { - func_decl: Rc::new(RefCell::new(func_decl)), - globals: Vec::new(), - body: None, - import_as: Some(entry.key().clone()), - tuning: Vec::new(), - }; - entry.insert(Directive::Method(func)); - fn_id - } - hash_map::Entry::Occupied(entry) => match entry.get() { - Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name { - ast::MethodName::Func(fn_id) => fn_id, - ast::MethodName::Kernel(_) => unreachable!(), - }, - _ => unreachable!(), - }, - }; - Statement::Call(ResolvedCall { - uniform: false, - name: fn_id, - return_arguments: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)], - input_arguments: vec![ - ( - arg.src1, - ast::Type::Scalar(typ.into()), - ast::StateSpace::Reg, - ), - ( - arg.src2, - ast::Type::Scalar(typ.into()), - ast::StateSpace::Reg, - ), - ( - arg.src3, - ast::Type::Scalar(ast::ScalarType::U32), - ast::StateSpace::Reg, - ), - ( - arg.src4, - ast::Type::Scalar(ast::ScalarType::U32), - ast::StateSpace::Reg, - ), - ], - }) + args: &[(ArgumentDescriptor, ast::Type, ast::StateSpace)], +) -> Vec> { + args.iter() + .map(|(_, typ, space)| ast::Variable { + align: None, + v_type: typ.clone(), + state_space: *space, + name: id_defs.register_intermediate(None), + array_init: Vec::new(), + }) + .collect::>() +} + +fn arguments_to_resolved_arguments( + args: &[(ArgumentDescriptor, ast::Type, ast::StateSpace)], +) -> Vec<(spirv::Word, ast::Type, ast::StateSpace)> { + args.iter() + .map(|(desc, typ, space)| (desc.op, typ.clone(), *space)) + .collect::>() } fn normalize_labels( @@ -3305,36 +3214,6 @@ struct PtxImplImport { in_args: Vec, } -fn ptx_semantics_name(sema: ast::AtomSemantics) -> &'static str { - match sema { - ast::AtomSemantics::Relaxed => "relaxed", - ast::AtomSemantics::Acquire => "acquire", - ast::AtomSemantics::Release => "release", - ast::AtomSemantics::AcquireRelease => "acq_rel", - } -} - -fn ptx_scope_name(scope: ast::MemScope) -> &'static str { - match scope { - ast::MemScope::Cta => "cta", - ast::MemScope::Gpu => "gpu", - ast::MemScope::Sys => "sys", - } -} - -fn ptx_space_name(space: ast::StateSpace) -> &'static str { - match space { - ast::StateSpace::Generic => "generic", - ast::StateSpace::Global => "global", - ast::StateSpace::Shared => "shared", - ast::StateSpace::Reg => "reg", - ast::StateSpace::Const => "const", - ast::StateSpace::Local => "local", - ast::StateSpace::Param => "param", - ast::StateSpace::Sreg => "sreg", - } -} - fn emit_mul_float( builder: &mut dr::Builder, map: &mut TypeWordMap,