agate

Simple gemini server for static files
git clone https://github.com/mbrubeck/agate.git
Log | Files | Refs | README

commit aa17b5bc17d318928db0dd574d378bf34f71ec91
parent 21486a0d11fa16932a075282ac9158fb0117baa7
Author: Johann150 <johann.galle@protonmail.com>
Date:   Sun, 24 Jan 2021 20:10:46 +0100

add RequestHandle struct

Diffstat:
Msrc/main.rs | 57++++++++++++++++++++++++++++++++-------------------------
1 file changed, 32 insertions(+), 25 deletions(-)

diff --git a/src/main.rs b/src/main.rs @@ -117,15 +117,21 @@ fn check_path(s: String) -> Result<String, String> { } } +struct RequestHandle { + pub stream: TlsStream<TcpStream>, +} + /// Handle a single client session (request + response). async fn handle_request(stream: TcpStream) -> Result { - let stream = &mut TLS.accept(stream).await?; + let stream = TLS.accept(stream).await?; + + let mut handle = RequestHandle { stream }; - match parse_request(stream).await { - Ok(url) => send_response(url, stream).await?, - Err((status, msg)) => send_header(stream, status, msg).await?, + match parse_request(&mut handle).await { + Ok(url) => send_response(url, &mut handle).await?, + Err((status, msg)) => send_header(&mut handle, status, msg).await?, } - stream.shutdown().await?; + handle.stream.shutdown().await?; Ok(()) } @@ -145,7 +151,7 @@ fn acceptor() -> Result<TlsAcceptor> { } /// Return the URL requested by the client. -async fn parse_request(stream: &mut TlsStream<TcpStream>) -> std::result::Result<Url, (u8, &'static str)> { +async fn parse_request(handle: &mut RequestHandle) -> std::result::Result<Url, (u8, &'static str)> { // Because requests are limited to 1024 bytes (plus 2 bytes for CRLF), we // can use a fixed-sized buffer on the stack, avoiding allocations and // copying, and stopping bad clients from making us use too much memory. @@ -155,7 +161,7 @@ async fn parse_request(stream: &mut TlsStream<TcpStream>) -> std::result::Result // Read until CRLF, end-of-stream, or there's no buffer space left. loop { - let bytes_read = stream.read(buf).await.or(Err((59, "Request ended unexpectedly")))?; + let bytes_read = handle.stream.read(buf).await.or(Err((59, "Request ended unexpectedly")))?; len += bytes_read; if request[..len].ends_with(b"\r\n") { break; @@ -169,7 +175,8 @@ async fn parse_request(stream: &mut TlsStream<TcpStream>) -> std::result::Result log::info!( "Got request for {:?} from {}", request, - stream + handle + .stream .get_ref() .0 .peer_addr() @@ -193,7 +200,7 @@ async fn parse_request(stream: &mut TlsStream<TcpStream>) -> std::result::Result } if let Some(port) = url.port() { // Validate that the port in the URL is the same as for the stream this request came in on. - if port != stream.get_ref().0.local_addr().unwrap().port() { + if port != handle.stream.get_ref().0.local_addr().unwrap().port() { return Err((53, "proxy request refused")); } } @@ -201,13 +208,13 @@ async fn parse_request(stream: &mut TlsStream<TcpStream>) -> std::result::Result } /// Send the client the file located at the requested URL. -async fn send_response(url: Url, stream: &mut TlsStream<TcpStream>) -> Result { +async fn send_response(url: Url, handle: &mut RequestHandle) -> Result { let mut path = std::path::PathBuf::from(&ARGS.content_dir); if let Some(segments) = url.path_segments() { for segment in segments { if !ARGS.serve_secret && segment.starts_with('.') { // Do not serve anything that looks like a hidden file. - return send_header(stream, 52, "If I told you, it would not be a secret.").await; + return send_header(handle, 52, "If I told you, it would not be a secret.").await; } path.push(&*percent_decode_str(segment).decode_utf8()?); } @@ -221,13 +228,13 @@ async fn send_response(url: Url, stream: &mut TlsStream<TcpStream>) -> Result { path.push("index.gmi"); if !path.exists() && path.with_file_name(".directory-listing-ok").exists() { path.pop(); - return list_directory(stream, &path).await; + return list_directory(handle, &path).await; } } else { // if client is not redirected, links may not work as expected without trailing slash let mut url = url; url.set_path(&format!("{}/", url.path())); - return send_header(stream, 31, url.as_str()).await; + return send_header(handle, 31, url.as_str()).await; } } } @@ -236,32 +243,32 @@ async fn send_response(url: Url, stream: &mut TlsStream<TcpStream>) -> Result { let mut file = match tokio::fs::File::open(&path).await { Ok(file) => file, Err(e) => { - send_header(stream, 51, "Not found, sorry.").await?; + send_header(handle, 51, "Not found, sorry.").await?; Err(e)? } }; // Send header. if path.extension() == Some(OsStr::new("gmi")) { - send_text_gemini_header(stream).await?; + send_text_gemini_header(handle).await?; } else { let mime = mime_guess::from_path(&path).first_or_octet_stream(); - send_header(stream, 20, mime.essence_str()).await?; + send_header(handle, 20, mime.essence_str()).await?; } // Send body. - tokio::io::copy(&mut file, stream).await?; + tokio::io::copy(&mut file, &mut handle.stream).await?; Ok(()) } -async fn list_directory(stream: &mut TlsStream<TcpStream>, path: &Path) -> Result { +async fn list_directory(handle: &mut RequestHandle, path: &Path) -> Result { // https://url.spec.whatwg.org/#path-percent-encode-set const ENCODE_SET: AsciiSet = CONTROLS.add(b' ') .add(b'"').add(b'#').add(b'<').add(b'>') .add(b'?').add(b'`').add(b'{').add(b'}'); log::info!("Listing directory {:?}", path); - send_text_gemini_header(stream).await?; + send_text_gemini_header(handle).await?; let mut entries = tokio::fs::read_dir(path).await?; let mut lines = vec![]; while let Some(entry) = entries.next_entry().await? { @@ -280,25 +287,25 @@ async fn list_directory(stream: &mut TlsStream<TcpStream>, path: &Path) -> Resul } lines.sort(); for line in lines { - stream.write_all(line.as_bytes()).await?; + handle.stream.write_all(line.as_bytes()).await?; } Ok(()) } -async fn send_header(stream: &mut TlsStream<TcpStream>, status: u8, meta: &str) -> Result { +async fn send_header(handle: &mut RequestHandle, status: u8, meta: &str) -> Result { use std::fmt::Write; let mut response = String::with_capacity(64); write!(response, "{} {}", status, meta)?; log::info!("Responding with status {:?}", response); response.push_str("\r\n"); - stream.write_all(response.as_bytes()).await?; + handle.stream.write_all(response.as_bytes()).await?; Ok(()) } -async fn send_text_gemini_header(stream: &mut TlsStream<TcpStream>) -> Result { +async fn send_text_gemini_header(handle: &mut RequestHandle) -> Result { if let Some(lang) = ARGS.language.as_deref() { - send_header(stream, 20, &format!("text/gemini;lang={}", lang)).await + send_header(handle, 20, &format!("text/gemini;lang={}", lang)).await } else { - send_header(stream, 20, "text/gemini").await + send_header(handle, 20, "text/gemini").await } }