15 Commits

Author SHA1 Message Date
5e6552dd6c Change docker repository
All checks were successful
CI / build (push) Successful in 23m2s
2023-07-13 14:58:40 +02:00
41d10c5c94 Add listener closing when the target set has changed.
All checks were successful
CI / build (push) Successful in 22m46s
2023-07-13 11:53:51 +02:00
c90a8fff4a Bumping version to 0.1.11 2023-06-07 08:08:25 +02:00
f290fa2eb1 Adjusting Earthfile 2023-06-07 08:07:40 +02:00
635488ddeb Bumping version 2023-06-06 22:28:23 +02:00
a7e06ca5ce Adding Earthfile for CI 2023-06-06 22:27:49 +02:00
ef1a922933 Make it work in nomad 2023-04-14 04:13:41 +02:00
747c4ddf6f Remove unused dependency 2023-02-06 14:24:46 +01:00
db06b7c7d0 Integrating consul catalog for autoconfig of listeners 2023-02-06 14:21:05 +01:00
05da7a97e0 Adding Dockerfile 2023-02-05 15:11:45 +01:00
f74409c9eb Remove unused imports 2023-02-05 11:37:18 +01:00
bcc332bd7b Adjust version and README 2023-02-05 11:34:57 +01:00
3eab1ae216 Remove config.yaml from repository 2023-02-05 11:31:39 +01:00
517889f875 UDP seems to work now 2023-02-05 00:36:53 +01:00
7530ce3aae Implementing udp support partially...in theory. Still untested :) 2022-09-21 22:22:10 +02:00
16 changed files with 1949 additions and 359 deletions

1
.earthlyignore Normal file
View File

@ -0,0 +1 @@
target/

2
.editorconfig Normal file
View File

@ -0,0 +1,2 @@
[*.rs]
indent_size = 4

36
.github/workflows/ci.yml vendored Normal file
View File

@ -0,0 +1,36 @@
# .github/workflows/ci.yml
name: CI
on:
push:
pull_request:
branches: [main]
jobs:
build:
runs-on: ubuntu-latest
env:
MY_DOCKER_USERNAME: ${{ secrets.MY_DOCKER_USERNAME }}
MY_DOCKER_PASSWORD: ${{ secrets.MY_DOCKER_PASSWORD }}
FORCE_COLOR: 1
steps:
- uses: https://github.com/earthly/actions-setup@v1
with:
version: v0.7.0
- uses: actions/checkout@v2
- name: Put back the git branch into git (Earthly uses it for tagging)
run: |
branch=""
if [ -n "$GITHUB_HEAD_REF" ]; then
branch="$GITHUB_HEAD_REF"
else
branch="${GITHUB_REF##*/}"
fi
git checkout -b "$branch" || true
- name: Docker Login
run: docker login git.hibas.dev --username "$MY_DOCKER_USERNAME" --password "$MY_DOCKER_PASSWORD"
- name: Earthly version
run: earthly --version
- name: Run build
run: earthly --push +docker-multi

1
.gitignore vendored
View File

@ -1,2 +1,3 @@
/target
config.yaml
config.json

