Change the context API to make Context-Ownership more flexible

This commit is contained in:
Fabian Stamm
2026-01-09 15:58:32 +01:00
parent 4c7084563f
commit c29dafb042
7 changed files with 1294 additions and 617 deletions

View File

@@ -12,5 +12,5 @@ anyhow = "1"
lazy_static = "1"
log = "0.4"
regex = "1"
reqwest = { version = "0.12", optional = true, features = ["blocking"] }
reqwest = { version = "0.13", optional = true, features = ["blocking"] }
url = { version = "2", optional = true }

View File

@@ -63,7 +63,7 @@ impl RustCompiler {
fn fix_keyword_name(name: &str) -> String {
if RUST_KEYWORDS.contains(&name) {
format!("{}_", name)
format!("r#{}", name)
} else {
name.to_string()
}
@@ -146,7 +146,7 @@ impl RustCompiler {
f.a0("#[async_trait]");
f.a0(format!("pub trait {} {{", definition.name));
f.a1("type Context: Clone + Sync + Send + 'static;");
f.a1("type Context: Sync + Send;");
for method in definition.methods.iter() {
let mut params = method
.inputs
@@ -159,7 +159,7 @@ impl RustCompiler {
)
})
.collect::<Vec<String>>();
params.push("ctx: Self::Context".to_string());
params.push("ctx: &Self::Context".to_string());
let params = params.join(", ");
let ret = method
@@ -190,7 +190,7 @@ impl RustCompiler {
f.a0("");
f.a0(format!(
"impl<Context: Clone + Sync + Send + 'static> {}Handler<Context> {{",
"impl<Context: Sync + Send> {}Handler<Context> {{",
definition.name
));
f.a1(format!(
@@ -205,7 +205,7 @@ impl RustCompiler {
f.a0("#[async_trait]");
f.a0(format!(
"impl<Context: Clone + Sync + Send + 'static> JRPCServerService for {}Handler<Context> {{",
"impl<Context: Sync + Send> JRPCServerService for {}Handler<Context> {{",
definition.name
));
f.a1("type Context = Context;");
@@ -218,7 +218,7 @@ impl RustCompiler {
f.a1("#[allow(non_snake_case)]");
f.a1(
"async fn handle(&self, msg: &JRPCRequest, function: &str, ctx: Self::Context) -> Result<(bool, Value)> {",
"async fn handle(&self, msg: &JRPCRequest, function: &str, ctx: &Self::Context) -> Result<(bool, Value)> {",
);
f.a2("match function {");
@@ -241,6 +241,9 @@ impl RustCompiler {
f.a5(
"let arr = msg.params.as_array().unwrap(); //TODO: Check if this can fail.",
);
f.a5(format!("if arr.len() != {} {{", method.inputs.len()));
f.a6("return Err(\"Invalid number of arguments!\".into())");
f.a5("}");
}
f.a5(format!("let res = self.implementation.{}(", method.name));
for (i, arg) in method.inputs.iter().enumerate() {
@@ -479,10 +482,10 @@ impl Compile for RustCompiler {
f.a(1, "#[allow(non_snake_case)]");
if Keywords::is_keyword(&field.name) {
warn!(
"[RUST] Warning: Field name '{}' is not allowed in Rust. Renaming to '{}_'",
field.name, field.name
);
// warn!(
// "[RUST] Warning: Field name '{}' is not allowed in Rust. Renaming to '{}_'",
// field.name, field.name
// );
f.a(1, format!("#[serde(rename = \"{}\")]", field.name));
}

View File

@@ -6,7 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
int-enum = { version = "0.5", features = ["serde", "convert"] }
int-enum = { version = "1.2.0" }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
nanoid = "0.4"

View File

@@ -112,19 +112,27 @@ pub trait JRPCServerService: Send + Sync {
&self,
request: &JRPCRequest,
function: &str,
ctx: Self::Context,
ctx: &Self::Context,
) -> Result<(bool, Value)>;
}
pub type JRPCServiceHandle<Context> = Arc<dyn JRPCServerService<Context = Context>>;
#[derive(Clone)]
pub struct JRPCSession<Context> {
server: JRPCServer<Context>,
message_sender: Sender<JRPCResult>,
}
impl<Context: Clone + Send + Sync + 'static> JRPCSession<Context> {
impl<Context> Clone for JRPCSession<Context> {
fn clone(&self) -> Self {
JRPCSession {
server: self.server.clone(),
message_sender: self.message_sender.clone(),
}
}
}
impl<Context: Send + Sync + 'static> JRPCSession<Context> {
pub fn new(server: JRPCServer<Context>, sender: Sender<JRPCResult>) -> Self {
JRPCSession {
server,
@@ -157,64 +165,69 @@ impl<Context: Clone + Send + Sync + 'static> JRPCSession<Context> {
pub fn handle_request(&self, request: JRPCRequest, ctx: Context) -> () {
let session = self.clone();
tokio::task::spawn(async move {
info!("Received request: {}", request.method);
trace!("Request data: {:?}", request);
let method: Vec<&str> = request.method.split('.').collect();
if method.len() != 2 {
warn!("Invalid method received: {}", request.method);
return;
}
let service = method[0];
let function = method[1];
let context = ctx;
session.handle_request_awaiting(request, &context).await;
});
}
let service = session.server.services.get(service);
if let Some(service) = service {
let result = service.handle(&request, function, ctx).await;
match result {
Ok((is_send, result)) => {
if is_send && request.id.is_some() {
let result = session
.message_sender
.send(JRPCResult {
jsonrpc: "2.0".to_string(),
id: request.id.unwrap(),
result: Some(result),
error: None,
})
.await;
if let Err(err) = result {
warn!("Error while sending result: {}", err);
}
pub async fn handle_request_awaiting(&self, request: JRPCRequest, ctx: &Context) -> () {
info!("Received request: {}", request.method);
trace!("Request data: {:?}", request);
let method: Vec<&str> = request.method.split('.').collect();
if method.len() != 2 {
warn!("Invalid method received: {}", request.method);
return;
}
let service = method[0];
let function = method[1];
let service = self.server.services.get(service);
if let Some(service) = service {
let result = service.handle(&request, function, ctx).await;
match result {
Ok((is_send, result)) => {
if is_send && request.id.is_some() {
let result = self
.message_sender
.send(JRPCResult {
jsonrpc: "2.0".to_string(),
id: request.id.unwrap(),
result: Some(result),
error: None,
})
.await;
if let Err(err) = result {
warn!("Error while sending result: {}", err);
}
}
Err(err) => {
warn!("Error while handling request: {}", err);
session
.send_error(
request,
format!("Error while handling request: {}", err),
1,
)
.await;
}
}
} else {
warn!("Service not found: {}", method[0]);
session
.send_error(request, "Service not found".to_string(), 1)
.await;
return;
Err(err) => {
warn!("Error while handling request: {}", err);
self.send_error(request, format!("Error while handling request: {}", err), 1)
.await;
}
}
});
} else {
warn!("Service not found: {}", method[0]);
self.send_error(request, "Service not found".to_string(), 1)
.await;
}
}
}
#[derive(Clone)]
pub struct JRPCServer<Context> {
services: HashMap<String, JRPCServiceHandle<Context>>,
}
impl<Context: Clone + Send + Sync + 'static> JRPCServer<Context> {
impl<Context> Clone for JRPCServer<Context> {
fn clone(&self) -> Self {
JRPCServer {
services: self.services.clone(),
}
}
}
impl<Context: Send + Sync + 'static> JRPCServer<Context> {
pub fn new() -> Self {
JRPCServer {
services: HashMap::new(),