use std::collections::HashMap; use std::sync::atomic::AtomicI32; use std::sync::Arc; use log::{debug, error, info, trace, warn}; use rand::seq::SliceRandom; use tokio::net::UdpSocket; use tokio::sync::Mutex; use crate::listener::Listener; use crate::shutdown::{Shutdown, ShutdownReceiver}; use crate::Result; const CONNECTION_TIMEOUT: i32 = 30; #[derive(Clone)] struct UDPMultiSender { socket: Arc, } impl UDPMultiSender { fn new(socket: UdpSocket) -> Self { return Self { socket: Arc::new(socket), }; } async fn send(&self, buf: &[u8]) -> Result<()> { self.socket.send(buf).await?; return Ok(()); } async fn send_to(&self, buf: &[u8], dest: &str) -> Result<()> { self.socket.send_to(buf, dest).await?; return Ok(()); } } async fn splitted_udp_socket(bind_addr: &str) -> Result<(UdpSocket, UDPMultiSender)> { let listener_std = loop { match std::net::UdpSocket::bind(bind_addr) { Ok(listener) => break listener, Err(e) => match e.kind() { std::io::ErrorKind::AddrInUse => { warn!("Address in use: {}", bind_addr); tokio::time::sleep(std::time::Duration::from_secs(5)).await; } std::io::ErrorKind::AddrNotAvailable => { warn!("Address not available: {}", bind_addr); tokio::time::sleep(std::time::Duration::from_secs(5)).await; } _ => { error!("Error binding to {}: {} ({})", bind_addr, e, e.kind()); trace!("Error: {}", e); return Err(Box::new(e)); } }, } }; let responder_std = listener_std.try_clone()?; listener_std.set_nonblocking(true)?; responder_std.set_nonblocking(true)?; let listener = UdpSocket::from_std(listener_std)?; let responder = UDPMultiSender::new(UdpSocket::from_std(responder_std)?); return Ok((listener, responder)); } struct UDPChannel { last_packet: Arc, sender: UDPMultiSender, shutdown: Shutdown, from: String, upstream: String, } impl UDPChannel { async fn start( upstream: String, responder: UDPMultiSender, source_addr: String, ) -> Result { let (upstream_listener, upstream_responder) = splitted_udp_socket("0.0.0.0:0").await?; upstream_listener.connect(upstream.clone()).await?; let shutdown = Shutdown::new(); let mut shutdown_receiver = shutdown.receiver(); let channel = Self { last_packet: Arc::new(AtomicI32::new(CONNECTION_TIMEOUT)), sender: upstream_responder, shutdown, from: source_addr.clone(), upstream: upstream.clone(), }; let last_packet = channel.last_packet.clone(); tokio::spawn(async move { let mut buf = [0; 64 * 1024]; loop { let num_bytes = tokio::select! { res = upstream_listener.recv(&mut buf) => res.unwrap(), _ = shutdown_receiver.recv() => { info!("Exiting"); return; } }; trace!("[{}] <- [{}] ...", source_addr, upstream); last_packet.store(CONNECTION_TIMEOUT, std::sync::atomic::Ordering::Relaxed); match responder.send_to(&buf[..num_bytes], &source_addr).await { Ok(_) => {} Err(e) => { error!("Failed to send packet: {}", e); } }; trace!("[{}] <- [{}] ---", source_addr, upstream); } }); return Ok(channel); } async fn close(&mut self) { self.shutdown.shutdown(); debug!("Closing connection from {}", self.from); } async fn handle(&self, data: &[u8]) { trace!("[{}] -> [{}] ...", self.from, self.upstream); self.last_packet .store(CONNECTION_TIMEOUT, std::sync::atomic::Ordering::Relaxed); match self.sender.send(data).await { Ok(_) => {} Err(e) => { error!("Failed to send packet: {}", e); } } trace!("[{}] -> [{}] ---", self.from, self.upstream); } } fn start_stale_check( connections: Arc>>, mut shutdown: ShutdownReceiver, ) { tokio::spawn(async move { let mut interval = tokio::time::interval(std::time::Duration::from_secs(1)); loop { tokio::select! { _ = interval.tick() => {} _ = shutdown.recv() => { info!("Exiting listener!"); break; } } trace!("Checking for stale connections"); trace!("Waiting for connections lock"); let mut connections = connections.lock().await; trace!("Got connections lock"); let mut to_remove: Vec = Vec::new(); for (source_addr, channel) in connections.iter() { let last = channel .last_packet .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); if last <= 0 { to_remove.push(source_addr.clone()); } } for source_addr in to_remove { debug!("Closing connection from {}", source_addr); let mut channel = connections.remove(&source_addr).unwrap(); channel.close().await; } trace!("Checking for stale connections - done"); drop(connections); } info!("Exiting stale check"); }); } pub(crate) async fn start_udp_listener(mut listener_config: Listener) -> Result<()> { info!("start listening on {}", &listener_config.source); let (listener, responder) = splitted_udp_socket(&listener_config.source).await?; let connections = Arc::new(Mutex::new(HashMap::::new())); start_stale_check(connections.clone(), listener_config.shutdown.clone()); loop { let mut buf = vec![0; 1024 * 64]; trace!("Waiting for packet or shutdown"); let (num_bytes, src_addr) = tokio::select! { res = listener.recv_from(&mut buf) => res?, _ = listener_config.targets_changed.recv() => { info!("Targets changed!"); info!("Closing all connections!"); for (_, handler) in connections.lock().await.iter_mut() { handler.close().await; } info!("Closed all connections!"); continue; } _ = listener_config.shutdown.recv() => { info!("Exiting listener!"); break; } }; let data = buf[0..num_bytes].to_vec(); let source_addr = src_addr.to_string(); let mut connections = connections.lock().await; let handler_opt = connections.get_mut(&source_addr); let handler = match handler_opt { Some(handler) => handler, None => { debug!("New connection from {}", source_addr); let targets = listener_config.targets.read().await; let upstream = { let mut rng = rand::thread_rng(); targets.choose(&mut rng).unwrap() }; let handler = UDPChannel::start(upstream.to_string(), responder.clone(), source_addr.clone()) .await?; connections.insert(source_addr.clone(), handler); connections.get_mut(&source_addr).unwrap() } }; handler.handle(&data).await; trace!("Handled packet"); drop(connections); } for (_, handler) in connections.lock().await.iter_mut() { handler.close().await; } println!("Listener closed."); return Ok(()); }