1287
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,24 +1,30 @@
[package]
name = "rustocat"
version = "0.0.3"
version = "0.1.12"
edition = "2021"
description = "Socat in rust with many less features and a configuration file"
license = "ISC"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
default = ["consul"]
consul = []
[dependencies]
tokio = { version = "1.17.0", features = ["full"] }
futures = "0.3.21"
serde = { version = "1.0.136", features = ["derive"] }
serde_json = "1.0.79"
serde_yaml = "0.8.23"
simple-error = "0.2.3"
rand = "0.8.5"
tokio = { version = "1", features = ["full"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
serde_yaml = "0.9"
rand = "0.8"
log = "0.4"
fern = "0.6"
chrono = "0.4"
async-trait = "0.1"
reqwest = { version = "0.11", features = ["json", "rustls", "hyper-tls"], default-features = false }
[profile.release]
opt-level = 3 # Optimize for size.
opt-level = 3 # Optimize for size.
lto = true # Enable Link Time Optimization
codegen-units = 1 # Reduce number of codegen units to increase optimizations.
panic = 'abort' # Abort on panic
strip = true # Strip symbols from binary*
strip = true # Strip symbols from binary*

25
Earthfile Normal file
View File

@ -0,0 +1,25 @@
VERSION 0.7
FROM rust:alpine
WORKDIR /rustbuild
prepare-env:
RUN apk add --no-cache musl-dev libssl3 openssl-dev
RUN cargo search test # Super simple way to cache the cargo index
docker-multi:
BUILD --platform linux/amd64 --platform linux/arm64 +docker
build:
FROM +prepare-env
COPY . .
RUN cargo build --release
SAVE ARTIFACT target/release/rustocat rustocat
docker:
FROM alpine
RUN apk add --no-cache libssl3
COPY +build/rustocat rustocat
ENTRYPOINT ["./rustocat"]
ARG EARTHLY_TARGET_TAG
ARG TAG=$EARTHLY_TARGET_TAG
SAVE IMAGE --push git.hibas.dev/ops/rustocat:$TAG

View File

@ -7,12 +7,13 @@ Rustocat is a simple socat alternative with way less features, but it has a conf
Configs can be either yaml or json and can be located in /etc/rustocat.{yaml|json} or in the current working directory as config.{yaml|json}.
```yaml
tcp:
- source: 0.0.0.0:2222
mappings:
- udp: false
source: 0.0.0.0:2222
targets: [127.0.0.1:22]
```
Currently only TCP is supported, UDP/Unix Socket support might be added later.
There is support for UDP and TCP sockets. Each socket can have multiple targets.
When multiple targets are set, it will randomly pick one of them.

View File

@ -1,5 +0,0 @@
tcp:
- source: 127.0.0.1:4422
targets: [127.0.0.1:22, fury.infra.stamm.me:22]
- source: 127.0.0.1:4423
targets: [127.0.0.1:22, fury.infra.stamm.me:22]

93
src/config.rs Normal file
View File

@ -0,0 +1,93 @@
use std::{fs::File, path::Path};
use log::info;
use serde::Deserialize;
use tokio::signal::unix::Signal;
use crate::Result;
#[derive(Debug, Deserialize)]
pub struct Target {
pub udp: Option<bool>,
pub source: String,
pub targets: Vec<String>,
}
#[derive(Debug, Deserialize)]
pub struct Config {
pub consul: Option<bool>,
pub consul_http_addr: Option<String>,
pub consul_http_token: Option<String>,
mappings: Vec<Target>,
}
#[async_trait::async_trait]
pub trait ConfigProvider {
async fn get_targets(&self) -> Result<Vec<Target>>;
async fn wait_for_change(&mut self) -> Result<()>;
}
pub struct FileConfigProvider {
sighup_stream: Signal,
}
impl FileConfigProvider {
pub fn new() -> Self {
Self {
sighup_stream: tokio::signal::unix::signal(tokio::signal::unix::SignalKind::hangup())
.expect("Failed to create sighup stream"),
}
}
fn load_yaml(&self, path: &Path) -> Result<Config> {
let file = File::open(path)?;
let config: Config = serde_yaml::from_reader(file).expect("Failed to parse!"); //TODO: Print path
return Ok(config);
}
fn load_json(&self, path: &Path) -> Result<Config> {
let file = File::open(path)?;
let config: Config = serde_json::from_reader(file).expect("Failed to parse!"); //TODO: Print path
return Ok(config);
}
pub fn load_config(&self) -> Result<Config> {
for path in [
"config.yaml",
"config.json",
"/etc/rustocat.yaml",
"/etc/rustocat.json",
]
.iter()
{
// if(p)
let config = if path.ends_with(".yaml") {
self.load_yaml(Path::new(path))
} else {
self.load_json(Path::new(path))
};
if config.is_ok() {
return config;
}
}
Err("No config file found".into())
}
}
#[async_trait::async_trait]
impl ConfigProvider for FileConfigProvider {
async fn get_targets(&self) -> Result<Vec<Target>> {
info!("Getting targets from file");
let config = self.load_config()?;
return Ok(config.mappings);
}
async fn wait_for_change(&mut self) -> Result<()> {
info!("Waiting for file config change (SIGHUP)");
self.sighup_stream.recv().await;
return Ok(());
}
}

204
src/consul.rs Normal file
View File

@ -0,0 +1,204 @@
#![allow(non_snake_case)]
use std::collections::HashMap;
use log::debug;
use log::info;
use log::trace;
use log::warn;
use reqwest::header::HeaderMap;
use serde::{Deserialize, Serialize};
use crate::config::Config;
use crate::config::ConfigProvider;
use crate::config::Target;
use crate::Result;
pub struct ConsulConfigProvider {
consul_config: ConsulConfig,
interval: tokio::time::Interval,
}
impl ConsulConfigProvider {
pub fn new(config: Option<&Config>) -> Self {
Self {
consul_config: if let Some(config) = config {
ConsulConfig::from_config_or_env(config)
} else {
ConsulConfig::from_env()
},
interval: tokio::time::interval(tokio::time::Duration::from_secs(10)),
}
}
}
#[async_trait::async_trait]
impl ConfigProvider for ConsulConfigProvider {
async fn get_targets(&self) -> Result<Vec<Target>> {
info!("Getting targets from consul");
let mut targets: Vec<Target> = Vec::new();
debug!("Calling consul_get_services");
let services = consul_get_services(&self.consul_config).await?;
// Find consul services and tags
// Format of tags: rustocat:udp:port
// rustocat:tcp:port
debug!("Iterating over services: {:?}", services);
for (name, tags) in services {
for tag in tags {
if tag.starts_with("rustocat") {
trace!("Found rustocat tag: {}", tag);
let parts = tag.split(":").collect::<Vec<&str>>();
if parts.len() != 3 {
warn!("Invalid tag: {} on service {}", tag, name);
continue;
}
let port = parts[2];
trace!("Getting nodes for service: {}", name);
let nodes = consul_get_service_nodes(&self.consul_config, &name).await?;
let mut t = vec![];
for node in nodes {
t.push(format!("{}:{}", node.ServiceAddress, node.ServicePort));
}
let target = Target {
udp: Some(parts[1] == "udp"),
source: format!("0.0.0.0:{}", port),
targets: t,
};
trace!("Adding target: {:?}", target);
targets.push(target);
}
}
}
Ok(targets)
}
async fn wait_for_change(&mut self) -> Result<()> {
info!("Waiting for consul config change");
self.interval.tick().await;
Ok(())
}
}
async fn consul_get_services(config: &ConsulConfig) -> Result<HashMap<String, Vec<String>>> {
let mut headers = HeaderMap::new();
if let Some(token) = config.token.clone() {
headers.insert("X-Consul-Token", token.parse()?);
}
return Ok(reqwest::Client::new()
.get(format!("{}/v1/catalog/services", config.baseurl))
.headers(headers)
.send()
.await?
.json::<HashMap<String, Vec<String>>>()
.await?);
}
async fn consul_get_service_nodes(config: &ConsulConfig, service: &str) -> Result<Vec<Node>> {
let mut headers = HeaderMap::new();
if let Some(token) = config.token.clone() {
headers.insert("X-Consul-Token", token.parse()?);
}
trace!(
"Calling consul_get_service_nodes: {}/v1/catalog/service/{service}",
config.baseurl
);
return Ok(reqwest::Client::new()
.get(format!("{}/v1/catalog/service/{service}", config.baseurl))
.headers(headers)
.send()
.await?
.json::<Vec<Node>>()
.await?);
}
#[derive(Eq, Default, PartialEq, Serialize, Deserialize, Debug)]
#[serde(default)]
struct AgentService {
ID: String,
Service: String,
Tags: Option<Vec<String>>,
Port: u16,
Address: String,
EnableTagOverride: bool,
CreateIndex: u64,
ModifyIndex: u64,
}
#[derive(Eq, Default, PartialEq, Serialize, Deserialize, Debug)]
#[serde(default)]
struct HealthCheck {
Node: String,
CheckID: String,
Name: String,
Status: String,
Notes: String,
Output: String,
ServiceID: String,
ServiceName: String,
ServiceTags: Option<Vec<String>>,
}
#[derive(Eq, Default, PartialEq, Serialize, Deserialize, Debug)]
#[serde(default)]
struct Node {
ID: String,
Node: String,
ServiceAddress: String,
ServicePort: u16,
Datacenter: Option<String>,
TaggedAddresses: Option<HashMap<String, String>>,
Meta: Option<HashMap<String, String>>,
CreateIndex: u64,
ModifyIndex: u64,
}
#[derive(Eq, Default, PartialEq, Serialize, Deserialize, Debug)]
#[serde(default)]
struct ServiceEntry {
Node: Node,
Service: AgentService,
Checks: Vec<HealthCheck>,
}
struct ConsulConfig {
baseurl: String,
token: Option<String>,
}
impl ConsulConfig {
fn from_env() -> Self {
Self {
baseurl: option_env!("CONSUL_HTTP_ADDR")
.expect("CONSUL_HTTP_ADDR not set")
.to_string(),
token: option_env!("CONSUL_HTTP_TOKEN").map(|s| s.to_string()),
}
}
fn from_config_or_env(config: &Config) -> Self {
let baseurl = match config.consul_http_addr {
Some(ref s) => s.clone(),
None => option_env!("CONSUL_HTTP_ADDR")
.expect("CONSUL_HTTP_ADDR not set")
.to_string(),
};
let token = match config.consul_http_token {
Some(ref s) => Some(s.clone()),
None => option_env!("CONSUL_HTTP_TOKEN").map(|s| s.to_string()),
};
Self { baseurl, token }
}
}

View File

@ -1,9 +1,11 @@
use crate::shutdown::Shutdown;
use crate::shutdown::ShutdownReceiver;
use std::sync::Arc;
use tokio::sync::broadcast;
use tokio::sync::RwLock;
pub(crate) struct Listener {
pub(crate) shutdown: Shutdown,
pub(crate) source: String,
pub(crate) targets: Arc<RwLock<Vec<String>>>,
pub(crate) shutdown: ShutdownReceiver,
pub(crate) source: String,
pub(crate) targets: Arc<RwLock<Vec<String>>>,
pub(crate) targets_changed: broadcast::Receiver<()>,
}

View File

@ -1,129 +1,158 @@
mod config;
mod listener;
mod shutdown;
mod tcp;
mod udp;
use serde::Deserialize;
use simple_error::bail;
#[cfg(feature = "consul")]
mod consul;
use config::ConfigProvider;
use log::{debug, error, info, warn};
use std::collections::{HashMap, HashSet};
use std::error::Error;
use std::fs::File;
use std::path::Path;
use std::sync::Arc;
use tokio::signal::unix::{signal, SignalKind};
use tokio::sync::{broadcast, RwLock};
#[derive(Debug, Deserialize)]
struct Target {
source: String,
targets: Vec<String>,
}
#[derive(Debug, Deserialize)]
struct Config {
tcp: Vec<Target>,
// udp: Vec<Target>,
}
fn load_yaml(path: &Path) -> Result<Config, Box<dyn Error>> {
let file = File::open(path)?;
let config: Config = serde_yaml::from_reader(file).expect("Failed to parse!"); //TODO: Print path
return Ok(config);
}
fn load_json(path: &Path) -> Result<Config, Box<dyn Error>> {
let file = File::open(path)?;
let config: Config = serde_json::from_reader(file).expect("Failed to parse!"); //TODO: Print path
return Ok(config);
}
fn load_config() -> Result<Config, Box<dyn Error>> {
for path in [
"config.yaml",
"config.json",
"/etc/rustocat.yaml",
"/etc/rustocat.json",
]
.iter()
{
// if(p)
let config = if path.ends_with(".yaml") {
load_yaml(Path::new(path))
} else {
load_json(Path::new(path))
};
if config.is_ok() {
return config;
}
}
bail!("No config file found");
}
pub type Result<T> = std::result::Result<T, Box<dyn Error>>;
// #[derive(Debug)]
struct ActiveListener {
notify_shutdown: broadcast::Sender<()>,
targets: Arc<RwLock<Vec<String>>>,
notify_targets: broadcast::Sender<()>,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut listeners: HashMap<String, ActiveListener> = HashMap::new();
let mut sighup_stream = signal(SignalKind::hangup())?;
async fn main() -> Result<()> {
fern::Dispatch::new()
.format(|out, message, record| {
out.finish(format_args!(
"{}[{}][{}] {}",
chrono::Local::now().format("[%Y-%m-%d][%H:%M:%S]"),
record.target(),
record.level(),
message
))
})
// Add blanket level filter -
.level(log::LevelFilter::Info)
.level_for("rustocat", log::LevelFilter::Debug)
.chain(std::io::stdout())
.apply()?;
let listeners: HashMap<String, ActiveListener> = HashMap::new();
let text_config_provider = config::FileConfigProvider::new();
let target_provider: Box<dyn config::ConfigProvider> = {
#[cfg(feature = "consul")]
{
let cfg = text_config_provider.load_config()?;
info!("Loaded yaml config");
if cfg.consul.is_some() && cfg.consul.unwrap() {
info!("Using consul config provider");
let consul_config_provider = consul::ConsulConfigProvider::new(Some(&cfg));
info!("Loaded consul config provider");
Box::new(consul_config_provider)
} else {
Box::new(text_config_provider)
}
}
#[cfg(not(feature = "consul"))]
{
let cfg = Box::new(text_config_provider);
info!("Loaded yaml config");
cfg
}
};
match run_loop(target_provider, listeners).await {
Ok(_) => {}
Err(e) => {
error!("Error in run loop: {}", e);
info!("Exiting");
}
}
Ok(())
}
async fn run_loop(
mut target_provider: Box<dyn ConfigProvider>,
mut listeners: HashMap<String, ActiveListener>,
) -> Result<()> {
loop {
let config = load_config().expect("config not found");
let mappings = target_provider.get_targets().await?;
let mut required_listeners: HashSet<String> = HashSet::new();
for target in config.tcp {
required_listeners.insert(target.source.clone());
if let Some(listener) = listeners.get(&target.source) {
for target in mappings {
let mut source_str = "".to_owned();
if target.udp == None || target.udp == Some(false) {
source_str.push_str("udp:");
} else {
source_str.push_str("tcp:");
}
source_str.push_str(&target.source);
required_listeners.insert(source_str.clone());
if let Some(listener) = listeners.get(&source_str) {
let mut invalid = false;
{
let targets = listener.targets.read().await;
for t in &target.targets {
if !targets.iter().any(|e| *e == *t) {
let targets = listener.targets.read().await;
for t in &target.targets {
if !targets.iter().any(|e| *e == *t) {
invalid = true;
break;
}
}
if !invalid {
for t in targets.iter() {
if !target.targets.iter().any(|e| *e == *t) {
invalid = true;
break;
}
}
if !invalid {
for t in targets.iter() {
if !target.targets.iter().any(|e| *e == *t) {
invalid = true;
break;
}
}
}
}
if invalid {
println!("Found invalid targets! Adjusting!");
warn!("Found invalid targets! Adjusting!");
let mut w = listener.targets.write().await;
*w = target.targets;
_ = listener.notify_targets.send(())?; //TODO: Maybe ignore this error
}
} else {
let (notify_shutdown, _) = broadcast::channel(1);
let (notify_targets, _) = broadcast::channel(1);
let listener = ActiveListener {
notify_shutdown: notify_shutdown,
targets: Arc::new(RwLock::new(target.targets)),
notify_targets: notify_targets,
};
let l = listener::Listener {
shutdown: shutdown::Shutdown::new(listener.notify_shutdown.subscribe()),
shutdown: shutdown::ShutdownReceiver::new(listener.notify_shutdown.subscribe()),
source: target.source.clone(),
targets: listener.targets.clone(),
targets_changed: listener.notify_targets.subscribe(),
};
tokio::spawn(async move {
if let Err(err) = tcp::start_tcp_listener(l).await {
println!("listener error: {}", err);
if target.udp == None || target.udp == Some(false) {
if let Err(err) = tcp::start_tcp_listener(l).await {
error!("tcp listener error: {}", err);
}
} else {
if let Err(err) = udp::start_udp_listener(l).await {
error!("udp listener error {}: {}", target.source, err);
}
}
});
listeners.insert(target.source, listener);
listeners.insert(source_str, listener);
}
}
@ -135,13 +164,24 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
for del_key in to_delete {
if let Some(listener) = listeners.get(&del_key) {
let _ = listener.notify_shutdown.send(()); //Errors are irrelevant here. I guess....
println!("Removing listener!");
let res = listener.notify_shutdown.send(()); //Errors are irrelevant here. I guess....
match res {
Ok(_) => {
info!("Sent shutdown signal!");
}
Err(_) => {
warn!("Failed to send shutdown signal!");
}
}
// TODO: Maybe Wait for the listener to shut down
debug!("Removing listener!");
listeners.remove(&del_key);
}
}
sighup_stream.recv().await;
println!("Recevied SIGHUP!");
target_provider.wait_for_change().await?;
info!("Reloading config!");
}
}

View File

@ -1,49 +1,62 @@
use tokio::sync::broadcast;
/// Listens for the server shutdown signal.
///
/// Shutdown is signalled using a `broadcast::Receiver`. Only a single value is
/// ever sent. Once a value has been sent via the broadcast channel, the server
/// should shutdown.
///
/// The `Shutdown` struct listens for the signal and tracks that the signal has
/// been received. Callers may query for whether the shutdown signal has been
/// received or not.
#[derive(Debug)]
pub(crate) struct Shutdown {
/// `true` if the shutdown signal has been received
shutdown: bool,
/// The receive half of the channel used to listen for shutdown.
notify: broadcast::Receiver<()>,
shutdown: bool,
notify: broadcast::Receiver<()>,
sender: broadcast::Sender<()>,
}
impl Shutdown {
/// Create a new `Shutdown` backed by the given `broadcast::Receiver`.
pub(crate) fn new(notify: broadcast::Receiver<()>) -> Shutdown {
Shutdown {
shutdown: false,
notify,
}
}
pub(crate) fn new() -> Shutdown {
let (sender, notify) = broadcast::channel(1);
Shutdown {
shutdown: false,
notify,
sender,
}
}
/// Returns `true` if the shutdown signal has been received.
// pub(crate) fn is_shutdown(&self) -> bool {
// self.shutdown
// }
pub(crate) fn shutdown(&mut self) {
if self.shutdown {
return;
}
let _ = self.sender.send(());
self.shutdown = true;
}
/// Receive the shutdown notice, waiting if necessary.
pub(crate) async fn recv(&mut self) {
// If the shutdown signal has already been received, then return
// immediately.
if self.shutdown {
return;
}
// Cannot receive a "lag error" as only one value is ever sent.
let _ = self.notify.recv().await;
// Remember that the signal has been received.
self.shutdown = true;
}
pub(crate) fn receiver(&self) -> ShutdownReceiver {
ShutdownReceiver::new(self.notify.resubscribe())
}
}
#[derive(Debug)]
pub(crate) struct ShutdownReceiver {
shutdown: bool,
notify: broadcast::Receiver<()>,
}
impl ShutdownReceiver {
pub(crate) fn new(notify: broadcast::Receiver<()>) -> ShutdownReceiver {
ShutdownReceiver {
shutdown: false,
notify,
}
}
pub(crate) async fn recv(&mut self) {
if self.shutdown {
return;
}
let _ = self.notify.recv().await;
self.shutdown = true;
}
}
impl Clone for ShutdownReceiver {
fn clone(&self) -> Self {
ShutdownReceiver {
shutdown: self.shutdown,
notify: self.notify.resubscribe(),
}
}
}

View File

@ -1,58 +1,63 @@
use crate::listener::Listener;
use log::{info, trace, warn};
use rand::seq::SliceRandom;
use std::error::Error;
use tokio::net::{TcpListener, TcpStream};
#[derive(Debug)]
struct TcpHandler {
stream: TcpStream,
target: String,
stream: TcpStream,
target: String,
}
impl TcpHandler {
async fn run(&mut self) -> Result<(), Box<dyn Error>> {
let mut stream = TcpStream::connect(&self.target).await?;
async fn run(&mut self) -> Result<(), Box<dyn Error>> {
let mut stream = TcpStream::connect(&self.target).await?;
tokio::io::copy_bidirectional(&mut self.stream, &mut stream).await?;
tokio::io::copy_bidirectional(&mut self.stream, &mut stream).await?;
return Ok(());
}
return Ok(());
}
}
pub(crate) async fn start_tcp_listener(
mut listener_config: Listener,
mut listener_config: Listener,
) -> Result<(), Box<dyn Error>> {
println!("start listening on {}", &listener_config.source);
let listener = TcpListener::bind(&listener_config.source).await?;
info!("start listening on {}", &listener_config.source);
let listener = TcpListener::bind(&listener_config.source).await?;
loop {
let (next_socket, _) = tokio::select! {
res = listener.accept() => res?,
_ = listener_config.shutdown.recv() => {
println!("Exiting listener!");
return Ok(());
}
};
loop {
let (next_socket, _) = tokio::select! {
res = listener.accept() => res?,
_ = listener_config.targets_changed.recv() => {
info!("Targets changed!");
continue;
}
_ = listener_config.shutdown.recv() => {
info!("Exiting listener!");
return Ok(());
}
};
let targets = listener_config.targets.read().await;
let mut rng = rand::thread_rng();
let selected_target = targets.choose(&mut rng).unwrap();
let targets = listener_config.targets.read().await;
let mut rng = rand::thread_rng();
let selected_target = targets.choose(&mut rng).unwrap();
println!(
"new connection from {} forwarding to {}",
next_socket.peer_addr()?,
&selected_target
);
let mut handler = TcpHandler {
stream: next_socket,
target: selected_target.clone(),
};
trace!(
"new connection from {} forwarding to {}",
next_socket.peer_addr()?,
&selected_target
);
let mut handler = TcpHandler {
stream: next_socket,
target: selected_target.clone(),
};
tokio::spawn(async move {
// Process the connection. If an error is encountered, log it.
if let Err(err) = handler.run().await {
println!("connection error {}", err);
}
});
}
tokio::spawn(async move {
// Process the connection. If an error is encountered, log it.
if let Err(err) = handler.run().await {
warn!("connection error {}", err);
}
});
}
}

View File

@ -1,2 +1,235 @@
// For UDP to work, there is some magic required with matches the returning packets
// back to the original sender. Idk. how to do that right now, but maybe some day.
use std::collections::HashMap;
use std::sync::atomic::AtomicI32;
use std::sync::Arc;
use log::{debug, error, info, trace};
use rand::seq::SliceRandom;
use tokio::net::UdpSocket;
use tokio::sync::Mutex;
use crate::listener::Listener;
use crate::shutdown::{Shutdown, ShutdownReceiver};
use crate::Result;
const CONNECTION_TIMEOUT: i32 = 30;
#[derive(Clone)]
struct UDPMultiSender {
socket: Arc<UdpSocket>,
}
impl UDPMultiSender {
fn new(socket: UdpSocket) -> Self {
return Self {
socket: Arc::new(socket),
};
}
async fn send(&self, buf: &[u8]) -> Result<()> {
self.socket.send(buf).await?;
return Ok(());
}
async fn send_to(&self, buf: &[u8], dest: &str) -> Result<()> {
self.socket.send_to(buf, dest).await?;
return Ok(());
}
}
async fn splitted_udp_socket(bind_addr: &str) -> Result<(UdpSocket, UDPMultiSender)> {
let listener_std = std::net::UdpSocket::bind(bind_addr)?;
let responder_std = listener_std.try_clone()?;
listener_std.set_nonblocking(true)?;
responder_std.set_nonblocking(true)?;
let listener = UdpSocket::from_std(listener_std)?;
let responder = UDPMultiSender::new(UdpSocket::from_std(responder_std)?);
return Ok((listener, responder));
}
struct UDPChannel {
last_packet: Arc<AtomicI32>,
sender: UDPMultiSender,
shutdown: Shutdown,
from: String,
upstream: String,
}
impl UDPChannel {
async fn start(
upstream: String,
responder: UDPMultiSender,
source_addr: String,
) -> Result<Self> {
let (upstream_listener, upstream_responder) = splitted_udp_socket("0.0.0.0:0").await?;
upstream_listener.connect(upstream.clone()).await?;
let shutdown = Shutdown::new();
let mut shutdown_receiver = shutdown.receiver();
let channel = Self {
last_packet: Arc::new(AtomicI32::new(CONNECTION_TIMEOUT)),
sender: upstream_responder,
shutdown,
from: source_addr.clone(),
upstream: upstream.clone(),
};
let last_packet = channel.last_packet.clone();
tokio::spawn(async move {
let mut buf = [0; 64 * 1024];
loop {
let num_bytes = tokio::select! {
res = upstream_listener.recv(&mut buf) => res.unwrap(),
_ = shutdown_receiver.recv() => {
info!("Exiting");
return;
}
};
trace!("[{}] <- [{}] ...", source_addr, upstream);
last_packet.store(CONNECTION_TIMEOUT, std::sync::atomic::Ordering::Relaxed);
match responder.send_to(&buf[..num_bytes], &source_addr).await {
Ok(_) => {}
Err(e) => {
error!("Failed to send packet: {}", e);
}
};
trace!("[{}] <- [{}] ---", source_addr, upstream);
}
});
return Ok(channel);
}
async fn close(&mut self) {
self.shutdown.shutdown();
debug!("Closing connection from {}", self.from);
}
async fn handle(&self, data: &[u8]) {
trace!("[{}] -> [{}] ...", self.from, self.upstream);
self.last_packet
.store(CONNECTION_TIMEOUT, std::sync::atomic::Ordering::Relaxed);
match self.sender.send(data).await {
Ok(_) => {}
Err(e) => {
error!("Failed to send packet: {}", e);
}
}
trace!("[{}] -> [{}] ---", self.from, self.upstream);
}
}
fn start_stale_check(
connections: Arc<Mutex<HashMap<String, UDPChannel>>>,
mut shutdown: ShutdownReceiver,
) {
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(1));
loop {
tokio::select! {
_ = interval.tick() => {}
_ = shutdown.recv() => {
info!("Exiting listener!");
break;
}
}
trace!("Checking for stale connections");
trace!("Waiting for connections lock");
let mut connections = connections.lock().await;
trace!("Got connections lock");
let mut to_remove: Vec<String> = Vec::new();
for (source_addr, channel) in connections.iter() {
let last = channel
.last_packet
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
if last <= 0 {
to_remove.push(source_addr.clone());
}
}
for source_addr in to_remove {
debug!("Closing connection from {}", source_addr);
let mut channel = connections.remove(&source_addr).unwrap();
channel.close().await;
}
trace!("Checking for stale connections - done");
drop(connections);
}
info!("Exiting stale check");
});
}
pub(crate) async fn start_udp_listener(mut listener_config: Listener) -> Result<()> {
info!("start listening on {}", &listener_config.source);
let (listener, responder) = splitted_udp_socket(&listener_config.source).await?;
let connections = Arc::new(Mutex::new(HashMap::<String, UDPChannel>::new()));
start_stale_check(connections.clone(), listener_config.shutdown.clone());
loop {
let mut buf = vec![0; 1024 * 64];
trace!("Waiting for packet or shutdown");
let (num_bytes, src_addr) = tokio::select! {
res = listener.recv_from(&mut buf) => res?,
_ = listener_config.targets_changed.recv() => {
info!("Targets changed!");
info!("Closing all connections!");
for (_, handler) in connections.lock().await.iter_mut() {
handler.close().await;
}
info!("Closed all connections!");
continue;
}
_ = listener_config.shutdown.recv() => {
info!("Exiting listener!");
break;
}
};
let data = buf[0..num_bytes].to_vec();
let source_addr = src_addr.to_string();
let mut connections = connections.lock().await;
let handler_opt = connections.get_mut(&source_addr);
let handler = match handler_opt {
Some(handler) => handler,
None => {
debug!("New connection from {}", source_addr);
let targets = listener_config.targets.read().await;
let upstream = {
let mut rng = rand::thread_rng();
targets.choose(&mut rng).unwrap()
};
let handler =
UDPChannel::start(upstream.to_string(), responder.clone(), source_addr.clone())
.await?;
connections.insert(source_addr.clone(), handler);
connections.get_mut(&source_addr).unwrap()
}
};
handler.handle(&data).await;
trace!("Handled packet");
drop(connections);
}
for (_, handler) in connections.lock().await.iter_mut() {
handler.close().await;
}
println!("Listener closed.");
return Ok(());
}