This commit is contained in:
deepCurse 2024-08-13 16:27:02 -03:00
parent 489dd0f14b
commit a3b46410e8
Signed by: u1
GPG key ID: 0EA7B9E85212693C
7 changed files with 283 additions and 0 deletions

13
Cargo.toml Normal file
View file

@ -0,0 +1,13 @@
[package]
name = "gpt-sandbox"
version = "0.1.0"
edition = "2021"
[dependencies]
tokio = { version = "1.39.2", features = ["full"] }
tracing = "0.1.40"
tracing-subscriber = "0.3.18"
libloading = "0.8.5"
[workspace]
members = ["plugin-lib"]

12
plugin-lib/Cargo.toml Normal file
View file

@ -0,0 +1,12 @@
[package]
name = "plugin-lib"
version = "0.1.0"
edition = "2021"
[lib]
crate-type = ["cdylib"]
[dependencies]
#tokio = { version = "1.39.2", features = ["full"] }
#tracing = "0.1.40"
gpt-sandbox = { path = "../" }

1
plugin-lib/rustfmt.toml Normal file
View file

@ -0,0 +1 @@
hard_tabs=true

79
plugin-lib/src/lib.rs Normal file
View file

@ -0,0 +1,79 @@
use gpt_sandbox::tokio;
use gpt_sandbox::tokio::runtime::Runtime;
use gpt_sandbox::tracing;
use gpt_sandbox::tracing::Dispatch;
use gpt_sandbox::AsyncFn;
use gpt_sandbox::AsyncFnReturnType;
use gpt_sandbox::PluginError;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tracing::info;
// TODO fix static mut black magic, it can and will set fire eventually
static mut FUNCTION_MAP: Option<Arc<Mutex<HashMap<String, AsyncFn>>>> = None;
static REACTOR: ::std::sync::OnceLock<Arc<Runtime>> = ::std::sync::OnceLock::new();
#[no_mangle]
pub extern "C" fn register_reactor(tokio_reactor: Arc<Runtime>) {
REACTOR.set(tokio_reactor);
}
#[no_mangle]
pub extern "C" fn register_dispatcher(dispatcher: Dispatch) {
//tracing::subscriber::set_global_default(dispatcher).unwrap();
tracing::dispatcher::set_global_default(dispatcher).unwrap();
info!("test");
}
#[no_mangle]
pub extern "C" fn register_functions(function_map: Arc<Mutex<HashMap<String, AsyncFn>>>) {
info!("registering functions...");
unsafe {
FUNCTION_MAP = Some(Arc::clone(&function_map));
}
let example_function: AsyncFn = Box::new(example_function);
info!("writing functions...");
function_map
.lock()
.unwrap()
.insert("example_function".to_string(), example_function);
}
#[no_mangle]
pub extern "C" fn unregister_functions() {
unsafe {
if let Some(function_map) = &FUNCTION_MAP {
function_map.lock().unwrap().remove("example_function");
}
}
}
//pub fn example_function() -> AsyncFnReturnType {
// REACTOR.get().unwrap().spawn(async {
// tracing::info!("example_function is running");
// // Simulate some async work
// tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
// tracing::info!("example_function completed");
// // Return an error for demonstration
// Err(PluginError::FunctionError(
// "Something went wrong".to_string(),
// ))
// })
//}
pub async fn example_function() -> AsyncFnReturnType {
REACTOR.get().unwrap().spawn(async {
tracing::info!("example_function is running");
// Simulate some async work
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
tracing::info!("example_function completed");
// Return an error for demonstration
Err(PluginError::FunctionError(
"Something went wrong".to_string(),
))
})
}

1
rustfmt.toml Normal file
View file

@ -0,0 +1 @@
hard_tabs=true

29
src/lib.rs Normal file
View file

@ -0,0 +1,29 @@
pub use tokio;
pub use tracing;
#[derive(Debug)]
pub enum PluginError {
FunctionError(String),
Other(String),
}
impl core::fmt::Display for PluginError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
PluginError::FunctionError(msg) => write!(f, "Function error: {}", msg),
PluginError::Other(msg) => write!(f, "Other error: {}", msg),
}
}
}
impl std::error::Error for PluginError {}
//pub type AsyncFnReturnType = Result<(), PluginError>;
//pub type AsyncFn = Box<
// dyn Fn()
// -> std::pin::Pin<Box<dyn std::future::Future<Output = AsyncFnReturnType> + Send + Sync>>
// + Send
// + Sync,
//>;
pub type AsyncFnReturnType = tokio::task::JoinHandle<Result<(), PluginError>>;
pub type AsyncFn = Box<dyn Fn() -> AsyncFnReturnType + Send + Sync>;

