agate

Simple gemini server for static files
git clone https://github.com/mbrubeck/agate.git
Log | Files | Refs | README

commit 259a190a9c898ce4afeaa47e387b3de88e61fc75
parent a165e8e142daeacac9c2ee03beb2bbb71046965f
Author: Matt Brubeck <mbrubeck@limpet.net>
Date:   Tue, 19 May 2020 21:21:13 -0700

Command-line args

Diffstat:
Msrc/main.rs | 62+++++++++++++++++++++++++++++++++++++++-----------------------
1 file changed, 39 insertions(+), 23 deletions(-)

diff --git a/src/main.rs b/src/main.rs @@ -11,7 +11,7 @@ use { error::Error, fs::{File, read}, io::BufReader, - path::{Path, PathBuf}, + path::PathBuf, sync::Arc, }, url::Url, @@ -19,13 +19,22 @@ use { pub type Result<T=()> = std::result::Result<T, Box<dyn Error + Send + Sync>>; -fn main() -> Result { - let addr = "localhost:1965"; +lazy_static! { + static ref ARGS: Args = args().expect("usage: agate <addr:port> <dir> <cert> <key>"); + static ref ACCEPTOR: TlsAcceptor = acceptor().unwrap(); +} + +struct Args { + sock_addr: String, + content_dir: String, + cert_file: String, + key_file: String, +} +fn main() -> Result { task::block_on(async { - let listener = TcpListener::bind(addr).await?; + let listener = TcpListener::bind(&ARGS.sock_addr).await?; let mut incoming = listener.incoming(); - while let Some(Ok(stream)) = incoming.next().await { task::spawn(async { if let Err(e) = connection(stream).await { @@ -37,8 +46,29 @@ fn main() -> Result { }) } +fn args() -> Option<Args> { + let mut args = std::env::args().skip(1); + Some(Args { + sock_addr: args.next()?, + content_dir: args.next()?, + cert_file: args.next()?, + key_file: args.next()?, + }) +} + +fn acceptor() -> Result<TlsAcceptor> { + let cert_file = File::open(&ARGS.cert_file)?; + let key_file = File::open(&ARGS.key_file)?; + + let certs = certs(&mut BufReader::new(cert_file)).or(Err("bad cert"))?; + let mut keys = pkcs8_private_keys(&mut BufReader::new(key_file)).or(Err("bad key"))?; + let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new()); + config.set_single_cert(certs, keys.remove(0))?; + Ok(TlsAcceptor::from(Arc::new(config))) +} + async fn connection(stream: TcpStream) -> Result { - let mut stream = TLS_ACCEPTOR.accept(stream).await?; + let mut stream = ACCEPTOR.accept(stream).await?; let url = match parse_request(&mut stream).await { Ok(url) => url, Err(e) => { @@ -59,20 +89,6 @@ async fn connection(stream: TcpStream) -> Result { Ok(()) } -lazy_static! { - static ref TLS_ACCEPTOR: TlsAcceptor = { - let cert_file = File::open("tests/cert.pem").unwrap(); - let certs = certs(&mut BufReader::new(cert_file)).unwrap(); - - let key_file = File::open("tests/key.rsa").unwrap(); - let mut keys = pkcs8_private_keys(&mut BufReader::new(key_file)).unwrap(); - - let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new()); - config.set_single_cert(certs, keys.remove(0)).unwrap(); - TlsAcceptor::from(Arc::new(config)) - }; -} - async fn parse_request(stream: &mut TlsStream<TcpStream>) -> Result<Url> { let mut stream = async_std::io::BufReader::new(stream); let mut request = String::new(); @@ -82,9 +98,9 @@ async fn parse_request(stream: &mut TlsStream<TcpStream>) -> Result<Url> { } fn get(url: &Url) -> Result<Vec<u8>> { - let path: PathBuf = url.path_segments().unwrap().collect(); - let path = Path::new(".").join(path).canonicalize()?; - if !path.starts_with(std::env::current_dir()?) { + let mut path = PathBuf::from(&ARGS.content_dir); + path.extend(url.path_segments().unwrap()); + if !path.starts_with(&ARGS.content_dir) { Err("invalid path")? } let response = read(path)?;