diff --git a/.github/workflows/homebrew-bump.yml b/.github/workflows/homebrew-bump.yml deleted file mode 100644 index 4a00f0b..0000000 --- a/.github/workflows/homebrew-bump.yml +++ /dev/null @@ -1,45 +0,0 @@ -name: Update Homebrew Formula - -on: - release: - types: [published] - -jobs: - update-formula: - runs-on: ubuntu-latest - steps: - - name: Checkout Repository - uses: actions/checkout@v4 - - - name: Get Latest PyPI Version - id: get_version - run: | - VERSION=$(curl -s https://pypi.org/pypi/corgea-cli/json | jq -r .info.version) - echo "Latest version: $VERSION" - echo "version=$VERSION" >> $GITHUB_ENV - - - name: Get Latest Source Tarball URL - id: get_tarball - run: | - URL=$(curl -s https://pypi.org/pypi/corgea-cli/json | jq -r '.urls[] | select(.packagetype=="sdist") | .url') - echo "Tarball URL: $URL" - echo "tarball_url=$URL" >> $GITHUB_ENV - - - name: Get SHA256 Hash - id: get_sha - run: | - curl -o corgea-cli.tar.gz ${{ env.tarball_url }} - SHA256=$(shasum -a 256 corgea-cli.tar.gz | awk '{print $1}') - echo "SHA256: $SHA256" - echo "sha256=$SHA256" >> $GITHUB_ENV - - - name: Update Homebrew Formula - run: | - brew bump-formula-pr --strict corgea-cli \ - --url=${{ env.tarball_url }} \ - --sha256=${{ env.sha256 }} \ - --no-browse \ - --no-fork \ - --force - env: - HOMEBREW_GITHUB_API_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..b1248a7 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,26 @@ +name: Test + +on: + push: + branches: + - main + - master + pull_request: + +jobs: + rust-tests: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Rust + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo + uses: Swatinem/rust-cache@v2 + + + - name: Run unit tests + run: cargo test diff --git a/Cargo.toml b/Cargo.toml index 5ee85bb..d68431c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "corgea" -version = "1.8.0" +version = "1.8.1" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/src/authorize.rs b/src/authorize.rs index 4d0475e..39b5df3 100644 --- a/src/authorize.rs +++ b/src/authorize.rs @@ -137,32 +137,36 @@ async fn start_callback_server( }; loop { - let (stream, _) = listener.accept().await?; - let io = TokioIo::new(stream); - let auth_code_clone = auth_code.clone(); - - let service = service_fn(move |req| { - handle_callback(req, auth_code_clone.clone()) - }); - - tokio::task::spawn(async move { - if let Err(err) = hyper::server::conn::http1::Builder::new() - .serve_connection(io, service) - .await - { - eprintln!("Error serving connection: {:?}", err); + tokio::select! { + accept_result = listener.accept() => { + let (stream, _) = accept_result?; + let io = TokioIo::new(stream); + let auth_code_clone = auth_code.clone(); + + let service = service_fn(move |req| { + handle_callback(req, auth_code_clone.clone()) + }); + + tokio::task::spawn(async move { + if let Err(err) = hyper::server::conn::http1::Builder::new() + .serve_connection(io, service) + .await + { + eprintln!("Error serving connection: {:?}", err); + } + }); } - }); - - // Check if we got the code + _ = tokio::time::sleep(Duration::from_millis(100)) => {} + } + + // Check if we got the code. + // We must do this outside of `accept()` blocking so we don't miss a code + // that was set by the request task after a single callback request. if let Ok(code_guard) = auth_code.lock() { if let Some(code) = code_guard.as_ref() { return Ok(code.clone()); } } - - // Add a small delay to prevent busy waiting - tokio::time::sleep(Duration::from_millis(100)).await; } } @@ -523,3 +527,183 @@ fn parse_query_params(query: &str) -> HashMap { } + +#[cfg(test)] +mod tests { + use super::*; + use std::io::{Read, Write}; + use std::net::{TcpListener as StdTcpListener, TcpStream}; + use std::sync::mpsc; + use std::thread; + use std::time::Duration as StdDuration; + use tokio::runtime::Runtime; + use tokio::time::{timeout, Duration}; + + fn reserve_ephemeral_port() -> u16 { + let listener = StdTcpListener::bind("127.0.0.1:0").expect("failed to bind ephemeral port"); + listener.local_addr().expect("failed to get local addr").port() + } + + fn spawn_callback_server( + port: u16, + auth_code: Arc>>, + ) -> mpsc::Receiver> { + let (tx, rx) = mpsc::channel(); + thread::spawn(move || { + let runtime = Runtime::new().expect("failed to create tokio runtime"); + let result = runtime + .block_on(start_callback_server(port, auth_code)) + .map_err(|e| e.to_string()); + tx.send(result).expect("failed to send callback result"); + }); + + rx + } + + fn send_http_get(port: u16, path: &str) -> (u16, String) { + let mut stream = None; + + for _ in 0..50 { + match TcpStream::connect(("127.0.0.1", port)) { + Ok(s) => { + stream = Some(s); + break; + } + Err(_) => thread::sleep(StdDuration::from_millis(20)), + } + } + + let mut stream = stream.expect("failed to connect to callback server"); + let request = format!("GET {} HTTP/1.0\r\n\r\n", path); + + stream + .write_all(request.as_bytes()) + .expect("failed to write request"); + + let mut raw_response = String::new(); + stream + .read_to_string(&mut raw_response) + .expect("failed to read response"); + + let mut sections = raw_response.splitn(2, "\r\n\r\n"); + let headers = sections.next().expect("response headers missing"); + let body = sections.next().unwrap_or_default().to_string(); + let status_line = headers.lines().next().expect("status line missing"); + let status = status_line + .split_whitespace() + .nth(1) + .expect("status code missing") + .parse::() + .expect("invalid status code"); + + (status, body) + } + + #[test] + fn parse_query_params_decodes_values() { + let params = parse_query_params("code=a%20b&error_description=needs%2Blogin"); + + assert_eq!(params.get("code"), Some(&"a b".to_string())); + assert_eq!(params.get("error_description"), Some(&"needs+login".to_string())); + } + + #[test] + fn parse_query_params_ignores_malformed_pairs() { + let params = parse_query_params("valid=ok&invalid&also_invalid="); + + assert_eq!(params.get("valid"), Some(&"ok".to_string())); + assert_eq!(params.get("invalid"), None); + assert_eq!(params.get("also_invalid"), Some(&"".to_string())); + } + + #[test] + fn port_is_available_reflects_current_port_usage() { + let listener = StdTcpListener::bind("127.0.0.1:0").expect("failed to bind ephemeral port"); + let port = listener + .local_addr() + .expect("failed to get listener addr") + .port(); + + assert!(!port_is_available(port)); + drop(listener); + assert!(port_is_available(port)); + } + + #[test] + fn find_available_port_skips_ports_that_are_in_use() { + let listener = StdTcpListener::bind("127.0.0.1:0").expect("failed to bind ephemeral port"); + let occupied_port = listener + .local_addr() + .expect("failed to get listener addr") + .port(); + + let found_port = find_available_port(occupied_port).expect("should find an available port"); + + assert_ne!(found_port, occupied_port); + } + + #[tokio::test] + async fn start_callback_server_returns_without_waiting_for_second_connection() { + let port = reserve_ephemeral_port(); + let auth_code = Arc::new(Mutex::new(Some("test-code".to_string()))); + + let returned_code = timeout( + Duration::from_millis(300), + start_callback_server(port, auth_code), + ) + .await + .expect("callback server timed out") + .expect("callback server should return code"); + + assert_eq!(returned_code, "test-code"); + } + + #[test] + fn start_callback_server_returns_bind_error_if_port_is_occupied() { + let listener = StdTcpListener::bind("127.0.0.1:0").expect("failed to bind ephemeral port"); + let occupied_port = listener + .local_addr() + .expect("failed to get listener addr") + .port(); + + let runtime = Runtime::new().expect("failed to create runtime"); + let result = runtime.block_on(start_callback_server( + occupied_port, + Arc::new(Mutex::new(None::)), + )); + + assert!(result.is_err()); + let error = result.err().expect("expected bind error").to_string(); + assert!(error.contains("Failed to bind")); + } + + #[test] + fn callback_server_serves_waiting_error_and_success_pages_then_returns_code() { + let port = reserve_ephemeral_port(); + let auth_code = Arc::new(Mutex::new(None::)); + let result_rx = spawn_callback_server(port, auth_code); + + let (waiting_status, waiting_body) = send_http_get(port, "/"); + assert_eq!(waiting_status, 200); + assert!(waiting_body.contains("Waiting for Authorization")); + + let (error_status, error_body) = send_http_get( + port, + "/?error=access_denied&error_description=user%20cancelled", + ); + assert_eq!(error_status, 400); + assert!(error_body.contains("Authorization Failed")); + assert!(error_body.contains("access_denied")); + + let (success_status, success_body) = send_http_get(port, "/?code=abc123"); + assert_eq!(success_status, 200); + assert!(success_body.contains("Successfully Signed In")); + + let returned_code = result_rx + .recv_timeout(StdDuration::from_secs(2)) + .expect("callback server should return in time") + .expect("callback server should return code"); + + assert_eq!(returned_code, "abc123"); + } +}