148
src/main.rs Normal file
View file

@ -0,0 +1,148 @@
use gpt_sandbox::AsyncFn;
use libloading::{Library, Symbol};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use tokio::runtime::Runtime;
use tokio::task;
use tracing::{error, info, Dispatch};
struct PluginContainer {
function_map: Arc<Mutex<HashMap<String, AsyncFn>>>,
libraries: Arc<Mutex<HashMap<String, Library>>>,
reactor: Arc<Runtime>,
tracing_dispatcher: Dispatch,
}
impl PluginContainer {
fn new(reactor: Arc<tokio::runtime::Runtime>) -> Self {
let mut tracing_dispatcher = None;
tracing::dispatcher::get_default(|what| tracing_dispatcher = Some(what.clone()));
PluginContainer {
function_map: Arc::new(Mutex::new(HashMap::new())),
libraries: Arc::new(Mutex::new(HashMap::new())),
reactor,
tracing_dispatcher: tracing_dispatcher.unwrap(),
}
}
fn load_plugin(&self, path: &PathBuf) -> Result<(), Box<dyn std::error::Error>> {
let lib_name = path.to_string_lossy().to_string();
let lib = unsafe { Library::new(path) }?;
info!("loading plugin {path:?}");
unsafe {
let func: Symbol<unsafe extern "C" fn(Arc<Runtime>)> =
lib.get(b"register_reactor\0")?;
func(Arc::clone(&self.reactor));
}
unsafe {
let func: Symbol<unsafe extern "C" fn(&Dispatch)> =
lib.get(b"register_dispatcher\0")?;
func(&self.tracing_dispatcher.clone());
}
unsafe {
let func: Symbol<unsafe extern "C" fn(Arc<Mutex<HashMap<String, AsyncFn>>>)> =
lib.get(b"register_functions\0")?;
func(Arc::clone(&self.function_map));
}
self.libraries.lock().unwrap().insert(lib_name, lib);
Ok(())
}
fn unload_plugin(&self, path: &PathBuf) -> Result<(), Box<dyn std::error::Error>> {
let lib_name = path.to_string_lossy().to_string();
let lib = self
.libraries
.lock()
.unwrap()
.remove(&lib_name)
.ok_or("Library not loaded")?;
unsafe {
let func: Symbol<unsafe extern "C" fn()> = lib.get(b"unregister_functions\0")?;
func();
}
Ok(())
}
}
//#[tokio::main]
fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing::subscriber::set_global_default(
tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.finish(),
)?;
info!("starting...");
let tokio_reactor = Arc::new(
::tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap(),
);
let plugin_manager = PluginContainer::new(tokio_reactor.clone());
let plugin_paths = vec![PathBuf::from("target/debug/libplugin_lib.so")]; // Adjust paths as needed
for path in plugin_paths {
if let Err(e) = plugin_manager.load_plugin(&path) {
eprintln!("Failed to load plugin from {:?}: {}", path, e);
}
}
info!("plugins loaded...");
tokio_reactor.block_on(async {
loop {
let name = get_name().await;
if let Some(func) = plugin_manager.function_map.lock().unwrap().get(&name) {
info!("found func...");
let handle = func();
info!("obtained handle...");
task::spawn(async move {
info!("executing function...",);
match handle.await {
//Ok(()) => info!("Function {} executed successfully", name),
//Err(e) => error!("Function {} returned an error: {}", name, e),
Ok(Ok(())) => info!("Function {} executed successfully", name),
Ok(Err(e)) => error!("Function {} returned an error: {}", name, e),
Err(e) => error!("Task join error for function {}: {:?}", name, e),
}
});
//let handle = func();
////info!("obtained handle...");
////task::spawn(async move {
// info!("executing function...",);
// match handle.await {
// Ok(()) => info!("Function {} executed successfully", name),
// Err(e) => error!("Function {} returned an error: {}", name, e),
// //Err(e) => error!("Task join error for function {}: {:?}", name, e),
// }
////});
} else {
info!("Function not found for name: {}", name);
}
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
if let Err(e) =
plugin_manager.unload_plugin(&PathBuf::from("target/debug/libplugin_lib.so"))
{
eprintln!("Failed to unload plugin: {}", e);
}
}
})
}
async fn get_name() -> String {
"example_function".to_string()
}