From 2cd3f71c4fd6c5bc3e9462e5b0d5edc6856d8f93 Mon Sep 17 00:00:00 2001 From: AWeirdDev Date: Mon, 23 Jun 2025 16:34:18 +0800 Subject: [PATCH] clean clean clean --- Cargo.lock | 31 +++- Cargo.toml | 7 +- src/main.rs | 185 ++++++++-------------- templates/base.html | 350 +++++++++++++++++++++++++++++++++++++---- templates/error.html | 9 -- templates/index.html | 17 +- templates/success.html | 60 ------- 7 files changed, 421 insertions(+), 238 deletions(-) delete mode 100644 templates/error.html delete mode 100644 templates/success.html diff --git a/Cargo.lock b/Cargo.lock index a51ee12..2cbbe63 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -38,6 +38,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anyhow" +version = "1.0.98" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" + [[package]] name = "askama" version = "0.12.1" @@ -396,12 +402,6 @@ dependencies = [ "syn", ] -[[package]] -name = "dotenv" -version = "0.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f" - [[package]] name = "dotenvy" version = "0.15.7" @@ -1883,6 +1883,20 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower-http" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adc82fd73de2a9722ac5da747f12383d2bfdb93591ee6c58486e0097890f05f2" +dependencies = [ + "bitflags", + "bytes", + "http", + "pin-project-lite", + "tower-layer", + "tower-service", +] + [[package]] name = "tower-layer" version = "0.3.3" @@ -2412,14 +2426,17 @@ dependencies = [ name = "yue-lat" version = "0.1.0" dependencies = [ + "anyhow", "askama", "askama_axum", "axum", - "dotenv", + "dotenvy", "rand 0.9.1", "serde", "sqlx", "tokio", + "tower", + "tower-http", "tracing", "tracing-subscriber", ] diff --git a/Cargo.toml b/Cargo.toml index 7839bff..e4eb23b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,6 @@ edition = "2024" [dependencies] axum = "0.8.4" -dotenv = "0.15.0" rand = "0.9.1" serde = { version = "1.0.219" , features = ["derive"] } tokio = { version = "1.45.1", features = ["rt-multi-thread"] } @@ -13,4 +12,8 @@ tracing = "0.1.41" tracing-subscriber = "0.3.19" sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "sqlite", "chrono"] } askama = "0.12" -askama_axum = "0.4" \ No newline at end of file +askama_axum = "0.4" +dotenvy = "0.15.7" +anyhow = "1.0.98" +tower-http = { version = "0.6.6", features = ["cors"] } +tower = "0.5.2" diff --git a/src/main.rs b/src/main.rs index 4d8475a..41bf096 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,40 +1,30 @@ +use anyhow::Result; use askama::Template; use axum::{ - Form, Json, Router, - extract::{DefaultBodyLimit, Path, State}, - http::{StatusCode, header}, - response::{Html, Redirect, Response}, - routing::{get, post}, + Json, + Router, + extract::{ DefaultBodyLimit, Path, State }, + http::{ StatusCode, header }, + response::{ Html, Redirect, Response }, + routing::{ get, post }, }; -use dotenv::dotenv; use rand::Rng; -use serde::{Deserialize, Serialize}; -use sqlx::{Row, SqlitePool}; -use std::env; +use serde::{ Deserialize, Serialize }; +use sqlx::{ Row, SqlitePool }; +use tower::ServiceBuilder; +use tower_http::cors::{ Any, CorsLayer }; use tracing_subscriber; // Constants for input validation -const MAX_URL_LENGTH: usize = 2048; // Reasonable limit for URLs -const MAX_SHORT_CODE_LENGTH: usize = 20; // Generous limit for short codes -const MAX_REQUEST_SIZE: usize = 1024 * 1024; // 1MB max request size +const MAX_URL_LENGTH: usize = 2048; +const MAX_SHORT_CODE_LENGTH: usize = 20; +const MAX_REQUEST_SIZE: usize = 1024 * 1024; // Template structs #[derive(Template)] #[template(path = "index.html")] struct IndexTemplate; -#[derive(Template)] -#[template(path = "success.html")] -struct SuccessTemplate { - short_url: String, -} - -#[derive(Template)] -#[template(path = "error.html")] -struct ErrorTemplate { - error_message: String, -} - // Shared state with SQLite connection pool #[derive(Clone)] struct AppState { @@ -43,67 +33,70 @@ struct AppState { } #[tokio::main] -async fn main() { +async fn main() -> Result<()> { tracing_subscriber::fmt::init(); - dotenv().ok(); + dotenvy::dotenv_override()?; - let port: u16 = env::var("PORT") + let port: u16 = dotenvy + ::var("PORT") .unwrap_or_else(|_| "3000".to_string()) .parse() .expect("PORT must be a valid number"); - let base_url = env::var("BASE_URL").unwrap_or_else(|_| "https://yue.lat".to_string()); + let base_url = dotenvy::var("BASE_URL").unwrap_or_else(|_| "https://yue.lat".to_string()); - // Remove trailing slash if present let base_url = base_url.trim_end_matches('/').to_string(); tracing::info!("Using base URL: {}", base_url); - // Initialize SQLite connection pool let db = setup_database().await.expect("Failed to setup database"); let app_state = AppState { db, base_url }; let app = Router::new() - .route("/", get(root).post(create_url_form)) + .route("/", get(root)) .route("/api/v1/shorten", post(create_url)) .route("/favicon.ico", get(favicon)) .route("/{short_code}", get(redirect_url)) - .layer(DefaultBodyLimit::max(MAX_REQUEST_SIZE)) // Limit request body size + .layer( + ServiceBuilder::new() + .layer(DefaultBodyLimit::max(MAX_REQUEST_SIZE)) + .layer(CorsLayer::new().allow_headers(Any).allow_methods(Any)) + ) .with_state(app_state); - let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", port)) - .await - .unwrap(); + let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", port)).await?; tracing::info!("URL Shortener listening on http://0.0.0.0:{}", port); - axum::serve(listener, app).await.unwrap(); + axum::serve(listener, app).await?; + + Ok(()) } async fn setup_database() -> Result { // Create connection pool with configuration let pool = SqlitePool::connect_with( - sqlx::sqlite::SqliteConnectOptions::new() + sqlx::sqlite::SqliteConnectOptions + ::new() .filename("urls.db") .create_if_missing(true) .pragma("journal_mode", "WAL") // Enable WAL mode for better concurrency .pragma("synchronous", "NORMAL") .pragma("cache_size", "1000") .pragma("foreign_keys", "true") - .pragma("temp_store", "memory"), - ) - .await?; + .pragma("temp_store", "memory") + ).await?; // Create table with length constraints - sqlx::query( - "CREATE TABLE IF NOT EXISTS urls ( + sqlx + ::query( + "CREATE TABLE IF NOT EXISTS urls ( short_code TEXT PRIMARY KEY CHECK(length(short_code) <= 20), original_url TEXT NOT NULL CHECK(length(original_url) <= 2048), created_at DATETIME DEFAULT CURRENT_TIMESTAMP - )", - ) - .execute(&pool) - .await?; + )" + ) + .execute(&pool).await?; tracing::info!("Database pool initialized with {} connections", pool.size()); Ok(pool) @@ -112,20 +105,16 @@ async fn setup_database() -> Result { // Input validation functions fn validate_url_length(url: &str) -> Result<(), String> { if url.len() > MAX_URL_LENGTH { - return Err(format!( - "URL too long. Maximum length is {} characters", - MAX_URL_LENGTH - )); + return Err(format!("URL too long. Maximum length is {} characters", MAX_URL_LENGTH)); } Ok(()) } fn validate_short_code_length(short_code: &str) -> Result<(), String> { if short_code.len() > MAX_SHORT_CODE_LENGTH { - return Err(format!( - "Short code too long. Maximum length is {} characters", - MAX_SHORT_CODE_LENGTH - )); + return Err( + format!("Short code too long. Maximum length is {} characters", MAX_SHORT_CODE_LENGTH) + ); } Ok(()) } @@ -146,7 +135,8 @@ fn validate_url_format(url: &str) -> Result<(), String> { async fn favicon() -> Response { // Simple link/chain icon as SVG favicon - let svg_favicon = r##" + let svg_favicon = + r##" @@ -173,9 +163,10 @@ async fn root() -> Result, StatusCode> { async fn create_url( State(app_state): State, - Json(payload): Json, + Json(payload): Json ) -> Result<(StatusCode, Json), (StatusCode, Json)> { let result = shorten_url(app_state, payload.url).await; + match result { Ok(response) => { tracing::info!( @@ -192,69 +183,20 @@ async fn create_url( } } -async fn create_url_form( - State(app_state): State, - Form(payload): Form, -) -> Result, StatusCode> { - let result = shorten_url(app_state, payload.url).await; - - match result { - Ok(response) => { - tracing::info!( - "Created short URL via form: {} -> {}", - response.short_code, - response.original_url - ); - - let success_template = SuccessTemplate { - short_url: response.short_url, - }; - - match success_template.render() { - Ok(html) => Ok(Html(html)), - Err(e) => { - tracing::error!("Template rendering error: {}", e); - Err(StatusCode::INTERNAL_SERVER_ERROR) - } - } - } - Err(error_response) => { - tracing::warn!( - "Failed to create short URL via form: {}", - error_response.error - ); - - let error_template = ErrorTemplate { - error_message: error_response.error, - }; - - match error_template.render() { - Ok(html) => Ok(Html(html)), - Err(e) => { - tracing::error!("Template rendering error: {}", e); - Err(StatusCode::INTERNAL_SERVER_ERROR) - } - } - } - } -} - async fn shorten_url(app_state: AppState, url: String) -> Result { - // Validate URL length first (before any processing) if let Err(error) = validate_url_length(&url) { return Err(ErrorResponse { error }); } + let url = url.trim().to_string(); + // Validate URL not empty - if url.trim().is_empty() { + if url.is_empty() { return Err(ErrorResponse { error: "URL cannot be empty".to_string(), }); } - let url = url.trim().to_string(); - - // Validate URL format if let Err(error) = validate_url_format(&url) { return Err(ErrorResponse { error }); } @@ -265,7 +207,6 @@ async fn shorten_url(app_state: AppState, url: String) -> Result Result { // Short code doesn't exist, try to insert - let result = - sqlx::query("INSERT INTO urls (short_code, original_url) VALUES (?, ?)") - .bind(&short_code) - .bind(&url) - .execute(&app_state.db) - .await; + let result = sqlx + ::query("INSERT INTO urls (short_code, original_url) VALUES (?, ?)") + .bind(&short_code) + .bind(&url) + .execute(&app_state.db).await; match result { Ok(_) => { @@ -328,7 +267,7 @@ async fn shorten_url(app_state: AppState, url: String) -> Result, - State(app_state): State, + State(app_state): State ) -> Result { // Validate short code length to prevent potential attacks if let Err(error) = validate_short_code_length(&short_code) { @@ -342,10 +281,10 @@ async fn redirect_url( return Err(StatusCode::BAD_REQUEST); } - let result = sqlx::query("SELECT original_url FROM urls WHERE short_code = ?") + let result = sqlx + ::query("SELECT original_url FROM urls WHERE short_code = ?") .bind(&short_code) - .fetch_optional(&app_state.db) - .await; + .fetch_optional(&app_state.db).await; match result { Ok(Some(row)) => { diff --git a/templates/base.html b/templates/base.html index 1cc98fb..1d11c8e 100644 --- a/templates/base.html +++ b/templates/base.html @@ -1,30 +1,34 @@ + {% block title %}Maoyue's URL Shortener{% endblock %} - - + + - + - + - + - - + + - + - + - + + - {% block content %}{% endblock %} - -