Finish implementation of new Rust Tokio Target

This commit is contained in:
Fabian Stamm
2022-12-17 17:11:28 +01:00
parent 890b903f04
commit 46aff0c61b
21 changed files with 1598 additions and 685 deletions

View File

@ -0,0 +1,2 @@
[*.rs]
indent_size = 4

View File

@ -6,10 +6,11 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
int-enum = "0.4.0"
serde = { version = "1.0.136", features = ["derive"] }
serde_json = "1.0.79"
threadpool = "1.8.1"
int-enum = "0.5.0"
serde = { version = "1.0.147", features = ["derive"] }
serde_json = "1.0.88"
nanoid = "0.4.0"
tokio = { version = "1.22.0", features = ["full"] }
log = "0.4.17"
simple_logger = { version = "4.0.0", features = ["threads", "colored", "timestamps", "stderr"] }
async-trait = "0.1.59"

View File

@ -1,286 +1,223 @@
use log::{info, trace, warn};
use nanoid::nanoid;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::boxed::Box;
use std::collections::HashMap;
use std::error::Error;
use std::marker::PhantomData;
use std::marker::Send;
use std::sync::mpsc::{Receiver, Sender};
use std::sync::{Arc, Mutex};
use threadpool::ThreadPool;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::{mpsc::Sender, Mutex};
pub type Result<T> = std::result::Result<T, Box<dyn Error>>;
pub type Result<T> = std::result::Result<T, Box<dyn Error + Send + Sync>>;
// TODO: Check what happens when error code is not included
// #[repr(i64)]
// #[derive(Clone, Copy, Debug, Eq, PartialEq, IntEnum, Deserialize, Serialize)]
// pub enum ErrorCodes {
// ParseError = -32700,
// InvalidRequest = -32600,
// MethodNotFound = -32601,
// InvalidParams = -32602,
// InternalError = -32603,
// }
#[derive(Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct JRPCRequest {
pub jsonrpc: String,
pub id: Option<String>,
pub method: String,
pub params: Value,
pub jsonrpc: String,
pub id: Option<String>,
pub method: String,
pub params: Value,
}
#[derive(Serialize, Deserialize)]
impl JRPCRequest {
pub fn new_request(method: String, params: Value) -> JRPCRequest {
JRPCRequest {
jsonrpc: "2.0".to_string(),
id: None,
method,
params,
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct JRPCError {
pub code: i64,
pub message: String,
pub data: Value,
pub code: i64,
pub message: String,
pub data: Value,
}
#[derive(Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct JRPCResult {
pub jsonrpc: String,
pub id: String,
pub result: Value,
pub error: Option<JRPCError>,
pub jsonrpc: String,
pub id: String,
pub result: Value,
pub error: Option<JRPCError>,
}
// ******************************************************************************
// * SERVER
// ******************************************************************************
pub trait JRPCServiceHandler<C: Sync>: Send {
fn get_name(&self) -> String;
fn on_message(&self, msg: JRPCRequest, function: String, ctx: &C) -> Result<(bool, Value)>;
}
type Shared<T> = Arc<Mutex<T>>;
type SharedHM<K, V> = Shared<HashMap<K, V>>;
type ServiceSharedHM<C> = SharedHM<String, Box<dyn JRPCServiceHandler<C>>>;
type SharedThreadPool = Shared<ThreadPool>;
pub struct JRPCServer<CTX: 'static + Sync + Send + Copy> {
services: ServiceSharedHM<CTX>,
pool: SharedThreadPool,
}
impl<CTX: 'static + Sync + Send + Copy> JRPCServer<CTX> {
pub fn new() -> Self {
return Self {
services: Arc::new(Mutex::new(HashMap::new())),
pool: Arc::new(Mutex::new(ThreadPool::new(32))),
};
}
pub fn add_service(&mut self, service: Box<dyn JRPCServiceHandler<CTX>>) {
let mut services = self.services.lock().unwrap();
services.insert(service.get_name(), service);
}
pub fn start_session(
&mut self,
read_ch: Receiver<String>,
write_ch: Sender<String>,
context: CTX,
) {
let services = self.services.clone();
let p = self.pool.lock().unwrap();
let pool = self.pool.clone();
p.execute(move || {
JRPCSession::start(read_ch, write_ch, context, services, pool);
});
}
}
pub struct JRPCSession<CTX: 'static + Sync + Send + Copy> {
_ctx: PhantomData<CTX>,
}
unsafe impl<CTX: 'static + Sync + Send + Copy> Sync for JRPCSession<CTX> {}
impl<CTX: 'static + Sync + Send + Copy> JRPCSession<CTX> {
fn start(
read_ch: Receiver<String>,
write_ch: Sender<String>,
context: CTX,
services: ServiceSharedHM<CTX>,
pool: SharedThreadPool,
) {
loop {
let pkg = read_ch.recv();
let data = match pkg {
Err(_) => return,
Ok(res) => res,
};
if data.len() == 0 {
//TODO: This can be done better
return;
}
let ctx = context.clone();
let svs = services.clone();
let wc = write_ch.clone();
pool.lock().unwrap().execute(move || {
JRPCSession::handle_packet(data, wc, ctx, svs);
})
}
}
fn handle_packet(
data: String,
write_ch: Sender<String>,
context: CTX,
services: ServiceSharedHM<CTX>,
) {
let req: Result<JRPCRequest> =
serde_json::from_str(data.as_str()).map_err(|err| Box::from(err));
let req = match req {
Err(_) => {
return;
}
Ok(parsed) => parsed,
};
let req_id = req.id.clone();
let mut parts: Vec<String> = req.method.splitn(2, '.').map(|e| e.to_owned()).collect();
if parts.len() != 2 {
return Self::send_err_res(req_id, write_ch, Box::from("Error".to_owned()));
}
let service = parts.remove(0);
let function = parts.remove(0);
let svs = services.lock().unwrap();
let srv = svs.get(&service);
if let Some(srv) = srv {
match srv.on_message(req, function, &context) {
Ok((is_send, value)) => {
if is_send {
if let Some(id) = req_id {
let r = JRPCResult {
jsonrpc: "2.0".to_owned(),
id,
result: value,
error: None,
};
let s = serde_json::to_string(&r);
if s.is_ok() {
write_ch
.send(s.unwrap())
.expect("Sending data into channel failed!");
}
}
}
}
Err(err) => return Self::send_err_res(req_id, write_ch, err),
}
}
}
fn send_err_res(id: Option<String>, write_ch: Sender<String>, err: Box<dyn Error>) {
if let Some(id) = id {
let error = JRPCError {
code: 0, //TODO: Make this better?
message: err.to_string(),
data: Value::Null,
};
let r = JRPCResult {
jsonrpc: "2.0".to_owned(),
id: id.clone(),
result: Value::Null,
error: Option::from(error),
};
let s = serde_json::to_string(&r);
if s.is_ok() {
write_ch
.send(s.unwrap())
.expect("Sending data into channel failed!");
}
}
return ();
}
}
// ******************************************************************************
// * CLIENT
// ******************************************************************************
#[derive(Clone)]
#[derive(Debug, Clone)]
pub struct JRPCClient {
write_ch: Sender<String>,
requests: SharedHM<String, Sender<Result<Value>>>,
message_sender: Sender<JRPCRequest>,
requests: Arc<Mutex<HashMap<String, Sender<JRPCResult>>>>,
}
unsafe impl Send for JRPCClient {} //TODO: Is this a problem
impl JRPCClient {
pub fn new(write_ch: Sender<String>, read_ch: Receiver<String>) -> Self {
let n = Self {
write_ch,
requests: Arc::new(Mutex::new(HashMap::new())),
};
pub fn new(sender: Sender<JRPCRequest>) -> JRPCClient {
JRPCClient {
message_sender: sender,
requests: Arc::new(Mutex::new(HashMap::new())),
}
}
n.start(read_ch);
return n;
}
pub async fn send_request(&self, mut request: JRPCRequest) -> Result<Value> {
let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
pub fn start(&self, read_ch: Receiver<String>) {
let s = self.clone();
std::thread::spawn(move || {
s.start_reader(read_ch);
});
}
if request.id.is_none() {
request.id = Some(nanoid!());
}
fn start_reader(&self, read_ch: Receiver<String>) {
loop {
let data = read_ch.recv().expect("Error receiving packet!");
let response: JRPCResult =
serde_json::from_str(data.as_str()).expect("Error decoding response!");
let id = response.id;
{
let mut self_requests = self.requests.lock().await;
self_requests.insert(request.id.clone().unwrap(), sender);
}
self.message_sender.send(request).await?;
let reqs = self.requests.lock().expect("Error locking requests map!");
let req = reqs.get(&id);
if let Some(req) = req {
let res = if let Some(err) = response.error {
Err(Box::from(err.message))
let result = receiver.recv().await;
if let Some(result) = result {
if let Some(error) = result.error {
return Err(format!("Error while receiving result: {}", error.message).into());
} else {
Ok(response.result)
return Ok(result.result);
}
} else {
return Err("Error while receiving result".into());
}
}
pub async fn send_notification(&self, mut request: JRPCRequest) {
request.id = None;
_ = self.message_sender.send(request).await;
}
pub async fn on_result(&self, result: JRPCResult) {
let id = result.id.clone();
let mut self_requests = self.requests.lock().await;
let sender = self_requests.get(&id);
if let Some(sender) = sender {
_ = sender.send(result).await;
self_requests.remove(&id);
}
}
}
#[async_trait::async_trait]
pub trait JRPCServerService: Send + Sync + 'static {
fn get_id(&self) -> String;
async fn handle(&self, request: &JRPCRequest, function: &str) -> Result<(bool, Value)>;
}
pub type JRPCServiceHandle = Arc<dyn JRPCServerService>;
#[derive(Clone)]
pub struct JRPCSession {
server: JRPCServer,
message_sender: Sender<JRPCResult>,
}
impl JRPCSession {
pub fn new(server: JRPCServer, sender: Sender<JRPCResult>) -> JRPCSession {
JRPCSession {
server,
message_sender: sender,
}
}
async fn send_error(&self, request: JRPCRequest, error_msg: String, error_code: i64) -> () {
if let Some(request_id) = request.id {
let error = JRPCError {
code: error_code,
message: error_msg,
data: Value::Null,
};
let result = JRPCResult {
jsonrpc: "2.0".to_string(),
id: request_id,
result: Value::Null,
error: Some(error),
};
req.send(res).expect("Error sending reponse!");
}
}
}
// Send result
let result = self.message_sender.send(result).await;
if let Err(err) = result {
warn!("Error while sending result: {}", err);
}
}
}
pub fn send_request(&self, mut req: JRPCRequest) -> Result<Value> {
let mut reqs = self.requests.lock().expect("Error locking requests map!");
let id = nanoid!();
req.id = Some(id.clone());
pub fn handle_request(&self, request: JRPCRequest) -> () {
let session = self.clone();
tokio::task::spawn(async move {
info!("Received request: {}", request.method);
trace!("Request data: {:?}", request);
let method: Vec<&str> = request.method.split('.').collect();
if method.len() != 2 {
warn!("Invalid method received: {}", request.method);
return;
}
let service = method[0];
let function = method[1];
let (tx, rx) = std::sync::mpsc::channel();
reqs.insert(id, tx);
self
.write_ch
.send(serde_json::to_string(&req).expect("Error converting Request to JSON!"))
.expect("Error Sending to Channel!");
return rx.recv().expect("Error getting response!");
}
pub fn send_notification(&self, mut req: JRPCRequest) {
req.id = None;
self
.write_ch
.send(serde_json::to_string(&req).expect("Error converting Request to JSON!"))
.expect("Error Sending to Channel!");
}
let service = session.server.services.get(service);
if let Some(service) = service {
let result = service.handle(&request, function).await;
match result {
Ok((is_send, result)) => {
if is_send && request.id.is_some() {
let result = session
.message_sender
.send(JRPCResult {
jsonrpc: "2.0".to_string(),
id: request.id.unwrap(),
result,
error: None,
})
.await;
if let Err(err) = result {
warn!("Error while sending result: {}", err);
}
}
}
Err(err) => {
warn!("Error while handling request: {}", err);
session
.send_error(
request,
format!("Error while handling request: {}", err),
1,
)
.await;
}
}
} else {
warn!("Service not found: {}", method[0]);
session
.send_error(request, "Service not found".to_string(), 1)
.await;
return;
}
});
}
}
#[derive(Clone)]
pub struct JRPCServer {
services: HashMap<String, JRPCServiceHandle>,
}
impl JRPCServer {
pub fn new() -> JRPCServer {
JRPCServer {
services: HashMap::new(),
}
}
pub fn add_service(&mut self, service: JRPCServiceHandle) -> () {
let id = service.get_id();
self.services.insert(id, service);
}
pub fn get_session(&self, sender: Sender<JRPCResult>) -> JRPCSession {
JRPCSession::new(self.clone(), sender)
}
}

