rustocat/src/main.rs

164 lines
4.9 KiB
Rust

mod listener;
mod shutdown;
mod tcp;
mod udp;
use serde::Deserialize;
use simple_error::bail;
use std::collections::{HashMap, HashSet};
use std::error::Error;
use std::fs::File;
use std::path::Path;
use std::sync::Arc;
use tokio::signal::unix::{signal, SignalKind};
use tokio::sync::{broadcast, RwLock};
#[derive(Debug, Deserialize)]
struct Target {
udp: Option<bool>,
source: String,
targets: Vec<String>,
}
#[derive(Debug, Deserialize)]
struct Config {
mappings: Vec<Target>,
}
fn load_yaml(path: &Path) -> Result<Config, Box<dyn Error>> {
let file = File::open(path)?;
let config: Config = serde_yaml::from_reader(file).expect("Failed to parse!"); //TODO: Print path
return Ok(config);
}
fn load_json(path: &Path) -> Result<Config, Box<dyn Error>> {
let file = File::open(path)?;
let config: Config = serde_json::from_reader(file).expect("Failed to parse!"); //TODO: Print path
return Ok(config);
}
fn load_config() -> Result<Config, Box<dyn Error>> {
for path in [
"config.yaml",
"config.json",
"/etc/rustocat.yaml",
"/etc/rustocat.json",
]
.iter()
{
// if(p)
let config = if path.ends_with(".yaml") {
load_yaml(Path::new(path))
} else {
load_json(Path::new(path))
};
if config.is_ok() {
return config;
}
}
bail!("No config file found");
}
// #[derive(Debug)]
struct ActiveListener {
udp: bool,
notify_shutdown: broadcast::Sender<()>,
targets: Arc<RwLock<Vec<String>>>,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut listeners: HashMap<String, ActiveListener> = HashMap::new();
let mut sighup_stream = signal(SignalKind::hangup())?;
loop {
let config = load_config().expect("config not found");
let mut required_listeners: HashSet<String> = HashSet::new();
for target in config.mappings {
let mut source_str = "".to_owned();
if target.udp == None || target.udp == Some(false) {
source_str.push_str("udp:");
} else {
source_str.push_str("tcp:");
}
source_str.push_str(&target.source);
required_listeners.insert(source_str.clone());
if let Some(listener) = listeners.get(&source_str) {
let mut invalid = false;
let targets = listener.targets.read().await;
for t in &target.targets {
if !targets.iter().any(|e| *e == *t) {
invalid = true;
break;
}
}
if !invalid {
for t in targets.iter() {
if !target.targets.iter().any(|e| *e == *t) {
invalid = true;
break;
}
}
}
if invalid {
println!("Found invalid targets! Adjusting!");
let mut w = listener.targets.write().await;
*w = target.targets;
}
} else {
let (notify_shutdown, _) = broadcast::channel(1);
let listener = ActiveListener {
notify_shutdown: notify_shutdown,
targets: Arc::new(RwLock::new(target.targets)),
udp: target.udp == None || target.udp == Some(false),
};
let l = listener::Listener {
shutdown: shutdown::Shutdown::new(listener.notify_shutdown.subscribe()),
source: target.source.clone(),
targets: listener.targets.clone(),
};
tokio::spawn(async move {
if target.udp == None || target.udp == Some(false) {
if let Err(err) = tcp::start_tcp_listener(l).await {
println!("tcp listener error: {}", err);
}
} else {
if let Err(err) = udp::start_udp_listener(l).await {
println!("udp listener error: {}", err);
}
}
});
listeners.insert(source_str, listener);
}
}
let to_delete: Vec<_> = listeners
.keys()
.filter(|x| required_listeners.get(*x).is_none())
.cloned() // Clone the result to make listeners mutable again!
.collect();
for del_key in to_delete {
if let Some(listener) = listeners.get(&del_key) {
let _ = listener.notify_shutdown.send(()); //Errors are irrelevant here. I guess....
println!("Removing listener!");
listeners.remove(&del_key);
}
}
sighup_stream.recv().await;
println!("Recevied SIGHUP!");
}
}