commit f0789921e04d1eb260c7dec3da5a66cd71c45d70
parent 116c9fdcb43758b591394c06a66da7d61e08481e
Author: Johann150 <johann.galle@protonmail.com>
Date: Mon, 25 Jan 2021 21:50:59 +0100
make functions into methods of RequestHandle
Diffstat:
M | src/main.rs | | | 365 | +++++++++++++++++++++++++++++++++++++++++++------------------------------------ |
1 file changed, 199 insertions(+), 166 deletions(-)
diff --git a/src/main.rs b/src/main.rs
@@ -34,7 +34,17 @@ fn main() -> Result {
log::info!("Listening on {:?}...", ARGS.addrs);
loop {
let (stream, _) = listener.accept().await?;
- tokio::spawn(async { handle_request(stream).await });
+ tokio::spawn(async {
+ match RequestHandle::new(stream).await {
+ Ok(handle) => match handle.handle().await {
+ Ok(info) => log::info!("{}", info),
+ Err(err) => log::warn!("{}", err),
+ },
+ Err(log_line) => {
+ log::warn!("{}", log_line);
+ }
+ }
+ });
}
})
}
@@ -114,52 +124,6 @@ fn check_path(s: String) -> Result<String, String> {
}
}
-struct RequestHandle {
- pub stream: TlsStream<TcpStream>,
- pub log_line: String,
-}
-
-/// Handle a single client session (request + response) and any errors that
-/// may occur while processing it.
-async fn handle_request(stream: TcpStream) {
- let log_line = format!(
- "{} {}",
- stream.local_addr().unwrap(),
- if ARGS.log_ips {
- stream
- .peer_addr()
- .expect("could not get peer address")
- .to_string()
- } else {
- // Do not log IP address, but something else so columns still line up.
- "-".into()
- }
- );
-
- let stream = match TLS.accept(stream).await {
- Ok(stream) => stream,
- Err(e) => {
- log::warn!("{} error:{}", log_line, e);
- return;
- }
- };
-
- let mut handle = RequestHandle { stream, log_line };
-
- let mut result = match parse_request(&mut handle).await {
- Ok(url) => send_response(url, &mut handle).await,
- Err((status, msg)) => send_header(&mut handle, status, msg).await,
- };
-
- if let Err(e) = result {
- log::warn!("{} error:{}", handle.log_line, e);
- } else if let Err(e) = handle.stream.shutdown().await {
- log::warn!("{} error:{}", handle.log_line, e);
- } else {
- log::info!("{}", handle.log_line);
- }
-}
-
/// TLS configuration.
static TLS: Lazy<TlsAcceptor> = Lazy::new(|| acceptor().unwrap());
@@ -175,152 +139,221 @@ fn acceptor() -> Result<TlsAcceptor> {
Ok(TlsAcceptor::from(Arc::new(config)))
}
-/// Return the URL requested by the client.
-async fn parse_request(handle: &mut RequestHandle) -> std::result::Result<Url, (u8, &'static str)> {
- // Because requests are limited to 1024 bytes (plus 2 bytes for CRLF), we
- // can use a fixed-sized buffer on the stack, avoiding allocations and
- // copying, and stopping bad clients from making us use too much memory.
- let mut request = [0; 1026];
- let mut buf = &mut request[..];
- let mut len = 0;
+struct RequestHandle {
+ stream: TlsStream<TcpStream>,
+ log_line: String,
+}
+
+impl RequestHandle {
+ /// Creates a new request handle for the given stream. If establishing the TLS
+ /// session fails, returns a corresponding log line.
+ async fn new(stream: TcpStream) -> Result<Self, String> {
+ let log_line = format!(
+ "{} {}",
+ stream.local_addr().unwrap(),
+ if ARGS.log_ips {
+ stream
+ .peer_addr()
+ .expect("could not get peer address")
+ .to_string()
+ } else {
+ // Do not log IP address, but something else so columns still line up.
+ "-".into()
+ }
+ );
- // Read until CRLF, end-of-stream, or there's no buffer space left.
- loop {
- let bytes_read = handle.stream.read(buf).await.or(Err((59, "Request ended unexpectedly")))?;
- len += bytes_read;
- if request[..len].ends_with(b"\r\n") {
- break;
- } else if bytes_read == 0 {
- return Err((59, "Request ended unexpectedly"));
+ match TLS.accept(stream).await {
+ Ok(stream) => Ok(Self { stream, log_line }),
+ Err(e) => Err(format!("{} error:{}", log_line, e)),
}
- buf = &mut request[len..];
}
- let request = std::str::from_utf8(&request[..len - 2]).or(Err((59, "Non-UTF-8 request")))?;
-
- // log literal request (might be different from or not an actual URL)
- write!(handle.log_line, " \"{}\"", request).unwrap();
- let url = Url::parse(request).or(Err((59, "Invalid URL")))?;
+ /// Do the necessary actions to handle this request. If the handle is already
+ /// in an error state, does nothing.
+ /// Finally return the generated log line content. If this contains
+ /// the string ` error:`, the handle ended in an error state.
+ async fn handle(mut self) -> Result<String, String> {
+ // not already in error condition
+ let result = match self.parse_request().await {
+ Ok(url) => self.send_response(url).await,
+ Err((status, msg)) => self.send_header(status, msg).await,
+ };
- // Validate the URL, host and port.
- if url.scheme() != "gemini" {
- return Err((53, "Unsupported URL scheme"));
- }
- // TODO: Can be simplified by https://github.com/servo/rust-url/pull/651
- if let (Some(Host::Domain(expected)), Some(Host::Domain(host))) = (url.host(), &ARGS.hostname) {
- if host != expected {
- return Err((53, "Proxy request refused"));
+ if let Err(e) = result {
+ Err(format!("{} error:{}", self.log_line, e))
+ } else if let Err(e) = self.stream.shutdown().await {
+ Err(format!("{} error:{}", self.log_line, e))
+ } else {
+ Ok(self.log_line)
}
}
- if let Some(port) = url.port() {
- // Validate that the port in the URL is the same as for the stream this request came in on.
- if port != handle.stream.get_ref().0.local_addr().unwrap().port() {
- return Err((53, "proxy request refused"));
+
+ /// Return the URL requested by the client.
+ async fn parse_request(&mut self) -> std::result::Result<Url, (u8, &'static str)> {
+ // Because requests are limited to 1024 bytes (plus 2 bytes for CRLF), we
+ // can use a fixed-sized buffer on the stack, avoiding allocations and
+ // copying, and stopping bad clients from making us use too much memory.
+ let mut request = [0; 1026];
+ let mut buf = &mut request[..];
+ let mut len = 0;
+
+ // Read until CRLF, end-of-stream, or there's no buffer space left.
+ loop {
+ let bytes_read = self
+ .stream
+ .read(buf)
+ .await
+ .or(Err((59, "Request ended unexpectedly")))?;
+ len += bytes_read;
+ if request[..len].ends_with(b"\r\n") {
+ break;
+ } else if bytes_read == 0 {
+ return Err((59, "Request ended unexpectedly"));
+ }
+ buf = &mut request[len..];
}
- }
- Ok(url)
-}
+ let request =
+ std::str::from_utf8(&request[..len - 2]).or(Err((59, "Non-UTF-8 request")))?;
+
+ // log literal request (might be different from or not an actual URL)
+ write!(self.log_line, " \"{}\"", request).unwrap();
+
+ let url = Url::parse(request).or(Err((59, "Invalid URL")))?;
-/// Send the client the file located at the requested URL.
-async fn send_response(url: Url, handle: &mut RequestHandle) -> Result {
- let mut path = std::path::PathBuf::from(&ARGS.content_dir);
- if let Some(segments) = url.path_segments() {
- for segment in segments {
- if !ARGS.serve_secret && segment.starts_with('.') {
- // Do not serve anything that looks like a hidden file.
- return send_header(handle, 52, "If I told you, it would not be a secret.").await;
+ // Validate the URL, host and port.
+ if url.scheme() != "gemini" {
+ return Err((53, "Unsupported URL scheme"));
+ }
+ // TODO: Can be simplified by https://github.com/servo/rust-url/pull/651
+ if let (Some(Host::Domain(expected)), Some(Host::Domain(host))) =
+ (url.host(), &ARGS.hostname)
+ {
+ if host != expected {
+ return Err((53, "Proxy request refused"));
+ }
+ }
+ if let Some(port) = url.port() {
+ // Validate that the port in the URL is the same as for the stream this request came in on.
+ if port != self.stream.get_ref().0.local_addr().unwrap().port() {
+ return Err((53, "proxy request refused"));
}
- path.push(&*percent_decode_str(segment).decode_utf8()?);
}
+ Ok(url)
}
- if let Ok(metadata) = tokio::fs::metadata(&path).await {
- if metadata.is_dir() {
- if url.path().ends_with('/') || url.path().is_empty() {
- // if the path ends with a slash or the path is empty, the links will work the same
- // without a redirect
- path.push("index.gmi");
- if !path.exists() && path.with_file_name(".directory-listing-ok").exists() {
- path.pop();
- return list_directory(handle, &path).await;
+ /// Send the client the file located at the requested URL.
+ async fn send_response(&mut self, url: Url) -> Result {
+ let mut path = std::path::PathBuf::from(&ARGS.content_dir);
+ if let Some(segments) = url.path_segments() {
+ for segment in segments {
+ if !ARGS.serve_secret && segment.starts_with('.') {
+ // Do not serve anything that looks like a hidden file.
+ return self
+ .send_header(52, "If I told you, it would not be a secret.")
+ .await;
}
- } else {
- // if client is not redirected, links may not work as expected without trailing slash
- let mut url = url;
- url.set_path(&format!("{}/", url.path()));
- return send_header(handle, 31, url.as_str()).await;
+ path.push(&*percent_decode_str(segment).decode_utf8()?);
}
}
- }
- // Make sure the file opens successfully before sending the success header.
- let mut file = match tokio::fs::File::open(&path).await {
- Ok(file) => file,
- Err(e) => {
- send_header(handle, 51, "Not found, sorry.").await?;
- Err(e)?
+ if let Ok(metadata) = tokio::fs::metadata(&path).await {
+ if metadata.is_dir() {
+ if url.path().ends_with('/') || url.path().is_empty() {
+ // if the path ends with a slash or the path is empty, the links will work the same
+ // without a redirect
+ path.push("index.gmi");
+ if !path.exists() && path.with_file_name(".directory-listing-ok").exists() {
+ path.pop();
+ return self.list_directory(&path).await;
+ }
+ } else {
+ // if client is not redirected, links may not work as expected without trailing slash
+ let mut url = url;
+ url.set_path(&format!("{}/", url.path()));
+ return self.send_header(31, url.as_str()).await;
+ }
+ }
}
- };
- // Send header.
- if path.extension() == Some(OsStr::new("gmi")) {
- send_text_gemini_header(handle).await?;
- } else {
- let mime = mime_guess::from_path(&path).first_or_octet_stream();
- send_header(handle, 20, mime.essence_str()).await?;
- }
+ // Make sure the file opens successfully before sending the success header.
+ let mut file = match tokio::fs::File::open(&path).await {
+ Ok(file) => file,
+ Err(e) => {
+ self.send_header(51, "Not found, sorry.").await?;
+ Err(e)?
+ }
+ };
- // Send body.
- tokio::io::copy(&mut file, &mut handle.stream).await?;
- Ok(())
-}
+ // Send header.
+ if path.extension() == Some(OsStr::new("gmi")) {
+ self.send_text_gemini_header().await?;
+ } else {
+ let mime = mime_guess::from_path(&path).first_or_octet_stream();
+ self.send_header(20, mime.essence_str()).await?;
+ }
-async fn list_directory(handle: &mut RequestHandle, path: &Path) -> Result {
- // https://url.spec.whatwg.org/#path-percent-encode-set
- const ENCODE_SET: AsciiSet = CONTROLS.add(b' ')
- .add(b'"').add(b'#').add(b'<').add(b'>')
- .add(b'?').add(b'`').add(b'{').add(b'}');
+ // Send body.
+ tokio::io::copy(&mut file, &mut self.stream).await?;
+ Ok(())
+ }
- log::info!("Listing directory {:?}", path);
- send_text_gemini_header(handle).await?;
- let mut entries = tokio::fs::read_dir(path).await?;
- let mut lines = vec![];
- while let Some(entry) = entries.next_entry().await? {
- let mut name = entry.file_name().into_string().or(Err("Non-Unicode filename"))?;
- if name.starts_with('.') {
- continue;
+ async fn list_directory(&mut self, path: &Path) -> Result {
+ // https://url.spec.whatwg.org/#path-percent-encode-set
+ const ENCODE_SET: AsciiSet = CONTROLS
+ .add(b' ')
+ .add(b'"')
+ .add(b'#')
+ .add(b'<')
+ .add(b'>')
+ .add(b'?')
+ .add(b'`')
+ .add(b'{')
+ .add(b'}');
+
+ log::info!("Listing directory {:?}", path);
+ self.send_text_gemini_header().await?;
+ let mut entries = tokio::fs::read_dir(path).await?;
+ let mut lines = vec![];
+ while let Some(entry) = entries.next_entry().await? {
+ let mut name = entry
+ .file_name()
+ .into_string()
+ .or(Err("Non-Unicode filename"))?;
+ if name.starts_with('.') {
+ continue;
+ }
+ if entry.file_type().await?.is_dir() {
+ name += "/";
+ }
+ let line = match percent_encode(name.as_bytes(), &ENCODE_SET).into() {
+ Cow::Owned(url) => format!("=> {} {}\n", url, name),
+ Cow::Borrowed(url) => format!("=> {}\n", url), // url and name are identical
+ };
+ lines.push(line);
}
- if entry.file_type().await?.is_dir() {
- name += "/";
+ lines.sort();
+ for line in lines {
+ self.stream.write_all(line.as_bytes()).await?;
}
- let line = match percent_encode(name.as_bytes(), &ENCODE_SET).into() {
- Cow::Owned(url) => format!("=> {} {}\n", url, name),
- Cow::Borrowed(url) => format!("=> {}\n", url), // url and name are identical
- };
- lines.push(line);
+ Ok(())
}
- lines.sort();
- for line in lines {
- handle.stream.write_all(line.as_bytes()).await?;
- }
- Ok(())
-}
-async fn send_header(handle: &mut RequestHandle, status: u8, meta: &str) -> Result {
- // add response status and response meta
- write!(handle.log_line, " {} \"{}\"", status, meta)?;
+ async fn send_header(&mut self, status: u8, meta: &str) -> Result {
+ // add response status and response meta
+ write!(self.log_line, " {} \"{}\"", status, meta)?;
- handle
- .stream
- .write_all(format!("{} {}\r\n", status, meta).as_bytes())
- .await?;
- Ok(())
-}
+ self.stream
+ .write_all(format!("{} {}\r\n", status, meta).as_bytes())
+ .await?;
+ Ok(())
+ }
-async fn send_text_gemini_header(handle: &mut RequestHandle) -> Result {
- if let Some(lang) = ARGS.language.as_deref() {
- send_header(handle, 20, &format!("text/gemini;lang={}", lang)).await
- } else {
- send_header(handle, 20, "text/gemini").await
+ async fn send_text_gemini_header(&mut self) -> Result {
+ if let Some(lang) = ARGS.language.as_deref() {
+ self.send_header(20, &format!("text/gemini;lang={}", lang))
+ .await
+ } else {
+ self.send_header(20, "text/gemini").await
+ }
}
}