diff --git a/docs/ARQUITETURA.md b/docs/ARQUITETURA.md index 9aa74be..14ab2be 100644 --- a/docs/ARQUITETURA.md +++ b/docs/ARQUITETURA.md @@ -159,7 +159,7 @@ class TripService: email: str, bus_line: str, distance: int, - trip_date: datetime, + trip_datetime: datetime, ) -> Trip: # 1. Validar usuário existe user = await self._user_repo.get_user(email) @@ -174,7 +174,7 @@ class TripService: email=email, bus_line=bus_line, score=score, - trip_date=trip_date, + trip_datetime=trip_datetime, ) # 4. Salvar viagem @@ -222,7 +222,7 @@ async def create_trip( email=request.email, bus_line=request.bus_line, distance=request.distance, - trip_date=request.trip_date, + trip_datetime=request.trip_datetime, ) # 2. Mapear domínio → API schema @@ -249,14 +249,14 @@ class CreateTripRequest(BaseModel): email: EmailStr bus_line: str distance: int - trip_date: datetime + trip_datetime: datetime class TripResponse(BaseModel): id: int email: str bus_line: str score: int - trip_date: datetime + trip_datetime: datetime ``` #### `mappers.py` - Conversão de Dados @@ -270,7 +270,7 @@ def trip_response_from_domain(trip: Trip) -> TripResponse: email=trip.email, bus_line=trip.bus_line, score=trip.score, - trip_date=trip.trip_date, + trip_datetime=trip.trip_datetime, ) ``` diff --git a/docs/TESTES.md b/docs/TESTES.md index 38ec7df..2cf2829 100644 --- a/docs/TESTES.md +++ b/docs/TESTES.md @@ -67,7 +67,7 @@ async def test_create_trip_calculates_score_correctly() -> None: bus_line="8000", bus_direction=1, distance=1000, - trip_date=datetime.now() + trip_datetime=datetime.now() ) # 3. Assert (Verificar) - Checar resultados @@ -117,7 +117,7 @@ async def test_create_trip_fails_for_nonexistent_user() -> None: bus_line="8000", bus_direction=1, distance=1000, - trip_date=datetime.now() + trip_datetime=datetime.now() ) # Verifica que save_trip NÃO foi chamado @@ -152,7 +152,7 @@ async def test_multiple_trips(mocker: "MockerFixture") -> None: bus_line="8000", bus_direction=1, distance=500, - trip_date=datetime.now(), + trip_datetime=datetime.now(), ) trip2 = await service.create_trip( @@ -160,7 +160,7 @@ async def test_multiple_trips(mocker: "MockerFixture") -> None: bus_line="8000", bus_direction=2, distance=1500, - trip_date=datetime.now(), + trip_datetime=datetime.now(), ) # Assert @@ -203,7 +203,7 @@ async def test_handles_repository_save_error(mocker: "MockerFixture") -> None: bus_line="8000", bus_direction=1, distance=1000, - trip_date=datetime.now(), + trip_datetime=datetime.now(), ) trip_repo.save_trip.assert_awaited_once() diff --git a/requirements.txt b/requirements.txt index bfc3d74..b5d8ccb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,5 +18,6 @@ pytest==8.3.0 pytest-asyncio==0.24.0 pytest-mock==3.14.0 pytest-cov==7.0.0 +pytest-dotenv==0.5.2 mypy==1.11.0 ruff==0.6.0 diff --git a/src/adapters/database/mappers.py b/src/adapters/database/mappers.py index 8cdc165..0e76625 100644 --- a/src/adapters/database/mappers.py +++ b/src/adapters/database/mappers.py @@ -4,7 +4,9 @@ These functions translate between the persistence layer and the domain layer. """ -from ...core.models.bus import RouteIdentifier +from typing import cast + +from ...core.models.bus import BusDirection, RouteIdentifier from ...core.models.trip import Trip from ...core.models.user import User from ...core.models.user_history import UserHistory @@ -79,11 +81,11 @@ def map_trip_db_to_domain(trip_db: TripDB) -> Trip: email=trip_db.email, route=RouteIdentifier( bus_line=trip_db.bus_line, - bus_direction=trip_db.bus_direction, + bus_direction=cast(BusDirection, trip_db.bus_direction), ), distance=trip_db.distance, score=trip_db.score, - trip_date=trip_db.trip_date, + trip_datetime=trip_db.trip_datetime, ) @@ -103,7 +105,7 @@ def map_trip_domain_to_db(trip: Trip) -> TripDB: bus_direction=trip.route.bus_direction, distance=trip.distance, score=trip.score, - trip_date=trip.trip_date, + trip_datetime=trip.trip_datetime, ) diff --git a/src/adapters/database/models.py b/src/adapters/database/models.py index 1045b60..34a110f 100644 --- a/src/adapters/database/models.py +++ b/src/adapters/database/models.py @@ -50,7 +50,7 @@ class TripDB(Base): bus_direction: Mapped[int] = mapped_column(Integer, nullable=False) distance: Mapped[int] = mapped_column(Integer, nullable=False) score: Mapped[int] = mapped_column(Integer, nullable=False) - trip_date: Mapped[datetime] = mapped_column(DateTime, nullable=False) + trip_datetime: Mapped[datetime] = mapped_column(DateTime, nullable=False) user: Mapped["UserDB"] = relationship("UserDB", back_populates="trips") diff --git a/src/adapters/external/models/LineInfo.py b/src/adapters/external/models/LineInfo.py deleted file mode 100644 index 2bd46fa..0000000 --- a/src/adapters/external/models/LineInfo.py +++ /dev/null @@ -1,11 +0,0 @@ -from typing import TypedDict - - -class LineInfo(TypedDict): - cl: int # código da linha (código interno SPTrans) - lc: bool # sentido preferencial (circular / comum) - lt: str # número da linha (ex: "8000") - sl: int # sentido (0: ida, 1: volta) - tl: int # tipo da linha (10 = urbana, etc.) - tp: str # ponto inicial - ts: str # ponto final diff --git a/src/adapters/external/models/SPTransPosResp.py b/src/adapters/external/models/SPTransPosResp.py deleted file mode 100644 index c787065..0000000 --- a/src/adapters/external/models/SPTransPosResp.py +++ /dev/null @@ -1,14 +0,0 @@ -from typing import TypedDict - - -class Vehicle(TypedDict): - p: str # prefixo do ônibus - a: bool # acessibilidade - ta: str # timestamp ISO - py: float # latitude - px: float # longitude - - -class SPTransPositionsResponse(TypedDict): - hr: str # horário da resposta - vs: list[Vehicle] # lista de veículos diff --git a/src/adapters/external/sptrans_adapter.py b/src/adapters/external/sptrans_adapter.py index a704620..f9036fe 100644 --- a/src/adapters/external/sptrans_adapter.py +++ b/src/adapters/external/sptrans_adapter.py @@ -5,18 +5,18 @@ with the SPTrans API. """ -from datetime import datetime - import httpx from httpx import Response from src.config import settings -from ...adapters.external.models.LineInfo import LineInfo -from ...adapters.external.models.SPTransPosResp import SPTransPositionsResponse, Vehicle -from ...core.models.bus import BusPosition, BusRoute, RouteIdentifier -from ...core.models.coordinate import Coordinate +from ...core.models.bus import BusPosition, BusRoute from ...core.ports.bus_provider_port import BusProviderPort +from .sptrans_mappers import ( + map_positions_response_to_bus_positions, + map_search_response_to_bus_route_list, +) +from .sptrans_schemas import SPTransLineSearchResponse, SPTransPositionsResponse class SpTransAdapter(BusProviderPort): @@ -29,11 +29,12 @@ def __init__( Initialize the SPTrans adapter. Args: - api_token: API authentication token (optional) - base_url: Base URL for the SPTrans API (optional) + api_token: API authentication token (optional, defaults to settings) + base_url: Base URL for the SPTrans API (optional, defaults to settings) + + Raises: + ValueError: If no API token is provided or configured. """ - # se o parâmetro foi passado, usa ele; - # senão, usa o valor do settings carregado do .env self.api_token = api_token or settings.sptrans_api_token self.base_url = base_url or settings.sptrans_base_url @@ -42,15 +43,27 @@ def __init__( "SPTransAdapter: nenhum token fornecido e SPTRANS_API_TOKEN não definido no ambiente/.env." ) - self.session_token: str | None = None + self._authenticated: bool = False self.client = httpx.AsyncClient(base_url=self.base_url, timeout=30.0) - async def authenticate(self) -> bool: + async def _ensure_authenticated(self) -> None: + """ + Ensure the client is authenticated, authenticating if necessary. + + Raises: + RuntimeError: If authentication fails. + """ + if not self._authenticated: + success = await self._authenticate() + if not success: + raise RuntimeError("SPTrans authentication failed") + + async def _authenticate(self) -> bool: """ Authenticate with the SPTrans API. Returns: - True if authentication successful, False otherwise + True if authentication successful, False otherwise. """ try: response: Response = await self.client.post( @@ -59,131 +72,111 @@ async def authenticate(self) -> bool: ) if response.status_code == 200 and response.text == "true": - # marca que autenticou com sucesso - self.session_token = "authenticated" + self._authenticated = True return True return False - except Exception as e: - exc: Exception = e - print(f"Authentication failed: {exc}") + except Exception: return False - async def get_bus_positions(self, bus_route: BusRoute) -> list[BusPosition]: + def _is_unauthorized_response(self, response: Response) -> bool: """ - Get real-time positions for a specific route using SPTrans data. - - Pré-condição: - O método `authenticate()` deve ter sido chamado com sucesso antes de - usar este método. + Check if response indicates authentication is required. Args: - bus_route: BusRoute(route_id=cl, route=RouteIdentifier) + response: HTTP response to check. Returns: - List of BusPosition objects. + True if response indicates unauthorized access. """ - # Checagem defensiva opcional - if getattr(self, "session_token", None) != "authenticated": - raise RuntimeError("SPTrans client not authenticated. Call `authenticate()` first.") - - positions: list[BusPosition] = [] + if response.status_code != 401: + return False try: - line_code: int = bus_route.route_id - - response: Response = await self.client.get( - "/Posicao/Linha", - params={"codigoLinha": line_code}, - ) - - if response.status_code != 200: - raise RuntimeError( - f"SPTrans returned status {response.status_code} for line {bus_route}" - ) - - response_data: SPTransPositionsResponse = response.json() - - vehicles: list[Vehicle] = response_data["vs"] - - for vehicle in vehicles: - pos: BusPosition = BusPosition( - route=bus_route.route, - position=Coordinate( - latitude=vehicle["py"], - longitude=vehicle["px"], - ), - time_updated=datetime.fromisoformat(vehicle["ta"]), - ) - positions.append(pos) - - except Exception as e: - exc: Exception = e - print(f"Failed to get positions for bus_route {bus_route}: {exc}") - - return positions + data = response.json() + return "Authorization has been denied" in data.get("Message", "") + except Exception: + return False - async def get_route_details(self, route: RouteIdentifier) -> list[BusRoute]: + async def _request_with_auth_retry( + self, + method: str, + url: str, + params: dict[str, str | int] | None = None, + ) -> Response: """ - Resolve a logical bus line (bus_line) into all SPTrans BusRoute entries using - the `/Linha/Buscar` endpoint. + Make an HTTP request with automatic authentication retry on 401. Args: - route (RouteIdentifier): Logical bus line (ex: "8000") + method: HTTP method (GET, POST, etc.) + url: Request URL path. + params: Query parameters for the request. Returns: - list[BusRoute]: Todas as variantes da linha retornadas pela SPTrans. + HTTP response. Raises: - RuntimeError: Se a requisição falhar, vier vazia ou inválida. + RuntimeError: If authentication fails after retry. """ + await self._ensure_authenticated() - # Verifica se está autenticado - if getattr(self, "session_token", None) != "authenticated": - raise RuntimeError("SPTrans client not authenticated. Call `authenticate()` first.") + response = await self.client.request(method, url, params=params) - try: - response: Response = await self.client.get( - "/Linha/Buscar", - params={"termosBusca": route.bus_line}, - ) + if self._is_unauthorized_response(response): + self._authenticated = False + await self._ensure_authenticated() + response = await self.client.request(method, url, params=params) - if response.status_code != 200: - raise RuntimeError( - f"SPTrans returned status {response.status_code} for line search." - ) + if self._is_unauthorized_response(response): + raise RuntimeError("SPTrans authentication failed after retry") - data: list[LineInfo] = response.json() + return response + + async def get_bus_positions( + self, + route_id: int, + ) -> list[BusPosition]: + """ + Get real-time positions for specified routes. - if not isinstance(data, list) or len(data) == 0: - raise RuntimeError(f"No SPTrans line found for line={route.bus_line}") + Args: + route_ids: List of provider-specific route IDs. - bus_routes: list[BusRoute] = [] + Returns: + List of BusPosition objects with route_id and coordinates. + """ - for item in data: - # Validate based on TypedDict keys - if "cl" not in item or "lt" not in item: - continue # Skip invalid entries + response = await self._request_with_auth_retry( + "GET", + "/Posicao/Linha", + params={"codigoLinha": route_id}, + ) - line_code = item["cl"] - line_text = item["lt"] - line_dir = item["sl"] + response_data = SPTransPositionsResponse.model_validate(response.json()) + route_positions = map_positions_response_to_bus_positions(response_data, route_id) - bus_routes.append( - BusRoute( - route_id=line_code, - route=RouteIdentifier(bus_line=line_text, bus_direction=line_dir), - ) - ) + return route_positions - if not bus_routes: - raise RuntimeError( - f"Invalid SPTrans response for line={route.bus_line}: " - "missing required fields" - ) + async def search_routes(self, query: str) -> list[BusRoute]: + """ + Search for bus routes matching a query string. - return bus_routes + Args: + query: Search term (e.g., "809" or "Vila Nova Conceição"). - except Exception as e: - raise RuntimeError(f"Failed to resolve route details for {route}: {e}") from e + Returns: + List of matching BusRoute objects. + """ + response = await self._request_with_auth_retry( + "GET", + "/Linha/Buscar", + params={"termosBusca": query}, + ) + + if response.status_code != 200: + raise RuntimeError(f"SPTrans returned status {response.status_code} for search.") + + response_data = SPTransLineSearchResponse.model_validate(response.json()) + bus_routes: list[BusRoute] = map_search_response_to_bus_route_list(response_data) + return bus_routes diff --git a/src/adapters/external/sptrans_mappers.py b/src/adapters/external/sptrans_mappers.py new file mode 100644 index 0000000..8148f5f --- /dev/null +++ b/src/adapters/external/sptrans_mappers.py @@ -0,0 +1,121 @@ +""" +Mappers for converting SPTrans API schemas to domain models. + +These mappers handle the transformation between external API DTOs +and internal domain objects, keeping the adapter layer clean. +""" + +from typing import cast + +from ...core.models.bus import BusDirection, BusPosition, BusRoute, RouteIdentifier +from ...core.models.coordinate import Coordinate +from .sptrans_schemas import ( + SPTransLineInfo, + SPTransLineSearchResponse, + SPTransPositionsResponse, + SPTransVehicle, +) + + +def map_search_response_to_bus_route_list( + data: SPTransLineSearchResponse, +) -> list[BusRoute]: + """ + Convert API search response to list of BusRoute domain objects. + + Args: + data: SPTransLineSearchResponse object. + + Returns: + List of BusRoute domain objects. + """ + bus_route_list: list[BusRoute] = [] + for item in data.root: + bus_route_list.append(map_line_info_to_bus_route(item)) + return bus_route_list + + +def map_line_info_to_bus_route(line_info: SPTransLineInfo) -> BusRoute: + """ + Convert SPTrans line info to domain BusRoute. + + Args: + line_info: SPTrans line information. + + Returns: + Domain BusRoute with route_id and identifier. + """ + # Terminal name: primary if direction=1 (ida), secondary if direction=2 (volta) + terminal_name = ( + line_info.primary_terminal if line_info.direction == 1 else line_info.secondary_terminal + ) + return BusRoute( + route_id=line_info.route_id, + route=map_line_info_to_route_identifier(line_info), + is_circular=line_info.is_circular, + terminal_name=terminal_name, + ) + + +def map_line_info_to_route_identifier(line_info: SPTransLineInfo) -> RouteIdentifier: + """ + Convert SPTrans line info to domain RouteIdentifier. + + Args: + line_info: SPTrans line information. + + Returns: + Domain RouteIdentifier with formatted bus_line (line_number-line_sufix format). + """ + bus_line = f"{line_info.line_number}-{line_info.line_sufix}" + return RouteIdentifier( + bus_line=bus_line, + bus_direction=cast(BusDirection, line_info.direction), + ) + + +def map_positions_response_to_bus_positions( + data: SPTransPositionsResponse, + route_id: int, +) -> list[BusPosition]: + """ + Convert API positions response to list of BusPosition domain objects. + + Args: + data: SPTransPositionsResponse object. + route_id: Provider-specific route identifier. + + Returns: + List of domain BusPosition objects. + """ + positions: list[BusPosition] = [] + + for vehicle in data.vehicles: + position = map_vehicle_to_bus_position(vehicle, route_id) + positions.append(position) + + return positions + + +def map_vehicle_to_bus_position( + vehicle: SPTransVehicle, + route_id: int, +) -> BusPosition: + """ + Convert SPTrans vehicle to domain BusPosition. + + Args: + vehicle: SPTransVehicle position data. + route_id: Provider-specific route identifier. + + Returns: + Domain BusPosition with coordinates and route_id. + """ + return BusPosition( + route_id=route_id, + position=Coordinate( + latitude=vehicle.latitude, + longitude=vehicle.longitude, + ), + time_updated=vehicle.time_updated, + ) diff --git a/src/adapters/external/sptrans_schemas.py b/src/adapters/external/sptrans_schemas.py index bd8d7e3..05cb8a6 100644 --- a/src/adapters/external/sptrans_schemas.py +++ b/src/adapters/external/sptrans_schemas.py @@ -1,5 +1,4 @@ -""" -SPTrans API-specific schemas. +"""SPTrans API-specific schemas. These DTOs are used exclusively for communication with the SPTrans API and should not leak into the domain or web layers. @@ -7,35 +6,47 @@ from datetime import datetime -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, RootModel + + +class SPTransLineInfo(BaseModel): + """Schema for SPTrans line information response.""" + + route_id: int = Field(..., alias="cl", description="Route code (internal SPTrans code)") + is_circular: bool = Field(..., alias="lc", description="Is circular route") + line_number: str = Field(..., alias="lt", description="Line number (e.g., '8000')") + direction: int = Field(..., alias="sl", description="Direction (1 = ida, 2 = volta)") + line_sufix: int = Field(..., alias="tl", description="Line type (10 = urban, etc.)") + primary_terminal: str = Field(..., alias="tp", description="Primary terminal") + secondary_terminal: str = Field(..., alias="ts", description="Secondary terminal") + + model_config = {"populate_by_name": True} -class SPTransRouteResponse(BaseModel): - """Schema for SPTrans route response.""" +class SPTransLineSearchResponse(RootModel[list[SPTransLineInfo]]): + """Schema for SPTrans line search response item.""" - cl: int = Field(..., description="Route code") - lc: bool = Field(..., description="Is circular") - lt: str = Field(..., description="Main direction") - tl: int = Field(..., description="Type") - sl: int = Field(..., description="Secondary type") - tp: str = Field(..., description="Terminal principal") - ts: str = Field(..., description="Terminal secundário") + root: list[SPTransLineInfo] = Field(..., description="List of line info results") -class SPTransVehicleResponse(BaseModel): - """Schema for SPTrans vehicle position response.""" +class SPTransVehicle(BaseModel): + """Schema for SPTrans vehicle position.""" - p: int = Field(..., description="Route code") - a: bool = Field(..., description="Is accessible") - ta: datetime = Field(..., description="Time updated") - px: int = Field(..., description="Longitude * 10^6") - py: int = Field(..., description="Latitude * 10^6") + vehicle_prefix: str = Field(..., alias="p", description="Vehicle prefix") + is_accessible: bool = Field(..., alias="a", description="Is accessible") + time_updated: datetime = Field(..., alias="ta", description="Time updated") + latitude: float = Field(..., alias="py", description="Latitude") + longitude: float = Field(..., alias="px", description="Longitude") + + model_config = {"populate_by_name": True} class SPTransPositionsResponse(BaseModel): - """Schema for SPTrans positions response.""" + """Schema for SPTrans positions API response.""" - hr: datetime = Field(..., alias="currentTime", description="Current time") - vs: list[SPTransVehicleResponse] = Field(..., alias="vehicles", description="List of vehicles") + response_time: str = Field(..., alias="hr", description="Response time") + vehicles: list[SPTransVehicle] = Field( + default_factory=list, alias="vs", description="List of vehicles" + ) model_config = {"populate_by_name": True} diff --git a/src/adapters/repositories/gtfs_repository_adapter.py b/src/adapters/repositories/gtfs_repository_adapter.py index 0ca250e..22043a9 100644 --- a/src/adapters/repositories/gtfs_repository_adapter.py +++ b/src/adapters/repositories/gtfs_repository_adapter.py @@ -56,7 +56,7 @@ def get_route_shape(self, route: RouteIdentifier) -> RouteShape | None: (shape_id,), ) - points = [] + points: list[RouteShapePoint] = [] for row in cursor.fetchall(): point = RouteShapePoint( coordinate=Coordinate( diff --git a/src/core/models/bus.py b/src/core/models/bus.py index 3c497dd..19e3ebe 100644 --- a/src/core/models/bus.py +++ b/src/core/models/bus.py @@ -1,13 +1,16 @@ """Bus-related domain models.""" -from dataclasses import dataclass from datetime import datetime +from typing import Literal + +from pydantic import BaseModel, Field from .coordinate import Coordinate +BusDirection = Literal[1, 2] + -@dataclass -class RouteIdentifier: +class RouteIdentifier(BaseModel): """ Identifier for a bus route. @@ -17,34 +20,45 @@ class RouteIdentifier: """ bus_line: str - bus_direction: int + bus_direction: BusDirection = Field(..., description="Direction (1 = ida, 2 = volta)") + model_config = {"frozen": True} -@dataclass -class BusRoute: + +class BusRoute(BaseModel): """ Bus route information. Attributes: route_id: Unique identifier for this route route: Route identifier containing line and direction + is_circular: Whether the route is circular + terminal_name: Terminal name (primary if direction=1, secondary if direction=2) """ route_id: int route: RouteIdentifier + is_circular: bool = Field(..., description="Whether the route is circular") + terminal_name: str = Field( + ..., + description="Terminal name (primary if direction=1, secondary if direction=2)", + ) + + model_config = {"frozen": True} -@dataclass -class BusPosition: +class BusPosition(BaseModel): """ Real-time position of a bus. Attributes: - route: Route identifier for this bus + route_id: Provider-specific route identifier position: Geographic coordinate of the bus time_updated: Last update timestamp """ - route: RouteIdentifier + route_id: int position: Coordinate time_updated: datetime + + model_config = {"frozen": True} diff --git a/src/core/models/coordinate.py b/src/core/models/coordinate.py index d2bb2ea..977aec2 100644 --- a/src/core/models/coordinate.py +++ b/src/core/models/coordinate.py @@ -1,10 +1,9 @@ """Coordinate domain model.""" -from dataclasses import dataclass +from pydantic import BaseModel -@dataclass -class Coordinate: +class Coordinate(BaseModel): """ Geographic coordinate representation. @@ -15,3 +14,5 @@ class Coordinate: latitude: float longitude: float + + model_config = {"frozen": True} diff --git a/src/core/models/trip.py b/src/core/models/trip.py index 00912b3..73c022f 100644 --- a/src/core/models/trip.py +++ b/src/core/models/trip.py @@ -16,11 +16,11 @@ class Trip: route: Route identifier containing bus_line and bus_direction distance: Distance traveled in meters score: Points earned from this trip - trip_date: When the trip occurred + trip_datetime: When the trip occurred """ email: str route: RouteIdentifier distance: int score: int - trip_date: datetime + trip_datetime: datetime diff --git a/src/core/ports/bus_provider_port.py b/src/core/ports/bus_provider_port.py index d3199e5..d7fff0d 100644 --- a/src/core/ports/bus_provider_port.py +++ b/src/core/ports/bus_provider_port.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod -from ..models.bus import BusPosition, BusRoute, RouteIdentifier +from ..models.bus import BusPosition, BusRoute class BusProviderPort(ABC): @@ -11,46 +11,40 @@ class BusProviderPort(ABC): This port defines the contract for interacting with external bus tracking APIs (e.g., SPTrans, NextBus, etc.). + Authentication is managed internally by implementations. """ @abstractmethod - async def authenticate(self) -> bool: + async def get_bus_positions( + self, + route_id: int, + ) -> list[BusPosition]: """ - Authenticate with the bus tracking API. - - Returns: - True if authentication successful, False otherwise - """ - pass - - @abstractmethod - async def get_bus_positions(self, bus_route: BusRoute) -> list[BusPosition]: - """ - Get real-time positions for specified bus routes. + Get real-time positions for specified routes. Args: - routes: List of route identifiers to query + route_ids: List of provider-specific route IDs. Returns: - List of current bus positions + List of current bus positions with route_id. Raises: - Exception: If API call fails or authentication required + RuntimeError: If API call fails or authentication fails. """ pass @abstractmethod - async def get_route_details(self, route: RouteIdentifier) -> list[BusRoute]: + async def search_routes(self, query: str) -> list[BusRoute]: """ - Get detailed information about a specific route. + Search for bus routes matching a query string. Args: - route: Route identifier to query + query: Search term (e.g., route number or destination name). Returns: - Bus route details + List of matching bus routes. Raises: - Exception: If route not found or API call fails + RuntimeError: If API call fails or authentication fails. """ pass diff --git a/src/core/services/history_service.py b/src/core/services/history_service.py index c4eeaa3..ca70330 100644 --- a/src/core/services/history_service.py +++ b/src/core/services/history_service.py @@ -35,7 +35,7 @@ async def get_user_history(self, email: str) -> list[HistoryEntry]: return [ HistoryEntry( - date=trip.trip_date, + date=trip.trip_datetime, score=trip.score, route=trip.route, ) diff --git a/src/core/services/route_service.py b/src/core/services/route_service.py index 48858c1..ff8cfce 100644 --- a/src/core/services/route_service.py +++ b/src/core/services/route_service.py @@ -19,46 +19,48 @@ def __init__(self, bus_provider: BusProviderPort, gtfs_repository: GTFSRepositor Initialize the route service. Args: - bus_provider: Implementation of BusProviderPort - gtfs_repository: Implementation of GTFSRepositoryPort + bus_provider: Implementation of BusProviderPort. + gtfs_repository: Implementation of GTFSRepositoryPort. """ self.bus_provider = bus_provider self.gtfs_repository = gtfs_repository - async def get_bus_positions(self, bus_route: BusRoute) -> list[BusPosition]: + async def get_bus_positions( + self, + route_ids: list[int], + ) -> list[BusPosition]: """ - Get current positions for specified bus routes. + Get current positions for specified routes. Args: - routes: List of route identifiers to query + route_ids: List of provider-specific route IDs. Returns: - List of current bus positions + List of current bus positions. Raises: - Exception: If API authentication fails or request fails + RuntimeError: If API request fails. """ - # Ensure we're authenticated - await self.bus_provider.authenticate() + positions: list[BusPosition] = [] + for route_id in route_ids: + route_positions: list[BusPosition] = await self.bus_provider.get_bus_positions(route_id) + positions.extend(route_positions) + return positions - # Get positions - return await self.bus_provider.get_bus_positions(bus_route) - - async def get_route_details(self, route: RouteIdentifier) -> list[BusRoute]: + async def search_routes(self, query: str) -> list[BusRoute]: """ - Get detailed information about a route. + Search for bus routes matching a query string. Args: - route: Route identifier + query: Search term (e.g., "809" or "Vila Nova Conceição"). Returns: - Route details + List of matching bus routes. Raises: - Exception: If route not found or API fails + RuntimeError: If API request fails. """ - await self.bus_provider.authenticate() - return await self.bus_provider.get_route_details(route) + return await self.bus_provider.search_routes(query) def get_route_shapes(self, routes: list[RouteIdentifier]) -> list[RouteShape]: """ diff --git a/src/core/services/trip_service.py b/src/core/services/trip_service.py index 9e17aae..7301bc2 100644 --- a/src/core/services/trip_service.py +++ b/src/core/services/trip_service.py @@ -35,7 +35,7 @@ async def create_trip( email: str, route: RouteIdentifier, distance: int, - trip_date: datetime, + trip_datetime: datetime, ) -> Trip: """ Create a new trip and update user score. @@ -47,7 +47,7 @@ async def create_trip( email: User's email route: Route identifier containing bus_line and bus_direction distance: Distance traveled in meters - trip_date: When the trip occurred + trip_datetime: When the trip occurred Returns: The created trip with calculated score @@ -69,7 +69,7 @@ async def create_trip( route=route, distance=distance, score=score, - trip_date=trip_date, + trip_datetime=trip_datetime, ) saved_trip = await self.trip_repository.save_trip(trip) diff --git a/src/web/controllers/route_controller.py b/src/web/controllers/route_controller.py index 85eba5b..0a104ec 100644 --- a/src/web/controllers/route_controller.py +++ b/src/web/controllers/route_controller.py @@ -4,27 +4,26 @@ This controller handles queries for real-time bus information and route shapes. """ -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, Query, status from ...adapters.external.sptrans_adapter import SpTransAdapter from ...adapters.repositories.gtfs_repository_adapter import GTFSRepositoryAdapter from ...config import settings -from ...core.models.bus import BusPosition, BusRoute, RouteIdentifier +from ...core.models.bus import BusPosition, BusRoute from ...core.models.user import User from ...core.services.route_service import RouteService from ..auth import get_current_user from ..mappers import ( map_bus_position_list_to_schema, + map_bus_route_domain_list_to_schema, + map_bus_route_request_list, map_route_identifier_schema_to_domain, map_route_shapes_to_response, ) from ..schemas import ( BusPositionsRequest, BusPositionsResponse, - BusRouteSchema, - BusRoutesDetailsRequest, - BusRoutesDetailsResponse, - RouteIdentifierSchema, + RouteSearchResponse, RouteShapesRequest, RouteShapesResponse, ) @@ -37,7 +36,7 @@ def get_route_service() -> RouteService: Dependency that provides a RouteService instance. Returns: - Configured RouteService instance + Configured RouteService instance. """ bus_provider = SpTransAdapter( api_token=settings.sptrans_api_token, @@ -47,52 +46,33 @@ def get_route_service() -> RouteService: return RouteService(bus_provider, gtfs_repository) -# NOTE: Having `current_user: User = Depends(get_current_user)` as a dependency -# makes this endpoint only accessible to authenticated users (requires valid JWT token). -@router.post("/details", response_model=BusRoutesDetailsResponse) -async def get_route_details_endpoint( - request: BusRoutesDetailsRequest, +@router.get("/search", response_model=RouteSearchResponse) +async def search_routes_endpoint( + query: str = Query(..., min_length=1, description="Search term for routes"), route_service: RouteService = Depends(get_route_service), current_user: User = Depends(get_current_user), -) -> BusRoutesDetailsResponse: +) -> RouteSearchResponse: """ - Resolve, para cada linha solicitada, as rotas concretas do provedor - (por exemplo, diferentes variantes/direções internamente). + Search for bus routes matching a query string. - Entrada: lista de RouteIdentifierSchema (apenas bus_line). - Saída: lista "achatada" de BusRouteSchema (route_id + bus_line). - """ - try: - # Schemas -> domínio (RouteIdentifier) - route_identifiers: list[RouteIdentifier] = [ - map_route_identifier_schema_to_domain(route_schema) for route_schema in request.routes - ] + Args: + query: Search term (e.g., "809" or "Vila Nova Conceição"). + route_service: Injected route service. + current_user: Authenticated user. - bus_routes: list[BusRoute] = [] - - for route_identifier in route_identifiers: - # O service retorna uma lista de BusRoute - resolved_routes: list[BusRoute] = await route_service.get_route_details( - route_identifier - ) - bus_routes.extend(resolved_routes) - - # Domínio -> schemas - route_schemas: list[BusRouteSchema] = [ - BusRouteSchema( - route_id=bus_route.route_id, - route=RouteIdentifierSchema( - bus_line=bus_route.route.bus_line, - bus_direction=bus_route.route.bus_direction, - ), - ) - for bus_route in bus_routes - ] + Returns: + List of matching routes with provider IDs. - return BusRoutesDetailsResponse(routes=route_schemas) + Raises: + HTTPException: If search fails. + """ + try: + bus_routes: list[BusRoute] = await route_service.search_routes(query) + route_schemas = map_bus_route_domain_list_to_schema(bus_routes) + return RouteSearchResponse(routes=route_schemas) - except Exception as e: # caminho de erro genérico - detail: str = f"Failed to retrieve route details: {str(e)}" + except Exception as e: + detail: str = f"Failed to search routes: {str(e)}" raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=detail, @@ -108,34 +88,28 @@ async def get_bus_positions( current_user: User = Depends(get_current_user), ) -> BusPositionsResponse: """ - Recupera as posições dos ônibus para as rotas já resolvidas. - - Entrada: lista de BusRouteSchema (tipicamente saída de /routes/details). - Saída: lista de BusPositionSchema. - """ - try: - all_positions: list[BusPosition] = [] + Get real-time bus positions for specified routes. - for route_schema in request.routes: - # Schema -> domínio (BusRoute) - route_identifier = RouteIdentifier( - bus_line=route_schema.route.bus_line, - bus_direction=1, # direção default; SPTrans /Linha/Buscar não usa mais isso - ) + Args: + request: Request containing list of route_ids. + route_service: Injected route service. + current_user: Authenticated user. - bus_route = BusRoute( - route_id=route_schema.route_id, - route=route_identifier, - ) + Returns: + List of bus positions with route_id. - route_positions: list[BusPosition] = await route_service.get_bus_positions(bus_route) - all_positions.extend(route_positions) + Raises: + HTTPException: If fetching positions fails. + """ + try: + # Extract route_ids from request + route_ids = map_bus_route_request_list(request.routes) - # Domínio -> schemas - position_schemas = map_bus_position_list_to_schema(all_positions) + positions: list[BusPosition] = await route_service.get_bus_positions(route_ids) + position_schemas = map_bus_position_list_to_schema(positions) return BusPositionsResponse(buses=position_schemas) - except Exception as e: # caminho de erro genérico + except Exception as e: detail: str = f"Failed to retrieve bus positions: {str(e)}" raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, diff --git a/src/web/controllers/trip_controller.py b/src/web/controllers/trip_controller.py index e10424e..1a15038 100644 --- a/src/web/controllers/trip_controller.py +++ b/src/web/controllers/trip_controller.py @@ -7,10 +7,11 @@ from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.ext.asyncio import AsyncSession +from src.web.mappers import map_route_identifier_schema_to_domain + from ...adapters.database.connection import get_db from ...adapters.repositories.trip_repository_adapter import TripRepositoryAdapter from ...adapters.repositories.user_repository_adapter import UserRepositoryAdapter -from ...core.models.bus import RouteIdentifier from ...core.models.user import User from ...core.services.trip_service import TripService from ..auth import get_current_user @@ -57,15 +58,12 @@ async def create_trip( HTTPException: If user not found or validation fails """ try: - route = RouteIdentifier( - bus_line=request.route.bus_line, - bus_direction=request.route.bus_direction, - ) + route = map_route_identifier_schema_to_domain(request.route) trip = await trip_service.create_trip( email=current_user.email, route=route, distance=request.distance, - trip_date=request.data, + trip_datetime=request.trip_datetime, ) return CreateTripResponse(score=trip.score) diff --git a/src/web/mappers.py b/src/web/mappers.py index 84f3af2..54526d6 100644 --- a/src/web/mappers.py +++ b/src/web/mappers.py @@ -5,13 +5,17 @@ maintaining the separation of concerns. """ -from ..core.models.bus import BusPosition, RouteIdentifier +from typing import cast + +from ..core.models.bus import BusDirection, BusPosition, BusRoute, RouteIdentifier from ..core.models.coordinate import Coordinate from ..core.models.route_shape import RouteShape from ..core.models.user import User from ..core.models.user_history import HistoryEntry from .schemas import ( BusPositionSchema, + BusRouteRequestSchema, + BusRouteResponseSchema, CoordinateSchema, HistoryResponse, RouteIdentifierSchema, @@ -70,7 +74,7 @@ def map_route_identifier_schema_to_domain( """ return RouteIdentifier( bus_line=schema.bus_line, - bus_direction=schema.bus_direction, + bus_direction=cast(BusDirection, schema.bus_direction), ) @@ -92,6 +96,69 @@ def map_route_identifier_domain_to_schema( ) +def map_bus_route_request_to_route_id( + schema: BusRouteRequestSchema, +) -> int: + """ + Extract route_id from a BusRouteRequestSchema. + + Args: + schema: BusRouteRequestSchema from API request + + Returns: + route_id (int) + """ + return schema.route_id + + +def map_bus_route_request_list( + schemas: list[BusRouteRequestSchema], +) -> list[int]: + """ + Map a list of BusRouteRequestSchema to a list of route_ids. + + Args: + schemas: List of BusRouteRequestSchema from API request + + Returns: + List of route_ids + """ + return [schema.route_id for schema in schemas] + + +def map_bus_route_domain_to_schema(bus_route: BusRoute) -> BusRouteResponseSchema: + """ + Map a BusRoute domain model to a BusRouteResponseSchema. + + Args: + bus_route: BusRoute domain model + + Returns: + BusRouteResponseSchema for API response + """ + return BusRouteResponseSchema( + route_id=bus_route.route_id, + route=map_route_identifier_domain_to_schema(bus_route.route), + is_circular=bus_route.is_circular, + terminal_name=bus_route.terminal_name, + ) + + +def map_bus_route_domain_list_to_schema( + bus_routes: list[BusRoute], +) -> list[BusRouteResponseSchema]: + """ + Map a list of BusRoute domain models to BusRouteResponseSchema list. + + Args: + bus_routes: List of BusRoute domain models + + Returns: + List of BusRouteResponseSchema for API response + """ + return [map_bus_route_domain_to_schema(bus_route) for bus_route in bus_routes] + + def map_coordinate_domain_to_schema(coord: Coordinate) -> CoordinateSchema: """ Map a Coordinate domain model to a CoordinateSchema. @@ -119,7 +186,7 @@ def map_bus_position_domain_to_schema(position: BusPosition) -> BusPositionSchem BusPositionSchema for API """ return BusPositionSchema( - route=map_route_identifier_domain_to_schema(position.route), + route_id=position.route_id, position=map_coordinate_domain_to_schema(position.position), time_updated=position.time_updated, ) diff --git a/src/web/schemas.py b/src/web/schemas.py index 32b918b..e23ff6c 100644 --- a/src/web/schemas.py +++ b/src/web/schemas.py @@ -36,18 +36,35 @@ class CoordinateSchema(BaseModel): class BusPositionSchema(BaseModel): """Schema for bus position information.""" - route: RouteIdentifierSchema + route_id: int = Field(..., description="Provider-specific route identifier") position: CoordinateSchema time_updated: datetime = Field(..., description="Last update timestamp") model_config = {"populate_by_name": True} -class BusRouteSchema(BaseModel): - """Schema for a resolved bus route (provider-specific ID + identifier).""" +class BusRouteRequestSchema(BaseModel): + """Schema for bus route in requests (input from user). + + Only requires route_id for querying positions. + """ + + route_id: int = Field(..., description="Provider-specific route identifier") + + +class BusRouteResponseSchema(BaseModel): + """Schema for bus route in responses (output to user). + + Includes full route information with metadata. + """ route_id: int = Field(..., description="Provider-specific route identifier") route: RouteIdentifierSchema + is_circular: bool = Field(..., description="Whether the route is circular") + terminal_name: str = Field( + ..., + description="Terminal name (primary if direction=1, secondary if direction=2)", + ) # ===== User Management Schemas ===== @@ -91,7 +108,7 @@ class TokenResponse(BaseModel): class CreateTripRequest(BaseModel): route: RouteIdentifierSchema distance: int = Field(..., ge=0, description="Distance traveled in meters") - data: datetime = Field(..., description="Trip date and time") + trip_datetime: datetime = Field(..., description="Trip date and time") model_config = {"populate_by_name": True} @@ -108,8 +125,8 @@ class CreateTripResponse(BaseModel): class BusPositionsRequest(BaseModel): """Request schema for querying bus positions.""" - routes: list[BusRouteSchema] = Field( - ..., description="List of resolved routes (with route_id) to query positions" + routes: list[BusRouteRequestSchema] = Field( + ..., description="List of routes to query positions for" ) @@ -119,19 +136,21 @@ class BusPositionsResponse(BaseModel): buses: list[BusPositionSchema] = Field(..., description="List of bus positions") -class BusRoutesDetailsRequest(BaseModel): - """Request schema for resolving route details.""" +class RouteSearchRequest(BaseModel): + """Request schema for searching routes.""" - routes: list[RouteIdentifierSchema] = Field( - ..., description="List of routes (line + direction) to resolve" + query: str = Field( + ..., + min_length=1, + description="Search term (e.g., '809' or 'Vila Nova Conceição')", ) -class BusRoutesDetailsResponse(BaseModel): - """Response schema for route details.""" +class RouteSearchResponse(BaseModel): + """Response schema for route search results.""" - routes: list[BusRouteSchema] = Field( - ..., description="List of resolved routes with provider IDs" + routes: list[BusRouteResponseSchema] = Field( + ..., description="List of matching routes with provider IDs" ) diff --git a/tests/adapters/test_sptrans_adapter.py b/tests/adapters/test_sptrans_adapter.py index 6e2549e..55a9e3f 100644 --- a/tests/adapters/test_sptrans_adapter.py +++ b/tests/adapters/test_sptrans_adapter.py @@ -1,82 +1,78 @@ +import os + import pytest from src.adapters.external.sptrans_adapter import SpTransAdapter -from src.core.models.bus import BusPosition, BusRoute, RouteIdentifier +from src.core.models.bus import BusPosition, BusRoute + +skip_if_no_token = pytest.mark.skipif( + not os.getenv("SPTRANS_API_TOKEN"), + reason="SPTRANS_API_TOKEN not set - skipping integration test", +) +@skip_if_no_token @pytest.mark.asyncio -async def test_authentication() -> None: +async def test_automatic_authentication() -> None: """ - Real authentication test against SPTrans. - Requires SPTRANS_API_TOKEN to be configured. + Test that the adapter authenticates automatically when making requests. """ adapter: SpTransAdapter = SpTransAdapter() - ok: bool = await adapter.authenticate() - - print("Cookies recebidos:", adapter.client.cookies) + routes: list[BusRoute] = await adapter.search_routes("8075") - assert ok is True, "A autenticação real falhou. Verifique seu TOKEN." assert "apiCredentials" in adapter.client.cookies, "Cookie de credenciais não foi criado." + assert len(routes) > 0 +@skip_if_no_token @pytest.mark.asyncio -async def test_get_route_details_8075_direction_1() -> None: +async def test_search_routes_number() -> None: """ - Integration test: resolves the internal SPTrans 'codigoLinha' (cl) - for route 8075 using get_route_details(), which now returns a list - of BusRoute entries (diferentes sentidos/variações). + Searches for route number and validates the results. """ adapter: SpTransAdapter = SpTransAdapter() - await adapter.authenticate() - - # RouteIdentifier para linha 8075 (direção ainda existe no domínio) - route: RouteIdentifier = RouteIdentifier(bus_line="8075", bus_direction=1) - - # Agora get_route_details retorna list[BusRoute] - bus_routes: list[BusRoute] = await adapter.get_route_details(route) - - print("Retrieved BusRoutes:", bus_routes) + bus_routes: list[BusRoute] = await adapter.search_routes("8075") - # Validate result assert isinstance(bus_routes, list) assert len(bus_routes) > 0, "Nenhuma rota retornada para 8075" for bus_route in bus_routes: assert isinstance(bus_route.route_id, int), "route_id must be an integer" assert bus_route.route_id > 0, "route_id must be positive" - # a linha deve bater com o que pedimos - assert bus_route.route.bus_line == "8075" + assert "8075" in bus_route.route.bus_line +@skip_if_no_token @pytest.mark.asyncio -async def test_get_bus_positions_8075_direction_1() -> None: +async def test_search_routes_by_destination() -> None: """ - Integration test: fetches real-time positions for one concrete route - of line 8075 (direction 1, se disponível), usando - get_route_details() + get_bus_positions(), sem mocks. + Searches for routes by destination name. """ adapter: SpTransAdapter = SpTransAdapter() - await adapter.authenticate() + bus_routes: list[BusRoute] = await adapter.search_routes("Lapa") - route: RouteIdentifier = RouteIdentifier(bus_line="8075", bus_direction=1) + assert isinstance(bus_routes, list) + assert len(bus_routes) > 0, "Nenhuma rota retornada para Lapa" - # Step 1: resolve SPTrans internal codes (list[BusRoute]) - bus_routes: list[BusRoute] = await adapter.get_route_details(route) - assert len(bus_routes) > 0 - # escolhe uma rota com direction 1, se existir; senão, pega a primeira - chosen_route: BusRoute = next( - (br for br in bus_routes if getattr(br.route, "bus_direction", None) == 1), - bus_routes[0], - ) +@skip_if_no_token +@pytest.mark.asyncio +async def test_get_bus_positions() -> None: + """ + Fetches real-time positions for route 8075. + """ + adapter: SpTransAdapter = SpTransAdapter() + bus_routes: list[BusRoute] = await adapter.search_routes("8075") + assert len(bus_routes) > 0 + + chosen_route: BusRoute = bus_routes[0] assert chosen_route.route_id > 0 - # Step 2: fetch positions usando BusRoute concreto - positions: list[BusPosition] = await adapter.get_bus_positions(chosen_route) + positions: list[BusPosition] = await adapter.get_bus_positions(chosen_route.route_id) assert positions is not None assert isinstance(positions, list) @@ -84,36 +80,19 @@ async def test_get_bus_positions_8075_direction_1() -> None: if positions: pos: BusPosition = positions[0] - # Route info should match what we requested (mesma linha) - assert pos.route.bus_line == "8075" - - # se tiver direction em BusPosition.route, valida também - if hasattr(pos.route, "bus_direction"): - assert pos.route.bus_direction in (1, 2) - - # Coordenadas assert isinstance(pos.position.latitude, float | int) assert isinstance(pos.position.longitude, float | int) +@skip_if_no_token @pytest.mark.asyncio -async def test_get_route_details_without_authentication(): - adapter: SpTransAdapter = SpTransAdapter(api_token="INVALID") - - route = RouteIdentifier(bus_line="8075", bus_direction=1) - - with pytest.raises(RuntimeError, match="not authenticated"): - await adapter.get_route_details(route) - - -@pytest.mark.asyncio -async def test_get_bus_positions_without_authentication(): +async def test_search_routes_returns_empty_for_unknown() -> None: + """ + Test that search returns empty list for unknown routes. + """ adapter: SpTransAdapter = SpTransAdapter() - bus_route = BusRoute( - route_id=8075, - route=RouteIdentifier(bus_line="8075", bus_direction=1), - ) + routes: list[BusRoute] = await adapter.search_routes("XYZNONEXISTENT999") - with pytest.raises(RuntimeError, match="not authenticated"): - await adapter.get_bus_positions(bus_route) + assert isinstance(routes, list) + assert len(routes) == 0 diff --git a/tests/adapters/test_trip_repository_adapter.py b/tests/adapters/test_trip_repository_adapter.py index 34668dc..0ea1688 100644 --- a/tests/adapters/test_trip_repository_adapter.py +++ b/tests/adapters/test_trip_repository_adapter.py @@ -35,7 +35,7 @@ def _make_domain_trip() -> object: route=RouteIdentifier(bus_line="8000", bus_direction=1), distance=1000, score=10, - trip_date=datetime(2025, 1, 1, 8, 0, 0), + trip_datetime=datetime(2025, 1, 1, 8, 0, 0), ) @@ -51,7 +51,7 @@ async def test_save_trip_unit(monkeypatch) -> None: bus_direction=1, distance=1000, score=10, - trip_date=datetime(2025, 1, 1, 8, 0, 0), + trip_datetime=datetime(2025, 1, 1, 8, 0, 0), ) monkeypatch.setattr(adapter_mod, "map_trip_domain_to_db", lambda t: dummy_db_obj) @@ -67,7 +67,7 @@ async def test_save_trip_unit(monkeypatch) -> None: ), distance=db.distance, score=db.score, - trip_date=db.trip_date, + trip_datetime=db.trip_datetime, ), ) diff --git a/tests/adapters/test_user_history_repository_adapter.py b/tests/adapters/test_user_history_repository_adapter.py index 7cd248d..5c1224d 100644 --- a/tests/adapters/test_user_history_repository_adapter.py +++ b/tests/adapters/test_user_history_repository_adapter.py @@ -20,7 +20,7 @@ def __init__(self, email: str, bus_line: str, bus_direction: int, score: int) -> self.bus_direction = bus_direction self.distance = 100 self.score = score - self.trip_date = datetime(2025, 1, 1) + self.trip_datetime = datetime(2025, 1, 1) class _DummyUser: diff --git a/tests/core/test_history_service.py b/tests/core/test_history_service.py index c299cb8..e008420 100644 --- a/tests/core/test_history_service.py +++ b/tests/core/test_history_service.py @@ -14,13 +14,13 @@ async def test_get_user_history_timezone_aware() -> None: history_repo = create_autospec(UserHistoryRepository, instance=True) - trip_date = datetime(2025, 10, 16, 10, 0, 0, tzinfo=UTC) + trip_datetime = datetime(2025, 10, 16, 10, 0, 0, tzinfo=UTC) trip = Trip( email="tz@example.com", route=RouteIdentifier(bus_line="8000", bus_direction=1), distance=1000, score=10, - trip_date=trip_date, + trip_datetime=trip_datetime, ) user_history = UserHistory(email="tz@example.com", trips=[trip]) @@ -32,7 +32,7 @@ async def test_get_user_history_timezone_aware() -> None: result = await service.get_user_history("tz@example.com") assert len(result) == 1 - assert result[0].date == trip_date + assert result[0].date == trip_datetime assert result[0].score == 10 assert result[0].route.bus_line == "8000" assert result[0].route.bus_direction == 1 @@ -56,13 +56,13 @@ async def test_get_user_history_no_data() -> None: async def test_get_user_history_single_entry() -> None: history_repo = create_autospec(UserHistoryRepository, instance=True) - trip_date = datetime(2025, 10, 16, 10, 0, 0) + trip_datetime = datetime(2025, 10, 16, 10, 0, 0) trip = Trip( email="test@example.com", route=RouteIdentifier(bus_line="8000", bus_direction=1), distance=1000, score=10, - trip_date=trip_date, + trip_datetime=trip_datetime, ) user_history = UserHistory(email="test@example.com", trips=[trip]) @@ -74,7 +74,7 @@ async def test_get_user_history_single_entry() -> None: result = await service.get_user_history("test@example.com") assert len(result) == 1 - assert result[0].date == trip_date + assert result[0].date == trip_datetime assert result[0].score == 10 assert result[0].route.bus_line == "8000" assert result[0].route.bus_direction == 1 @@ -95,21 +95,21 @@ async def test_get_user_history_multiple_entries() -> None: route=RouteIdentifier(bus_line="8000", bus_direction=1), distance=1000, score=10, - trip_date=trip1_date, + trip_datetime=trip1_date, ), Trip( email="multi@example.com", route=RouteIdentifier(bus_line="9000", bus_direction=2), distance=2000, score=20, - trip_date=trip2_date, + trip_datetime=trip2_date, ), Trip( email="multi@example.com", route=RouteIdentifier(bus_line="7000", bus_direction=1), distance=3000, score=30, - trip_date=trip3_date, + trip_datetime=trip3_date, ), ] diff --git a/tests/core/test_route_service.py b/tests/core/test_route_service.py index 8b62587..de69d40 100644 --- a/tests/core/test_route_service.py +++ b/tests/core/test_route_service.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from datetime import UTC, datetime from typing import cast from unittest.mock import AsyncMock, Mock, create_autospec @@ -15,53 +13,35 @@ @pytest.mark.asyncio -async def test_get_bus_positions_calls_auth_and_provider() -> None: - # Arrange +async def test_get_bus_positions_calls_provider() -> None: raw_provider: Mock = Mock(spec=BusProviderPort) - raw_provider.authenticate = AsyncMock(return_value=True) raw_provider.get_bus_positions = AsyncMock() bus_provider: BusProviderPort = cast(BusProviderPort, raw_provider) - route_identifier: RouteIdentifier = RouteIdentifier( - bus_line="8075", - bus_direction=1, - ) - - bus_route: BusRoute = BusRoute( - route_id=1234, - route=route_identifier, - ) - expected_positions: list[BusPosition] = [ BusPosition( - route=route_identifier, + route_id=1234, position=Coordinate(latitude=-23.0, longitude=-46.0), time_updated=datetime.now(UTC), ), ] - # configurando retorno tipado do mock raw_provider.get_bus_positions.return_value = expected_positions gtfs_repo = create_autospec(GTFSRepositoryPort, instance=True) service: RouteService = RouteService(bus_provider=bus_provider, gtfs_repository=gtfs_repo) - # Act - result: list[BusPosition] = await service.get_bus_positions(bus_route) + result: list[BusPosition] = await service.get_bus_positions([1234]) - # Assert - raw_provider.authenticate.assert_awaited_once() - raw_provider.get_bus_positions.assert_awaited_once_with(bus_route) + raw_provider.get_bus_positions.assert_awaited_once_with(1234) assert result == expected_positions @pytest.mark.asyncio -async def test_get_route_details_calls_auth_and_provider() -> None: - # Arrange +async def test_search_routes_calls_provider() -> None: raw_provider: Mock = Mock(spec=BusProviderPort) - raw_provider.authenticate = AsyncMock(return_value=True) - raw_provider.get_route_details = AsyncMock() + raw_provider.search_routes = AsyncMock() bus_provider: BusProviderPort = cast(BusProviderPort, raw_provider) @@ -73,80 +53,56 @@ async def test_get_route_details_calls_auth_and_provider() -> None: expected_bus_route: BusRoute = BusRoute( route_id=1234, route=route_identifier, + is_circular=False, + terminal_name="Terminal A", ) expected_routes: list[BusRoute] = [expected_bus_route] - # agora o provider também retorna lista - raw_provider.get_route_details.return_value = expected_routes + raw_provider.search_routes.return_value = expected_routes gtfs_repo = create_autospec(GTFSRepositoryPort, instance=True) service: RouteService = RouteService(bus_provider=bus_provider, gtfs_repository=gtfs_repo) - # Act - result: list[BusRoute] = await service.get_route_details(route_identifier) + query = "8075" + result: list[BusRoute] = await service.search_routes(query) - # Assert - raw_provider.authenticate.assert_awaited_once() - raw_provider.get_route_details.assert_awaited_once_with(route_identifier) + raw_provider.search_routes.assert_awaited_once_with(query) assert result == expected_routes @pytest.mark.asyncio async def test_get_bus_positions_propagates_exception_from_provider() -> None: - # Arrange + """Test that exceptions from the provider are propagated.""" raw_provider: Mock = Mock(spec=BusProviderPort) - raw_provider.authenticate = AsyncMock(return_value=True) raw_provider.get_bus_positions = AsyncMock(side_effect=RuntimeError("boom")) bus_provider: BusProviderPort = cast(BusProviderPort, raw_provider) - route_identifier: RouteIdentifier = RouteIdentifier( - bus_line="8075", - bus_direction=1, - ) - - bus_route: BusRoute = BusRoute( - route_id=1234, - route=route_identifier, - ) - gtfs_repo = create_autospec(GTFSRepositoryPort, instance=True) service: RouteService = RouteService(bus_provider=bus_provider, gtfs_repository=gtfs_repo) - # Act / Assert with pytest.raises(RuntimeError, match="boom"): - await service.get_bus_positions(bus_route) + await service.get_bus_positions([1234]) - raw_provider.authenticate.assert_awaited_once() - raw_provider.get_bus_positions.assert_awaited_once_with(bus_route) + raw_provider.get_bus_positions.assert_awaited_once_with(1234) @pytest.mark.asyncio -async def test_get_route_details_propagates_exception_from_authenticate() -> None: - # Arrange +async def test_search_routes_propagates_exception_from_provider() -> None: + """Test that exceptions from search_routes are propagated.""" raw_provider: Mock = Mock(spec=BusProviderPort) - raw_provider.authenticate = AsyncMock( - side_effect=RuntimeError("auth failed"), - ) - raw_provider.get_route_details = AsyncMock() + raw_provider.search_routes = AsyncMock(side_effect=RuntimeError("search failed")) bus_provider: BusProviderPort = cast(BusProviderPort, raw_provider) - route_identifier: RouteIdentifier = RouteIdentifier( - bus_line="8075", - bus_direction=1, - ) - gtfs_repo = create_autospec(GTFSRepositoryPort, instance=True) service: RouteService = RouteService(bus_provider=bus_provider, gtfs_repository=gtfs_repo) - # Act / Assert - with pytest.raises(RuntimeError, match="auth failed"): - await service.get_route_details(route_identifier) + with pytest.raises(RuntimeError, match="search failed"): + await service.search_routes("8075") - raw_provider.authenticate.assert_awaited_once() - raw_provider.get_route_details.assert_not_awaited() + raw_provider.search_routes.assert_awaited_once_with("8075") def test_get_route_shape_found() -> None: @@ -181,7 +137,6 @@ def test_get_route_shape_found() -> None: # Act result = service.get_route_shapes([route]) - # Assert assert result is not None assert len(result) == 1 assert result[0].route.bus_line == "1012-10" @@ -236,7 +191,6 @@ def test_get_route_shape_with_many_points() -> None: # Act result = service.get_route_shapes([route]) - # Assert assert result is not None assert len(result) == 1 assert len(result[0].points) == 100 @@ -271,47 +225,11 @@ def test_get_route_shape_with_special_characters() -> None: # Act result = service.get_route_shapes([route]) - # Assert assert result is not None assert result[0].route.bus_line == "route-with-special_chars@123" gtfs_repo.get_route_shape.assert_called_once_with(route) -def test_get_route_shape_independent_of_bus_provider() -> None: - # Arrange - bus_provider = create_autospec(BusProviderPort, instance=True) - gtfs_repo = create_autospec(GTFSRepositoryPort, instance=True) - - route = RouteIdentifier(bus_line="test-route", bus_direction=1) - - mock_shape = RouteShape( - route=route, - shape_id="test-shape", - points=[ - RouteShapePoint( - coordinate=Coordinate(latitude=-23.5505, longitude=-46.6333), - sequence=1, - distance_traveled=0.0, - ) - ], - ) - - gtfs_repo.get_route_shape.return_value = mock_shape - - service = RouteService(bus_provider, gtfs_repo) - - # Act - result = service.get_route_shapes([route]) - - # Assert - assert result is not None - assert len(result) == 1 - # Verify bus_provider was not called at all - bus_provider.authenticate.assert_not_called() - bus_provider.get_bus_positions.assert_not_called() - bus_provider.get_route_details.assert_not_called() - - def test_get_route_shapes_multiple_routes() -> None: # Arrange bus_provider = create_autospec(BusProviderPort, instance=True) diff --git a/tests/core/test_trip_service_basic.py b/tests/core/test_trip_service_basic.py index 247ae8b..bd7e44c 100644 --- a/tests/core/test_trip_service_basic.py +++ b/tests/core/test_trip_service_basic.py @@ -27,7 +27,7 @@ async def test_create_trip_no_user() -> None: email="missing@example.com", route=RouteIdentifier(bus_line="8000", bus_direction=1), distance=1000, - trip_date=datetime.now(), + trip_datetime=datetime.now(), ) user_repo.get_user_by_email.assert_awaited_once_with("missing@example.com") @@ -54,7 +54,7 @@ async def test_create_trip_single_user() -> None: email="user@example.com", route=RouteIdentifier(bus_line="9000", bus_direction=2), distance=distance, - trip_date=datetime(2025, 11, 15, 12, 0, 0), + trip_datetime=datetime(2025, 11, 15, 12, 0, 0), ) assert isinstance(trip, Trip) @@ -83,7 +83,7 @@ async def test_create_trip_zero_distance() -> None: email="zero@example.com", route=RouteIdentifier(bus_line="0000", bus_direction=1), distance=0, - trip_date=datetime(2025, 11, 15, 12, 0, 0), + trip_datetime=datetime(2025, 11, 15, 12, 0, 0), ) assert isinstance(trip, Trip) @@ -109,7 +109,7 @@ async def test_create_trip_negative_distance() -> None: email="neg@example.com", route=RouteIdentifier(bus_line="-100", bus_direction=1), distance=-150, - trip_date=datetime(2025, 11, 15, 12, 0, 0), + trip_datetime=datetime(2025, 11, 15, 12, 0, 0), ) trip_repo.save_trip.assert_not_awaited() @@ -134,7 +134,7 @@ async def test_create_trip_very_large_distance() -> None: email="big@example.com", route=RouteIdentifier(bus_line="BIG", bus_direction=2), distance=big_distance, - trip_date=datetime(2025, 11, 15, 12, 0, 0), + trip_datetime=datetime(2025, 11, 15, 12, 0, 0), ) expected_score = (big_distance // 1000) * 77 @@ -167,7 +167,7 @@ async def capture_trip(t: Trip) -> Trip: email="test@example.com", route=RouteIdentifier(bus_line="8000", bus_direction=1), distance=5000, - trip_date=datetime(2025, 11, 15, 12, 0, 0), + trip_datetime=datetime(2025, 11, 15, 12, 0, 0), ) assert saved_trip is not None diff --git a/tests/core/test_trip_service_example.py b/tests/core/test_trip_service_example.py index 5960081..d0b4bd8 100644 --- a/tests/core/test_trip_service_example.py +++ b/tests/core/test_trip_service_example.py @@ -37,7 +37,7 @@ async def test_create_trip_calculates_score_correctly() -> None: email="test@example.com", route=RouteIdentifier(bus_line="8000", bus_direction=1), distance=distance, - trip_date=datetime(2025, 10, 16, 10, 0, 0), + trip_datetime=datetime(2025, 10, 16, 10, 0, 0), ) expected_score = (distance // 1000) * 77 @@ -66,7 +66,7 @@ async def test_create_trip_fails_for_nonexistent_user(mocker: "MockerFixture") - email="ghost@example.com", route=RouteIdentifier(bus_line="8000", bus_direction=1), distance=1000, - trip_date=datetime.now(), + trip_datetime=datetime.now(), ) user_repo.get_user_by_email.assert_awaited_once_with("ghost@example.com") @@ -95,14 +95,14 @@ async def test_multiple_trips(mocker: "MockerFixture") -> None: email="bob@example.com", route=RouteIdentifier(bus_line="8000", bus_direction=1), distance=500, - trip_date=datetime.now(), + trip_datetime=datetime.now(), ) trip2 = await service.create_trip( email="bob@example.com", route=RouteIdentifier(bus_line="8000", bus_direction=2), distance=1500, - trip_date=datetime.now(), + trip_datetime=datetime.now(), ) assert trip1.score == 0 @@ -137,7 +137,7 @@ async def test_handles_repository_save_error(mocker: "MockerFixture") -> None: email="charlie@example.com", route=RouteIdentifier(bus_line="8000", bus_direction=1), distance=1000, - trip_date=datetime.now(), + trip_datetime=datetime.now(), ) trip_repo.save_trip.assert_awaited_once() diff --git a/tests/integration/test_ranking.py b/tests/integration/test_ranking.py index 7d420dc..5fde5d2 100644 --- a/tests/integration/test_ranking.py +++ b/tests/integration/test_ranking.py @@ -53,7 +53,7 @@ async def test_get_user_rank_position_with_multiple_users( trip_data = CreateTripRequest( route=RouteIdentifierSchema(bus_line="8000", bus_direction=1), distance=0, - data=datetime.now(UTC), + trip_datetime=datetime.now(UTC), ) await client.post( "/trips/", json=trip_data.model_dump(mode="json"), headers=auth["headers"] diff --git a/tests/integration/test_route.py b/tests/integration/test_route.py index 6b6a267..1b51bf3 100644 --- a/tests/integration/test_route.py +++ b/tests/integration/test_route.py @@ -8,17 +8,15 @@ from src.core.models.coordinate import Coordinate from src.web.schemas import ( BusPositionsRequest, - BusRouteSchema, - BusRoutesDetailsRequest, - RouteIdentifierSchema, + BusRouteRequestSchema, ) from .conftest import create_user_and_login -class TestRouteDetails: +class TestRouteSearch: @pytest.mark.asyncio - async def test_get_route_details_returns_successfully( + async def test_search_routes_returns_successfully( self, client: AsyncClient, ) -> None: @@ -32,31 +30,23 @@ async def test_get_route_details_returns_successfully( mock_bus_routes = [ BusRoute( route_id=12345, - route=RouteIdentifier(bus_line="8000", bus_direction=1), + route=RouteIdentifier( + bus_line="8000", + bus_direction=1, + ), + is_circular=False, + terminal_name="Terminal A", ) ] - with ( - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", - new_callable=AsyncMock, - return_value=True, - ), - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.get_route_details", - new_callable=AsyncMock, - return_value=mock_bus_routes, - ), + with patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.search_routes", + new_callable=AsyncMock, + return_value=mock_bus_routes, ): - request_data = BusRoutesDetailsRequest( - routes=[ - RouteIdentifierSchema(bus_line="8000", bus_direction=1), - ] - ) - - response = await client.post( - "/routes/details", - json=request_data.model_dump(), + response = await client.get( + "/routes/search", + params={"query": "8000"}, headers=auth["headers"], ) @@ -74,7 +64,7 @@ async def test_get_route_details_returns_successfully( assert first_route["route"]["bus_direction"] == 1 @pytest.mark.asyncio - async def test_get_route_details_with_multiple_lines( + async def test_search_routes_returns_multiple_results( self, client: AsyncClient, ) -> None: @@ -85,41 +75,35 @@ async def test_get_route_details_with_multiple_lines( } auth = await create_user_and_login(client, user_data) - mock_routes_8000 = [ + mock_bus_routes = [ BusRoute( route_id=12345, - route=RouteIdentifier(bus_line="8000", bus_direction=1), + route=RouteIdentifier( + bus_line="8000", + bus_direction=1, + ), + is_circular=False, + terminal_name="Terminal A", ), - ] - mock_routes_9000 = [ BusRoute( - route_id=67890, - route=RouteIdentifier(bus_line="9000", bus_direction=1), + route_id=12346, + route=RouteIdentifier( + bus_line="8000", + bus_direction=2, + ), + is_circular=False, + terminal_name="Terminal B", ), ] - with ( - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", - new_callable=AsyncMock, - return_value=True, - ), - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.get_route_details", - new_callable=AsyncMock, - side_effect=[mock_routes_8000, mock_routes_9000], - ), + with patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.search_routes", + new_callable=AsyncMock, + return_value=mock_bus_routes, ): - request_data = BusRoutesDetailsRequest( - routes=[ - RouteIdentifierSchema(bus_line="8000", bus_direction=1), - RouteIdentifierSchema(bus_line="9000", bus_direction=1), - ] - ) - - response = await client.post( - "/routes/details", - json=request_data.model_dump(), + response = await client.get( + "/routes/search", + params={"query": "8000"}, headers=auth["headers"], ) @@ -129,10 +113,10 @@ async def test_get_route_details_with_multiple_lines( assert len(data["routes"]) == 2 route_ids = [r["route_id"] for r in data["routes"]] assert 12345 in route_ids - assert 67890 in route_ids + assert 12346 in route_ids @pytest.mark.asyncio - async def test_get_route_details_returns_empty_for_unknown_line( + async def test_search_routes_returns_empty_for_unknown_query( self, client: AsyncClient, ) -> None: @@ -143,27 +127,14 @@ async def test_get_route_details_returns_empty_for_unknown_line( } auth = await create_user_and_login(client, user_data) - with ( - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", - new_callable=AsyncMock, - return_value=True, - ), - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.get_route_details", - new_callable=AsyncMock, - return_value=[], - ), + with patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.search_routes", + new_callable=AsyncMock, + return_value=[], ): - request_data = BusRoutesDetailsRequest( - routes=[ - RouteIdentifierSchema(bus_line="UNKNOWN", bus_direction=1), - ] - ) - - response = await client.post( - "/routes/details", - json=request_data.model_dump(), + response = await client.get( + "/routes/search", + params={"query": "UNKNOWN"}, headers=auth["headers"], ) @@ -172,7 +143,7 @@ async def test_get_route_details_returns_empty_for_unknown_line( assert data["routes"] == [] @pytest.mark.asyncio - async def test_get_route_details_returns_500_on_api_error( + async def test_search_routes_returns_500_on_api_error( self, client: AsyncClient, ) -> None: @@ -183,35 +154,22 @@ async def test_get_route_details_returns_500_on_api_error( } auth = await create_user_and_login(client, user_data) - with ( - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", - new_callable=AsyncMock, - return_value=True, - ), - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.get_route_details", - new_callable=AsyncMock, - side_effect=RuntimeError("API unavailable"), - ), + with patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.search_routes", + new_callable=AsyncMock, + side_effect=RuntimeError("API unavailable"), ): - request_data = BusRoutesDetailsRequest( - routes=[ - RouteIdentifierSchema(bus_line="8000", bus_direction=1), - ] - ) - - response = await client.post( - "/routes/details", - json=request_data.model_dump(), + response = await client.get( + "/routes/search", + params={"query": "8000"}, headers=auth["headers"], ) assert response.status_code == 500 - assert "Failed to retrieve route details" in response.json()["detail"] + assert "Failed to search routes" in response.json()["detail"] @pytest.mark.asyncio - async def test_get_route_details_with_empty_routes_list( + async def test_search_routes_returns_422_when_query_missing( self, client: AsyncClient, ) -> None: @@ -222,34 +180,19 @@ async def test_get_route_details_with_empty_routes_list( } auth = await create_user_and_login(client, user_data) - with patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", - new_callable=AsyncMock, - return_value=True, - ): - request_data = BusRoutesDetailsRequest(routes=[]) - - response = await client.post( - "/routes/details", - json=request_data.model_dump(), - headers=auth["headers"], - ) + response = await client.get( + "/routes/search", + headers=auth["headers"], + ) - assert response.status_code == 200 - assert response.json()["routes"] == [] + assert response.status_code == 422 @pytest.mark.asyncio - async def test_get_route_details_without_auth_fails( + async def test_search_routes_without_auth_fails( self, client: AsyncClient, ) -> None: - request_data = BusRoutesDetailsRequest( - routes=[ - RouteIdentifierSchema(bus_line="8000", bus_direction=1), - ] - ) - - response = await client.post("/routes/details", json=request_data.model_dump()) + response = await client.get("/routes/search", params={"query": "8000"}) assert response.status_code == 401 @@ -269,35 +212,25 @@ async def test_get_bus_position_returns_successfully( mock_positions = [ BusPosition( - route=RouteIdentifier(bus_line="8000", bus_direction=1), + route_id=12345, position=Coordinate(latitude=-23.550520, longitude=-46.633308), time_updated=datetime.now(UTC), ), BusPosition( - route=RouteIdentifier(bus_line="8000", bus_direction=1), + route_id=12345, position=Coordinate(latitude=-23.551234, longitude=-46.634567), time_updated=datetime.now(UTC), ), ] - with ( - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", - new_callable=AsyncMock, - return_value=True, - ), - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", - new_callable=AsyncMock, - return_value=mock_positions, - ), + with patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", + new_callable=AsyncMock, + return_value=mock_positions, ): request_data = BusPositionsRequest( routes=[ - BusRouteSchema( - route_id=12345, - route=RouteIdentifierSchema(bus_line="8000", bus_direction=1), - ), + BusRouteRequestSchema(route_id=12345), ] ) @@ -314,9 +247,8 @@ async def test_get_bus_position_returns_successfully( assert len(data["buses"]) == 2 first_bus = data["buses"][0] - assert "route" in first_bus - assert first_bus["route"]["bus_line"] == "8000" - assert first_bus["route"]["bus_direction"] == 1 + assert "route_id" in first_bus + assert first_bus["route_id"] == 12345 assert "position" in first_bus assert "latitude" in first_bus["position"] assert "longitude" in first_bus["position"] @@ -334,24 +266,14 @@ async def test_get_bus_position_returns_500_when_error( } auth = await create_user_and_login(client, user_data) - with ( - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", - new_callable=AsyncMock, - return_value=True, - ), - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", - new_callable=AsyncMock, - side_effect=ValueError("Error fetching positions"), - ), + with patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", + new_callable=AsyncMock, + side_effect=ValueError("Error fetching positions"), ): request_data = BusPositionsRequest( routes=[ - BusRouteSchema( - route_id=99999, - route=RouteIdentifierSchema(bus_line="123", bus_direction=1), - ), + BusRouteRequestSchema(route_id=99999), ] ) @@ -376,24 +298,14 @@ async def test_get_bus_position_returns_empty_when_no_buses_on_line( } auth = await create_user_and_login(client, user_data) - with ( - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", - new_callable=AsyncMock, - return_value=True, - ), - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", - new_callable=AsyncMock, - return_value=[], - ), + with patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", + new_callable=AsyncMock, + return_value=[], ): request_data = BusPositionsRequest( routes=[ - BusRouteSchema( - route_id=12345, - route=RouteIdentifierSchema(bus_line="8000", bus_direction=1), - ), + BusRouteRequestSchema(route_id=12345), ] ) @@ -421,43 +333,37 @@ async def test_get_bus_position_works_with_multiple_routes( } auth = await create_user_and_login(client, user_data) - mock_positions_8000 = [ + mock_position_12345 = [ BusPosition( - route=RouteIdentifier(bus_line="8000", bus_direction=1), + route_id=12345, position=Coordinate(latitude=-23.550520, longitude=-46.633308), time_updated=datetime.now(UTC), ), ] - mock_positions_9000 = [ + mock_position_67890 = [ BusPosition( - route=RouteIdentifier(bus_line="9000", bus_direction=2), + route_id=67890, position=Coordinate(latitude=-23.560520, longitude=-46.643308), time_updated=datetime.now(UTC), ), ] - with ( - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", - new_callable=AsyncMock, - return_value=True, - ), - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", - new_callable=AsyncMock, - side_effect=[mock_positions_8000, mock_positions_9000], - ), + def mock_get_positions(route_id: int) -> list[BusPosition]: + if route_id == 12345: + return mock_position_12345 + elif route_id == 67890: + return mock_position_67890 + return [] + + with patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", + new_callable=AsyncMock, + side_effect=mock_get_positions, ): request_data = BusPositionsRequest( routes=[ - BusRouteSchema( - route_id=12345, - route=RouteIdentifierSchema(bus_line="8000", bus_direction=1), - ), - BusRouteSchema( - route_id=67890, - route=RouteIdentifierSchema(bus_line="9000", bus_direction=2), - ), + BusRouteRequestSchema(route_id=12345), + BusRouteRequestSchema(route_id=67890), ] ) @@ -471,12 +377,12 @@ async def test_get_bus_position_works_with_multiple_routes( data = response.json() assert len(data["buses"]) == 2 - bus_lines = [bus["route"]["bus_line"] for bus in data["buses"]] - assert "8000" in bus_lines - assert "9000" in bus_lines + route_ids = [bus["route_id"] for bus in data["buses"]] + assert 12345 in route_ids + assert 67890 in route_ids @pytest.mark.asyncio - async def test_get_bus_position_returns_500_when_authentication_failure( + async def test_get_bus_position_returns_500_when_api_error( self, client: AsyncClient, ) -> None: @@ -488,16 +394,13 @@ async def test_get_bus_position_returns_500_when_authentication_failure( auth = await create_user_and_login(client, user_data) with patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", + "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", new_callable=AsyncMock, - side_effect=RuntimeError("Authentication failed"), + side_effect=RuntimeError("API error"), ): request_data = BusPositionsRequest( routes=[ - BusRouteSchema( - route_id=12345, - route=RouteIdentifierSchema(bus_line="8000", bus_direction=1), - ), + BusRouteRequestSchema(route_id=12345), ] ) @@ -521,18 +424,13 @@ async def test_get_bus_position_returns_422_when_invalid_data( } auth = await create_user_and_login(client, user_data) - request_data = { - "routes": [ - { - "route_id": 12345, - "route": {"bus_line": "8000", "bus_direction": 3}, - } - ] + invalid_request_data: dict[str, list[dict[str, str]]] = { + "routes": [{"route_id": "not_an_int"}] } response = await client.post( "/routes/positions", - json=request_data, + json=invalid_request_data, headers=auth["headers"], ) @@ -550,17 +448,10 @@ async def test_get_bus_position_returns_successfully_with_empty_routes_list( } auth = await create_user_and_login(client, user_data) - with ( - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", - new_callable=AsyncMock, - return_value=True, - ), - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", - new_callable=AsyncMock, - return_value=[], - ), + with patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", + new_callable=AsyncMock, + return_value=[], ): request_data = BusPositionsRequest(routes=[]) @@ -580,10 +471,7 @@ async def test_get_bus_position_without_auth_fails( ) -> None: request_data = BusPositionsRequest( routes=[ - BusRouteSchema( - route_id=12345, - route=RouteIdentifierSchema(bus_line="8000", bus_direction=1), - ), + BusRouteRequestSchema(route_id=12345), ] ) diff --git a/tests/integration/test_trip.py b/tests/integration/test_trip.py index 3bbe4c0..8005f5a 100644 --- a/tests/integration/test_trip.py +++ b/tests/integration/test_trip.py @@ -31,7 +31,7 @@ async def test_create_trip_should_return_successfully_and_save_to_database( bus_direction=1, ), distance=5000, - data=datetime.now(UTC), + trip_datetime=datetime.now(UTC), ) response = await client.post( @@ -73,7 +73,7 @@ async def test_create_trip_updates_user_score( bus_direction=1, ), distance=1000, - data=datetime.now(UTC), + trip_datetime=datetime.now(UTC), ) second_trip_data = CreateTripRequest( route=RouteIdentifierSchema( @@ -81,7 +81,7 @@ async def test_create_trip_updates_user_score( bus_direction=1, ), distance=2000, - data=datetime.now(UTC), + trip_datetime=datetime.now(UTC), ) resp1 = await client.post( @@ -117,7 +117,7 @@ async def test_create_trip_without_authentication_fails( bus_direction=1, ), distance=1000, - data=datetime.now(UTC), + trip_datetime=datetime.now(UTC), ) response = await client.post("/trips/", json=trip_data.model_dump(mode="json")) @@ -142,7 +142,7 @@ async def test_create_trip_zero_distance( bus_direction=2, ), distance=0, - data=datetime.now(UTC), + trip_datetime=datetime.now(UTC), ) response = await client.post( @@ -173,7 +173,7 @@ async def test_create_trip_negative_distance_fails( "bus_direction": 1, }, "distance": -1000, - "data": datetime.now(UTC).isoformat(), + "trip_datetime": datetime.now(UTC).isoformat(), } response = await client.post( @@ -207,7 +207,7 @@ async def test_create_trip_invalid_route_identifier_fails( "bus_direction": 3, }, "distance": 1000, - "data": datetime.now(UTC).isoformat(), + "trip_datetime": datetime.now(UTC).isoformat(), } response = await client.post( @@ -241,7 +241,7 @@ async def test_create_trip_stores_route_identifier( bus_direction=2, ), distance=5000, - data=datetime.now(UTC), + trip_datetime=datetime.now(UTC), ) response = await client.post( diff --git a/tests/integration/test_user_history.py b/tests/integration/test_user_history.py index c49d884..fbf832d 100644 --- a/tests/integration/test_user_history.py +++ b/tests/integration/test_user_history.py @@ -32,11 +32,11 @@ async def test_get_user_history_should_work( ] scores: list[int] = [] - for i, trip_date in enumerate(trip_dates): + for i, trip_datetime in enumerate(trip_dates): trip_request = CreateTripRequest( route=RouteIdentifierSchema(bus_line=f"800{i}", bus_direction=1), distance=(i + 1) * 1000, - data=trip_date, + trip_datetime=trip_datetime, ) response = await client.post( "/trips/", @@ -110,7 +110,7 @@ async def test_get_history_includes_correct_dates( trip_request = CreateTripRequest( route=RouteIdentifierSchema(bus_line="8000", bus_direction=1), distance=1000, - data=specific_date, + trip_datetime=specific_date, ) await client.post( "/trips/", @@ -126,10 +126,10 @@ async def test_get_history_includes_correct_dates( assert response.status_code == 200 history_response = HistoryResponse.model_validate(response.json()) - trip_date = history_response.trips[0].date - assert trip_date.year == 2025 - assert trip_date.month == 6 - assert trip_date.day == 15 + trip_datetime = history_response.trips[0].date + assert trip_datetime.year == 2025 + assert trip_datetime.month == 6 + assert trip_datetime.day == 15 @pytest.mark.asyncio async def test_get_history_includes_route_identifier( @@ -149,7 +149,7 @@ async def test_get_history_includes_route_identifier( trip_request = CreateTripRequest( route=RouteIdentifierSchema(bus_line=bus_line, bus_direction=bus_direction), distance=5000, - data=datetime(2025, 6, 15, 10, 30, 0, tzinfo=UTC), + trip_datetime=datetime(2025, 6, 15, 10, 30, 0, tzinfo=UTC), ) await client.post( "/trips/", diff --git a/tests/web/test_route_controller.py b/tests/web/test_route_controller.py index 31619bc..61e5566 100644 --- a/tests/web/test_route_controller.py +++ b/tests/web/test_route_controller.py @@ -24,15 +24,10 @@ def client() -> TestClient: @pytest.fixture def mock_service() -> RouteService: - """ - Cria um mock fortemente tipado do RouteService, - mas com métodos assíncronos (AsyncMock). - """ service = AsyncMock(spec=RouteService) typed_service: RouteService = service - typed_service.get_route_details = AsyncMock() # type: ignore[method-assign] + typed_service.search_routes = AsyncMock() # type: ignore[method-assign] typed_service.get_bus_positions = AsyncMock() # type: ignore[method-assign] - typed_service.get_route_shape = Mock() # type: ignore[method-assign] typed_service.get_route_shapes = Mock() # type: ignore[method-assign] return typed_service @@ -52,325 +47,372 @@ def override_dependency( app.dependency_overrides.clear() -# ========================= -# /routes/details -# ========================= - - -@pytest.mark.asyncio -async def test_details_endpoint_success(client: TestClient, mock_service: RouteService) -> None: - """ - Testa o endpoint POST /routes/details garantindo que: - - Ele chama RouteService.get_route_details() - - Ele retorna uma lista achatada de rotas - """ - - # ----- Arrange ----- - # domínio - route_identifier = RouteIdentifier(bus_line="8075", bus_direction=1) - bus_route_1 = BusRoute(route_id=2044, route=route_identifier) - bus_route_2 = BusRoute(route_id=34812, route=route_identifier) - - # get_route_details agora retorna list[BusRoute] - mock_service.get_route_details.return_value = [bus_route_1, bus_route_2] # type: ignore[attr-defined] - - payload = { - "routes": [ - {"bus_line": "8075"}, - ] - } - - # ----- Act ----- - response = client.post("/routes/details", json=payload) +class TestSearchRoutes: + @pytest.mark.asyncio + async def test_search_endpoint_success( + self, client: TestClient, mock_service: RouteService + ) -> None: + bus_route_1 = BusRoute( + route_id=2044, + route=RouteIdentifier( + bus_line="8075", + bus_direction=1, + ), + is_circular=False, + terminal_name="Terminal A", + ) + bus_route_2 = BusRoute( + route_id=34812, + route=RouteIdentifier( + bus_line="8075", + bus_direction=2, + ), + is_circular=False, + terminal_name="Terminal B", + ) - # ----- Assert ----- - assert response.status_code == 200 - data = response.json() + mock_service.search_routes.return_value = [bus_route_1, bus_route_2] # type: ignore[attr-defined] - assert "routes" in data - assert len(data["routes"]) == 2 + response = client.get("/routes/search", params={"query": "8075"}) - routes = data["routes"] + assert response.status_code == 200 + data = response.json() - assert routes[0]["route_id"] == 2044 - assert routes[0]["route"]["bus_line"] == "8075" + assert "routes" in data + assert len(data["routes"]) == 2 - assert routes[1]["route_id"] == 34812 - assert routes[1]["route"]["bus_line"] == "8075" + routes = data["routes"] - # garante que o service foi chamado uma vez - mock_service.get_route_details.assert_awaited_once() # type: ignore[attr-defined] - called_arg = mock_service.get_route_details.await_args.args[0] # type: ignore[attr-defined] - assert isinstance(called_arg, RouteIdentifier) - assert called_arg.bus_line == "8075" - # direção padrão que estamos usando - assert called_arg.bus_direction == 1 + assert routes[0]["route_id"] == 2044 + assert routes[0]["route"]["bus_line"] == "8075" + assert routes[1]["route_id"] == 34812 + assert routes[1]["route"]["bus_line"] == "8075" -@pytest.mark.asyncio -async def test_details_endpoint_error_returns_500( - client: TestClient, mock_service: RouteService -) -> None: - """ - Testa se o controller retorna 500 caso o service levante exception - em /routes/details. - """ + mock_service.search_routes.assert_awaited_once() # type: ignore[attr-defined] + called_arg = mock_service.search_routes.await_args.args[0] # type: ignore[attr-defined] + assert called_arg == "8075" - mock_service.get_route_details.side_effect = RuntimeError("boom") # type: ignore[attr-defined] + @pytest.mark.asyncio + async def test_search_endpoint_with_destination_name( + self, client: TestClient, mock_service: RouteService + ) -> None: + """Test search with destination name query.""" + route_identifier = RouteIdentifier( + bus_line="809", + bus_direction=1, + ) + bus_route = BusRoute( + route_id=1234, + route=route_identifier, + is_circular=False, + terminal_name="Vila Nova Conceição", + ) + + mock_service.search_routes.return_value = [bus_route] # type: ignore[attr-defined] + + response = client.get("/routes/search", params={"query": "Vila Nova Conceição"}) + + assert response.status_code == 200 + data = response.json() + + assert "routes" in data + assert len(data["routes"]) == 1 + + mock_service.search_routes.assert_awaited_once() # type: ignore[attr-defined] + called_arg = mock_service.search_routes.await_args.args[0] # type: ignore[attr-defined] + assert called_arg == "Vila Nova Conceição" + + @pytest.mark.asyncio + async def test_search_endpoint_empty_results( + self, client: TestClient, mock_service: RouteService + ) -> None: + """Test search returning empty results.""" + mock_service.search_routes.return_value = [] # type: ignore[attr-defined] + + response = client.get("/routes/search", params={"query": "UNKNOWN"}) + + assert response.status_code == 200 + data = response.json() + assert data["routes"] == [] + + @pytest.mark.asyncio + async def test_search_endpoint_error_returns_500( + self, client: TestClient, mock_service: RouteService + ) -> None: + """Test that service exception returns 500 error.""" + mock_service.search_routes.side_effect = RuntimeError("boom") # type: ignore[attr-defined] + + response = client.get("/routes/search", params={"query": "8075"}) + + assert response.status_code == 500 + body = response.json() + assert "Failed to search routes" in body["detail"] - payload = {"routes": [{"bus_line": "8075"}]} - response = client.post("/routes/details", json=payload) +class TestBusPositions: + """Tests for the /routes/positions endpoint.""" - assert response.status_code == 500 - body = response.json() - assert "Failed to retrieve route details" in body["detail"] + @pytest.mark.asyncio + async def test_positions_endpoint_success( + self, client: TestClient, mock_service: RouteService + ) -> None: + """Test successful positions retrieval.""" + position = BusPosition( + route_id=2044, + position=Coordinate(latitude=-23.5, longitude=-46.6), + time_updated=datetime.now(UTC), + ) + mock_service.get_bus_positions.return_value = [position] # type: ignore[attr-defined] -# ========================= -# /routes/positions -# ========================= + payload = {"routes": [{"route_id": 2044}]} + response = client.post("/routes/positions", json=payload) -@pytest.mark.asyncio -async def test_positions_endpoint_success(client: TestClient, mock_service: RouteService) -> None: - """ - Testa o endpoint POST /routes/positions garantindo que: - - Ele chama RouteService.get_bus_positions() - - Ele retorna os dados de posição corretamente - """ + assert response.status_code == 200 + data = response.json() - # ----- Arrange ----- - route_identifier = RouteIdentifier(bus_line="8075", bus_direction=1) + assert "buses" in data + assert len(data["buses"]) == 1 - position = BusPosition( - route=route_identifier, - position=Coordinate(latitude=-23.5, longitude=-46.6), - time_updated=datetime.now(UTC), - ) + bus = data["buses"][0] - mock_service.get_bus_positions.return_value = [position] # type: ignore[attr-defined] + assert bus["route_id"] == 2044 + assert "position" in bus + assert "latitude" in bus["position"] + assert "longitude" in bus["position"] + assert "time_updated" in bus + + mock_service.get_bus_positions.assert_awaited_once() # type: ignore[attr-defined] + # Service now receives route_ids: list[int] + called_args = mock_service.get_bus_positions.await_args.args # type: ignore[attr-defined] + route_ids = called_args[0] + assert route_ids == [2044] + + @pytest.mark.asyncio + async def test_positions_endpoint_multiple_routes( + self, client: TestClient, mock_service: RouteService + ) -> None: + """Test positions for multiple routes.""" + position1 = BusPosition( + route_id=2044, + position=Coordinate(latitude=-23.5, longitude=-46.6), + time_updated=datetime.now(UTC), + ) + position2 = BusPosition( + route_id=5678, + position=Coordinate(latitude=-23.6, longitude=-46.7), + time_updated=datetime.now(UTC), + ) - payload = { - "routes": [ - { - "route_id": 2044, - "route": {"bus_line": "8075"}, - } - ] - } + mock_service.get_bus_positions.return_value = [position1, position2] # type: ignore[attr-defined] - # ----- Act ----- - response = client.post("/routes/positions", json=payload) + payload = { + "routes": [ + {"route_id": 2044}, + {"route_id": 5678}, + ] + } - # ----- Assert ----- - assert response.status_code == 200 - data = response.json() + response = client.post("/routes/positions", json=payload) - assert "buses" in data - assert len(data["buses"]) == 1 + assert response.status_code == 200 + data = response.json() - bus = data["buses"][0] + assert len(data["buses"]) == 2 + mock_service.get_bus_positions.assert_awaited_once() # type: ignore[attr-defined] - assert bus["route"]["bus_line"] == "8075" - assert "position" in bus - assert "latitude" in bus["position"] - assert "longitude" in bus["position"] - assert "time_updated" in bus + @pytest.mark.asyncio + async def test_positions_endpoint_error_returns_500( + self, client: TestClient, mock_service: RouteService + ) -> None: + """Test that service exception returns 500 error.""" + mock_service.get_bus_positions.side_effect = RuntimeError("boom") # type: ignore[attr-defined] - mock_service.get_bus_positions.assert_awaited_once() # type: ignore[attr-defined] - called_arg = mock_service.get_bus_positions.await_args.args[0] # type: ignore[attr-defined] - assert isinstance(called_arg, BusRoute) - assert called_arg.route.bus_line == "8075" - assert called_arg.route.bus_direction == 1 - assert called_arg.route_id == 2044 + payload = {"routes": [{"route_id": 2044}]} + response = client.post("/routes/positions", json=payload) -@pytest.mark.asyncio -async def test_positions_endpoint_error_returns_500( - client: TestClient, mock_service: RouteService -) -> None: - """ - Testa se o controller retorna 500 caso o service levante exception - em /routes/positions. - """ + assert response.status_code == 500 + body = response.json() + assert "Failed to retrieve bus positions" in body["detail"] - mock_service.get_bus_positions.side_effect = RuntimeError("boom") # type: ignore[attr-defined] + @pytest.mark.asyncio + async def test_positions_endpoint_empty_routes( + self, client: TestClient, mock_service: RouteService + ) -> None: + """Test positions with empty routes list.""" + mock_service.get_bus_positions.return_value = [] # type: ignore[attr-defined] - payload = { - "routes": [ - { - "route_id": 2044, - "route": {"bus_line": "8075"}, - } - ] - } + payload = {"routes": []} - response = client.post("/routes/positions", json=payload) + response = client.post("/routes/positions", json=payload) - assert response.status_code == 500 - body = response.json() - assert "Failed to retrieve bus positions" in body["detail"] + assert response.status_code == 200 + assert response.json()["buses"] == [] # ========================= # /routes/shapes # ========================= - - -@pytest.mark.asyncio -async def test_shapes_endpoint_success(client: TestClient, mock_service: RouteService) -> None: - """ - Testa o endpoint POST /routes/shapes garantindo que: - - Ele chama RouteService.get_route_shapes() - - Ele retorna uma lista de shapes - """ - - # ----- Arrange ----- - route1 = RouteIdentifier(bus_line="8075", bus_direction=1) - route2 = RouteIdentifier(bus_line="8075", bus_direction=2) - - shape1 = RouteShape( - route=route1, - shape_id="shape_8075_1", - points=[ - RouteShapePoint( - coordinate=Coordinate(latitude=-23.5505, longitude=-46.6333), - sequence=1, - distance_traveled=0.0, - ), - RouteShapePoint( - coordinate=Coordinate(latitude=-23.5510, longitude=-46.6340), - sequence=2, - distance_traveled=10.5, - ), - ], - ) - - shape2 = RouteShape( - route=route2, - shape_id="shape_8075_2", - points=[ - RouteShapePoint( - coordinate=Coordinate(latitude=-23.5515, longitude=-46.6345), - sequence=1, - distance_traveled=0.0, - ), - ], - ) - - mock_service.get_route_shapes.return_value = [shape1, shape2] # type: ignore[attr-defined] - - payload = { - "routes": [ - {"bus_line": "8075", "bus_direction": 1}, - {"bus_line": "8075", "bus_direction": 2}, - ] - } - - # ----- Act ----- - response = client.post("/routes/shapes", json=payload) - - # ----- Assert ----- - assert response.status_code == 200 - data = response.json() - - assert "shapes" in data - assert len(data["shapes"]) == 2 - - # First shape - assert data["shapes"][0]["route"]["bus_line"] == "8075" - assert data["shapes"][0]["route"]["bus_direction"] == 1 - assert data["shapes"][0]["shape_id"] == "shape_8075_1" - assert len(data["shapes"][0]["points"]) == 2 - - # Second shape - assert data["shapes"][1]["route"]["bus_line"] == "8075" - assert data["shapes"][1]["route"]["bus_direction"] == 2 - assert data["shapes"][1]["shape_id"] == "shape_8075_2" - assert len(data["shapes"][1]["points"]) == 1 - - # Verify service was called correctly - mock_service.get_route_shapes.assert_called_once() # type: ignore[attr-defined] - called_args = mock_service.get_route_shapes.call_args.args[0] # type: ignore[attr-defined] - assert len(called_args) == 2 - assert called_args[0].bus_line == "8075" - assert called_args[0].bus_direction == 1 - assert called_args[1].bus_line == "8075" - assert called_args[1].bus_direction == 2 - - -@pytest.mark.asyncio -async def test_shapes_endpoint_single_route(client: TestClient, mock_service: RouteService) -> None: - """ - Testa o endpoint POST /routes/shapes com uma única rota. - """ - - # ----- Arrange ----- - route = RouteIdentifier(bus_line="1012", bus_direction=1) - - shape = RouteShape( - route=route, - shape_id="shape_1012_1", - points=[ - RouteShapePoint( - coordinate=Coordinate(latitude=-23.5505, longitude=-46.6333), - sequence=1, - distance_traveled=0.0, - ), - ], - ) - - mock_service.get_route_shapes.return_value = [shape] # type: ignore[attr-defined] - - payload = {"routes": [{"bus_line": "1012", "bus_direction": 1}]} - - # ----- Act ----- - response = client.post("/routes/shapes", json=payload) - - # ----- Assert ----- - assert response.status_code == 200 - data = response.json() - - assert len(data["shapes"]) == 1 - assert data["shapes"][0]["route"]["bus_line"] == "1012" - assert data["shapes"][0]["route"]["bus_direction"] == 1 - - -@pytest.mark.asyncio -async def test_shapes_endpoint_empty_result(client: TestClient, mock_service: RouteService) -> None: - """ - Testa o endpoint POST /routes/shapes quando nenhuma rota é encontrada. - """ - - mock_service.get_route_shapes.return_value = [] # type: ignore[attr-defined] - - payload = {"routes": [{"bus_line": "nonexistent", "bus_direction": 1}]} - - # ----- Act ----- - response = client.post("/routes/shapes", json=payload) - - # ----- Assert ----- - assert response.status_code == 200 - data = response.json() - assert data["shapes"] == [] - - -@pytest.mark.asyncio -async def test_shapes_endpoint_error_returns_500( - client: TestClient, mock_service: RouteService -) -> None: - """ - Testa se o controller retorna 500 caso o service levante exception - em /routes/shapes. - """ - - mock_service.get_route_shapes.side_effect = RuntimeError("boom") # type: ignore[attr-defined] - - payload = {"routes": [{"bus_line": "8075", "bus_direction": 1}]} - - response = client.post("/routes/shapes", json=payload) - - assert response.status_code == 500 - body = response.json() - assert "Failed to retrieve route shapes" in body["detail"] +class TestRouteShapes: + @pytest.mark.asyncio + async def test_shapes_endpoint_success( + self, client: TestClient, mock_service: RouteService + ) -> None: + """ + Testa o endpoint POST /routes/shapes garantindo que: + - Ele chama RouteService.get_route_shapes() + - Ele retorna uma lista de shapes + """ + + # ----- Arrange ----- + route1 = RouteIdentifier(bus_line="8075", bus_direction=1) + route2 = RouteIdentifier(bus_line="8075", bus_direction=2) + + shape1 = RouteShape( + route=route1, + shape_id="shape_8075_1", + points=[ + RouteShapePoint( + coordinate=Coordinate(latitude=-23.5505, longitude=-46.6333), + sequence=1, + distance_traveled=0.0, + ), + RouteShapePoint( + coordinate=Coordinate(latitude=-23.5510, longitude=-46.6340), + sequence=2, + distance_traveled=10.5, + ), + ], + ) + + shape2 = RouteShape( + route=route2, + shape_id="shape_8075_2", + points=[ + RouteShapePoint( + coordinate=Coordinate(latitude=-23.5515, longitude=-46.6345), + sequence=1, + distance_traveled=0.0, + ), + ], + ) + + mock_service.get_route_shapes.return_value = [shape1, shape2] # type: ignore[attr-defined] + + payload = { + "routes": [ + {"bus_line": "8075", "bus_direction": 1}, + {"bus_line": "8075", "bus_direction": 2}, + ] + } + + # ----- Act ----- + response = client.post("/routes/shapes", json=payload) + + # ----- Assert ----- + assert response.status_code == 200 + data = response.json() + + assert "shapes" in data + assert len(data["shapes"]) == 2 + + # First shape + assert data["shapes"][0]["route"]["bus_line"] == "8075" + assert data["shapes"][0]["route"]["bus_direction"] == 1 + assert data["shapes"][0]["shape_id"] == "shape_8075_1" + assert len(data["shapes"][0]["points"]) == 2 + + # Second shape + assert data["shapes"][1]["route"]["bus_line"] == "8075" + assert data["shapes"][1]["route"]["bus_direction"] == 2 + assert data["shapes"][1]["shape_id"] == "shape_8075_2" + assert len(data["shapes"][1]["points"]) == 1 + + # Verify service was called correctly + mock_service.get_route_shapes.assert_called_once() # type: ignore[attr-defined] + called_args = mock_service.get_route_shapes.call_args.args[0] # type: ignore[attr-defined] + assert len(called_args) == 2 + assert called_args[0].bus_line == "8075" + assert called_args[0].bus_direction == 1 + assert called_args[1].bus_line == "8075" + assert called_args[1].bus_direction == 2 + + @pytest.mark.asyncio + async def test_shapes_endpoint_single_route( + self, client: TestClient, mock_service: RouteService + ) -> None: + """ + Testa o endpoint POST /routes/shapes com uma única rota. + """ + + # ----- Arrange ----- + route = RouteIdentifier(bus_line="1012", bus_direction=1) + + shape = RouteShape( + route=route, + shape_id="shape_1012_1", + points=[ + RouteShapePoint( + coordinate=Coordinate(latitude=-23.5505, longitude=-46.6333), + sequence=1, + distance_traveled=0.0, + ), + ], + ) + + mock_service.get_route_shapes.return_value = [shape] # type: ignore[attr-defined] + + payload = {"routes": [{"bus_line": "1012", "bus_direction": 1}]} + + # ----- Act ----- + response = client.post("/routes/shapes", json=payload) + + # ----- Assert ----- + assert response.status_code == 200 + data = response.json() + + assert len(data["shapes"]) == 1 + assert data["shapes"][0]["route"]["bus_line"] == "1012" + assert data["shapes"][0]["route"]["bus_direction"] == 1 + + @pytest.mark.asyncio + async def test_shapes_endpoint_empty_result( + self, client: TestClient, mock_service: RouteService + ) -> None: + """ + Testa o endpoint POST /routes/shapes quando nenhuma rota é encontrada. + """ + + mock_service.get_route_shapes.return_value = [] # type: ignore[attr-defined] + + payload = {"routes": [{"bus_line": "nonexistent", "bus_direction": 1}]} + + # ----- Act ----- + response = client.post("/routes/shapes", json=payload) + + # ----- Assert ----- + assert response.status_code == 200 + data = response.json() + assert data["shapes"] == [] + + @pytest.mark.asyncio + async def test_shapes_endpoint_error_returns_500( + self, client: TestClient, mock_service: RouteService + ) -> None: + """ + Testa se o controller retorna 500 caso o service levante exception + em /routes/shapes. + """ + + mock_service.get_route_shapes.side_effect = RuntimeError("boom") # type: ignore[attr-defined] + + payload = {"routes": [{"bus_line": "8075", "bus_direction": 1}]} + + response = client.post("/routes/shapes", json=payload) + + assert response.status_code == 500 + body = response.json() + assert "Failed to retrieve route shapes" in body["detail"]