Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 59 additions & 43 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
mod apis;
mod backends;
mod utils;
use crate::utils::core::{split_at_first_slash, StreamingResponse};
use crate::utils::core::{split_at_first_slash, RangeRequest, StreamingResponse};
use actix_cors::Cors;
use actix_web::body::{BodySize, BoxBody, MessageBody};
use actix_web::error::ErrorInternalServerError;
use actix_web::{
delete, get, head, http::header::CONTENT_TYPE, http::header::RANGE, middleware, post, put, web,
App, HttpRequest, HttpResponse, HttpServer, Responder,
delete, get, head,
http::header::{CONTENT_TYPE, RANGE},
middleware, post, put, web, App, HttpRequest, HttpResponse, HttpServer, Responder,
};

use apis::source::{RepositoryPermission, SourceApi};
Expand Down Expand Up @@ -56,27 +57,11 @@ async fn get_object(
user_identity: web::ReqData<UserIdentity>,
) -> Result<impl Responder, BackendError> {
let (account_id, repository_id, key) = path.into_inner();
let headers = req.headers();
let mut range_start = 0;
let mut is_range_request = false;

let range = headers
let range_info: Option<RangeRequest> = req
.headers()
.get(RANGE)
.and_then(|h| h.to_str().ok())
.and_then(|r| r.strip_prefix("bytes="))
.and_then(|bytes_range| bytes_range.split_once('-'))
.and_then(|(start, end)| {
start.parse::<u64>().ok().map(|s| {
range_start = s;
if end.is_empty() || end.parse::<u64>().is_ok() {
is_range_request = true;
Some(format!("bytes={start}-{end}"))
} else {
None
}
})
})
.flatten();
.and_then(|v| String::from_utf8(v.as_ref().to_vec()).ok())
.and_then(|s| s.parse().ok());

let client = api_client
.get_backend_client(&account_id, &repository_id)
Expand All @@ -92,13 +77,13 @@ async fn get_object(
.await?;

// Found the repository, now try to get the object
let res = client.get_object(key.clone(), range).await?;
let res = client
.get_object(key.clone(), range_info.map(Into::into))
.await?;

let mut content_length = String::from("*");
// Remove this if statement to increase performance since it's making an extra request just to get the total content-length
// This is only needed for range requests and in theory, you can return a * in the Content-Range header to indicate that the content length is unknown
if is_range_request {
content_length = client
let mut total_content_length = String::from("*");
if range_info.is_some() {
total_content_length = client
.head_object(key.clone())
.await?
.content_length
Expand All @@ -110,7 +95,7 @@ async fn get_object(
.map(|result| result.map_err(|e| ErrorInternalServerError(e.to_string())));

let streaming_response = StreamingResponse::new(stream, res.content_length);
let mut response = if is_range_request {
let mut response = if range_info.is_some() {
HttpResponse::PartialContent()
} else {
HttpResponse::Ok()
Expand All @@ -124,15 +109,15 @@ async fn get_object(
.insert_header(("Content-Length", res.content_length.to_string()))
.insert_header(("ETag", res.etag));

if is_range_request {
if let Some(ref range) = range_info {
response = response
.insert_header((
"Content-Range",
format!(
"bytes {}-{}/{}",
range_start,
range_start + res.content_length - 1,
content_length
range.start,
range.start + res.content_length - 1,
total_content_length
),
))
.insert_header((
Expand Down Expand Up @@ -336,6 +321,7 @@ async fn post_handler(
#[head("/{account_id}/{repository_id}/{key:.*}")]
async fn head_object(
api_client: web::Data<SourceApi>,
req: HttpRequest,
path: web::Path<(String, String, String)>,
user_identity: web::ReqData<UserIdentity>,
) -> Result<impl Responder, BackendError> {
Expand All @@ -355,15 +341,45 @@ async fn head_object(
.await?;

let res = client.head_object(key.clone()).await?;
Ok(HttpResponse::Ok()
.insert_header(("Accept-Ranges", "bytes"))
.insert_header(("Access-Control-Expose-Headers", "Accept-Ranges"))
.insert_header(("Content-Type", res.content_type))
.insert_header(("Last-Modified", res.last_modified))
.insert_header(("ETag", res.etag))
.body(BoxBody::new(FakeBody {
size: res.content_length as usize,
})))
let total_size = res.content_length;
let range_info: Option<RangeRequest> = req
.headers()
.get(RANGE)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse().ok());

let response = if let Some(ref range) = range_info {
let end = range.end.unwrap();
let content_length = end - range.start + 1;
HttpResponse::PartialContent()
.insert_header(("Accept-Ranges", "bytes"))
.insert_header((
"Content-Range",
format!("bytes {}-{}/{}", range.start, end, total_size),
))
.insert_header((
"Access-Control-Expose-Headers",
"Accept-Ranges, Content-Range",
))
.insert_header(("Content-Type", res.content_type))
.insert_header(("Last-Modified", res.last_modified))
.insert_header(("ETag", res.etag))
.body(BoxBody::new(FakeBody {
size: content_length as usize,
}))
} else {
HttpResponse::Ok()
.insert_header(("Accept-Ranges", "bytes"))
.insert_header(("Access-Control-Expose-Headers", "Accept-Ranges"))
.insert_header(("Content-Type", res.content_type))
.insert_header(("Last-Modified", res.last_modified))
.insert_header(("ETag", res.etag))
.body(BoxBody::new(FakeBody {
size: total_size as usize,
}))
};

Ok(response)
}

#[derive(Deserialize)]
Expand Down
149 changes: 148 additions & 1 deletion src/utils/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use actix_web::{
};
use futures::Stream;
use pin_project_lite::pin_project;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{pin::Pin, str::FromStr};

pin_project! {
pub struct StreamingResponse<S> {
Expand Down Expand Up @@ -97,3 +97,150 @@ pub fn split_at_first_slash(input: &str) -> (&str, &str) {
None => (input, ""),
}
}

/// Parsed range request information.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct RangeRequest {
/// The start byte offset
pub start: u64,
/// The end byte offset (inclusive), or `None` for open-ended ranges (e.g. "bytes=100-")
pub end: Option<u64>,
}

impl From<RangeRequest> for String {
fn from(r: RangeRequest) -> Self {
match r.end {
Some(end) => format!("bytes={}-{}", r.start, end),
None => format!("bytes={}-", r.start),
}
}
}

impl FromStr for RangeRequest {
type Err = ();

fn from_str(s: &str) -> Result<Self, Self::Err> {
let bytes_range = s.strip_prefix("bytes=").ok_or(())?;
let (start_str, end_str) = bytes_range.split_once('-').ok_or(())?;
let start = start_str.parse::<u64>().map_err(|_| ())?;

let end = if end_str.is_empty() {
None
} else {
let end = end_str.parse::<u64>().map_err(|_| ())?;
if start > end {
return Err(());
}
Some(end)
};

Ok(RangeRequest { start, end })
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_parse_range_full_range() {
let result: RangeRequest = "bytes=0-1023".parse().unwrap();
assert_eq!(
result,
RangeRequest {
start: 0,
end: Some(1023),
}
);
}

#[test]
fn test_parse_range_open_ended() {
let result = "bytes=100-".parse::<RangeRequest>().unwrap();
assert_eq!(
result,
RangeRequest {
start: 100,
end: None,
}
);
}

#[test]
fn test_parse_range_open_ended_no_total_size() {
// Without with_total_size, end remains None
let result: RangeRequest = "bytes=100-".parse().unwrap();
assert_eq!(result.end, None);
}

#[test]
fn test_parse_range_missing_prefix() {
assert!("invalid=0-100".parse::<RangeRequest>().is_err());
}

#[test]
fn test_parse_range_non_numeric_start() {
assert!("bytes=abc-100".parse::<RangeRequest>().is_err());
}

#[test]
fn test_parse_range_non_numeric_end() {
assert!("bytes=0-abc".parse::<RangeRequest>().is_err());
}

#[test]
fn test_parse_range_start_greater_than_end() {
assert!("bytes=500-100".parse::<RangeRequest>().is_err());
}

#[test]
fn test_parse_range_start_beyond_total_size() {
// Parsing succeeds; validation against total_size is the caller's responsibility
let result: RangeRequest = "bytes=1000-1023".parse().unwrap();
assert_eq!(result.start, 1000);
assert_eq!(result.end, Some(1023));
}

#[test]
fn test_parse_range_single_byte() {
let result: RangeRequest = "bytes=0-0".parse().unwrap();
assert_eq!(
result,
RangeRequest {
start: 0,
end: Some(0),
}
);
}

#[test]
fn test_parse_range_large_file() {
let rr: RangeRequest = "bytes=0-1023".parse().unwrap();
assert_eq!(rr.start, 0);
assert_eq!(rr.end, Some(1023));
assert_eq!(rr.end.unwrap() - rr.start + 1, 1024);
}

#[test]
fn test_parse_range_no_hyphen() {
assert!("bytes=100".parse::<RangeRequest>().is_err());
}

#[test]
fn test_parse_range_content_length_calculation() {
let result: RangeRequest = "bytes=0-1023".parse().unwrap();
assert_eq!(result.end.unwrap() - result.start + 1, 1024);
}

#[test]
fn test_parse_range_content_range_format() {
let result: RangeRequest = "bytes=0-1023".parse().unwrap();
let content_range = format!(
"bytes {}-{}/{}",
result.start,
result.end.unwrap(),
3515053862u64
);
assert_eq!(content_range, "bytes 0-1023/3515053862");
}
}