Implementing udp support partially...in theory. Still untested :)
This commit is contained in:
58
src/main.rs
58
src/main.rs
@ -1,6 +1,7 @@
|
||||
mod listener;
|
||||
mod shutdown;
|
||||
mod tcp;
|
||||
mod udp;
|
||||
|
||||
use serde::Deserialize;
|
||||
use simple_error::bail;
|
||||
@ -14,14 +15,14 @@ use tokio::sync::{broadcast, RwLock};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Target {
|
||||
udp: Option<bool>,
|
||||
source: String,
|
||||
targets: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Config {
|
||||
tcp: Vec<Target>,
|
||||
// udp: Vec<Target>,
|
||||
mappings: Vec<Target>,
|
||||
}
|
||||
|
||||
fn load_yaml(path: &Path) -> Result<Config, Box<dyn Error>> {
|
||||
@ -62,6 +63,7 @@ fn load_config() -> Result<Config, Box<dyn Error>> {
|
||||
|
||||
// #[derive(Debug)]
|
||||
struct ActiveListener {
|
||||
udp: bool,
|
||||
notify_shutdown: broadcast::Sender<()>,
|
||||
targets: Arc<RwLock<Vec<String>>>,
|
||||
}
|
||||
@ -75,27 +77,34 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let config = load_config().expect("config not found");
|
||||
let mut required_listeners: HashSet<String> = HashSet::new();
|
||||
|
||||
for target in config.tcp {
|
||||
required_listeners.insert(target.source.clone());
|
||||
if let Some(listener) = listeners.get(&target.source) {
|
||||
for target in config.mappings {
|
||||
let mut source_str = "".to_owned();
|
||||
if target.udp == None || target.udp == Some(false) {
|
||||
source_str.push_str("udp:");
|
||||
} else {
|
||||
source_str.push_str("tcp:");
|
||||
}
|
||||
source_str.push_str(&target.source);
|
||||
|
||||
required_listeners.insert(source_str.clone());
|
||||
if let Some(listener) = listeners.get(&source_str) {
|
||||
let mut invalid = false;
|
||||
{
|
||||
let targets = listener.targets.read().await;
|
||||
for t in &target.targets {
|
||||
if !targets.iter().any(|e| *e == *t) {
|
||||
|
||||
let targets = listener.targets.read().await;
|
||||
for t in &target.targets {
|
||||
if !targets.iter().any(|e| *e == *t) {
|
||||
invalid = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if !invalid {
|
||||
for t in targets.iter() {
|
||||
if !target.targets.iter().any(|e| *e == *t) {
|
||||
invalid = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if !invalid {
|
||||
for t in targets.iter() {
|
||||
if !target.targets.iter().any(|e| *e == *t) {
|
||||
invalid = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if invalid {
|
||||
@ -109,6 +118,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let listener = ActiveListener {
|
||||
notify_shutdown: notify_shutdown,
|
||||
targets: Arc::new(RwLock::new(target.targets)),
|
||||
udp: target.udp == None || target.udp == Some(false),
|
||||
};
|
||||
|
||||
let l = listener::Listener {
|
||||
@ -118,12 +128,18 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
};
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = tcp::start_tcp_listener(l).await {
|
||||
println!("listener error: {}", err);
|
||||
if target.udp == None || target.udp == Some(false) {
|
||||
if let Err(err) = tcp::start_tcp_listener(l).await {
|
||||
println!("tcp listener error: {}", err);
|
||||
}
|
||||
} else {
|
||||
if let Err(err) = udp::start_udp_listener(l).await {
|
||||
println!("udp listener error: {}", err);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
listeners.insert(target.source, listener);
|
||||
listeners.insert(source_str, listener);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -11,39 +11,48 @@ use tokio::sync::broadcast;
|
||||
/// received or not.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Shutdown {
|
||||
/// `true` if the shutdown signal has been received
|
||||
shutdown: bool,
|
||||
/// `true` if the shutdown signal has been received
|
||||
shutdown: bool,
|
||||
|
||||
/// The receive half of the channel used to listen for shutdown.
|
||||
notify: broadcast::Receiver<()>,
|
||||
/// The receive half of the channel used to listen for shutdown.
|
||||
notify: broadcast::Receiver<()>,
|
||||
}
|
||||
|
||||
impl Shutdown {
|
||||
/// Create a new `Shutdown` backed by the given `broadcast::Receiver`.
|
||||
pub(crate) fn new(notify: broadcast::Receiver<()>) -> Shutdown {
|
||||
Shutdown {
|
||||
shutdown: false,
|
||||
notify,
|
||||
}
|
||||
}
|
||||
/// Create a new `Shutdown` backed by the given `broadcast::Receiver`.
|
||||
pub(crate) fn new(notify: broadcast::Receiver<()>) -> Shutdown {
|
||||
Shutdown {
|
||||
shutdown: false,
|
||||
notify,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `true` if the shutdown signal has been received.
|
||||
// pub(crate) fn is_shutdown(&self) -> bool {
|
||||
// self.shutdown
|
||||
// }
|
||||
/// Returns `true` if the shutdown signal has been received.
|
||||
// pub(crate) fn is_shutdown(&self) -> bool {
|
||||
// self.shutdown
|
||||
// }
|
||||
|
||||
/// Receive the shutdown notice, waiting if necessary.
|
||||
pub(crate) async fn recv(&mut self) {
|
||||
// If the shutdown signal has already been received, then return
|
||||
// immediately.
|
||||
if self.shutdown {
|
||||
return;
|
||||
}
|
||||
/// Receive the shutdown notice, waiting if necessary.
|
||||
pub(crate) async fn recv(&mut self) {
|
||||
// If the shutdown signal has already been received, then return
|
||||
// immediately.
|
||||
if self.shutdown {
|
||||
return;
|
||||
}
|
||||
|
||||
// Cannot receive a "lag error" as only one value is ever sent.
|
||||
let _ = self.notify.recv().await;
|
||||
// Cannot receive a "lag error" as only one value is ever sent.
|
||||
let _ = self.notify.recv().await;
|
||||
|
||||
// Remember that the signal has been received.
|
||||
self.shutdown = true;
|
||||
}
|
||||
// Remember that the signal has been received.
|
||||
self.shutdown = true;
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for Shutdown {
|
||||
fn clone(&self) -> Self {
|
||||
Shutdown {
|
||||
shutdown: self.shutdown,
|
||||
notify: self.notify.resubscribe(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
74
src/tcp.rs
74
src/tcp.rs
@ -5,54 +5,54 @@ use tokio::net::{TcpListener, TcpStream};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TcpHandler {
|
||||
stream: TcpStream,
|
||||
target: String,
|
||||
stream: TcpStream,
|
||||
target: String,
|
||||
}
|
||||
|
||||
impl TcpHandler {
|
||||
async fn run(&mut self) -> Result<(), Box<dyn Error>> {
|
||||
let mut stream = TcpStream::connect(&self.target).await?;
|
||||
async fn run(&mut self) -> Result<(), Box<dyn Error>> {
|
||||
let mut stream = TcpStream::connect(&self.target).await?;
|
||||
|
||||
tokio::io::copy_bidirectional(&mut self.stream, &mut stream).await?;
|
||||
tokio::io::copy_bidirectional(&mut self.stream, &mut stream).await?;
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn start_tcp_listener(
|
||||
mut listener_config: Listener,
|
||||
mut listener_config: Listener,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
println!("start listening on {}", &listener_config.source);
|
||||
let listener = TcpListener::bind(&listener_config.source).await?;
|
||||
println!("start listening on {}", &listener_config.source);
|
||||
let listener = TcpListener::bind(&listener_config.source).await?;
|
||||
|
||||
loop {
|
||||
let (next_socket, _) = tokio::select! {
|
||||
res = listener.accept() => res?,
|
||||
_ = listener_config.shutdown.recv() => {
|
||||
println!("Exiting listener!");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
loop {
|
||||
let (next_socket, _) = tokio::select! {
|
||||
res = listener.accept() => res?,
|
||||
_ = listener_config.shutdown.recv() => {
|
||||
println!("Exiting listener!");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
let targets = listener_config.targets.read().await;
|
||||
let mut rng = rand::thread_rng();
|
||||
let selected_target = targets.choose(&mut rng).unwrap();
|
||||
let targets = listener_config.targets.read().await;
|
||||
let mut rng = rand::thread_rng();
|
||||
let selected_target = targets.choose(&mut rng).unwrap();
|
||||
|
||||
println!(
|
||||
"new connection from {} forwarding to {}",
|
||||
next_socket.peer_addr()?,
|
||||
&selected_target
|
||||
);
|
||||
let mut handler = TcpHandler {
|
||||
stream: next_socket,
|
||||
target: selected_target.clone(),
|
||||
};
|
||||
println!(
|
||||
"new connection from {} forwarding to {}",
|
||||
next_socket.peer_addr()?,
|
||||
&selected_target
|
||||
);
|
||||
let mut handler = TcpHandler {
|
||||
stream: next_socket,
|
||||
target: selected_target.clone(),
|
||||
};
|
||||
|
||||
tokio::spawn(async move {
|
||||
// Process the connection. If an error is encountered, log it.
|
||||
if let Err(err) = handler.run().await {
|
||||
println!("connection error {}", err);
|
||||
}
|
||||
});
|
||||
}
|
||||
tokio::spawn(async move {
|
||||
// Process the connection. If an error is encountered, log it.
|
||||
if let Err(err) = handler.run().await {
|
||||
println!("connection error {}", err);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
178
src/udp.rs
178
src/udp.rs
@ -1,2 +1,176 @@
|
||||
// For UDP to work, there is some magic required with matches the returning packets
|
||||
// back to the original sender. Idk. how to do that right now, but maybe some day.
|
||||
use crate::listener::Listener;
|
||||
use crate::shutdown::Shutdown;
|
||||
use rand::seq::SliceRandom;
|
||||
use std::collections::HashMap;
|
||||
use std::error::Error;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::sync::broadcast;
|
||||
use tokio::sync::mpsc::{channel, Sender};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::time::Instant;
|
||||
|
||||
async fn create_dual_udpsocket(
|
||||
bind_addr: &String,
|
||||
) -> Result<(UdpSocket, UdpSocket), Box<dyn Error>> {
|
||||
let listener_std = std::net::UdpSocket::bind(bind_addr)?;
|
||||
let responder_std = listener_std.try_clone()?;
|
||||
let listener = UdpSocket::from_std(listener_std)?;
|
||||
let responder = UdpSocket::from_std(responder_std)?;
|
||||
|
||||
return Ok((listener, responder));
|
||||
}
|
||||
|
||||
fn get_udp_background_send(socket: UdpSocket, mut exit: Shutdown) -> Sender<(Vec<u8>, String)> {
|
||||
let (tx, mut rx) = channel::<(Vec<u8>, String)>(1);
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let (buf, dest) = (tokio::select! {
|
||||
res = rx.recv() => Some(res.unwrap()),
|
||||
_ = exit.recv() => {
|
||||
println!("Exiting listener!");
|
||||
return;
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let to_send = buf.as_slice();
|
||||
socket.send_to(to_send, &dest).await.expect(&format!(
|
||||
"Failed to forward response from upstream server to client {}",
|
||||
dest
|
||||
));
|
||||
}
|
||||
});
|
||||
|
||||
return tx;
|
||||
}
|
||||
|
||||
struct UdpHandler {
|
||||
last_packet: Arc<Mutex<Instant>>,
|
||||
kill: broadcast::Sender<()>,
|
||||
target: String,
|
||||
sender: Sender<(Vec<u8>, String)>,
|
||||
}
|
||||
|
||||
impl UdpHandler {
|
||||
async fn start(
|
||||
target: String,
|
||||
source: String,
|
||||
sender: Sender<(Vec<u8>, String)>,
|
||||
) -> Result<UdpHandler, Box<dyn Error>> {
|
||||
// Kill Channel
|
||||
let (tx, mut rx) = broadcast::channel::<()>(1);
|
||||
|
||||
let (listener, responder) = create_dual_udpsocket(&"0.0.0.0:0".to_owned()).await?;
|
||||
|
||||
let s = get_udp_background_send(responder, Shutdown::new(tx.subscribe()));
|
||||
|
||||
let last_packet = Arc::new(Mutex::new(Instant::now()));
|
||||
|
||||
let handler = UdpHandler {
|
||||
kill: tx,
|
||||
last_packet: last_packet.clone(),
|
||||
// source: source.clone(),
|
||||
target: target.clone(),
|
||||
sender: s,
|
||||
};
|
||||
|
||||
listener.connect(target).await?;
|
||||
tokio::spawn(async move {
|
||||
let mut buf = [0; 64 * 1024];
|
||||
loop {
|
||||
let (num_bytes, _) = tokio::select! {
|
||||
res = listener.recv_from(&mut buf) => res.unwrap(),
|
||||
_ = rx.recv() => {
|
||||
// FIXME: Source of memory leaks?
|
||||
return ;
|
||||
}
|
||||
};
|
||||
let mut n = last_packet.lock().await;
|
||||
*n = Instant::now();
|
||||
sender
|
||||
.send((buf[0..num_bytes].to_vec(), source.clone()))
|
||||
.await
|
||||
.expect(&format!("Failed to send answer to sender {}", source));
|
||||
}
|
||||
});
|
||||
return Ok(handler);
|
||||
}
|
||||
|
||||
async fn exit(&self) -> Result<(), Box<dyn Error>> {
|
||||
self.kill.send(())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn on_packet(&mut self, pkg: Vec<u8>) -> Result<(), Box<dyn Error>> {
|
||||
let mut n = self.last_packet.lock().await;
|
||||
*n = Instant::now();
|
||||
self.sender.send((pkg, self.target.clone())).await?;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn start_udp_listener(
|
||||
mut listener_config: Listener,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
println!("start listening on {}", &listener_config.source);
|
||||
let (listener, responder) =
|
||||
create_dual_udpsocket(&listener_config.source)
|
||||
.await
|
||||
.expect(&format!(
|
||||
"Failed to clone primary listening address socket {}",
|
||||
&listener_config.source,
|
||||
));
|
||||
|
||||
let sender = get_udp_background_send(responder, listener_config.shutdown.clone());
|
||||
|
||||
let mut connections: HashMap<String, UdpHandler> = HashMap::new();
|
||||
|
||||
let mut buf = [0; 64 * 1024];
|
||||
loop {
|
||||
let (num_bytes, src_addr) = tokio::select! {
|
||||
res = listener.recv_from(&mut buf) => res?,
|
||||
_ = listener_config.shutdown.recv() => {
|
||||
println!("Exiting listener!");
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let addr = src_addr.to_string();
|
||||
let handler_opt = connections.get_mut(&addr);
|
||||
let vec = buf[0..num_bytes].to_vec();
|
||||
let handler = match handler_opt {
|
||||
Some(handler) => handler,
|
||||
None => {
|
||||
let targets = listener_config.targets.read().await;
|
||||
let selected_target = {
|
||||
let mut rng = rand::thread_rng();
|
||||
targets.choose(&mut rng).unwrap()
|
||||
};
|
||||
|
||||
let handler =
|
||||
UdpHandler::start(selected_target.clone(), addr.clone(), sender.clone())
|
||||
.await?;
|
||||
connections.insert(addr.clone(), handler);
|
||||
|
||||
connections.get_mut(&addr).unwrap()
|
||||
}
|
||||
};
|
||||
|
||||
match handler.on_packet(vec).await {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
println!("Failed to forward request from client to server {}", err);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for handler in connections.values() {
|
||||
handler.exit().await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user