View File

@ -1,2 +0,0 @@
target
Cargo.lock

View File

@ -1,15 +0,0 @@
[package]
name = "__name__"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
int-enum = "0.5.0"
serde = { version = "1.0.147", features = ["derive"] }
serde_json = "1.0.88"
threadpool = "1.8.1"
nanoid = "0.4.0"
tokio = { version = "1.22.0", features = ["full"] }

View File

@ -1,44 +0,0 @@
use nanoid::nanoid;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::boxed::Box;
use std::collections::HashMap;
use std::error::Error;
use std::marker::PhantomData;
use std::marker::Send;
use std::sync::mpsc::{Receiver, Sender};
use std::sync::{Arc, Mutex};
use threadpool::ThreadPool;
pub type Result<T> = std::result::Result<T, Box<dyn Error>>;
#[derive(Serialize, Deserialize)]
pub struct JRPCRequest {
pub jsonrpc: String,
pub id: Option<String>,
pub method: String,
pub params: Value,
}
#[derive(Serialize, Deserialize)]
pub struct JRPCError {
pub code: i64,
pub message: String,
pub data: Value,
}
#[derive(Serialize, Deserialize)]
pub struct JRPCResult {
pub jsonrpc: String,
pub id: String,
pub result: Value,
pub error: Option<JRPCError>,
}
struct JRPCServer {}
impl JRPCServer {
fn handle(&self) -> () {}
}
struct JRPCServerService {}