diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 79c070b..db4e3e0 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -392,7 +392,7 @@ impl<'a, 'input> LinkingResolver<'a, 'input> { linking, Cow::Borrowed(decl.name()), symbol, - decl.name.is_kernel(), + decl.name.is_kernel() && is_definition, ) } @@ -591,10 +591,21 @@ impl<'input> ResolvedLinking<'input> { explicit_initializer: bool, ) -> Result { if linking == ast::LinkingDirective::None { - if self.implicit_globals.get(&name).copied() == Some((module, directive)) { - Ok(VisibilityAdjustment::Global) - } else { - Ok(VisibilityAdjustment::Module) + match self.implicit_globals.get(&name).copied() { + Some((implicit_module, implicit_directive)) => { + if implicit_module == module { + if implicit_directive == directive { + Ok(VisibilityAdjustment::Global) + } else { + // If it were something other than a declaration it would + // fail module-level symbol resolution + Ok(VisibilityAdjustment::GlobalDeclaration(None)) + } + } else { + Ok(VisibilityAdjustment::Module) + } + } + None => Ok(VisibilityAdjustment::Module), } } else { if let Some((global_module, global_directive, type_)) = self.explicit_globals.get(&name) diff --git a/zluda/tests/linking.rs b/zluda/tests/linking.rs index 6cd861e..57ada55 100644 --- a/zluda/tests/linking.rs +++ b/zluda/tests/linking.rs @@ -229,16 +229,6 @@ impl Directive { Directive::Shared => unimplemented!(), } } - - fn assert_exact(self) -> bool { - match self { - Directive::Kernel => false, - Directive::Method => true, - Directive::Global => false, - Directive::Const => false, - Directive::Shared => unimplemented!(), - } - } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] @@ -370,45 +360,6 @@ fn create_kernel(linking: Linking, directive: Directive, defined: bool) -> Strin kernel } -fn assert_compatible( - results: Vec<(Linking, Directive, bool, i32, Option)>, - expected: [(Linking, Directive, bool, i32, Option); 50], -) { - if results.len() != expected.len() { - panic!(); - } - let mut broken = Vec::new(); - for (result, expected) in results.into_iter().zip(IntoIterator::into_iter(expected)) { - let (linking, directive, defined, build_result, load_result) = result; - let (_, _, _, expected_build, expected_load) = expected; - if expected_build == 0 { - if build_result != 0 { - broken.push(( - linking, - directive, - defined, - (build_result, load_result), - (expected_build, expected_load), - )); - continue; - } - if expected_load == Some(0) { - if load_result != Some(0) { - broken.push(( - linking, - directive, - defined, - (build_result, load_result), - (expected_build, expected_load), - )); - continue; - } - } - } - } - assert_eq!(broken, []); -} - fn assert_compatible_compile( compiled: &[T], compiled_expected: &[T], @@ -1109,22 +1060,16 @@ unsafe fn emit_weak_fn(cuda: T) { } -cuda_driver_test!(weak_func_address); +cuda_driver_test!(static_entry_decl); -unsafe fn weak_func_address(cuda: T) { +unsafe fn static_entry_decl(cuda: T) { let input1 = " .version 6.5 - .target sm_50 + .target sm_35 .address_size 64 - .weak .func foobar(.reg .b32 input); - - .weak .global .align 8 .u64 fn_ptrs[2] = {0, foobar}; - - .weak .func foobar(.reg .b32 input) - { - ret; - }\0" + .entry foobar(); + .entry foobar() { ret; }\0" .to_string(); assert_eq!(cuda.cuInit(0), CUresult::CUDA_SUCCESS); let mut ctx = ptr::null_mut();