mod listener; mod shutdown; mod tcp; mod udp; use serde::Deserialize; use simple_error::bail; use std::collections::{HashMap, HashSet}; use std::error::Error; use std::fs::File; use std::path::Path; use std::sync::Arc; use tokio::signal::unix::{signal, SignalKind}; use tokio::sync::{broadcast, RwLock}; #[derive(Debug, Deserialize)] struct Target { udp: Option, source: String, targets: Vec, } #[derive(Debug, Deserialize)] struct Config { mappings: Vec, } fn load_yaml(path: &Path) -> Result> { let file = File::open(path)?; let config: Config = serde_yaml::from_reader(file).expect("Failed to parse!"); //TODO: Print path return Ok(config); } fn load_json(path: &Path) -> Result> { let file = File::open(path)?; let config: Config = serde_json::from_reader(file).expect("Failed to parse!"); //TODO: Print path return Ok(config); } fn load_config() -> Result> { for path in [ "config.yaml", "config.json", "/etc/rustocat.yaml", "/etc/rustocat.json", ] .iter() { // if(p) let config = if path.ends_with(".yaml") { load_yaml(Path::new(path)) } else { load_json(Path::new(path)) }; if config.is_ok() { return config; } } bail!("No config file found"); } // #[derive(Debug)] struct ActiveListener { udp: bool, notify_shutdown: broadcast::Sender<()>, targets: Arc>>, } #[tokio::main] async fn main() -> Result<(), Box> { let mut listeners: HashMap = HashMap::new(); let mut sighup_stream = signal(SignalKind::hangup())?; loop { let config = load_config().expect("config not found"); let mut required_listeners: HashSet = HashSet::new(); for target in config.mappings { let mut source_str = "".to_owned(); if target.udp == None || target.udp == Some(false) { source_str.push_str("udp:"); } else { source_str.push_str("tcp:"); } source_str.push_str(&target.source); required_listeners.insert(source_str.clone()); if let Some(listener) = listeners.get(&source_str) { let mut invalid = false; let targets = listener.targets.read().await; for t in &target.targets { if !targets.iter().any(|e| *e == *t) { invalid = true; break; } } if !invalid { for t in targets.iter() { if !target.targets.iter().any(|e| *e == *t) { invalid = true; break; } } } if invalid { println!("Found invalid targets! Adjusting!"); let mut w = listener.targets.write().await; *w = target.targets; } } else { let (notify_shutdown, _) = broadcast::channel(1); let listener = ActiveListener { notify_shutdown: notify_shutdown, targets: Arc::new(RwLock::new(target.targets)), udp: target.udp == None || target.udp == Some(false), }; let l = listener::Listener { shutdown: shutdown::Shutdown::new(listener.notify_shutdown.subscribe()), source: target.source.clone(), targets: listener.targets.clone(), }; tokio::spawn(async move { if target.udp == None || target.udp == Some(false) { if let Err(err) = tcp::start_tcp_listener(l).await { println!("tcp listener error: {}", err); } } else { if let Err(err) = udp::start_udp_listener(l).await { println!("udp listener error: {}", err); } } }); listeners.insert(source_str, listener); } } let to_delete: Vec<_> = listeners .keys() .filter(|x| required_listeners.get(*x).is_none()) .cloned() // Clone the result to make listeners mutable again! .collect(); for del_key in to_delete { if let Some(listener) = listeners.get(&del_key) { let _ = listener.notify_shutdown.send(()); //Errors are irrelevant here. I guess.... println!("Removing listener!"); listeners.remove(&del_key); } } sighup_stream.recv().await; println!("Recevied SIGHUP!"); } }