diff --git a/Cargo.toml b/Cargo.toml index 55ca9f4..89ac780 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ tokio = { version = "1.39.2", features = ["full"] } tracing = "0.1.40" tracing-subscriber = "0.3.18" libloading = "0.8.5" +once_cell = "1.19.0" [workspace] -members = ["plugin-lib"] \ No newline at end of file +members = ["plugins/testing"] diff --git a/plugin-lib/Cargo.toml b/plugin-lib/Cargo.toml deleted file mode 100644 index 04e4311..0000000 --- a/plugin-lib/Cargo.toml +++ /dev/null @@ -1,12 +0,0 @@ -[package] -name = "plugin-lib" -version = "0.1.0" -edition = "2021" - -[lib] -crate-type = ["cdylib"] - -[dependencies] -#tokio = { version = "1.39.2", features = ["full"] } -#tracing = "0.1.40" -gpt-sandbox = { path = "../" } \ No newline at end of file diff --git a/plugin-lib/rustfmt.toml b/plugin-lib/rustfmt.toml deleted file mode 100644 index 7c224aa..0000000 --- a/plugin-lib/rustfmt.toml +++ /dev/null @@ -1 +0,0 @@ -hard_tabs=true diff --git a/plugin-lib/src/lib.rs b/plugin-lib/src/lib.rs deleted file mode 100644 index 37dd696..0000000 --- a/plugin-lib/src/lib.rs +++ /dev/null @@ -1,79 +0,0 @@ -use gpt_sandbox::tokio; -use gpt_sandbox::tokio::runtime::Runtime; -use gpt_sandbox::tracing; - -use gpt_sandbox::tracing::Dispatch; -use gpt_sandbox::AsyncFn; -use gpt_sandbox::AsyncFnReturnType; -use gpt_sandbox::PluginError; -use std::collections::HashMap; -use std::sync::{Arc, Mutex}; -use tracing::info; - -// TODO fix static mut black magic, it can and will set fire eventually -static mut FUNCTION_MAP: Option>>> = None; -static REACTOR: ::std::sync::OnceLock> = ::std::sync::OnceLock::new(); - -#[no_mangle] -pub extern "C" fn register_reactor(tokio_reactor: Arc) { - REACTOR.set(tokio_reactor); -} - -#[no_mangle] -pub extern "C" fn register_dispatcher(dispatcher: Dispatch) { - //tracing::subscriber::set_global_default(dispatcher).unwrap(); - tracing::dispatcher::set_global_default(dispatcher).unwrap(); - info!("test"); -} - -#[no_mangle] -pub extern "C" fn register_functions(function_map: Arc>>) { - info!("registering functions..."); - unsafe { - FUNCTION_MAP = Some(Arc::clone(&function_map)); - } - - let example_function: AsyncFn = Box::new(example_function); - - info!("writing functions..."); - - function_map - .lock() - .unwrap() - .insert("example_function".to_string(), example_function); -} - -#[no_mangle] -pub extern "C" fn unregister_functions() { - unsafe { - if let Some(function_map) = &FUNCTION_MAP { - function_map.lock().unwrap().remove("example_function"); - } - } -} - -//pub fn example_function() -> AsyncFnReturnType { -// REACTOR.get().unwrap().spawn(async { -// tracing::info!("example_function is running"); -// // Simulate some async work -// tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; -// tracing::info!("example_function completed"); -// // Return an error for demonstration -// Err(PluginError::FunctionError( -// "Something went wrong".to_string(), -// )) -// }) -//} - -pub async fn example_function() -> AsyncFnReturnType { - REACTOR.get().unwrap().spawn(async { - tracing::info!("example_function is running"); - // Simulate some async work - tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; - tracing::info!("example_function completed"); - // Return an error for demonstration - Err(PluginError::FunctionError( - "Something went wrong".to_string(), - )) - }) -} diff --git a/plugins/testing/Cargo.toml b/plugins/testing/Cargo.toml new file mode 100644 index 0000000..6648dcc --- /dev/null +++ b/plugins/testing/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "testing" +version = "0.1.0" +edition = "2021" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +tokio = "1.39.3" +tracing = "0.1.40" +gpt-sandbox = { path = "../../" } +async-trait = "0.1.81" +once_cell = "1.19.0" diff --git a/plugins/testing/src/lib.rs b/plugins/testing/src/lib.rs new file mode 100644 index 0000000..719c493 --- /dev/null +++ b/plugins/testing/src/lib.rs @@ -0,0 +1,34 @@ +use async_trait::async_trait; +use gpt_sandbox::{AsyncFn, Plugin, PluginResult, RegisterFn}; +use std::collections::HashMap; +use tokio::runtime::Handle; + +#[allow(improper_ctypes_definitions)] // will never be used by c programs, would use rust dylib if it existed +#[no_mangle] +pub extern "C" fn register_plugin() -> RegisterFn { + Box::into_raw(Box::new(ExamplePlugin)) +} + +struct ExamplePlugin; + +#[async_trait] +impl Plugin for ExamplePlugin { + fn register_functions(&self) -> HashMap { + let mut map = HashMap::new(); + map.insert("example_function".to_string(), example_function as AsyncFn); + dbg!(map) + } + + fn unregister_functions(&self) -> Vec { + dbg!(vec!["example_function".to_string()]) + } +} + +fn example_function(handle: Handle) -> tokio::task::JoinHandle { + println!("start of example_function"); + + handle.spawn(async { + println!("task running"); + PluginResult::Ok + }) +} diff --git a/src/lib.rs b/src/lib.rs index e12925d..664869a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,29 +1,19 @@ -pub use tokio; -pub use tracing; +use std::collections::HashMap; + +use tokio::runtime::Handle; + +pub type AsyncFn = fn(Handle) -> tokio::task::JoinHandle; + +pub type RegisterFn = *mut dyn Plugin; #[derive(Debug)] -pub enum PluginError { - FunctionError(String), +pub enum PluginResult { + Ok, + FunctionNotFound, Other(String), } -impl core::fmt::Display for PluginError { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - match self { - PluginError::FunctionError(msg) => write!(f, "Function error: {}", msg), - PluginError::Other(msg) => write!(f, "Other error: {}", msg), - } - } +pub trait Plugin { + fn register_functions(&self) -> HashMap; + fn unregister_functions(&self) -> Vec; } - -impl std::error::Error for PluginError {} - -//pub type AsyncFnReturnType = Result<(), PluginError>; -//pub type AsyncFn = Box< -// dyn Fn() -// -> std::pin::Pin + Send + Sync>> -// + Send -// + Sync, -//>; -pub type AsyncFnReturnType = tokio::task::JoinHandle>; -pub type AsyncFn = Box AsyncFnReturnType + Send + Sync>; diff --git a/src/main.rs b/src/main.rs index b2acddd..d77c208 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,148 +1,141 @@ -use gpt_sandbox::AsyncFn; +use gpt_sandbox::{AsyncFn, PluginResult, RegisterFn}; use libloading::{Library, Symbol}; use std::collections::HashMap; -use std::path::PathBuf; -use std::sync::{Arc, Mutex}; -use tokio::runtime::Runtime; -use tokio::task; -use tracing::{error, info, Dispatch}; +use std::sync::Arc; +use tokio::runtime::Handle; +use tokio::sync::RwLock; +use tracing::{error, info}; -struct PluginContainer { - function_map: Arc>>, - libraries: Arc>>, - reactor: Arc, - tracing_dispatcher: Dispatch, -} - -impl PluginContainer { - fn new(reactor: Arc) -> Self { - let mut tracing_dispatcher = None; - tracing::dispatcher::get_default(|what| tracing_dispatcher = Some(what.clone())); - PluginContainer { - function_map: Arc::new(Mutex::new(HashMap::new())), - libraries: Arc::new(Mutex::new(HashMap::new())), - reactor, - tracing_dispatcher: tracing_dispatcher.unwrap(), - } - } - - fn load_plugin(&self, path: &PathBuf) -> Result<(), Box> { - let lib_name = path.to_string_lossy().to_string(); - let lib = unsafe { Library::new(path) }?; - - info!("loading plugin {path:?}"); - - unsafe { - let func: Symbol)> = - lib.get(b"register_reactor\0")?; - func(Arc::clone(&self.reactor)); - } - - unsafe { - let func: Symbol = - lib.get(b"register_dispatcher\0")?; - func(&self.tracing_dispatcher.clone()); - } - - unsafe { - let func: Symbol>>)> = - lib.get(b"register_functions\0")?; - func(Arc::clone(&self.function_map)); - } - - self.libraries.lock().unwrap().insert(lib_name, lib); - Ok(()) - } - - fn unload_plugin(&self, path: &PathBuf) -> Result<(), Box> { - let lib_name = path.to_string_lossy().to_string(); - let lib = self - .libraries - .lock() - .unwrap() - .remove(&lib_name) - .ok_or("Library not loaded")?; - - unsafe { - let func: Symbol = lib.get(b"unregister_functions\0")?; - func(); - } - - Ok(()) - } -} - -//#[tokio::main] -fn main() -> Result<(), Box> { - tracing::subscriber::set_global_default( - tracing_subscriber::fmt() - .with_max_level(tracing::Level::DEBUG) - .finish(), - )?; - - info!("starting..."); - - let tokio_reactor = Arc::new( - ::tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .unwrap(), - ); - - let plugin_manager = PluginContainer::new(tokio_reactor.clone()); - - let plugin_paths = vec![PathBuf::from("target/debug/libplugin_lib.so")]; // Adjust paths as needed - for path in plugin_paths { - if let Err(e) = plugin_manager.load_plugin(&path) { - eprintln!("Failed to load plugin from {:?}: {}", path, e); - } - } - - info!("plugins loaded..."); - - tokio_reactor.block_on(async { - loop { - let name = get_name().await; - if let Some(func) = plugin_manager.function_map.lock().unwrap().get(&name) { - info!("found func..."); - - let handle = func(); - info!("obtained handle..."); - task::spawn(async move { - info!("executing function...",); - match handle.await { - //Ok(()) => info!("Function {} executed successfully", name), - //Err(e) => error!("Function {} returned an error: {}", name, e), - Ok(Ok(())) => info!("Function {} executed successfully", name), - Ok(Err(e)) => error!("Function {} returned an error: {}", name, e), - Err(e) => error!("Task join error for function {}: {:?}", name, e), - } - }); - - //let handle = func(); - ////info!("obtained handle..."); - ////task::spawn(async move { - // info!("executing function...",); - // match handle.await { - // Ok(()) => info!("Function {} executed successfully", name), - // Err(e) => error!("Function {} returned an error: {}", name, e), - // //Err(e) => error!("Task join error for function {}: {:?}", name, e), - // } - ////}); - } else { - info!("Function not found for name: {}", name); - } - - tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; - if let Err(e) = - plugin_manager.unload_plugin(&PathBuf::from("target/debug/libplugin_lib.so")) - { - eprintln!("Failed to unload plugin: {}", e); - } - } - }) -} - -async fn get_name() -> String { +fn get_name() -> String { "example_function".to_string() } + +async fn load_plugin( + lib_path: &str, + function_map: Arc>>, +) -> Result { + let lib = unsafe { Library::new(lib_path).map_err(|e| PluginResult::Other(e.to_string())) }?; + + let plugin: Symbol RegisterFn> = + unsafe { lib.get(b"register_plugin") }.map_err(|e| PluginResult::Other(e.to_string()))?; + let plugin_instance = unsafe { Box::from_raw(plugin()) }; + + let functions = plugin_instance.register_functions(); + function_map.write().await.extend(functions); + + Ok(lib) +} + +async fn unload_plugin( + lib: Library, + function_map: Arc>>, +) -> Result<(), PluginResult> { + let plugin: Symbol RegisterFn> = + unsafe { lib.get(b"register_plugin") }.map_err(|e| PluginResult::Other(e.to_string()))?; + let plugin_instance = unsafe { Box::from_raw(plugin()) }; + + info!("found plugin instance"); + + let functions = plugin_instance.unregister_functions(); + info!("runregister functions"); + { + let mut map = function_map.write().await; + for function in functions { + map.remove(&function); + } + } + + info!("removed functions"); + + let _ = lib; + + info!("unloaded lib"); + Ok(()) +} +//#[tokio::main(flavor = "multi_thread", worker_threads = 6)] +//async fn main() -> Result<(), PluginResult> { +// tracing::subscriber::set_global_default(tracing_subscriber::fmt().finish()).unwrap(); +// let function_map: Arc>> = Arc::new(RwLock::new(HashMap::new())); + +// // Load the plugin +// info!("Loading plugin"); +// let plugin_lib = load_plugin("target/release/libtesting.so", Arc::clone(&function_map)).await?; +// info!("Plugin loaded"); + +// let function_name = get_name(); +// let mut handles = Vec::new(); + +// // Start tasks +// info!("Starting task execution loop"); +// for i in 0..3 { +// if let Some(function) = function_map.read().await.get(&function_name).cloned() { +// info!("Starting task for function '{}'", function_name); +// let handle = function(Handle::current().clone()); +// handles.push(handle); +// } else { +// info!("Function '{}' not found in map", function_name); +// } +// } + +// // Await all tasks +// info!("Awaiting tasks to complete"); +// for (i, handle) in handles.into_iter().enumerate() { +// match handle.await.unwrap() { +// PluginResult::Ok => info!("Task {} exited with Ok", i), +// PluginResult::FunctionNotFound => info!("Task {} found FunctionNotFound", i), +// PluginResult::Other(err) => error!("Task {} exited with error: `{}`", i, err), +// } +// } + +// // Unload the plugin +// info!("Unloading plugin"); +// unload_plugin(plugin_lib, Arc::clone(&function_map)).await?; +// info!("Plugin unloaded successfully"); + +// // Ensure all tasks are done +// info!("Waiting for all tasks to complete"); +// tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; // Add a short sleep to ensure tasks finish + +// info!("Exiting main function"); +// Ok(()) +//} + +#[tokio::main(flavor = "multi_thread", worker_threads = 6)] +async fn main() -> Result<(), PluginResult> { + tracing::subscriber::set_global_default(tracing_subscriber::fmt().finish()).unwrap(); + let function_map: Arc>> = Arc::new(RwLock::new(HashMap::new())); + + let plugin_lib = load_plugin("target/release/libtesting.so", Arc::clone(&function_map)).await?; + + let function_name = get_name(); + + //for i in 0..3 { + // if let Some(function) = function_map.read().await.get(&function_name).cloned() { + // match function(Handle::current().clone()).await.unwrap() { + // PluginResult::Ok => info!("Function {i} exited with Ok"), + // PluginResult::FunctionNotFound => todo!(), + // PluginResult::Other(err) => error!("Function exited with error: `{err}`"), + // }; + // } else { + // println!("Function '{}' not found", function_name); + // } + //} + + for i in 0..3 { + if let Some(function) = function_map.read().await.get(&function_name).cloned() { + let join_handle = function(Handle::current().clone()); + match join_handle.await.unwrap() { + PluginResult::Ok => info!("Function {i} exited with Ok"), + PluginResult::FunctionNotFound => todo!(), + PluginResult::Other(err) => error!("Function exited with error: `{err}`"), + }; + } else { + println!("Function '{}' not found", function_name); + } + } + + info!("unloading"); + unload_plugin(plugin_lib, Arc::clone(&function_map)).await?; + + Ok(()) +}