225 lines
6.8 KiB
Rust
225 lines
6.8 KiB
Rust
use std::collections::HashMap;
|
|
use std::sync::atomic::AtomicI32;
|
|
use std::sync::Arc;
|
|
|
|
use log::{debug, error, info, trace};
|
|
use rand::seq::SliceRandom;
|
|
|
|
use tokio::net::UdpSocket;
|
|
use tokio::sync::Mutex;
|
|
|
|
use crate::listener::Listener;
|
|
use crate::shutdown::{Shutdown, ShutdownReceiver};
|
|
use crate::Result;
|
|
|
|
const CONNECTION_TIMEOUT: i32 = 5;
|
|
|
|
#[derive(Clone)]
|
|
struct UDPMultiSender {
|
|
socket: Arc<UdpSocket>,
|
|
}
|
|
|
|
impl UDPMultiSender {
|
|
fn new(socket: UdpSocket) -> Self {
|
|
return Self {
|
|
socket: Arc::new(socket),
|
|
};
|
|
}
|
|
|
|
async fn send(&self, buf: &[u8]) -> Result<()> {
|
|
self.socket.send(buf).await?;
|
|
return Ok(());
|
|
}
|
|
|
|
async fn send_to(&self, buf: &[u8], dest: &str) -> Result<()> {
|
|
self.socket.send_to(buf, dest).await?;
|
|
return Ok(());
|
|
}
|
|
}
|
|
|
|
async fn splitted_udp_socket(bind_addr: &str) -> Result<(UdpSocket, UDPMultiSender)> {
|
|
let listener_std = std::net::UdpSocket::bind(bind_addr)?;
|
|
let responder_std = listener_std.try_clone()?;
|
|
listener_std.set_nonblocking(true)?;
|
|
responder_std.set_nonblocking(true)?;
|
|
let listener = UdpSocket::from_std(listener_std)?;
|
|
let responder = UDPMultiSender::new(UdpSocket::from_std(responder_std)?);
|
|
|
|
return Ok((listener, responder));
|
|
}
|
|
|
|
struct UDPChannel {
|
|
last_packet: Arc<AtomicI32>,
|
|
sender: UDPMultiSender,
|
|
shutdown: Shutdown,
|
|
from: String,
|
|
upstream: String,
|
|
}
|
|
|
|
impl UDPChannel {
|
|
async fn start(
|
|
upstream: String,
|
|
responder: UDPMultiSender,
|
|
source_addr: String,
|
|
) -> Result<Self> {
|
|
let (upstream_listener, upstream_responder) = splitted_udp_socket("0.0.0.0:0").await?;
|
|
upstream_listener.connect(upstream.clone()).await?;
|
|
|
|
let shutdown = Shutdown::new();
|
|
let mut shutdown_receiver = shutdown.receiver();
|
|
|
|
let channel = Self {
|
|
last_packet: Arc::new(AtomicI32::new(CONNECTION_TIMEOUT)),
|
|
sender: upstream_responder,
|
|
shutdown,
|
|
from: source_addr.clone(),
|
|
upstream: upstream.clone(),
|
|
};
|
|
|
|
let last_packet = channel.last_packet.clone();
|
|
|
|
tokio::spawn(async move {
|
|
let mut buf = [0; 64 * 1024];
|
|
loop {
|
|
let num_bytes = tokio::select! {
|
|
res = upstream_listener.recv(&mut buf) => res.unwrap(),
|
|
_ = shutdown_receiver.recv() => {
|
|
info!("Exiting");
|
|
return;
|
|
}
|
|
};
|
|
|
|
trace!("[{}] <- [{}] ...", source_addr, upstream);
|
|
last_packet.store(CONNECTION_TIMEOUT, std::sync::atomic::Ordering::Relaxed);
|
|
|
|
match responder.send_to(&buf[..num_bytes], &source_addr).await {
|
|
Ok(_) => {}
|
|
Err(e) => {
|
|
error!("Failed to send packet: {}", e);
|
|
}
|
|
};
|
|
|
|
trace!("[{}] <- [{}] ---", source_addr, upstream);
|
|
}
|
|
});
|
|
|
|
return Ok(channel);
|
|
}
|
|
|
|
async fn close(&mut self) {
|
|
self.shutdown.shutdown();
|
|
debug!("Closing connection from {}", self.from);
|
|
}
|
|
|
|
async fn handle(&self, data: &[u8]) {
|
|
trace!("[{}] -> [{}] ...", self.from, self.upstream);
|
|
self.last_packet
|
|
.store(CONNECTION_TIMEOUT, std::sync::atomic::Ordering::Relaxed);
|
|
match self.sender.send(data).await {
|
|
Ok(_) => {}
|
|
Err(e) => {
|
|
error!("Failed to send packet: {}", e);
|
|
}
|
|
}
|
|
trace!("[{}] -> [{}] ---", self.from, self.upstream);
|
|
}
|
|
}
|
|
|
|
fn start_stale_check(
|
|
connections: Arc<Mutex<HashMap<String, UDPChannel>>>,
|
|
mut shutdown: ShutdownReceiver,
|
|
) {
|
|
tokio::spawn(async move {
|
|
let mut interval = tokio::time::interval(std::time::Duration::from_secs(1));
|
|
loop {
|
|
tokio::select! {
|
|
_ = interval.tick() => {}
|
|
_ = shutdown.recv() => {
|
|
info!("Exiting listener!");
|
|
break;
|
|
}
|
|
}
|
|
|
|
trace!("Checking for stale connections");
|
|
trace!("Waiting for connections lock");
|
|
let mut connections = connections.lock().await;
|
|
trace!("Got connections lock");
|
|
let mut to_remove: Vec<String> = Vec::new();
|
|
for (source_addr, channel) in connections.iter() {
|
|
let last = channel
|
|
.last_packet
|
|
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
|
|
|
|
if last <= 0 {
|
|
to_remove.push(source_addr.clone());
|
|
}
|
|
}
|
|
|
|
for source_addr in to_remove {
|
|
debug!("Closing connection from {}", source_addr);
|
|
let mut channel = connections.remove(&source_addr).unwrap();
|
|
channel.close().await;
|
|
}
|
|
trace!("Checking for stale connections - done");
|
|
drop(connections);
|
|
}
|
|
|
|
info!("Exiting stale check");
|
|
});
|
|
}
|
|
|
|
pub(crate) async fn start_udp_listener(mut listener_config: Listener) -> Result<()> {
|
|
info!("start listening on {}", &listener_config.source);
|
|
let (listener, responder) = splitted_udp_socket(&listener_config.source).await?;
|
|
let connections = Arc::new(Mutex::new(HashMap::<String, UDPChannel>::new()));
|
|
|
|
start_stale_check(connections.clone(), listener_config.shutdown.clone());
|
|
|
|
loop {
|
|
let mut buf = vec![0; 1024 * 64];
|
|
trace!("Waiting for packet or shutdown");
|
|
let (num_bytes, src_addr) = tokio::select! {
|
|
res = listener.recv_from(&mut buf) => res?,
|
|
_ = listener_config.shutdown.recv() => {
|
|
info!("Exiting listener!");
|
|
break;
|
|
}
|
|
};
|
|
|
|
let data = buf[0..num_bytes].to_vec();
|
|
let source_addr = src_addr.to_string();
|
|
|
|
let mut connections = connections.lock().await;
|
|
|
|
let handler_opt = connections.get_mut(&source_addr);
|
|
let handler = match handler_opt {
|
|
Some(handler) => handler,
|
|
None => {
|
|
debug!("New connection from {}", source_addr);
|
|
let targets = listener_config.targets.read().await;
|
|
let upstream = {
|
|
let mut rng = rand::thread_rng();
|
|
targets.choose(&mut rng).unwrap()
|
|
};
|
|
|
|
let handler =
|
|
UDPChannel::start(upstream.to_string(), responder.clone(), source_addr.clone())
|
|
.await?;
|
|
|
|
connections.insert(source_addr.clone(), handler);
|
|
connections.get_mut(&source_addr).unwrap()
|
|
}
|
|
};
|
|
|
|
handler.handle(&data).await;
|
|
trace!("Handled packet");
|
|
drop(connections);
|
|
}
|
|
|
|
for (_, handler) in connections.lock().await.iter_mut() {
|
|
handler.close().await;
|
|
}
|
|
|
|
return Ok(());
|
|
}
|