agate

Simple gemini server for static files
git clone https://github.com/mbrubeck/agate.git
Log | Files | Refs | README

commit f0789921e04d1eb260c7dec3da5a66cd71c45d70
parent 116c9fdcb43758b591394c06a66da7d61e08481e
Author: Johann150 <johann.galle@protonmail.com>
Date:   Mon, 25 Jan 2021 21:50:59 +0100

make functions into methods of RequestHandle

Diffstat:
Msrc/main.rs | 365+++++++++++++++++++++++++++++++++++++++++++------------------------------------
1 file changed, 199 insertions(+), 166 deletions(-)

diff --git a/src/main.rs b/src/main.rs @@ -34,7 +34,17 @@ fn main() -> Result { log::info!("Listening on {:?}...", ARGS.addrs); loop { let (stream, _) = listener.accept().await?; - tokio::spawn(async { handle_request(stream).await }); + tokio::spawn(async { + match RequestHandle::new(stream).await { + Ok(handle) => match handle.handle().await { + Ok(info) => log::info!("{}", info), + Err(err) => log::warn!("{}", err), + }, + Err(log_line) => { + log::warn!("{}", log_line); + } + } + }); } }) } @@ -114,52 +124,6 @@ fn check_path(s: String) -> Result<String, String> { } } -struct RequestHandle { - pub stream: TlsStream<TcpStream>, - pub log_line: String, -} - -/// Handle a single client session (request + response) and any errors that -/// may occur while processing it. -async fn handle_request(stream: TcpStream) { - let log_line = format!( - "{} {}", - stream.local_addr().unwrap(), - if ARGS.log_ips { - stream - .peer_addr() - .expect("could not get peer address") - .to_string() - } else { - // Do not log IP address, but something else so columns still line up. - "-".into() - } - ); - - let stream = match TLS.accept(stream).await { - Ok(stream) => stream, - Err(e) => { - log::warn!("{} error:{}", log_line, e); - return; - } - }; - - let mut handle = RequestHandle { stream, log_line }; - - let mut result = match parse_request(&mut handle).await { - Ok(url) => send_response(url, &mut handle).await, - Err((status, msg)) => send_header(&mut handle, status, msg).await, - }; - - if let Err(e) = result { - log::warn!("{} error:{}", handle.log_line, e); - } else if let Err(e) = handle.stream.shutdown().await { - log::warn!("{} error:{}", handle.log_line, e); - } else { - log::info!("{}", handle.log_line); - } -} - /// TLS configuration. static TLS: Lazy<TlsAcceptor> = Lazy::new(|| acceptor().unwrap()); @@ -175,152 +139,221 @@ fn acceptor() -> Result<TlsAcceptor> { Ok(TlsAcceptor::from(Arc::new(config))) } -/// Return the URL requested by the client. -async fn parse_request(handle: &mut RequestHandle) -> std::result::Result<Url, (u8, &'static str)> { - // Because requests are limited to 1024 bytes (plus 2 bytes for CRLF), we - // can use a fixed-sized buffer on the stack, avoiding allocations and - // copying, and stopping bad clients from making us use too much memory. - let mut request = [0; 1026]; - let mut buf = &mut request[..]; - let mut len = 0; +struct RequestHandle { + stream: TlsStream<TcpStream>, + log_line: String, +} + +impl RequestHandle { + /// Creates a new request handle for the given stream. If establishing the TLS + /// session fails, returns a corresponding log line. + async fn new(stream: TcpStream) -> Result<Self, String> { + let log_line = format!( + "{} {}", + stream.local_addr().unwrap(), + if ARGS.log_ips { + stream + .peer_addr() + .expect("could not get peer address") + .to_string() + } else { + // Do not log IP address, but something else so columns still line up. + "-".into() + } + ); - // Read until CRLF, end-of-stream, or there's no buffer space left. - loop { - let bytes_read = handle.stream.read(buf).await.or(Err((59, "Request ended unexpectedly")))?; - len += bytes_read; - if request[..len].ends_with(b"\r\n") { - break; - } else if bytes_read == 0 { - return Err((59, "Request ended unexpectedly")); + match TLS.accept(stream).await { + Ok(stream) => Ok(Self { stream, log_line }), + Err(e) => Err(format!("{} error:{}", log_line, e)), } - buf = &mut request[len..]; } - let request = std::str::from_utf8(&request[..len - 2]).or(Err((59, "Non-UTF-8 request")))?; - - // log literal request (might be different from or not an actual URL) - write!(handle.log_line, " \"{}\"", request).unwrap(); - let url = Url::parse(request).or(Err((59, "Invalid URL")))?; + /// Do the necessary actions to handle this request. If the handle is already + /// in an error state, does nothing. + /// Finally return the generated log line content. If this contains + /// the string ` error:`, the handle ended in an error state. + async fn handle(mut self) -> Result<String, String> { + // not already in error condition + let result = match self.parse_request().await { + Ok(url) => self.send_response(url).await, + Err((status, msg)) => self.send_header(status, msg).await, + }; - // Validate the URL, host and port. - if url.scheme() != "gemini" { - return Err((53, "Unsupported URL scheme")); - } - // TODO: Can be simplified by https://github.com/servo/rust-url/pull/651 - if let (Some(Host::Domain(expected)), Some(Host::Domain(host))) = (url.host(), &ARGS.hostname) { - if host != expected { - return Err((53, "Proxy request refused")); + if let Err(e) = result { + Err(format!("{} error:{}", self.log_line, e)) + } else if let Err(e) = self.stream.shutdown().await { + Err(format!("{} error:{}", self.log_line, e)) + } else { + Ok(self.log_line) } } - if let Some(port) = url.port() { - // Validate that the port in the URL is the same as for the stream this request came in on. - if port != handle.stream.get_ref().0.local_addr().unwrap().port() { - return Err((53, "proxy request refused")); + + /// Return the URL requested by the client. + async fn parse_request(&mut self) -> std::result::Result<Url, (u8, &'static str)> { + // Because requests are limited to 1024 bytes (plus 2 bytes for CRLF), we + // can use a fixed-sized buffer on the stack, avoiding allocations and + // copying, and stopping bad clients from making us use too much memory. + let mut request = [0; 1026]; + let mut buf = &mut request[..]; + let mut len = 0; + + // Read until CRLF, end-of-stream, or there's no buffer space left. + loop { + let bytes_read = self + .stream + .read(buf) + .await + .or(Err((59, "Request ended unexpectedly")))?; + len += bytes_read; + if request[..len].ends_with(b"\r\n") { + break; + } else if bytes_read == 0 { + return Err((59, "Request ended unexpectedly")); + } + buf = &mut request[len..]; } - } - Ok(url) -} + let request = + std::str::from_utf8(&request[..len - 2]).or(Err((59, "Non-UTF-8 request")))?; + + // log literal request (might be different from or not an actual URL) + write!(self.log_line, " \"{}\"", request).unwrap(); + + let url = Url::parse(request).or(Err((59, "Invalid URL")))?; -/// Send the client the file located at the requested URL. -async fn send_response(url: Url, handle: &mut RequestHandle) -> Result { - let mut path = std::path::PathBuf::from(&ARGS.content_dir); - if let Some(segments) = url.path_segments() { - for segment in segments { - if !ARGS.serve_secret && segment.starts_with('.') { - // Do not serve anything that looks like a hidden file. - return send_header(handle, 52, "If I told you, it would not be a secret.").await; + // Validate the URL, host and port. + if url.scheme() != "gemini" { + return Err((53, "Unsupported URL scheme")); + } + // TODO: Can be simplified by https://github.com/servo/rust-url/pull/651 + if let (Some(Host::Domain(expected)), Some(Host::Domain(host))) = + (url.host(), &ARGS.hostname) + { + if host != expected { + return Err((53, "Proxy request refused")); + } + } + if let Some(port) = url.port() { + // Validate that the port in the URL is the same as for the stream this request came in on. + if port != self.stream.get_ref().0.local_addr().unwrap().port() { + return Err((53, "proxy request refused")); } - path.push(&*percent_decode_str(segment).decode_utf8()?); } + Ok(url) } - if let Ok(metadata) = tokio::fs::metadata(&path).await { - if metadata.is_dir() { - if url.path().ends_with('/') || url.path().is_empty() { - // if the path ends with a slash or the path is empty, the links will work the same - // without a redirect - path.push("index.gmi"); - if !path.exists() && path.with_file_name(".directory-listing-ok").exists() { - path.pop(); - return list_directory(handle, &path).await; + /// Send the client the file located at the requested URL. + async fn send_response(&mut self, url: Url) -> Result { + let mut path = std::path::PathBuf::from(&ARGS.content_dir); + if let Some(segments) = url.path_segments() { + for segment in segments { + if !ARGS.serve_secret && segment.starts_with('.') { + // Do not serve anything that looks like a hidden file. + return self + .send_header(52, "If I told you, it would not be a secret.") + .await; } - } else { - // if client is not redirected, links may not work as expected without trailing slash - let mut url = url; - url.set_path(&format!("{}/", url.path())); - return send_header(handle, 31, url.as_str()).await; + path.push(&*percent_decode_str(segment).decode_utf8()?); } } - } - // Make sure the file opens successfully before sending the success header. - let mut file = match tokio::fs::File::open(&path).await { - Ok(file) => file, - Err(e) => { - send_header(handle, 51, "Not found, sorry.").await?; - Err(e)? + if let Ok(metadata) = tokio::fs::metadata(&path).await { + if metadata.is_dir() { + if url.path().ends_with('/') || url.path().is_empty() { + // if the path ends with a slash or the path is empty, the links will work the same + // without a redirect + path.push("index.gmi"); + if !path.exists() && path.with_file_name(".directory-listing-ok").exists() { + path.pop(); + return self.list_directory(&path).await; + } + } else { + // if client is not redirected, links may not work as expected without trailing slash + let mut url = url; + url.set_path(&format!("{}/", url.path())); + return self.send_header(31, url.as_str()).await; + } + } } - }; - // Send header. - if path.extension() == Some(OsStr::new("gmi")) { - send_text_gemini_header(handle).await?; - } else { - let mime = mime_guess::from_path(&path).first_or_octet_stream(); - send_header(handle, 20, mime.essence_str()).await?; - } + // Make sure the file opens successfully before sending the success header. + let mut file = match tokio::fs::File::open(&path).await { + Ok(file) => file, + Err(e) => { + self.send_header(51, "Not found, sorry.").await?; + Err(e)? + } + }; - // Send body. - tokio::io::copy(&mut file, &mut handle.stream).await?; - Ok(()) -} + // Send header. + if path.extension() == Some(OsStr::new("gmi")) { + self.send_text_gemini_header().await?; + } else { + let mime = mime_guess::from_path(&path).first_or_octet_stream(); + self.send_header(20, mime.essence_str()).await?; + } -async fn list_directory(handle: &mut RequestHandle, path: &Path) -> Result { - // https://url.spec.whatwg.org/#path-percent-encode-set - const ENCODE_SET: AsciiSet = CONTROLS.add(b' ') - .add(b'"').add(b'#').add(b'<').add(b'>') - .add(b'?').add(b'`').add(b'{').add(b'}'); + // Send body. + tokio::io::copy(&mut file, &mut self.stream).await?; + Ok(()) + } - log::info!("Listing directory {:?}", path); - send_text_gemini_header(handle).await?; - let mut entries = tokio::fs::read_dir(path).await?; - let mut lines = vec![]; - while let Some(entry) = entries.next_entry().await? { - let mut name = entry.file_name().into_string().or(Err("Non-Unicode filename"))?; - if name.starts_with('.') { - continue; + async fn list_directory(&mut self, path: &Path) -> Result { + // https://url.spec.whatwg.org/#path-percent-encode-set + const ENCODE_SET: AsciiSet = CONTROLS + .add(b' ') + .add(b'"') + .add(b'#') + .add(b'<') + .add(b'>') + .add(b'?') + .add(b'`') + .add(b'{') + .add(b'}'); + + log::info!("Listing directory {:?}", path); + self.send_text_gemini_header().await?; + let mut entries = tokio::fs::read_dir(path).await?; + let mut lines = vec![]; + while let Some(entry) = entries.next_entry().await? { + let mut name = entry + .file_name() + .into_string() + .or(Err("Non-Unicode filename"))?; + if name.starts_with('.') { + continue; + } + if entry.file_type().await?.is_dir() { + name += "/"; + } + let line = match percent_encode(name.as_bytes(), &ENCODE_SET).into() { + Cow::Owned(url) => format!("=> {} {}\n", url, name), + Cow::Borrowed(url) => format!("=> {}\n", url), // url and name are identical + }; + lines.push(line); } - if entry.file_type().await?.is_dir() { - name += "/"; + lines.sort(); + for line in lines { + self.stream.write_all(line.as_bytes()).await?; } - let line = match percent_encode(name.as_bytes(), &ENCODE_SET).into() { - Cow::Owned(url) => format!("=> {} {}\n", url, name), - Cow::Borrowed(url) => format!("=> {}\n", url), // url and name are identical - }; - lines.push(line); + Ok(()) } - lines.sort(); - for line in lines { - handle.stream.write_all(line.as_bytes()).await?; - } - Ok(()) -} -async fn send_header(handle: &mut RequestHandle, status: u8, meta: &str) -> Result { - // add response status and response meta - write!(handle.log_line, " {} \"{}\"", status, meta)?; + async fn send_header(&mut self, status: u8, meta: &str) -> Result { + // add response status and response meta + write!(self.log_line, " {} \"{}\"", status, meta)?; - handle - .stream - .write_all(format!("{} {}\r\n", status, meta).as_bytes()) - .await?; - Ok(()) -} + self.stream + .write_all(format!("{} {}\r\n", status, meta).as_bytes()) + .await?; + Ok(()) + } -async fn send_text_gemini_header(handle: &mut RequestHandle) -> Result { - if let Some(lang) = ARGS.language.as_deref() { - send_header(handle, 20, &format!("text/gemini;lang={}", lang)).await - } else { - send_header(handle, 20, "text/gemini").await + async fn send_text_gemini_header(&mut self) -> Result { + if let Some(lang) = ARGS.language.as_deref() { + self.send_header(20, &format!("text/gemini;lang={}", lang)) + .await + } else { + self.send_header(20, "text/gemini").await + } } }