UDP seems to work now

This commit is contained in:
2023-02-05 00:36:53 +01:00
parent 7530ce3aae
commit 517889f875
7 changed files with 544 additions and 191 deletions

View File

@ -1,9 +1,9 @@
use crate::shutdown::Shutdown;
use crate::shutdown::ShutdownReceiver;
use std::sync::Arc;
use tokio::sync::RwLock;
pub(crate) struct Listener {
pub(crate) shutdown: Shutdown,
pub(crate) source: String,
pub(crate) targets: Arc<RwLock<Vec<String>>>,
pub(crate) shutdown: ShutdownReceiver,
pub(crate) source: String,
pub(crate) targets: Arc<RwLock<Vec<String>>>,
}

View File

@ -3,6 +3,7 @@ mod shutdown;
mod tcp;
mod udp;
use log::{debug, error, info, warn};
use serde::Deserialize;
use simple_error::bail;
use std::collections::{HashMap, HashSet};
@ -13,6 +14,8 @@ use std::sync::Arc;
use tokio::signal::unix::{signal, SignalKind};
use tokio::sync::{broadcast, RwLock};
pub type Result<T> = std::result::Result<T, Box<dyn Error>>;
#[derive(Debug, Deserialize)]
struct Target {
udp: Option<bool>,
@ -25,21 +28,21 @@ struct Config {
mappings: Vec<Target>,
}
fn load_yaml(path: &Path) -> Result<Config, Box<dyn Error>> {
fn load_yaml(path: &Path) -> Result<Config> {
let file = File::open(path)?;
let config: Config = serde_yaml::from_reader(file).expect("Failed to parse!"); //TODO: Print path
return Ok(config);
}
fn load_json(path: &Path) -> Result<Config, Box<dyn Error>> {
fn load_json(path: &Path) -> Result<Config> {
let file = File::open(path)?;
let config: Config = serde_json::from_reader(file).expect("Failed to parse!"); //TODO: Print path
return Ok(config);
}
fn load_config() -> Result<Config, Box<dyn Error>> {
fn load_config() -> Result<Config> {
for path in [
"config.yaml",
"config.json",
@ -63,13 +66,28 @@ fn load_config() -> Result<Config, Box<dyn Error>> {
// #[derive(Debug)]
struct ActiveListener {
udp: bool,
notify_shutdown: broadcast::Sender<()>,
targets: Arc<RwLock<Vec<String>>>,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
async fn main() -> Result<()> {
fern::Dispatch::new()
.format(|out, message, record| {
out.finish(format_args!(
"{}[{}][{}] {}",
chrono::Local::now().format("[%Y-%m-%d][%H:%M:%S]"),
record.target(),
record.level(),
message
))
})
// Add blanket level filter -
.level(log::LevelFilter::Info)
.level_for("rustocat", log::LevelFilter::Trace)
.chain(std::io::stdout())
.apply()?;
let mut listeners: HashMap<String, ActiveListener> = HashMap::new();
let mut sighup_stream = signal(SignalKind::hangup())?;
@ -108,7 +126,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}
if invalid {
println!("Found invalid targets! Adjusting!");
warn!("Found invalid targets! Adjusting!");
let mut w = listener.targets.write().await;
*w = target.targets;
}
@ -118,11 +136,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let listener = ActiveListener {
notify_shutdown: notify_shutdown,
targets: Arc::new(RwLock::new(target.targets)),
udp: target.udp == None || target.udp == Some(false),
};
let l = listener::Listener {
shutdown: shutdown::Shutdown::new(listener.notify_shutdown.subscribe()),
shutdown: shutdown::ShutdownReceiver::new(listener.notify_shutdown.subscribe()),
source: target.source.clone(),
targets: listener.targets.clone(),
};
@ -130,11 +147,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
tokio::spawn(async move {
if target.udp == None || target.udp == Some(false) {
if let Err(err) = tcp::start_tcp_listener(l).await {
println!("tcp listener error: {}", err);
error!("tcp listener error: {}", err);
}
} else {
if let Err(err) = udp::start_udp_listener(l).await {
println!("udp listener error: {}", err);
error!("udp listener error: {}", err);
}
}
});
@ -151,13 +168,22 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
for del_key in to_delete {
if let Some(listener) = listeners.get(&del_key) {
let _ = listener.notify_shutdown.send(()); //Errors are irrelevant here. I guess....
println!("Removing listener!");
let res = listener.notify_shutdown.send(()); //Errors are irrelevant here. I guess....
match res {
Ok(_) => {
info!("Sent shutdown signal!");
}
Err(_) => {
warn!("Failed to send shutdown signal!");
}
}
debug!("Removing listener!");
listeners.remove(&del_key);
}
}
sighup_stream.recv().await;
println!("Recevied SIGHUP!");
info!("Recevied SIGHUP, reloading config!");
}
}

View File

@ -1,56 +1,60 @@
use tokio::sync::broadcast;
/// Listens for the server shutdown signal.
///
/// Shutdown is signalled using a `broadcast::Receiver`. Only a single value is
/// ever sent. Once a value has been sent via the broadcast channel, the server
/// should shutdown.
///
/// The `Shutdown` struct listens for the signal and tracks that the signal has
/// been received. Callers may query for whether the shutdown signal has been
/// received or not.
#[derive(Debug)]
pub(crate) struct Shutdown {
/// `true` if the shutdown signal has been received
shutdown: bool,
/// The receive half of the channel used to listen for shutdown.
notify: broadcast::Receiver<()>,
sender: broadcast::Sender<()>,
}
impl Shutdown {
/// Create a new `Shutdown` backed by the given `broadcast::Receiver`.
pub(crate) fn new(notify: broadcast::Receiver<()>) -> Shutdown {
pub(crate) fn new() -> Shutdown {
let (sender, notify) = broadcast::channel(1);
Shutdown {
shutdown: false,
notify,
sender,
}
}
pub(crate) fn shutdown(&mut self) {
if self.shutdown {
return;
}
let _ = self.sender.send(());
self.shutdown = true;
}
pub(crate) fn receiver(&self) -> ShutdownReceiver {
ShutdownReceiver::new(self.notify.resubscribe())
}
}
#[derive(Debug)]
pub(crate) struct ShutdownReceiver {
shutdown: bool,
notify: broadcast::Receiver<()>,
}
impl ShutdownReceiver {
pub(crate) fn new(notify: broadcast::Receiver<()>) -> ShutdownReceiver {
ShutdownReceiver {
shutdown: false,
notify,
}
}
/// Returns `true` if the shutdown signal has been received.
// pub(crate) fn is_shutdown(&self) -> bool {
// self.shutdown
// }
/// Receive the shutdown notice, waiting if necessary.
pub(crate) async fn recv(&mut self) {
// If the shutdown signal has already been received, then return
// immediately.
if self.shutdown {
return;
}
// Cannot receive a "lag error" as only one value is ever sent.
let _ = self.notify.recv().await;
// Remember that the signal has been received.
self.shutdown = true;
}
}
impl Clone for Shutdown {
impl Clone for ShutdownReceiver {
fn clone(&self) -> Self {
Shutdown {
ShutdownReceiver {
shutdown: self.shutdown,
notify: self.notify.resubscribe(),
}

View File

@ -1,4 +1,5 @@
use crate::listener::Listener;
use log::{info, trace, warn};
use rand::seq::SliceRandom;
use std::error::Error;
use tokio::net::{TcpListener, TcpStream};
@ -22,14 +23,14 @@ impl TcpHandler {
pub(crate) async fn start_tcp_listener(
mut listener_config: Listener,
) -> Result<(), Box<dyn Error>> {
println!("start listening on {}", &listener_config.source);
info!("start listening on {}", &listener_config.source);
let listener = TcpListener::bind(&listener_config.source).await?;
loop {
let (next_socket, _) = tokio::select! {
res = listener.accept() => res?,
_ = listener_config.shutdown.recv() => {
println!("Exiting listener!");
info!("Exiting listener!");
return Ok(());
}
};
@ -38,7 +39,7 @@ pub(crate) async fn start_tcp_listener(
let mut rng = rand::thread_rng();
let selected_target = targets.choose(&mut rng).unwrap();
println!(
trace!(
"new connection from {} forwarding to {}",
next_socket.peer_addr()?,
&selected_target
@ -51,7 +52,7 @@ pub(crate) async fn start_tcp_listener(
tokio::spawn(async move {
// Process the connection. If an error is encountered, log it.
if let Err(err) = handler.run().await {
println!("connection error {}", err);
warn!("connection error {}", err);
}
});
}

View File

@ -1,176 +1,225 @@
use crate::listener::Listener;
use crate::shutdown::Shutdown;
use rand::seq::SliceRandom;
use std::collections::HashMap;
use std::error::Error;
use std::sync::atomic::{AtomicI32, AtomicU64};
use std::sync::Arc;
use log::{debug, error, info, trace};
use rand::seq::SliceRandom;
use tokio::net::UdpSocket;
use tokio::sync::broadcast;
use tokio::sync::mpsc::{channel, Sender};
use tokio::sync::Mutex;
use tokio::time::Instant;
async fn create_dual_udpsocket(
bind_addr: &String,
) -> Result<(UdpSocket, UdpSocket), Box<dyn Error>> {
let listener_std = std::net::UdpSocket::bind(bind_addr)?;
let responder_std = listener_std.try_clone()?;
let listener = UdpSocket::from_std(listener_std)?;
let responder = UdpSocket::from_std(responder_std)?;
use crate::listener::Listener;
use crate::shutdown::{Shutdown, ShutdownReceiver};
use crate::Result;
return Ok((listener, responder));
const CONNECTION_TIMEOUT: i32 = 5;
#[derive(Clone)]
struct UDPMultiSender {
socket: Arc<UdpSocket>,
}
fn get_udp_background_send(socket: UdpSocket, mut exit: Shutdown) -> Sender<(Vec<u8>, String)> {
let (tx, mut rx) = channel::<(Vec<u8>, String)>(1);
tokio::spawn(async move {
loop {
let (buf, dest) = (tokio::select! {
res = rx.recv() => Some(res.unwrap()),
_ = exit.recv() => {
println!("Exiting listener!");
return;
}
})
.unwrap();
let to_send = buf.as_slice();
socket.send_to(to_send, &dest).await.expect(&format!(
"Failed to forward response from upstream server to client {}",
dest
));
}
});
return tx;
}
struct UdpHandler {
last_packet: Arc<Mutex<Instant>>,
kill: broadcast::Sender<()>,
target: String,
sender: Sender<(Vec<u8>, String)>,
}
impl UdpHandler {
async fn start(
target: String,
source: String,
sender: Sender<(Vec<u8>, String)>,
) -> Result<UdpHandler, Box<dyn Error>> {
// Kill Channel
let (tx, mut rx) = broadcast::channel::<()>(1);
let (listener, responder) = create_dual_udpsocket(&"0.0.0.0:0".to_owned()).await?;
let s = get_udp_background_send(responder, Shutdown::new(tx.subscribe()));
let last_packet = Arc::new(Mutex::new(Instant::now()));
let handler = UdpHandler {
kill: tx,
last_packet: last_packet.clone(),
// source: source.clone(),
target: target.clone(),
sender: s,
impl UDPMultiSender {
fn new(socket: UdpSocket) -> Self {
return Self {
socket: Arc::new(socket),
};
listener.connect(target).await?;
tokio::spawn(async move {
let mut buf = [0; 64 * 1024];
loop {
let (num_bytes, _) = tokio::select! {
res = listener.recv_from(&mut buf) => res.unwrap(),
_ = rx.recv() => {
// FIXME: Source of memory leaks?
return ;
}
};
let mut n = last_packet.lock().await;
*n = Instant::now();
sender
.send((buf[0..num_bytes].to_vec(), source.clone()))
.await
.expect(&format!("Failed to send answer to sender {}", source));
}
});
return Ok(handler);
}
async fn exit(&self) -> Result<(), Box<dyn Error>> {
self.kill.send(())?;
Ok(())
async fn send(&self, buf: &[u8]) -> Result<()> {
self.socket.send(buf).await?;
return Ok(());
}
async fn on_packet(&mut self, pkg: Vec<u8>) -> Result<(), Box<dyn Error>> {
let mut n = self.last_packet.lock().await;
*n = Instant::now();
self.sender.send((pkg, self.target.clone())).await?;
async fn send_to(&self, buf: &[u8], dest: &str) -> Result<()> {
self.socket.send_to(buf, dest).await?;
return Ok(());
}
}
pub(crate) async fn start_udp_listener(
mut listener_config: Listener,
) -> Result<(), Box<dyn Error>> {
println!("start listening on {}", &listener_config.source);
let (listener, responder) =
create_dual_udpsocket(&listener_config.source)
.await
.expect(&format!(
"Failed to clone primary listening address socket {}",
&listener_config.source,
));
async fn splitted_udp_socket(bind_addr: &str) -> Result<(UdpSocket, UDPMultiSender)> {
let listener_std = std::net::UdpSocket::bind(bind_addr)?;
let responder_std = listener_std.try_clone()?;
listener_std.set_nonblocking(true)?;
responder_std.set_nonblocking(true)?;
let listener = UdpSocket::from_std(listener_std)?;
let responder = UDPMultiSender::new(UdpSocket::from_std(responder_std)?);
let sender = get_udp_background_send(responder, listener_config.shutdown.clone());
return Ok((listener, responder));
}
let mut connections: HashMap<String, UdpHandler> = HashMap::new();
struct UDPChannel {
last_packet: Arc<AtomicI32>,
sender: UDPMultiSender,
shutdown: Shutdown,
from: String,
upstream: String,
}
impl UDPChannel {
async fn start(
upstream: String,
responder: UDPMultiSender,
source_addr: String,
) -> Result<Self> {
let (upstream_listener, upstream_responder) = splitted_udp_socket("0.0.0.0:0").await?;
upstream_listener.connect(upstream.clone()).await?;
let shutdown = Shutdown::new();
let mut shutdown_receiver = shutdown.receiver();
let channel = Self {
last_packet: Arc::new(AtomicI32::new(CONNECTION_TIMEOUT)),
sender: upstream_responder,
shutdown,
from: source_addr.clone(),
upstream: upstream.clone(),
};
let last_packet = channel.last_packet.clone();
tokio::spawn(async move {
let mut buf = [0; 64 * 1024];
loop {
let num_bytes = tokio::select! {
res = upstream_listener.recv(&mut buf) => res.unwrap(),
_ = shutdown_receiver.recv() => {
info!("Exiting");
return;
}
};
trace!("[{}] <- [{}] ...", source_addr, upstream);
last_packet.store(CONNECTION_TIMEOUT, std::sync::atomic::Ordering::Relaxed);
match responder.send_to(&buf[..num_bytes], &source_addr).await {
Ok(_) => {}
Err(e) => {
error!("Failed to send packet: {}", e);
}
};
trace!("[{}] <- [{}] ---", source_addr, upstream);
}
});
return Ok(channel);
}
async fn close(&mut self) {
self.shutdown.shutdown();
debug!("Closing connection from {}", self.from);
}
async fn handle(&self, data: &[u8]) {
trace!("[{}] -> [{}] ...", self.from, self.upstream);
self.last_packet
.store(CONNECTION_TIMEOUT, std::sync::atomic::Ordering::Relaxed);
match self.sender.send(data).await {
Ok(_) => {}
Err(e) => {
error!("Failed to send packet: {}", e);
}
}
trace!("[{}] -> [{}] ---", self.from, self.upstream);
}
}
fn start_stale_check(
connections: Arc<Mutex<HashMap<String, UDPChannel>>>,
mut shutdown: ShutdownReceiver,
) {
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(1));
loop {
tokio::select! {
_ = interval.tick() => {}
_ = shutdown.recv() => {
info!("Exiting listener!");
break;
}
}
trace!("Checking for stale connections");
trace!("Waiting for connections lock");
let mut connections = connections.lock().await;
trace!("Got connections lock");
let mut to_remove: Vec<String> = Vec::new();
for (source_addr, channel) in connections.iter() {
let last = channel
.last_packet
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
if last <= 0 {
to_remove.push(source_addr.clone());
}
}
for source_addr in to_remove {
debug!("Closing connection from {}", source_addr);
let mut channel = connections.remove(&source_addr).unwrap();
channel.close().await;
}
trace!("Checking for stale connections - done");
drop(connections);
}
info!("Exiting stale check");
});
}
pub(crate) async fn start_udp_listener(mut listener_config: Listener) -> Result<()> {
info!("start listening on {}", &listener_config.source);
let (listener, responder) = splitted_udp_socket(&listener_config.source).await?;
let connections = Arc::new(Mutex::new(HashMap::<String, UDPChannel>::new()));
start_stale_check(connections.clone(), listener_config.shutdown.clone());
let mut buf = [0; 64 * 1024];
loop {
let mut buf = vec![0; 1024 * 64];
trace!("Waiting for packet or shutdown");
let (num_bytes, src_addr) = tokio::select! {
res = listener.recv_from(&mut buf) => res?,
_ = listener_config.shutdown.recv() => {
println!("Exiting listener!");
info!("Exiting listener!");
break;
}
};
let addr = src_addr.to_string();
let handler_opt = connections.get_mut(&addr);
let vec = buf[0..num_bytes].to_vec();
let data = buf[0..num_bytes].to_vec();
let source_addr = src_addr.to_string();
let mut connections = connections.lock().await;
let handler_opt = connections.get_mut(&source_addr);
let handler = match handler_opt {
Some(handler) => handler,
None => {
debug!("New connection from {}", source_addr);
let targets = listener_config.targets.read().await;
let selected_target = {
let upstream = {
let mut rng = rand::thread_rng();
targets.choose(&mut rng).unwrap()
};
let handler =
UdpHandler::start(selected_target.clone(), addr.clone(), sender.clone())
UDPChannel::start(upstream.to_string(), responder.clone(), source_addr.clone())
.await?;
connections.insert(addr.clone(), handler);
connections.get_mut(&addr).unwrap()
connections.insert(source_addr.clone(), handler);
connections.get_mut(&source_addr).unwrap()
}
};
match handler.on_packet(vec).await {
Ok(_) => (),
Err(err) => {
println!("Failed to forward request from client to server {}", err);
return Ok(());
}
}
handler.handle(&data).await;
trace!("Handled packet");
drop(connections);
}
for handler in connections.values() {
handler.exit().await?;
for (_, handler) in connections.lock().await.iter_mut() {
handler.close().await;
}
Ok(())
return Ok(());
}