diff --git a/src/main.rs b/src/main.rs index 6abce68..3f87297 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,13 +1,14 @@ mod apis; mod backends; mod utils; -use crate::utils::core::{split_at_first_slash, StreamingResponse}; +use crate::utils::core::{split_at_first_slash, RangeRequest, StreamingResponse}; use actix_cors::Cors; use actix_web::body::{BodySize, BoxBody, MessageBody}; use actix_web::error::ErrorInternalServerError; use actix_web::{ - delete, get, head, http::header::CONTENT_TYPE, http::header::RANGE, middleware, post, put, web, - App, HttpRequest, HttpResponse, HttpServer, Responder, + delete, get, head, + http::header::{CONTENT_TYPE, RANGE}, + middleware, post, put, web, App, HttpRequest, HttpResponse, HttpServer, Responder, }; use apis::source::{RepositoryPermission, SourceApi}; @@ -56,27 +57,11 @@ async fn get_object( user_identity: web::ReqData, ) -> Result { let (account_id, repository_id, key) = path.into_inner(); - let headers = req.headers(); - let mut range_start = 0; - let mut is_range_request = false; - - let range = headers + let range_info: Option = req + .headers() .get(RANGE) - .and_then(|h| h.to_str().ok()) - .and_then(|r| r.strip_prefix("bytes=")) - .and_then(|bytes_range| bytes_range.split_once('-')) - .and_then(|(start, end)| { - start.parse::().ok().map(|s| { - range_start = s; - if end.is_empty() || end.parse::().is_ok() { - is_range_request = true; - Some(format!("bytes={start}-{end}")) - } else { - None - } - }) - }) - .flatten(); + .and_then(|v| String::from_utf8(v.as_ref().to_vec()).ok()) + .and_then(|s| s.parse().ok()); let client = api_client .get_backend_client(&account_id, &repository_id) @@ -92,13 +77,13 @@ async fn get_object( .await?; // Found the repository, now try to get the object - let res = client.get_object(key.clone(), range).await?; + let res = client + .get_object(key.clone(), range_info.map(Into::into)) + .await?; - let mut content_length = String::from("*"); - // Remove this if statement to increase performance since it's making an extra request just to get the total content-length - // This is only needed for range requests and in theory, you can return a * in the Content-Range header to indicate that the content length is unknown - if is_range_request { - content_length = client + let mut total_content_length = String::from("*"); + if range_info.is_some() { + total_content_length = client .head_object(key.clone()) .await? .content_length @@ -110,7 +95,7 @@ async fn get_object( .map(|result| result.map_err(|e| ErrorInternalServerError(e.to_string()))); let streaming_response = StreamingResponse::new(stream, res.content_length); - let mut response = if is_range_request { + let mut response = if range_info.is_some() { HttpResponse::PartialContent() } else { HttpResponse::Ok() @@ -124,15 +109,15 @@ async fn get_object( .insert_header(("Content-Length", res.content_length.to_string())) .insert_header(("ETag", res.etag)); - if is_range_request { + if let Some(ref range) = range_info { response = response .insert_header(( "Content-Range", format!( "bytes {}-{}/{}", - range_start, - range_start + res.content_length - 1, - content_length + range.start, + range.start + res.content_length - 1, + total_content_length ), )) .insert_header(( @@ -336,6 +321,7 @@ async fn post_handler( #[head("/{account_id}/{repository_id}/{key:.*}")] async fn head_object( api_client: web::Data, + req: HttpRequest, path: web::Path<(String, String, String)>, user_identity: web::ReqData, ) -> Result { @@ -355,15 +341,45 @@ async fn head_object( .await?; let res = client.head_object(key.clone()).await?; - Ok(HttpResponse::Ok() - .insert_header(("Accept-Ranges", "bytes")) - .insert_header(("Access-Control-Expose-Headers", "Accept-Ranges")) - .insert_header(("Content-Type", res.content_type)) - .insert_header(("Last-Modified", res.last_modified)) - .insert_header(("ETag", res.etag)) - .body(BoxBody::new(FakeBody { - size: res.content_length as usize, - }))) + let total_size = res.content_length; + let range_info: Option = req + .headers() + .get(RANGE) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse().ok()); + + let response = if let Some(ref range) = range_info { + let end = range.end.unwrap(); + let content_length = end - range.start + 1; + HttpResponse::PartialContent() + .insert_header(("Accept-Ranges", "bytes")) + .insert_header(( + "Content-Range", + format!("bytes {}-{}/{}", range.start, end, total_size), + )) + .insert_header(( + "Access-Control-Expose-Headers", + "Accept-Ranges, Content-Range", + )) + .insert_header(("Content-Type", res.content_type)) + .insert_header(("Last-Modified", res.last_modified)) + .insert_header(("ETag", res.etag)) + .body(BoxBody::new(FakeBody { + size: content_length as usize, + })) + } else { + HttpResponse::Ok() + .insert_header(("Accept-Ranges", "bytes")) + .insert_header(("Access-Control-Expose-Headers", "Accept-Ranges")) + .insert_header(("Content-Type", res.content_type)) + .insert_header(("Last-Modified", res.last_modified)) + .insert_header(("ETag", res.etag)) + .body(BoxBody::new(FakeBody { + size: total_size as usize, + })) + }; + + Ok(response) } #[derive(Deserialize)] diff --git a/src/utils/core.rs b/src/utils/core.rs index 6d2d3f2..723c6dc 100644 --- a/src/utils/core.rs +++ b/src/utils/core.rs @@ -4,8 +4,8 @@ use actix_web::{ }; use futures::Stream; use pin_project_lite::pin_project; -use std::pin::Pin; use std::task::{Context, Poll}; +use std::{pin::Pin, str::FromStr}; pin_project! { pub struct StreamingResponse { @@ -97,3 +97,150 @@ pub fn split_at_first_slash(input: &str) -> (&str, &str) { None => (input, ""), } } + +/// Parsed range request information. +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct RangeRequest { + /// The start byte offset + pub start: u64, + /// The end byte offset (inclusive), or `None` for open-ended ranges (e.g. "bytes=100-") + pub end: Option, +} + +impl From for String { + fn from(r: RangeRequest) -> Self { + match r.end { + Some(end) => format!("bytes={}-{}", r.start, end), + None => format!("bytes={}-", r.start), + } + } +} + +impl FromStr for RangeRequest { + type Err = (); + + fn from_str(s: &str) -> Result { + let bytes_range = s.strip_prefix("bytes=").ok_or(())?; + let (start_str, end_str) = bytes_range.split_once('-').ok_or(())?; + let start = start_str.parse::().map_err(|_| ())?; + + let end = if end_str.is_empty() { + None + } else { + let end = end_str.parse::().map_err(|_| ())?; + if start > end { + return Err(()); + } + Some(end) + }; + + Ok(RangeRequest { start, end }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_range_full_range() { + let result: RangeRequest = "bytes=0-1023".parse().unwrap(); + assert_eq!( + result, + RangeRequest { + start: 0, + end: Some(1023), + } + ); + } + + #[test] + fn test_parse_range_open_ended() { + let result = "bytes=100-".parse::().unwrap(); + assert_eq!( + result, + RangeRequest { + start: 100, + end: None, + } + ); + } + + #[test] + fn test_parse_range_open_ended_no_total_size() { + // Without with_total_size, end remains None + let result: RangeRequest = "bytes=100-".parse().unwrap(); + assert_eq!(result.end, None); + } + + #[test] + fn test_parse_range_missing_prefix() { + assert!("invalid=0-100".parse::().is_err()); + } + + #[test] + fn test_parse_range_non_numeric_start() { + assert!("bytes=abc-100".parse::().is_err()); + } + + #[test] + fn test_parse_range_non_numeric_end() { + assert!("bytes=0-abc".parse::().is_err()); + } + + #[test] + fn test_parse_range_start_greater_than_end() { + assert!("bytes=500-100".parse::().is_err()); + } + + #[test] + fn test_parse_range_start_beyond_total_size() { + // Parsing succeeds; validation against total_size is the caller's responsibility + let result: RangeRequest = "bytes=1000-1023".parse().unwrap(); + assert_eq!(result.start, 1000); + assert_eq!(result.end, Some(1023)); + } + + #[test] + fn test_parse_range_single_byte() { + let result: RangeRequest = "bytes=0-0".parse().unwrap(); + assert_eq!( + result, + RangeRequest { + start: 0, + end: Some(0), + } + ); + } + + #[test] + fn test_parse_range_large_file() { + let rr: RangeRequest = "bytes=0-1023".parse().unwrap(); + assert_eq!(rr.start, 0); + assert_eq!(rr.end, Some(1023)); + assert_eq!(rr.end.unwrap() - rr.start + 1, 1024); + } + + #[test] + fn test_parse_range_no_hyphen() { + assert!("bytes=100".parse::().is_err()); + } + + #[test] + fn test_parse_range_content_length_calculation() { + let result: RangeRequest = "bytes=0-1023".parse().unwrap(); + assert_eq!(result.end.unwrap() - result.start + 1, 1024); + } + + #[test] + fn test_parse_range_content_range_format() { + let result: RangeRequest = "bytes=0-1023".parse().unwrap(); + let content_range = format!( + "bytes {}-{}/{}", + result.start, + result.end.unwrap(), + 3515053862u64 + ); + assert_eq!(content_range, "bytes 0-1023/3515053862"); + } +}