diff --git a/Cargo.lock b/Cargo.lock index ae45c78..b96bda8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,15 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "autocfg" version = "1.1.0" @@ -14,18 +23,114 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "bumpalo" +version = "3.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d261e256854913907f67ed06efbc3338dfe6179796deefc1ff763fc1aee5535" + [[package]] name = "bytes" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c4872d67bab6358e59559027aa3b9157c53d9358c51423c17554809a8858e0f8" +[[package]] +name = "cc" +version = "1.0.79" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" + [[package]] name = "cfg-if" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chrono" +version = "0.4.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16b0a3d9ed01224b22057780a37bb8c5dbfe1be8ba48678e7bf57ec4b385411f" +dependencies = [ + "iana-time-zone", + "js-sys", + "num-integer", + "num-traits", + "time", + "wasm-bindgen", + "winapi", +] + +[[package]] +name = "codespan-reporting" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3538270d33cc669650c4b093848450d380def10c331d38c768e34cac80576e6e" +dependencies = [ + "termcolor", + "unicode-width", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5827cebf4670468b8772dd191856768aedcb1b0278a04f989f7766351917b9dc" + +[[package]] +name = "cxx" +version = "1.0.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc831ee6a32dd495436e317595e639a587aa9907bef96fe6e6abc290ab6204e9" +dependencies = [ + "cc", + "cxxbridge-flags", + "cxxbridge-macro", + "link-cplusplus", +] + +[[package]] +name = "cxx-build" +version = "1.0.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94331d54f1b1a8895cd81049f7eaaaef9d05a7dcb4d1fd08bf3ff0806246789d" +dependencies = [ + "cc", + "codespan-reporting", + "once_cell", + "proc-macro2", + "quote", + "scratch", + "syn", +] + +[[package]] +name = "cxxbridge-flags" +version = "1.0.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48dcd35ba14ca9b40d6e4b4b39961f23d835dbb8eed74565ded361d93e1feb8a" + +[[package]] +name = "cxxbridge-macro" +version = "1.0.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81bbeb29798b407ccd82a3324ade1a7286e0d29851475990b612670f6f5124d2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "fern" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3bdd7b0849075e79ee9a1836df22c717d1eba30451796fdc631b04565dd11e2a" +dependencies = [ + "log", +] + [[package]] name = "futures" version = "0.3.21" @@ -123,7 +228,7 @@ checksum = "4eb1a864a501629691edf6c15a593b7a51eebaa1e8468e9ddc623de7c9b58ec6" dependencies = [ "cfg-if", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", ] [[package]] @@ -141,6 +246,30 @@ dependencies = [ "libc", ] +[[package]] +name = "iana-time-zone" +version = "0.1.53" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64c122667b287044802d6ce17ee2ddf13207ed924c712de9a66a5814d5b64765" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "winapi", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0703ae284fc167426161c2e3f1da3ea71d94b21bedbcc9494e92b28e334e3dca" +dependencies = [ + "cxx", + "cxx-build", +] + [[package]] name = "indexmap" version = "1.9.1" @@ -158,10 +287,28 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1aab8fc367588b89dcee83ab0fd66b72b50b72fa1904d7095045ace2b0c81c35" [[package]] -name = "libc" -version = "0.2.121" +name = "js-sys" +version = "0.3.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efaa7b300f3b5fe8eb6bf21ce3895e1751d9665086af2d64b42f19701015ff4f" +checksum = "445dde2150c55e483f3d8416706b97ec8e8237c307e5b7b4b8dd15e6af2a0730" +dependencies = [ + "wasm-bindgen", +] + +[[package]] +name = "libc" +version = "0.2.139" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "201de327520df007757c1f0adce6e827fe8562fbc28bfd9c15571c66ca1f5f79" + +[[package]] +name = "link-cplusplus" +version = "1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecd207c9c713c34f95a097a5b029ac2ce6010530c7b49d7fea24d977dede04f5" +dependencies = [ + "cc", +] [[package]] name = "lock_api" @@ -174,9 +321,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.14" +version = "0.4.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51b9bbe6c47d51fc3e1a9b945965946b4c44142ab8792c50835a980d362c2710" +checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e" dependencies = [ "cfg-if", ] @@ -195,10 +342,29 @@ checksum = "57ee1c23c7c63b0c9250c339ffdc69255f110b298b901b9f6c82547b7b87caaf" dependencies = [ "libc", "log", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "windows-sys 0.36.1", ] +[[package]] +name = "num-integer" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +dependencies = [ + "autocfg", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" +dependencies = [ + "autocfg", +] + [[package]] name = "num_cpus" version = "1.13.1" @@ -211,9 +377,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.10.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87f3e037eac156d1775da914196f0f37741a274155e34a0b7e427c35d2a2ecb9" +checksum = "6f61fba1741ea2b3d6a1e3178721804bb716a68a6aeba1149b5d52e3d464ea66" [[package]] name = "parking_lot" @@ -317,7 +483,10 @@ dependencies = [ name = "rustocat" version = "0.0.3" dependencies = [ + "chrono", + "fern", "futures", + "log", "rand", "serde", "serde_json", @@ -338,6 +507,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +[[package]] +name = "scratch" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddccb15bcce173023b3fedd9436f882a0739b8dfb45e4f6b6002bee5929f61b2" + [[package]] name = "serde" version = "1.0.144" @@ -430,6 +605,26 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "termcolor" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be55cf8942feac5c765c2c993422806843c9a9a45d4d5c407ad6dd2ea95eb9b6" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "time" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b797afad3f312d1c66a56d11d0316f916356d11bd158fbc6ca6389ff6bf805a" +dependencies = [ + "libc", + "wasi 0.10.0+wasi-snapshot-preview1", + "winapi", +] + [[package]] name = "tokio" version = "1.21.1" @@ -468,18 +663,84 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dcc811dc4066ac62f84f11307873c4850cb653bfa9b1719cee2bd2204a4bc5dd" +[[package]] +name = "unicode-width" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b" + [[package]] name = "unsafe-libyaml" version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c1e5fa573d8ac5f1a856f8d7be41d390ee973daf97c806b2c1a465e4e1406e68" +[[package]] +name = "wasi" +version = "0.10.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a143597ca7c7793eff794def352d41792a93c481eb1042423ff7ff72ba2c31f" + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasm-bindgen" +version = "0.2.84" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31f8dcbc21f30d9b8f2ea926ecb58f6b91192c17e9d33594b3df58b2007ca53b" +dependencies = [ + "cfg-if", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.84" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95ce90fd5bcc06af55a641a86428ee4229e44e07033963a2290a8e241607ccb9" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.84" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c21f77c0bedc37fd5dc21f897894a5ca01e7bb159884559461862ae90c0b4c5" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.84" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2aff81306fcac3c7515ad4e177f521b5c9a15f2b08f4e32d823066102f35a5f6" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.84" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0046fef7e28c3804e5e38bfa31ea2a0f73905319b677e57ebe37e49358989b5d" + [[package]] name = "winapi" version = "0.3.9" @@ -496,6 +757,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" +[[package]] +name = "winapi-util" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178" +dependencies = [ + "winapi", +] + [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" diff --git a/Cargo.toml b/Cargo.toml index c086116..5252025 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,17 +8,20 @@ license = "ISC" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -tokio = { version = "1.21.1", features = ["full"] } +tokio = { version = "1", features = ["full"] } futures = "0.3.21" -serde = { version = "1.0.136", features = ["derive"] } -serde_json = "1.0.79" +serde = { version = "1", features = ["derive"] } +serde_json = "1" serde_yaml = "0.9.13" simple-error = "0.2.3" rand = "0.8.5" +log = "0.4.17" +fern = "0.6.1" +chrono = "0.4.23" [profile.release] opt-level = 3 # Optimize for size. lto = true # Enable Link Time Optimization codegen-units = 1 # Reduce number of codegen units to increase optimizations. panic = 'abort' # Abort on panic -strip = true # Strip symbols from binary* \ No newline at end of file +strip = true # Strip symbols from binary* diff --git a/src/listener.rs b/src/listener.rs index 2f9102d..bdc3160 100644 --- a/src/listener.rs +++ b/src/listener.rs @@ -1,9 +1,9 @@ -use crate::shutdown::Shutdown; +use crate::shutdown::ShutdownReceiver; use std::sync::Arc; use tokio::sync::RwLock; pub(crate) struct Listener { - pub(crate) shutdown: Shutdown, - pub(crate) source: String, - pub(crate) targets: Arc>>, + pub(crate) shutdown: ShutdownReceiver, + pub(crate) source: String, + pub(crate) targets: Arc>>, } diff --git a/src/main.rs b/src/main.rs index 9de7a14..26951dc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,6 +3,7 @@ mod shutdown; mod tcp; mod udp; +use log::{debug, error, info, warn}; use serde::Deserialize; use simple_error::bail; use std::collections::{HashMap, HashSet}; @@ -13,6 +14,8 @@ use std::sync::Arc; use tokio::signal::unix::{signal, SignalKind}; use tokio::sync::{broadcast, RwLock}; +pub type Result = std::result::Result>; + #[derive(Debug, Deserialize)] struct Target { udp: Option, @@ -25,21 +28,21 @@ struct Config { mappings: Vec, } -fn load_yaml(path: &Path) -> Result> { +fn load_yaml(path: &Path) -> Result { let file = File::open(path)?; let config: Config = serde_yaml::from_reader(file).expect("Failed to parse!"); //TODO: Print path return Ok(config); } -fn load_json(path: &Path) -> Result> { +fn load_json(path: &Path) -> Result { let file = File::open(path)?; let config: Config = serde_json::from_reader(file).expect("Failed to parse!"); //TODO: Print path return Ok(config); } -fn load_config() -> Result> { +fn load_config() -> Result { for path in [ "config.yaml", "config.json", @@ -63,13 +66,28 @@ fn load_config() -> Result> { // #[derive(Debug)] struct ActiveListener { - udp: bool, notify_shutdown: broadcast::Sender<()>, targets: Arc>>, } #[tokio::main] -async fn main() -> Result<(), Box> { +async fn main() -> Result<()> { + fern::Dispatch::new() + .format(|out, message, record| { + out.finish(format_args!( + "{}[{}][{}] {}", + chrono::Local::now().format("[%Y-%m-%d][%H:%M:%S]"), + record.target(), + record.level(), + message + )) + }) + // Add blanket level filter - + .level(log::LevelFilter::Info) + .level_for("rustocat", log::LevelFilter::Trace) + .chain(std::io::stdout()) + .apply()?; + let mut listeners: HashMap = HashMap::new(); let mut sighup_stream = signal(SignalKind::hangup())?; @@ -108,7 +126,7 @@ async fn main() -> Result<(), Box> { } if invalid { - println!("Found invalid targets! Adjusting!"); + warn!("Found invalid targets! Adjusting!"); let mut w = listener.targets.write().await; *w = target.targets; } @@ -118,11 +136,10 @@ async fn main() -> Result<(), Box> { let listener = ActiveListener { notify_shutdown: notify_shutdown, targets: Arc::new(RwLock::new(target.targets)), - udp: target.udp == None || target.udp == Some(false), }; let l = listener::Listener { - shutdown: shutdown::Shutdown::new(listener.notify_shutdown.subscribe()), + shutdown: shutdown::ShutdownReceiver::new(listener.notify_shutdown.subscribe()), source: target.source.clone(), targets: listener.targets.clone(), }; @@ -130,11 +147,11 @@ async fn main() -> Result<(), Box> { tokio::spawn(async move { if target.udp == None || target.udp == Some(false) { if let Err(err) = tcp::start_tcp_listener(l).await { - println!("tcp listener error: {}", err); + error!("tcp listener error: {}", err); } } else { if let Err(err) = udp::start_udp_listener(l).await { - println!("udp listener error: {}", err); + error!("udp listener error: {}", err); } } }); @@ -151,13 +168,22 @@ async fn main() -> Result<(), Box> { for del_key in to_delete { if let Some(listener) = listeners.get(&del_key) { - let _ = listener.notify_shutdown.send(()); //Errors are irrelevant here. I guess.... - println!("Removing listener!"); + let res = listener.notify_shutdown.send(()); //Errors are irrelevant here. I guess.... + match res { + Ok(_) => { + info!("Sent shutdown signal!"); + } + Err(_) => { + warn!("Failed to send shutdown signal!"); + } + } + + debug!("Removing listener!"); listeners.remove(&del_key); } } sighup_stream.recv().await; - println!("Recevied SIGHUP!"); + info!("Recevied SIGHUP, reloading config!"); } } diff --git a/src/shutdown.rs b/src/shutdown.rs index 69687fc..896a15e 100644 --- a/src/shutdown.rs +++ b/src/shutdown.rs @@ -1,56 +1,60 @@ use tokio::sync::broadcast; -/// Listens for the server shutdown signal. -/// -/// Shutdown is signalled using a `broadcast::Receiver`. Only a single value is -/// ever sent. Once a value has been sent via the broadcast channel, the server -/// should shutdown. -/// -/// The `Shutdown` struct listens for the signal and tracks that the signal has -/// been received. Callers may query for whether the shutdown signal has been -/// received or not. -#[derive(Debug)] pub(crate) struct Shutdown { - /// `true` if the shutdown signal has been received shutdown: bool, - - /// The receive half of the channel used to listen for shutdown. notify: broadcast::Receiver<()>, + sender: broadcast::Sender<()>, } impl Shutdown { - /// Create a new `Shutdown` backed by the given `broadcast::Receiver`. - pub(crate) fn new(notify: broadcast::Receiver<()>) -> Shutdown { + pub(crate) fn new() -> Shutdown { + let (sender, notify) = broadcast::channel(1); Shutdown { shutdown: false, notify, + sender, + } + } + + pub(crate) fn shutdown(&mut self) { + if self.shutdown { + return; + } + let _ = self.sender.send(()); + self.shutdown = true; + } + + pub(crate) fn receiver(&self) -> ShutdownReceiver { + ShutdownReceiver::new(self.notify.resubscribe()) + } +} + +#[derive(Debug)] +pub(crate) struct ShutdownReceiver { + shutdown: bool, + notify: broadcast::Receiver<()>, +} + +impl ShutdownReceiver { + pub(crate) fn new(notify: broadcast::Receiver<()>) -> ShutdownReceiver { + ShutdownReceiver { + shutdown: false, + notify, } } - /// Returns `true` if the shutdown signal has been received. - // pub(crate) fn is_shutdown(&self) -> bool { - // self.shutdown - // } - - /// Receive the shutdown notice, waiting if necessary. pub(crate) async fn recv(&mut self) { - // If the shutdown signal has already been received, then return - // immediately. if self.shutdown { return; } - - // Cannot receive a "lag error" as only one value is ever sent. let _ = self.notify.recv().await; - - // Remember that the signal has been received. self.shutdown = true; } } -impl Clone for Shutdown { +impl Clone for ShutdownReceiver { fn clone(&self) -> Self { - Shutdown { + ShutdownReceiver { shutdown: self.shutdown, notify: self.notify.resubscribe(), } diff --git a/src/tcp.rs b/src/tcp.rs index 176d99b..45f4d9a 100644 --- a/src/tcp.rs +++ b/src/tcp.rs @@ -1,4 +1,5 @@ use crate::listener::Listener; +use log::{info, trace, warn}; use rand::seq::SliceRandom; use std::error::Error; use tokio::net::{TcpListener, TcpStream}; @@ -22,14 +23,14 @@ impl TcpHandler { pub(crate) async fn start_tcp_listener( mut listener_config: Listener, ) -> Result<(), Box> { - println!("start listening on {}", &listener_config.source); + info!("start listening on {}", &listener_config.source); let listener = TcpListener::bind(&listener_config.source).await?; loop { let (next_socket, _) = tokio::select! { res = listener.accept() => res?, _ = listener_config.shutdown.recv() => { - println!("Exiting listener!"); + info!("Exiting listener!"); return Ok(()); } }; @@ -38,7 +39,7 @@ pub(crate) async fn start_tcp_listener( let mut rng = rand::thread_rng(); let selected_target = targets.choose(&mut rng).unwrap(); - println!( + trace!( "new connection from {} forwarding to {}", next_socket.peer_addr()?, &selected_target @@ -51,7 +52,7 @@ pub(crate) async fn start_tcp_listener( tokio::spawn(async move { // Process the connection. If an error is encountered, log it. if let Err(err) = handler.run().await { - println!("connection error {}", err); + warn!("connection error {}", err); } }); } diff --git a/src/udp.rs b/src/udp.rs index cc38cd1..46af522 100644 --- a/src/udp.rs +++ b/src/udp.rs @@ -1,176 +1,225 @@ -use crate::listener::Listener; -use crate::shutdown::Shutdown; -use rand::seq::SliceRandom; use std::collections::HashMap; -use std::error::Error; +use std::sync::atomic::{AtomicI32, AtomicU64}; use std::sync::Arc; + +use log::{debug, error, info, trace}; +use rand::seq::SliceRandom; + use tokio::net::UdpSocket; -use tokio::sync::broadcast; -use tokio::sync::mpsc::{channel, Sender}; use tokio::sync::Mutex; use tokio::time::Instant; -async fn create_dual_udpsocket( - bind_addr: &String, -) -> Result<(UdpSocket, UdpSocket), Box> { - let listener_std = std::net::UdpSocket::bind(bind_addr)?; - let responder_std = listener_std.try_clone()?; - let listener = UdpSocket::from_std(listener_std)?; - let responder = UdpSocket::from_std(responder_std)?; +use crate::listener::Listener; +use crate::shutdown::{Shutdown, ShutdownReceiver}; +use crate::Result; - return Ok((listener, responder)); +const CONNECTION_TIMEOUT: i32 = 5; + +#[derive(Clone)] +struct UDPMultiSender { + socket: Arc, } -fn get_udp_background_send(socket: UdpSocket, mut exit: Shutdown) -> Sender<(Vec, String)> { - let (tx, mut rx) = channel::<(Vec, String)>(1); - - tokio::spawn(async move { - loop { - let (buf, dest) = (tokio::select! { - res = rx.recv() => Some(res.unwrap()), - _ = exit.recv() => { - println!("Exiting listener!"); - return; - } - }) - .unwrap(); - - let to_send = buf.as_slice(); - socket.send_to(to_send, &dest).await.expect(&format!( - "Failed to forward response from upstream server to client {}", - dest - )); - } - }); - - return tx; -} - -struct UdpHandler { - last_packet: Arc>, - kill: broadcast::Sender<()>, - target: String, - sender: Sender<(Vec, String)>, -} - -impl UdpHandler { - async fn start( - target: String, - source: String, - sender: Sender<(Vec, String)>, - ) -> Result> { - // Kill Channel - let (tx, mut rx) = broadcast::channel::<()>(1); - - let (listener, responder) = create_dual_udpsocket(&"0.0.0.0:0".to_owned()).await?; - - let s = get_udp_background_send(responder, Shutdown::new(tx.subscribe())); - - let last_packet = Arc::new(Mutex::new(Instant::now())); - - let handler = UdpHandler { - kill: tx, - last_packet: last_packet.clone(), - // source: source.clone(), - target: target.clone(), - sender: s, +impl UDPMultiSender { + fn new(socket: UdpSocket) -> Self { + return Self { + socket: Arc::new(socket), }; - - listener.connect(target).await?; - tokio::spawn(async move { - let mut buf = [0; 64 * 1024]; - loop { - let (num_bytes, _) = tokio::select! { - res = listener.recv_from(&mut buf) => res.unwrap(), - _ = rx.recv() => { - // FIXME: Source of memory leaks? - return ; - } - }; - let mut n = last_packet.lock().await; - *n = Instant::now(); - sender - .send((buf[0..num_bytes].to_vec(), source.clone())) - .await - .expect(&format!("Failed to send answer to sender {}", source)); - } - }); - return Ok(handler); } - async fn exit(&self) -> Result<(), Box> { - self.kill.send(())?; - - Ok(()) + async fn send(&self, buf: &[u8]) -> Result<()> { + self.socket.send(buf).await?; + return Ok(()); } - async fn on_packet(&mut self, pkg: Vec) -> Result<(), Box> { - let mut n = self.last_packet.lock().await; - *n = Instant::now(); - self.sender.send((pkg, self.target.clone())).await?; + async fn send_to(&self, buf: &[u8], dest: &str) -> Result<()> { + self.socket.send_to(buf, dest).await?; return Ok(()); } } -pub(crate) async fn start_udp_listener( - mut listener_config: Listener, -) -> Result<(), Box> { - println!("start listening on {}", &listener_config.source); - let (listener, responder) = - create_dual_udpsocket(&listener_config.source) - .await - .expect(&format!( - "Failed to clone primary listening address socket {}", - &listener_config.source, - )); +async fn splitted_udp_socket(bind_addr: &str) -> Result<(UdpSocket, UDPMultiSender)> { + let listener_std = std::net::UdpSocket::bind(bind_addr)?; + let responder_std = listener_std.try_clone()?; + listener_std.set_nonblocking(true)?; + responder_std.set_nonblocking(true)?; + let listener = UdpSocket::from_std(listener_std)?; + let responder = UDPMultiSender::new(UdpSocket::from_std(responder_std)?); - let sender = get_udp_background_send(responder, listener_config.shutdown.clone()); + return Ok((listener, responder)); +} - let mut connections: HashMap = HashMap::new(); +struct UDPChannel { + last_packet: Arc, + sender: UDPMultiSender, + shutdown: Shutdown, + from: String, + upstream: String, +} + +impl UDPChannel { + async fn start( + upstream: String, + responder: UDPMultiSender, + source_addr: String, + ) -> Result { + let (upstream_listener, upstream_responder) = splitted_udp_socket("0.0.0.0:0").await?; + upstream_listener.connect(upstream.clone()).await?; + + let shutdown = Shutdown::new(); + let mut shutdown_receiver = shutdown.receiver(); + + let channel = Self { + last_packet: Arc::new(AtomicI32::new(CONNECTION_TIMEOUT)), + sender: upstream_responder, + shutdown, + from: source_addr.clone(), + upstream: upstream.clone(), + }; + + let last_packet = channel.last_packet.clone(); + + tokio::spawn(async move { + let mut buf = [0; 64 * 1024]; + loop { + let num_bytes = tokio::select! { + res = upstream_listener.recv(&mut buf) => res.unwrap(), + _ = shutdown_receiver.recv() => { + info!("Exiting"); + return; + } + }; + + trace!("[{}] <- [{}] ...", source_addr, upstream); + last_packet.store(CONNECTION_TIMEOUT, std::sync::atomic::Ordering::Relaxed); + + match responder.send_to(&buf[..num_bytes], &source_addr).await { + Ok(_) => {} + Err(e) => { + error!("Failed to send packet: {}", e); + } + }; + + trace!("[{}] <- [{}] ---", source_addr, upstream); + } + }); + + return Ok(channel); + } + + async fn close(&mut self) { + self.shutdown.shutdown(); + debug!("Closing connection from {}", self.from); + } + + async fn handle(&self, data: &[u8]) { + trace!("[{}] -> [{}] ...", self.from, self.upstream); + self.last_packet + .store(CONNECTION_TIMEOUT, std::sync::atomic::Ordering::Relaxed); + match self.sender.send(data).await { + Ok(_) => {} + Err(e) => { + error!("Failed to send packet: {}", e); + } + } + trace!("[{}] -> [{}] ---", self.from, self.upstream); + } +} + +fn start_stale_check( + connections: Arc>>, + mut shutdown: ShutdownReceiver, +) { + tokio::spawn(async move { + let mut interval = tokio::time::interval(std::time::Duration::from_secs(1)); + loop { + tokio::select! { + _ = interval.tick() => {} + _ = shutdown.recv() => { + info!("Exiting listener!"); + break; + } + } + + trace!("Checking for stale connections"); + trace!("Waiting for connections lock"); + let mut connections = connections.lock().await; + trace!("Got connections lock"); + let mut to_remove: Vec = Vec::new(); + for (source_addr, channel) in connections.iter() { + let last = channel + .last_packet + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + + if last <= 0 { + to_remove.push(source_addr.clone()); + } + } + + for source_addr in to_remove { + debug!("Closing connection from {}", source_addr); + let mut channel = connections.remove(&source_addr).unwrap(); + channel.close().await; + } + trace!("Checking for stale connections - done"); + drop(connections); + } + + info!("Exiting stale check"); + }); +} + +pub(crate) async fn start_udp_listener(mut listener_config: Listener) -> Result<()> { + info!("start listening on {}", &listener_config.source); + let (listener, responder) = splitted_udp_socket(&listener_config.source).await?; + let connections = Arc::new(Mutex::new(HashMap::::new())); + + start_stale_check(connections.clone(), listener_config.shutdown.clone()); - let mut buf = [0; 64 * 1024]; loop { + let mut buf = vec![0; 1024 * 64]; + trace!("Waiting for packet or shutdown"); let (num_bytes, src_addr) = tokio::select! { res = listener.recv_from(&mut buf) => res?, _ = listener_config.shutdown.recv() => { - println!("Exiting listener!"); + info!("Exiting listener!"); break; } }; - let addr = src_addr.to_string(); - let handler_opt = connections.get_mut(&addr); - let vec = buf[0..num_bytes].to_vec(); + let data = buf[0..num_bytes].to_vec(); + let source_addr = src_addr.to_string(); + + let mut connections = connections.lock().await; + + let handler_opt = connections.get_mut(&source_addr); let handler = match handler_opt { Some(handler) => handler, None => { + debug!("New connection from {}", source_addr); let targets = listener_config.targets.read().await; - let selected_target = { + let upstream = { let mut rng = rand::thread_rng(); targets.choose(&mut rng).unwrap() }; let handler = - UdpHandler::start(selected_target.clone(), addr.clone(), sender.clone()) + UDPChannel::start(upstream.to_string(), responder.clone(), source_addr.clone()) .await?; - connections.insert(addr.clone(), handler); - connections.get_mut(&addr).unwrap() + connections.insert(source_addr.clone(), handler); + connections.get_mut(&source_addr).unwrap() } }; - match handler.on_packet(vec).await { - Ok(_) => (), - Err(err) => { - println!("Failed to forward request from client to server {}", err); - return Ok(()); - } - } + handler.handle(&data).await; + trace!("Handled packet"); + drop(connections); } - for handler in connections.values() { - handler.exit().await?; + for (_, handler) in connections.lock().await.iter_mut() { + handler.close().await; } - Ok(()) + return Ok(()); }