diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 2e8c5fc..5271157 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -486,7 +486,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { ast::Instruction::Cos { data, arguments } => todo!(), ast::Instruction::Lg2 { data, arguments } => todo!(), ast::Instruction::Ex2 { data, arguments } => todo!(), - ast::Instruction::Clz { data, arguments } => todo!(), + ast::Instruction::Clz { data, arguments } => self.emit_clz(data, arguments), ast::Instruction::Brev { data, arguments } => self.emit_brev(data, arguments), ast::Instruction::Popc { data, arguments } => todo!(), ast::Instruction::Xor { data, arguments } => todo!(), @@ -828,6 +828,43 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { }; Ok(()) } + + fn emit_clz( + &mut self, + data: ptx_parser::ScalarType, + arguments: ptx_parser::ClzArgs, + ) -> Result<(), TranslateError> { + let llvm_fn = match data.size_of() { + 4 => c"llvm.ctlz.i32", + 8 => c"llvm.ctlz.i64", + _ => return Err(error_unreachable()), + }; + let type_ = get_scalar_type(self.context, data.into()); + let pred = get_scalar_type(self.context, ast::ScalarType::Pred); + let fn_type = get_function_type( + self.context, + iter::once(&ast::ScalarType::U32.into()), + [Ok(type_), Ok(pred)].into_iter(), + )?; + let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, llvm_fn.as_ptr()) }; + if fn_ == ptr::null_mut() { + fn_ = unsafe { LLVMAddFunction(self.module, llvm_fn.as_ptr(), fn_type) }; + } + let src = self.resolver.value(arguments.src)?; + let false_ = unsafe { LLVMConstInt(pred, 0, 0) }; + let mut args = [src, false_]; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildCall2( + self.builder, + fn_type, + fn_, + args.as_mut_ptr(), + args.len() as u32, + dst, + ) + }); + Ok(()) + } } fn get_pointer_type<'ctx>(