rustocat/src/udp.rs

236 lines
7.2 KiB
Rust

use std::collections::HashMap;
use std::sync::atomic::AtomicI32;
use std::sync::Arc;
use log::{debug, error, info, trace};
use rand::seq::SliceRandom;
use tokio::net::UdpSocket;
use tokio::sync::Mutex;
use crate::listener::Listener;
use crate::shutdown::{Shutdown, ShutdownReceiver};
use crate::Result;
const CONNECTION_TIMEOUT: i32 = 30;
#[derive(Clone)]
struct UDPMultiSender {
socket: Arc<UdpSocket>,
}
impl UDPMultiSender {
fn new(socket: UdpSocket) -> Self {
return Self {
socket: Arc::new(socket),
};
}
async fn send(&self, buf: &[u8]) -> Result<()> {
self.socket.send(buf).await?;
return Ok(());
}
async fn send_to(&self, buf: &[u8], dest: &str) -> Result<()> {
self.socket.send_to(buf, dest).await?;
return Ok(());
}
}
async fn splitted_udp_socket(bind_addr: &str) -> Result<(UdpSocket, UDPMultiSender)> {
let listener_std = std::net::UdpSocket::bind(bind_addr)?;
let responder_std = listener_std.try_clone()?;
listener_std.set_nonblocking(true)?;
responder_std.set_nonblocking(true)?;
let listener = UdpSocket::from_std(listener_std)?;
let responder = UDPMultiSender::new(UdpSocket::from_std(responder_std)?);
return Ok((listener, responder));
}
struct UDPChannel {
last_packet: Arc<AtomicI32>,
sender: UDPMultiSender,
shutdown: Shutdown,
from: String,
upstream: String,
}
impl UDPChannel {
async fn start(
upstream: String,
responder: UDPMultiSender,
source_addr: String,
) -> Result<Self> {
let (upstream_listener, upstream_responder) = splitted_udp_socket("0.0.0.0:0").await?;
upstream_listener.connect(upstream.clone()).await?;
let shutdown = Shutdown::new();
let mut shutdown_receiver = shutdown.receiver();
let channel = Self {
last_packet: Arc::new(AtomicI32::new(CONNECTION_TIMEOUT)),
sender: upstream_responder,
shutdown,
from: source_addr.clone(),
upstream: upstream.clone(),
};
let last_packet = channel.last_packet.clone();
tokio::spawn(async move {
let mut buf = [0; 64 * 1024];
loop {
let num_bytes = tokio::select! {
res = upstream_listener.recv(&mut buf) => res.unwrap(),
_ = shutdown_receiver.recv() => {
info!("Exiting");
return;
}
};
trace!("[{}] <- [{}] ...", source_addr, upstream);
last_packet.store(CONNECTION_TIMEOUT, std::sync::atomic::Ordering::Relaxed);
match responder.send_to(&buf[..num_bytes], &source_addr).await {
Ok(_) => {}
Err(e) => {
error!("Failed to send packet: {}", e);
}
};
trace!("[{}] <- [{}] ---", source_addr, upstream);
}
});
return Ok(channel);
}
async fn close(&mut self) {
self.shutdown.shutdown();
debug!("Closing connection from {}", self.from);
}
async fn handle(&self, data: &[u8]) {
trace!("[{}] -> [{}] ...", self.from, self.upstream);
self.last_packet
.store(CONNECTION_TIMEOUT, std::sync::atomic::Ordering::Relaxed);
match self.sender.send(data).await {
Ok(_) => {}
Err(e) => {
error!("Failed to send packet: {}", e);
}
}
trace!("[{}] -> [{}] ---", self.from, self.upstream);
}
}
fn start_stale_check(
connections: Arc<Mutex<HashMap<String, UDPChannel>>>,
mut shutdown: ShutdownReceiver,
) {
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(1));
loop {
tokio::select! {
_ = interval.tick() => {}
_ = shutdown.recv() => {
info!("Exiting listener!");
break;
}
}
trace!("Checking for stale connections");
trace!("Waiting for connections lock");
let mut connections = connections.lock().await;
trace!("Got connections lock");
let mut to_remove: Vec<String> = Vec::new();
for (source_addr, channel) in connections.iter() {
let last = channel
.last_packet
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
if last <= 0 {
to_remove.push(source_addr.clone());
}
}
for source_addr in to_remove {
debug!("Closing connection from {}", source_addr);
let mut channel = connections.remove(&source_addr).unwrap();
channel.close().await;
}
trace!("Checking for stale connections - done");
drop(connections);
}
info!("Exiting stale check");
});
}
pub(crate) async fn start_udp_listener(mut listener_config: Listener) -> Result<()> {
info!("start listening on {}", &listener_config.source);
let (listener, responder) = splitted_udp_socket(&listener_config.source).await?;
let connections = Arc::new(Mutex::new(HashMap::<String, UDPChannel>::new()));
start_stale_check(connections.clone(), listener_config.shutdown.clone());
loop {
let mut buf = vec![0; 1024 * 64];
trace!("Waiting for packet or shutdown");
let (num_bytes, src_addr) = tokio::select! {
res = listener.recv_from(&mut buf) => res?,
_ = listener_config.targets_changed.recv() => {
info!("Targets changed!");
info!("Closing all connections!");
for (_, handler) in connections.lock().await.iter_mut() {
handler.close().await;
}
info!("Closed all connections!");
continue;
}
_ = listener_config.shutdown.recv() => {
info!("Exiting listener!");
break;
}
};
let data = buf[0..num_bytes].to_vec();
let source_addr = src_addr.to_string();
let mut connections = connections.lock().await;
let handler_opt = connections.get_mut(&source_addr);
let handler = match handler_opt {
Some(handler) => handler,
None => {
debug!("New connection from {}", source_addr);
let targets = listener_config.targets.read().await;
let upstream = {
let mut rng = rand::thread_rng();
targets.choose(&mut rng).unwrap()
};
let handler =
UDPChannel::start(upstream.to_string(), responder.clone(), source_addr.clone())
.await?;
connections.insert(source_addr.clone(), handler);
connections.get_mut(&source_addr).unwrap()
}
};
handler.handle(&data).await;
trace!("Handled packet");
drop(connections);
}
for (_, handler) in connections.lock().await.iter_mut() {
handler.close().await;
}
println!("Listener closed.");
return Ok(());
}