diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index b816481..da46329 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -65,17 +65,21 @@ impl Drop for Context { } } -pub struct Module(LLVMModuleRef); +pub struct Module(LLVMModuleRef, Context); impl Module { - fn new(ctx: &Context, name: &CStr) -> Self { - Self(unsafe { LLVMModuleCreateWithNameInContext(name.as_ptr(), ctx.get()) }) + fn new(ctx: Context, name: &CStr) -> Self { + Self(unsafe { LLVMModuleCreateWithNameInContext(name.as_ptr(), ctx.get()) }, ctx) } fn get(&self) -> LLVMModuleRef { self.0 } + fn context(&self) -> &Context { + &self.1 + } + fn verify(&self) -> Result<(), Message> { let mut err = ptr::null_mut(); let error = unsafe { @@ -180,10 +184,9 @@ impl Deref for MemoryBuffer { pub(super) fn run<'input>( id_defs: GlobalStringIdentResolver2<'input>, directives: Vec, SpirvWord>>, -) -> Result<(Module, Context), TranslateError> { - let context = Context::new(); - let module = Module::new(&context, LLVM_UNNAMED); - let mut emit_ctx = ModuleEmitContext::new(&context, &module, &id_defs); +) -> Result { + let module = Module::new(Context::new(), LLVM_UNNAMED); + let mut emit_ctx = ModuleEmitContext::new(&module, &id_defs); for directive in directives { match directive { Directive2::Variable(linking, variable) => emit_ctx.emit_global(linking, variable)?, @@ -193,7 +196,7 @@ pub(super) fn run<'input>( if let Err(err) = module.verify() { panic!("{:?}", err); } - Ok((module, context)) + Ok(module) } struct ModuleEmitContext<'a, 'input> { @@ -206,10 +209,10 @@ struct ModuleEmitContext<'a, 'input> { impl<'a, 'input> ModuleEmitContext<'a, 'input> { fn new( - context: &Context, module: &Module, id_defs: &'a GlobalStringIdentResolver2<'input>, ) -> Self { + let context= module.context(); ModuleEmitContext { context: context.get(), module: module.get(), diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 40df188..f11a381 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -53,17 +53,15 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result, }