commit 4b6d238436361d2184b955a5455e7bdd0ba7df06
parent 60f76bc617e3f810705846339a8ec1cb2d5a6c45
Author: equalsraf <undisclosed>
Date: Thu, 16 Mar 2023 15:21:29 +0000
Add support for listening UNIX sockets
A new CLI option --socket enables listening on UNIX sockets. This is
similar to the --addr option, but takes a path as argument.
If the given path already exists and it is a socket, attempt to remove
it before listening.
The port check was refactored to avoid the retrieval of the TCP port on
every request.
Diffstat:
M | src/main.rs | | | 115 | ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----- |
1 file changed, 109 insertions(+), 6 deletions(-)
diff --git a/src/main.rs b/src/main.rs
@@ -32,6 +32,11 @@ use {
url::{Host, Url},
};
+#[cfg(target_family = "unix")]
+use std::os::unix::fs::FileTypeExt;
+#[cfg(target_family = "unix")]
+use tokio::net::{UnixListener, UnixStream};
+
static DEFAULT_PORT: u16 = 1965;
fn main() {
@@ -96,6 +101,48 @@ fn main() {
}))
};
+ #[cfg(target_family = "unix")]
+ for socketpath in &ARGS.sockets {
+ let arc = mimetypes.clone();
+
+ if socketpath.exists() && socketpath.metadata()
+ .expect("Failed to get existing socket metadata")
+ .file_type()
+ .is_socket() {
+ log::warn!("Socket already exists, attempting to remove {}", socketpath.display());
+ let _ = std::fs::remove_file(socketpath);
+ }
+
+ let listener = match UnixListener::bind(socketpath) {
+ Err(e) => {
+ panic!("Failed to listen on {}: {}", socketpath.display(), e)
+ }
+ Ok(listener) => listener,
+ };
+
+ handles.push(tokio::spawn(async move {
+ log::info!("Started listener on {}", socketpath.display());
+
+ loop {
+ let (stream, _) = listener.accept().await.unwrap_or_else(|e| {
+ panic!("could not accept new connection on {}: {}", socketpath.display(), e)
+ });
+ let arc = arc.clone();
+ tokio::spawn(async {
+ match RequestHandle::new_unix(stream, arc).await {
+ Ok(handle) => match handle.handle().await {
+ Ok(info) => log::info!("{}", info),
+ Err(err) => log::warn!("{}", err),
+ },
+ Err(log_line) => {
+ log::warn!("{}", log_line);
+ }
+ }
+ });
+ }
+ }))
+ };
+
futures_util::future::join_all(handles).await;
});
}
@@ -111,6 +158,7 @@ static ARGS: Lazy<Args> = Lazy::new(|| {
struct Args {
addrs: Vec<SocketAddr>,
+ sockets: Vec<PathBuf>,
content_dir: PathBuf,
certs: Arc<certificates::CertStore>,
hostnames: Vec<Host>,
@@ -143,6 +191,13 @@ fn args() -> Result<Args> {
&format!("Address to listen on (default 0.0.0.0:{DEFAULT_PORT} and [::]:{DEFAULT_PORT}; muliple occurences means listening on multiple interfaces)"),
"IP:PORT",
);
+ #[cfg(target_family = "unix")]
+ opts.optmulti(
+ "",
+ "socket",
+ "Unix socket to listen on (muliple occurences means listening on multiple sockets)",
+ "PATH",
+ );
opts.optmulti(
"",
"hostname",
@@ -290,7 +345,13 @@ fn args() -> Result<Args> {
for i in matches.opt_strs("addr") {
addrs.push(i.parse()?);
}
- if addrs.is_empty() {
+
+ let mut sockets = vec![];
+ for i in matches.opt_strs("socket") {
+ sockets.push(i.parse()?);
+ }
+
+ if addrs.is_empty() && sockets.is_empty() {
addrs = vec![
SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), DEFAULT_PORT),
SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), DEFAULT_PORT),
@@ -299,6 +360,7 @@ fn args() -> Result<Args> {
Ok(Args {
addrs,
+ sockets,
content_dir: check_path(matches.opt_get_default("content", "content".into())?)?,
certs: Arc::new(certs),
hostnames,
@@ -338,13 +400,14 @@ fn acceptor() -> TlsAcceptor {
TlsAcceptor::from(Arc::new(config))
}
-struct RequestHandle {
- stream: TlsStream<TcpStream>,
+struct RequestHandle<T> {
+ stream: TlsStream<T>,
+ local_port_check: Option<u16>,
log_line: String,
metadata: Arc<Mutex<FileOptions>>,
}
-impl RequestHandle {
+impl RequestHandle<TcpStream> {
/// Creates a new request handle for the given stream. If establishing the TLS
/// session fails, returns a corresponding log line.
async fn new(stream: TcpStream, metadata: Arc<Mutex<FileOptions>>) -> Result<Self, String> {
@@ -369,9 +432,16 @@ impl RequestHandle {
let log_line = format!("{local_addr} {peer_addr}",);
+ let local_port_check = if ARGS.skip_port_check {
+ None
+ } else {
+ Some(stream.local_addr().unwrap().port())
+ };
+
match TLS.accept(stream).await {
Ok(stream) => Ok(Self {
stream,
+ local_port_check,
log_line,
metadata,
}),
@@ -379,7 +449,40 @@ impl RequestHandle {
Err(e) => Err(format!("{log_line} \"\" 00 \"TLS error\" error:{e}")),
}
}
+}
+
+#[cfg(target_family = "unix")]
+impl RequestHandle<UnixStream> {
+ async fn new_unix(
+ stream: UnixStream,
+ metadata: Arc<Mutex<FileOptions>>,
+ ) -> Result<Self, String> {
+ let log_line = match stream.local_addr() {
+ Ok(a) => match a.as_pathname() {
+ Some(p) => format!("{} -", p.display()),
+ None => "<unnamed socket> -".to_string(),
+ },
+ Err(_) => "<unnamed socket> -".to_string(),
+ };
+
+ match TLS.accept(stream).await {
+ Ok(stream) => Ok(Self {
+ stream,
+ // TODO add port check for unix sockets, requires extra arg for port
+ local_port_check: None,
+ log_line,
+ metadata,
+ }),
+ // use nonexistent status code 00 if connection was not established
+ Err(e) => Err(format!("{} \"\" 00 \"TLS error\" error:{}", log_line, e)),
+ }
+ }
+}
+impl<T> RequestHandle<T>
+where
+ T: AsyncWriteExt + AsyncReadExt + Unpin,
+{
/// Do the necessary actions to handle this request. Returns a corresponding
/// log line as Err or Ok, depending on if the request finished with or
/// without errors.
@@ -476,11 +579,11 @@ impl RequestHandle {
}
// correct port
- if !ARGS.skip_port_check {
+ if let Some(expected_port) = self.local_port_check {
if let Some(port) = url.port() {
// Validate that the port in the URL is the same as for the stream this request
// came in on.
- if port != self.stream.get_ref().0.local_addr().unwrap().port() {
+ if port != expected_port {
return Err((PROXY_REQUEST_REFUSED, "Proxy request refused"));
}
}