From 12d0c8fbd012328f633a9fb7292a6ff8931e703e Mon Sep 17 00:00:00 2001 From: Kim Kakeya Date: Mon, 1 Dec 2025 17:26:35 -0300 Subject: [PATCH 01/12] style: fix tests typing --- .python-version | 1 + tests/__init__.py | 1 - .../test_user_history_repository_adapter.py | 18 +++++----- tests/core/test_history_service.py | 10 +++--- tests/core/test_route_service.py | 14 ++++---- tests/core/test_trip_service_basic.py | 32 ++++++++--------- tests/core/test_trip_service_example.py | 36 +++++++++---------- tests/integration/__init__.py | 1 + tests/web/test_route_controller.py | 34 +++++++++++------- 9 files changed, 79 insertions(+), 68 deletions(-) create mode 100644 .python-version create mode 100644 tests/integration/__init__.py diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..e4fba21 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.12 diff --git a/tests/__init__.py b/tests/__init__.py index fae6326..e69de29 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +0,0 @@ -"""Test package initialization.""" diff --git a/tests/adapters/test_user_history_repository_adapter.py b/tests/adapters/test_user_history_repository_adapter.py index 1d65335..2011a95 100644 --- a/tests/adapters/test_user_history_repository_adapter.py +++ b/tests/adapters/test_user_history_repository_adapter.py @@ -12,9 +12,11 @@ from collections.abc import AsyncGenerator from datetime import datetime +from typing import Any from unittest.mock import AsyncMock import pytest +from sqlalchemy.ext.asyncio import AsyncSession from src.adapters.repositories.history_repository_adapter import ( UserHistoryRepositoryAdapter, @@ -39,15 +41,15 @@ def __init__(self, email: str, trips: list[_DummyTrip] | None): class _DummyResult: - def __init__(self, value): + def __init__(self, value: Any) -> None: self._value = value - def scalar_one_or_none(self): + def scalar_one_or_none(self) -> Any: return self._value @pytest.fixture(scope="function") -async def db_session_transactional() -> AsyncGenerator: +async def db_session_transactional() -> AsyncGenerator[AsyncSession, None]: """Provide an AsyncSession inside a transaction that will be rolled back. This fixture opens a connection from the project's engine, begins an @@ -85,9 +87,9 @@ async def test_get_user_history_returns_history_using_autospec() -> None: user_db = _DummyUser(email="alice@example.com", trips=[trip]) # Mock execute to be async and return an object with scalar_one_or_none() - session.execute = AsyncMock(return_value=_DummyResult(user_db)) # type: ignore[attr-defined] + session.execute = AsyncMock(return_value=_DummyResult(user_db)) - adapter = UserHistoryRepositoryAdapter(session) # type: ignore[arg-type] + adapter = UserHistoryRepositoryAdapter(session) # Act history = await adapter.get_user_history("alice@example.com") @@ -106,7 +108,7 @@ async def test_get_user_history_returns_none_when_missing_or_no_trips() -> None: # Arrange: case where execute returns None session_none = AsyncMock() session_none.execute = AsyncMock(return_value=_DummyResult(None)) - adapter_none = UserHistoryRepositoryAdapter(session_none) # type: ignore[arg-type] + adapter_none = UserHistoryRepositoryAdapter(session_none) # Act & Assert history_none = await adapter_none.get_user_history("noone@example.com") @@ -115,8 +117,8 @@ async def test_get_user_history_returns_none_when_missing_or_no_trips() -> None: # Arrange: user exists but has empty trips session_empty = AsyncMock() user_empty = _DummyUser(email="empty@example.com", trips=[]) - session_empty.execute = AsyncMock(return_value=_DummyResult(user_empty)) # type: ignore[attr-defined] - adapter_empty = UserHistoryRepositoryAdapter(session_empty) # type: ignore[arg-type] + session_empty.execute = AsyncMock(return_value=_DummyResult(user_empty)) + adapter_empty = UserHistoryRepositoryAdapter(session_empty) # Act & Assert history_empty = await adapter_empty.get_user_history("empty@example.com") diff --git a/tests/core/test_history_service.py b/tests/core/test_history_service.py index 83e41e7..3cf2b16 100644 --- a/tests/core/test_history_service.py +++ b/tests/core/test_history_service.py @@ -23,7 +23,7 @@ async def test_get_user_history_summary_no_data() -> None: history_repo = create_autospec(UserHistoryRepository, instance=True) history_repo.get_user_history = AsyncMock(return_value=None) - service = HistoryService(history_repo) # type: ignore[arg-type] + service = HistoryService(history_repo) # Act summary_dates, summary_scores = await service.get_user_history_summary("noone@example.com") @@ -55,7 +55,7 @@ async def test_get_user_history_summary_single_entry() -> None: history_repo.get_user_history = AsyncMock(return_value=user_history) - service = HistoryService(history_repo) # type: ignore[arg-type] + service = HistoryService(history_repo) # Act summary_dates, summary_scores = await service.get_user_history_summary("test@example.com") @@ -87,7 +87,7 @@ async def test_get_user_history_summary_timezone_aware() -> None: history_repo.get_user_history = AsyncMock(return_value=user_history) - service = HistoryService(history_repo) # type: ignore[arg-type] + service = HistoryService(history_repo) # Act dates, scores = await service.get_user_history_summary("tz@example.com") @@ -105,7 +105,7 @@ async def test_get_user_history_no_data() -> None: history_repo = create_autospec(UserHistoryRepository, instance=True) history_repo.get_user_history = AsyncMock(return_value=None) - service = HistoryService(history_repo) # type: ignore[arg-type] + service = HistoryService(history_repo) # Act result = await service.get_user_history("noone@example.com") @@ -136,7 +136,7 @@ async def test_get_user_history_single_entry() -> None: history_repo.get_user_history = AsyncMock(return_value=user_history) - service = HistoryService(history_repo) # type: ignore[arg-type] + service = HistoryService(history_repo) # Act result = await service.get_user_history("test@example.com") diff --git a/tests/core/test_route_service.py b/tests/core/test_route_service.py index 2e254af..af5f519 100644 --- a/tests/core/test_route_service.py +++ b/tests/core/test_route_service.py @@ -49,7 +49,7 @@ async def test_get_bus_positions_calls_auth_and_provider() -> None: ] # configurando retorno tipado do mock - raw_provider.get_bus_positions.return_value = expected_positions # type: ignore[assignment] + raw_provider.get_bus_positions.return_value = expected_positions gtfs_repo = create_autospec(GTFSRepositoryPort, instance=True) service: RouteService = RouteService(bus_provider=bus_provider, gtfs_repository=gtfs_repo) @@ -85,7 +85,7 @@ async def test_get_route_details_calls_auth_and_provider() -> None: expected_routes: list[BusRoute] = [expected_bus_route] # agora o provider também retorna lista - raw_provider.get_route_details.return_value = expected_routes # type: ignore[assignment] + raw_provider.get_route_details.return_value = expected_routes gtfs_repo = create_autospec(GTFSRepositoryPort, instance=True) service: RouteService = RouteService(bus_provider=bus_provider, gtfs_repository=gtfs_repo) @@ -182,7 +182,7 @@ def test_get_route_shape_found() -> None: gtfs_repo.get_route_shape.return_value = mock_shape - service = RouteService(bus_provider, gtfs_repo) # type: ignore[arg-type] + service = RouteService(bus_provider, gtfs_repo) # Act result = service.get_route_shape("1012-10") @@ -203,7 +203,7 @@ def test_get_route_shape_not_found() -> None: gtfs_repo.get_route_shape.return_value = None - service = RouteService(bus_provider, gtfs_repo) # type: ignore[arg-type] + service = RouteService(bus_provider, gtfs_repo) # Act result = service.get_route_shape("nonexistent-route") @@ -233,7 +233,7 @@ def test_get_route_shape_with_many_points() -> None: gtfs_repo.get_route_shape.return_value = mock_shape - service = RouteService(bus_provider, gtfs_repo) # type: ignore[arg-type] + service = RouteService(bus_provider, gtfs_repo) # Act result = service.get_route_shape("long-route") @@ -266,7 +266,7 @@ def test_get_route_shape_with_special_characters() -> None: gtfs_repo.get_route_shape.return_value = mock_shape - service = RouteService(bus_provider, gtfs_repo) # type: ignore[arg-type] + service = RouteService(bus_provider, gtfs_repo) # Act result = service.get_route_shape("route-with-special_chars@123") @@ -297,7 +297,7 @@ def test_get_route_shape_independent_of_bus_provider() -> None: gtfs_repo.get_route_shape.return_value = mock_shape - service = RouteService(bus_provider, gtfs_repo) # type: ignore[arg-type] + service = RouteService(bus_provider, gtfs_repo) # Act result = service.get_route_shape("test-route") diff --git a/tests/core/test_trip_service_basic.py b/tests/core/test_trip_service_basic.py index dfbd015..10cd874 100644 --- a/tests/core/test_trip_service_basic.py +++ b/tests/core/test_trip_service_basic.py @@ -28,7 +28,7 @@ async def test_create_trip_no_user() -> None: trip_repo.save_trip = AsyncMock() user_repo.add_user_score = AsyncMock() - service = TripService(trip_repo, user_repo) # type: ignore[arg-type] + service = TripService(trip_repo, user_repo) # Act / Assert with pytest.raises(ValueError, match="not found"): @@ -41,8 +41,8 @@ async def test_create_trip_no_user() -> None: ) user_repo.get_user_by_email.assert_awaited_once_with("missing@example.com") - trip_repo.save_trip.assert_not_awaited() # type: ignore[attr-defined] - user_repo.add_user_score.assert_not_awaited() # type: ignore[attr-defined] + trip_repo.save_trip.assert_not_awaited() + user_repo.add_user_score.assert_not_awaited() @pytest.mark.asyncio @@ -55,9 +55,9 @@ async def test_create_trip_single_user() -> None: test_user = User(name="Test", email="user@example.com", score=0, password="hash") user_repo.get_user_by_email = AsyncMock(return_value=test_user) user_repo.add_user_score = AsyncMock(return_value=test_user) - trip_repo.save_trip = AsyncMock(side_effect=lambda t: t) # type: ignore[misc] + trip_repo.save_trip = AsyncMock(side_effect=lambda t: t) - service = TripService(trip_repo, user_repo) # type: ignore[arg-type] + service = TripService(trip_repo, user_repo) distance = 1500 # metros expected_score = (distance // 1000) * 77 # 77 pontos por km inteiro @@ -77,8 +77,8 @@ async def test_create_trip_single_user() -> None: assert trip.email == "user@example.com" user_repo.get_user_by_email.assert_awaited_once_with("user@example.com") - trip_repo.save_trip.assert_awaited_once() # type: ignore[attr-defined] - user_repo.add_user_score.assert_awaited_once_with("user@example.com", expected_score) # type: ignore[attr-defined] + trip_repo.save_trip.assert_awaited_once() + user_repo.add_user_score.assert_awaited_once_with("user@example.com", expected_score) @pytest.mark.asyncio @@ -91,9 +91,9 @@ async def test_create_trip_zero_distance() -> None: test_user = User(name="Zero", email="zero@example.com", score=0, password="hash") user_repo.get_user_by_email = AsyncMock(return_value=test_user) user_repo.add_user_score = AsyncMock(return_value=test_user) - trip_repo.save_trip = AsyncMock(side_effect=lambda t: t) # type: ignore[misc] + trip_repo.save_trip = AsyncMock(side_effect=lambda t: t) - service = TripService(trip_repo, user_repo) # type: ignore[arg-type] + service = TripService(trip_repo, user_repo) # Act trip = await service.create_trip( @@ -107,7 +107,7 @@ async def test_create_trip_zero_distance() -> None: # Assert assert isinstance(trip, Trip) assert trip.score == 0 - user_repo.add_user_score.assert_awaited_once_with("zero@example.com", 0) # type: ignore[attr-defined] + user_repo.add_user_score.assert_awaited_once_with("zero@example.com", 0) @pytest.mark.asyncio @@ -125,7 +125,7 @@ async def test_create_trip_negative_distance() -> None: user_repo.add_user_score = AsyncMock() trip_repo.save_trip = AsyncMock() - service = TripService(trip_repo, user_repo) # type: ignore[arg-type] + service = TripService(trip_repo, user_repo) # Act & Assert with pytest.raises(ValueError, match="distance"): @@ -138,8 +138,8 @@ async def test_create_trip_negative_distance() -> None: ) # Ensure repository save/add were not called - trip_repo.save_trip.assert_not_awaited() # type: ignore[attr-defined] - user_repo.add_user_score.assert_not_awaited() # type: ignore[attr-defined] + trip_repo.save_trip.assert_not_awaited() + user_repo.add_user_score.assert_not_awaited() @pytest.mark.asyncio @@ -152,9 +152,9 @@ async def test_create_trip_very_large_distance() -> None: test_user = User(name="Big", email="big@example.com", score=0, password="hash") user_repo.get_user_by_email = AsyncMock(return_value=test_user) user_repo.add_user_score = AsyncMock(return_value=test_user) - trip_repo.save_trip = AsyncMock(side_effect=lambda t: t) # type: ignore[misc] + trip_repo.save_trip = AsyncMock(side_effect=lambda t: t) - service = TripService(trip_repo, user_repo) # type: ignore[arg-type] + service = TripService(trip_repo, user_repo) big_distance = 10_000_000 # 10 million meters @@ -170,4 +170,4 @@ async def test_create_trip_very_large_distance() -> None: # Assert expected_score = (big_distance // 1000) * 77 # 77 pontos por km inteiro assert trip.score == expected_score - user_repo.add_user_score.assert_awaited_once_with("big@example.com", expected_score) # type: ignore[attr-defined] + user_repo.add_user_score.assert_awaited_once_with("big@example.com", expected_score) diff --git a/tests/core/test_trip_service_example.py b/tests/core/test_trip_service_example.py index 2d33e03..871b104 100644 --- a/tests/core/test_trip_service_example.py +++ b/tests/core/test_trip_service_example.py @@ -39,9 +39,9 @@ async def test_create_trip_calculates_score_correctly() -> None: ) user_repo.get_user_by_email = AsyncMock(return_value=test_user) user_repo.add_user_score = AsyncMock(return_value=test_user) - trip_repo.save_trip = AsyncMock(side_effect=lambda t: t) # type: ignore[misc] + trip_repo.save_trip = AsyncMock(side_effect=lambda t: t) - service = TripService(trip_repo, user_repo) # type: ignore[arg-type] + service = TripService(trip_repo, user_repo) distance = 1000 @@ -80,9 +80,9 @@ async def test_create_trip_with_pytest_mock(mocker: "MockerFixture") -> None: ) user_repo.get_user_by_email = AsyncMock(return_value=test_user) user_repo.add_user_score = AsyncMock(return_value=test_user) - trip_repo.save_trip = AsyncMock(side_effect=lambda t: t) # type: ignore[misc] + trip_repo.save_trip = AsyncMock(side_effect=lambda t: t) - service = TripService(trip_repo, user_repo) # type: ignore[arg-type] + service = TripService(trip_repo, user_repo) distance = 2500 @@ -98,7 +98,7 @@ async def test_create_trip_with_pytest_mock(mocker: "MockerFixture") -> None: # Assert expected_score = (distance // 1000) * 77 assert trip.score == expected_score - user_repo.add_user_score.assert_awaited_once_with("alice@example.com", expected_score) # type: ignore[attr-defined] + user_repo.add_user_score.assert_awaited_once_with("alice@example.com", expected_score) @pytest.mark.asyncio @@ -112,7 +112,7 @@ async def test_create_trip_fails_for_nonexistent_user(mocker: "MockerFixture") - user_repo.add_user_score = AsyncMock() trip_repo.save_trip = AsyncMock() - service = TripService(trip_repo, user_repo) # type: ignore[arg-type] + service = TripService(trip_repo, user_repo) # Act & Assert with pytest.raises(ValueError, match="not found"): @@ -124,9 +124,9 @@ async def test_create_trip_fails_for_nonexistent_user(mocker: "MockerFixture") - trip_date=datetime.now(), ) - user_repo.get_user_by_email.assert_awaited_once_with("ghost@example.com") # type: ignore[attr-defined] - trip_repo.save_trip.assert_not_awaited() # type: ignore[attr-defined] - user_repo.add_user_score.assert_not_awaited() # type: ignore[attr-defined] + user_repo.get_user_by_email.assert_awaited_once_with("ghost@example.com") + trip_repo.save_trip.assert_not_awaited() + user_repo.add_user_score.assert_not_awaited() @pytest.mark.asyncio @@ -144,9 +144,9 @@ async def test_multiple_trips(mocker: "MockerFixture") -> None: ) user_repo.get_user_by_email = AsyncMock(return_value=test_user) user_repo.add_user_score = AsyncMock(return_value=test_user) - trip_repo.save_trip = AsyncMock(side_effect=lambda t: t) # type: ignore[misc] + trip_repo.save_trip = AsyncMock(side_effect=lambda t: t) - service = TripService(trip_repo, user_repo) # type: ignore[arg-type] + service = TripService(trip_repo, user_repo) # Act trip1 = await service.create_trip( @@ -168,10 +168,10 @@ async def test_multiple_trips(mocker: "MockerFixture") -> None: # Assert assert trip1.score == 0 assert trip2.score == 77 - assert trip_repo.save_trip.await_count == 2 # type: ignore[attr-defined] - assert user_repo.add_user_score.await_count == 2 # type: ignore[attr-defined] - user_repo.add_user_score.assert_any_await("bob@example.com", 0) # type: ignore[attr-defined] - user_repo.add_user_score.assert_any_await("bob@example.com", 77) # type: ignore[attr-defined] + assert trip_repo.save_trip.await_count == 2 + assert user_repo.add_user_score.await_count == 2 + user_repo.add_user_score.assert_any_await("bob@example.com", 0) + user_repo.add_user_score.assert_any_await("bob@example.com", 77) @pytest.mark.asyncio @@ -191,7 +191,7 @@ async def test_handles_repository_save_error(mocker: "MockerFixture") -> None: user_repo.add_user_score = AsyncMock() trip_repo.save_trip = AsyncMock(side_effect=RuntimeError("Database connection lost!")) - service = TripService(trip_repo, user_repo) # type: ignore[arg-type] + service = TripService(trip_repo, user_repo) # Act & Assert with pytest.raises(RuntimeError, match="Database connection lost"): @@ -203,5 +203,5 @@ async def test_handles_repository_save_error(mocker: "MockerFixture") -> None: trip_date=datetime.now(), ) - trip_repo.save_trip.assert_awaited_once() # type: ignore[attr-defined] - user_repo.add_user_score.assert_not_awaited() # type: ignore[attr-defined] + trip_repo.save_trip.assert_awaited_once() + user_repo.add_user_score.assert_not_awaited() diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..65873f3 --- /dev/null +++ b/tests/integration/__init__.py @@ -0,0 +1 @@ +"""Integration tests for BusSP API.""" diff --git a/tests/web/test_route_controller.py b/tests/web/test_route_controller.py index 0d853d3..739dc2b 100644 --- a/tests/web/test_route_controller.py +++ b/tests/web/test_route_controller.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Generator from datetime import UTC, datetime from unittest.mock import AsyncMock @@ -25,13 +26,16 @@ def mock_service() -> RouteService: mas com métodos assíncronos (AsyncMock). """ service = AsyncMock(spec=RouteService) - - typed_service: RouteService = service # type: ignore[assignment] + # Cast to RouteService to satisfy type checker + typed_service: RouteService = service + # Set up return values with proper types - these are AsyncMock instances + typed_service.get_route_details = AsyncMock() # type: ignore[method-assign] + typed_service.get_bus_positions = AsyncMock() # type: ignore[method-assign] return typed_service @pytest.fixture(autouse=True) -def override_dependency(mock_service: RouteService) -> None: +def override_dependency(mock_service: RouteService) -> Generator[None, None, None]: """ Override da dependência get_route_service para usar o mock. """ @@ -46,7 +50,9 @@ def override_dependency(mock_service: RouteService) -> None: @pytest.mark.asyncio -async def test_details_endpoint_success(client: TestClient, mock_service: RouteService) -> None: +async def test_details_endpoint_success( + client: TestClient, mock_service: RouteService +) -> None: """ Testa o endpoint POST /routes/details garantindo que: - Ele chama RouteService.get_route_details() @@ -60,7 +66,7 @@ async def test_details_endpoint_success(client: TestClient, mock_service: RouteS bus_route_2 = BusRoute(route_id=34812, route=route_identifier) # get_route_details agora retorna list[BusRoute] - mock_service.get_route_details.return_value = [bus_route_1, bus_route_2] # type: ignore[assignment] + mock_service.get_route_details.return_value = [bus_route_1, bus_route_2] # type: ignore[attr-defined] payload = { "routes": [ @@ -87,8 +93,8 @@ async def test_details_endpoint_success(client: TestClient, mock_service: RouteS assert routes[1]["route"]["bus_line"] == "8075" # garante que o service foi chamado uma vez - mock_service.get_route_details.assert_awaited_once() - called_arg = mock_service.get_route_details.await_args.args[0] + mock_service.get_route_details.assert_awaited_once() # type: ignore[attr-defined] + called_arg = mock_service.get_route_details.await_args.args[0] # type: ignore[attr-defined] assert isinstance(called_arg, RouteIdentifier) assert called_arg.bus_line == "8075" # direção padrão que estamos usando @@ -104,7 +110,7 @@ async def test_details_endpoint_error_returns_500( em /routes/details. """ - mock_service.get_route_details.side_effect = RuntimeError("boom") # type: ignore[assignment] + mock_service.get_route_details.side_effect = RuntimeError("boom") # type: ignore[attr-defined] payload = {"routes": [{"bus_line": "8075"}]} @@ -121,7 +127,9 @@ async def test_details_endpoint_error_returns_500( @pytest.mark.asyncio -async def test_positions_endpoint_success(client: TestClient, mock_service: RouteService) -> None: +async def test_positions_endpoint_success( + client: TestClient, mock_service: RouteService +) -> None: """ Testa o endpoint POST /routes/positions garantindo que: - Ele chama RouteService.get_bus_positions() @@ -137,7 +145,7 @@ async def test_positions_endpoint_success(client: TestClient, mock_service: Rout time_updated=datetime.now(UTC), ) - mock_service.get_bus_positions.return_value = [position] # type: ignore[assignment] + mock_service.get_bus_positions.return_value = [position] # type: ignore[attr-defined] payload = { "routes": [ @@ -166,8 +174,8 @@ async def test_positions_endpoint_success(client: TestClient, mock_service: Rout assert "longitude" in bus["position"] assert "time_updated" in bus - mock_service.get_bus_positions.assert_awaited_once() - called_arg = mock_service.get_bus_positions.await_args.args[0] + mock_service.get_bus_positions.assert_awaited_once() # type: ignore[attr-defined] + called_arg = mock_service.get_bus_positions.await_args.args[0] # type: ignore[attr-defined] assert isinstance(called_arg, BusRoute) assert called_arg.route.bus_line == "8075" assert called_arg.route.bus_direction == 1 @@ -183,7 +191,7 @@ async def test_positions_endpoint_error_returns_500( em /routes/positions. """ - mock_service.get_bus_positions.side_effect = RuntimeError("boom") # type: ignore[assignment] + mock_service.get_bus_positions.side_effect = RuntimeError("boom") # type: ignore[attr-defined] payload = { "routes": [ From 66d87ebd69d147021878094ba11f496dd1758e12 Mon Sep 17 00:00:00 2001 From: Kim Kakeya Date: Mon, 1 Dec 2025 17:27:24 -0300 Subject: [PATCH 02/12] feat: integration tests --- tests/integration/conftest.py | 149 ++++++++++++++ tests/integration/test_bus_position.py | 266 +++++++++++++++++++++++++ tests/integration/test_login.py | 238 ++++++++++++++++++++++ tests/integration/test_ranking.py | 151 ++++++++++++++ tests/integration/test_trip.py | 241 ++++++++++++++++++++++ tests/integration/test_user_history.py | 143 +++++++++++++ 6 files changed, 1188 insertions(+) create mode 100644 tests/integration/conftest.py create mode 100644 tests/integration/test_bus_position.py create mode 100644 tests/integration/test_login.py create mode 100644 tests/integration/test_ranking.py create mode 100644 tests/integration/test_trip.py create mode 100644 tests/integration/test_user_history.py diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 0000000..0762f05 --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,149 @@ +""" +Integration tests configuration and fixtures. + +This module provides shared fixtures for all integration tests, +including test database setup, test client, and helper functions. +""" + +from collections.abc import AsyncGenerator +from typing import Any + +import pytest +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from src.adapters.database.connection import Base, get_db +from src.adapters.database.models import UserDB +from src.adapters.security.hashing import PasslibPasswordHasher +from src.main import app + +IN_MEMORY_TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:" + +test_engine = create_async_engine( + IN_MEMORY_TEST_DATABASE_URL, + echo=False, + future=True, +) + +TestAsyncSessionLocal = async_sessionmaker( + test_engine, + class_=AsyncSession, + expire_on_commit=False, +) + + +async def override_get_db() -> AsyncGenerator[AsyncSession, None]: + """Override database dependency for tests.""" + async with TestAsyncSessionLocal() as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + finally: + await session.close() + + +@pytest.fixture +async def test_db() -> AsyncGenerator[AsyncSession, None]: + """ + Create test database tables before each test and drop them after. + + Yields: + AsyncSession: Test database session + """ + async with test_engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + app.dependency_overrides[get_db] = override_get_db + + async with TestAsyncSessionLocal() as session: + yield session + + async with test_engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + + app.dependency_overrides.clear() + + +@pytest.fixture +async def client(test_db: AsyncSession) -> AsyncGenerator[AsyncClient, None]: + """ + Create test HTTP client. + + Args: + test_db: Test database session (ensures DB is set up) + + Yields: + AsyncClient: HTTP client for testing + """ + transport = ASGITransport(app=app) # type: ignore[arg-type] + async with AsyncClient(transport=transport, base_url="http://test") as ac: + yield ac + +async def create_user_and_login( + client: AsyncClient, + user_data: dict[str, str], +) -> dict[str, Any]: + """ + Helper to create a user and login, returning token info. + + Args: + client: HTTP client + user_data: User registration data + + Returns: + dict: Contains access_token and user info + """ + # Register user + await client.post("/users/register", json=user_data) + + # Login + login_response = await client.post( + "/users/login", + data={ + "username": user_data["email"], + "password": user_data["password"], + }, + ) + token_data = login_response.json() + + return { + "access_token": token_data["access_token"], + "headers": {"Authorization": f"Bearer {token_data['access_token']}"}, + "user": user_data, + } + + +async def create_test_user_in_db( + session: AsyncSession, + email: str, + score: int = 0, + password: str = "password123", +) -> UserDB: + """ + Helper to create a user directly in the database. + + Args: + session: Database session + email: User email + score: User score + password: User password + + Returns: + UserDB: Created user object + """ + hasher = PasslibPasswordHasher() + hashed_password = hasher.hash(password) + + user = UserDB( + name=email.split("@")[0], + email=email, + password=hashed_password, + score=score, + ) + session.add(user) + await session.commit() + await session.refresh(user) + return user diff --git a/tests/integration/test_bus_position.py b/tests/integration/test_bus_position.py new file mode 100644 index 0000000..5cd65ba --- /dev/null +++ b/tests/integration/test_bus_position.py @@ -0,0 +1,266 @@ +from datetime import datetime, timezone +from unittest.mock import AsyncMock, patch + +import pytest +from httpx import AsyncClient + +from src.core.models.bus import BusPosition, RouteIdentifier +from src.core.models.coordinate import Coordinate +from src.web.schemas import BusRoutesDetailsRequest, RouteIdentifierSchema + + +class TestBusPositions: + """Test bus position queries.""" + + @pytest.mark.asyncio + async def test_get_bus_position_returns_successfully( + self, + client: AsyncClient, + ) -> None: + """Test that getting bus positions works when the API returns data.""" + + mock_positions = [ + BusPosition( + route=RouteIdentifier(bus_line="8000", bus_direction=1), + position=Coordinate(latitude=-23.550520, longitude=-46.633308), + time_updated=datetime.now(timezone.utc), + ), + BusPosition( + route=RouteIdentifier(bus_line="8000", bus_direction=1), + position=Coordinate(latitude=-23.551234, longitude=-46.634567), + time_updated=datetime.now(timezone.utc), + ), + ] + + # Mock the SpTransAdapter methods + with ( + patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", + new_callable=AsyncMock, + return_value=True, + ), + patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", + new_callable=AsyncMock, + return_value=mock_positions, + ), + ): + request_data = BusRoutesDetailsRequest( + routes=[ + RouteIdentifierSchema(bus_line="8000", bus_direction=1), + ] + ) + + response = await client.post( + "/routes/positions", json=request_data.model_dump() + ) + + assert response.status_code == 200 + data = response.json() + + assert "buses" in data + assert len(data["buses"]) == 2 + + # Verify first bus position structure + first_bus = data["buses"][0] + assert "route" in first_bus + assert first_bus["route"]["bus_line"] == "8000" + assert first_bus["route"]["bus_direction"] == 1 + assert "position" in first_bus + assert "latitude" in first_bus["position"] + assert "longitude" in first_bus["position"] + assert "time_updated" in first_bus + + @pytest.mark.asyncio + async def test_get_bus_position_returns_404_when_line_not_found( + self, + client: AsyncClient, + ) -> None: + """ + Test that getting bus positions returns error when bus line is not found. + + Note: Since the current implementation catches exceptions and returns 500, + we test that behavior. In a production system, you might want to + distinguish between "line not found" (404) and "API error" (500). + """ + # Mock the adapter to simulate line not found + with ( + patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", + new_callable=AsyncMock, + return_value=True, + ), + patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", + new_callable=AsyncMock, + side_effect=ValueError("Line INVALID123 not found"), + ), + ): + request_data = BusRoutesDetailsRequest( + routes=[ + RouteIdentifierSchema(bus_line="INVALID123", bus_direction=1), + ] + ) + + response = await client.post( + "/routes/positions", json=request_data.model_dump() + ) + + assert response.status_code == 404 + assert "Failed to retrieve bus positions" in response.json()["detail"] + + @pytest.mark.asyncio + async def test_get_bus_position_returns_empty_when_no_buses_on_line( + self, + client: AsyncClient, + ) -> None: + """ + Test that getting bus positions returns empty list when SPTrans + returns no buses for a valid line (e.g., no buses currently running). + """ + # Mock empty response (valid line but no buses currently) + with ( + patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", + new_callable=AsyncMock, + return_value=True, + ), + patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", + new_callable=AsyncMock, + return_value=[], + ), + ): + request_data = BusRoutesDetailsRequest( + routes=[ + RouteIdentifierSchema(bus_line="8000", bus_direction=1), + ] + ) + + response = await client.post( + "/routes/positions", json=request_data.model_dump() + ) + + assert response.status_code == 200 + data = response.json() + + assert "buses" in data + assert len(data["buses"]) == 0 + + @pytest.mark.asyncio + async def test_get_bus_position_works_with_multiple_routes( + self, + client: AsyncClient, + ) -> None: + """Test that querying multiple routes returns positions for all of them.""" + mock_positions = [ + BusPosition( + route=RouteIdentifier(bus_line="8000", bus_direction=1), + position=Coordinate(latitude=-23.550520, longitude=-46.633308), + time_updated=datetime.now(timezone.utc), + ), + BusPosition( + route=RouteIdentifier(bus_line="9000", bus_direction=2), + position=Coordinate(latitude=-23.560520, longitude=-46.643308), + time_updated=datetime.now(timezone.utc), + ), + ] + + with ( + patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", + new_callable=AsyncMock, + return_value=True, + ), + patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", + new_callable=AsyncMock, + return_value=mock_positions, + ), + ): + request_data = BusRoutesDetailsRequest( + routes=[ + RouteIdentifierSchema(bus_line="8000", bus_direction=1), + RouteIdentifierSchema(bus_line="9000", bus_direction=2), + ] + ) + + response = await client.post( + "/routes/positions", json=request_data.model_dump() + ) + + assert response.status_code == 200 + data = response.json() + + assert len(data["buses"]) == 2 + bus_lines = [bus["route"]["bus_line"] for bus in data["buses"]] + assert "8000" in bus_lines + assert "9000" in bus_lines + + @pytest.mark.asyncio + async def test_get_bus_position_returns_500_when_authentication_failure( + self, + client: AsyncClient, + ) -> None: + """Test behavior when SPTrans authentication fails.""" + with patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", + new_callable=AsyncMock, + side_effect=RuntimeError("Authentication failed"), + ): + request_data = BusRoutesDetailsRequest( + routes=[ + RouteIdentifierSchema(bus_line="8000", bus_direction=1), + ] + ) + + response = await client.post( + "/routes/positions", json=request_data.model_dump() + ) + + assert response.status_code == 500 + + @pytest.mark.asyncio + async def test_get_bus_position_returns_422_when_invalid_direction( + self, + client: AsyncClient, + ) -> None: + """Test that invalid bus direction fails validation.""" + request_data = BusRoutesDetailsRequest( + routes=[ + RouteIdentifierSchema(bus_line="8000", bus_direction=3), # Invalid + ] + ) + + response = await client.post( + "/routes/positions", json=request_data.model_dump() + ) + + assert response.status_code == 422 # Validation error + + @pytest.mark.asyncio + async def test_get_bus_position_returns_successfully_with_empty_routes_list( + self, + client: AsyncClient, + ) -> None: + """Test behavior with empty routes list.""" + with ( + patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", + new_callable=AsyncMock, + return_value=True, + ), + patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", + new_callable=AsyncMock, + return_value=[], + ), + ): + request_data = BusRoutesDetailsRequest(routes=[]) + + response = await client.post( + "/routes/positions", json=request_data.model_dump() + ) + + assert response.status_code == 200 + assert response.json()["buses"] == [] diff --git a/tests/integration/test_login.py b/tests/integration/test_login.py new file mode 100644 index 0000000..05f1971 --- /dev/null +++ b/tests/integration/test_login.py @@ -0,0 +1,238 @@ +import pytest +from httpx import AsyncClient + +from .conftest import create_user_and_login + + +class TestUserRegistration: + + @pytest.mark.asyncio + async def test_create_account_should_work( + self, + client: AsyncClient, + ) -> None: + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + response = await client.post("/users/register", json=user_data) + + assert response.status_code == 201 + data = response.json() + + assert data["name"] == user_data["name"] + assert data["email"] == user_data["email"] + assert data["score"] == 0 + # Password should not be returned + assert "password" not in data + + @pytest.mark.asyncio + async def test_create_account_duplicate_email_fails( + self, + client: AsyncClient, + ) -> None: + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + # Create first user + response1 = await client.post("/users/register", json=user_data) + assert response1.status_code == 201 + + # Try to create second user with same email + response2 = await client.post("/users/register", json=user_data) + assert response2.status_code == 400 + + @pytest.mark.asyncio + async def test_create_account_invalid_email_fails( + self, + client: AsyncClient, + ) -> None: + invalid_data = { + "name": "Test User", + "email": "not-an-email", + "password": "securepassword123", + } + response = await client.post("/users/register", json=invalid_data) + assert response.status_code == 422 # Validation error + + @pytest.mark.asyncio + async def test_create_account_short_password_fails( + self, + client: AsyncClient, + ) -> None: + invalid_data = { + "name": "Test User", + "email": "test@example.com", + "password": "short", + } + response = await client.post("/users/register", json=invalid_data) + assert response.status_code == 422 # Validation error + + +class TestUserLogin: + + @pytest.mark.asyncio + async def test_login_should_work( + self, + client: AsyncClient, + ) -> None: + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + + # First, create the user + await client.post("/users/register", json=user_data) + + # Then login + response = await client.post( + "/users/login", + data={ + "username": user_data["email"], + "password": user_data["password"], + }, + ) + + assert response.status_code == 200 + data = response.json() + + assert "access_token" in data + assert data["token_type"] == "bearer" + assert len(data["access_token"]) > 0 + + @pytest.mark.asyncio + async def test_login_wrong_password_fails( + self, + client: AsyncClient, + ) -> None: + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + + # Create user + await client.post("/users/register", json=user_data) + + # Try login with wrong password + response = await client.post( + "/users/login", + data={ + "username": user_data["email"], + "password": "wrongpassword", + }, + ) + + assert response.status_code == 401 + assert "Invalid email or password" in response.json()["detail"] + + @pytest.mark.asyncio + async def test_login_nonexistent_user_fails( + self, + client: AsyncClient, + ) -> None: + + response = await client.post( + "/users/login", + data={ + "username": "nonexistent@example.com", + "password": "somepassword", + }, + ) + + assert response.status_code == 401 + + +class TestGetUserInfo: + """Test getting user information.""" + + @pytest.mark.asyncio + async def test_get_user_info_should_work( + self, + client: AsyncClient, + ) -> None: + """Test that authenticated user can get their own info.""" + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + auth = await create_user_and_login(client, user_data) + + response = await client.get("/users/me", headers=auth["headers"]) + + assert response.status_code == 200 + data = response.json() + + assert data["name"] == user_data["name"] + assert data["email"] == user_data["email"] + assert data["score"] == 0 + + @pytest.mark.asyncio + async def test_get_user_info_without_token_fails( + self, + client: AsyncClient, + ) -> None: + response = await client.get("/users/me") + + assert response.status_code == 401 + + @pytest.mark.asyncio + async def test_get_user_info_invalid_token_fails( + self, + client: AsyncClient, + ) -> None: + response = await client.get( + "/users/me", + headers={"Authorization": "Bearer invalid-token"}, + ) + + assert response.status_code == 401 + + +class TestGetOtherUserInfo: + @pytest.mark.asyncio + async def test_get_other_user_info_should_work( + self, + client: AsyncClient, + ) -> None: + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + second_user_data = { + "name": "Second User", + "email": "second@example.com", + "password": "anotherpassword123", + } + # Create first user + await client.post("/users/register", json=user_data) + + # Create second user + await client.post("/users/register", json=second_user_data) + + # Get first user's info by email (public endpoint) + response = await client.get(f"/users/{user_data['email']}") + + assert response.status_code == 200 + data = response.json() + + assert data["name"] == user_data["name"] + assert data["email"] == user_data["email"] + assert data["score"] == 0 + + @pytest.mark.asyncio + async def test_get_nonexistent_user_info_fails( + self, + client: AsyncClient, + ) -> None: + """Test that getting non-existent user's info returns 404.""" + response = await client.get("/users/nonexistent@example.com") + + assert response.status_code == 404 + assert "User not found" in response.json()["detail"] diff --git a/tests/integration/test_ranking.py b/tests/integration/test_ranking.py new file mode 100644 index 0000000..b1b6a10 --- /dev/null +++ b/tests/integration/test_ranking.py @@ -0,0 +1,151 @@ +from datetime import datetime, timezone + +import pytest +from httpx import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession + +from src.web.schemas import CreateTripRequest, RouteIdentifierSchema + +from .conftest import create_user_and_login, create_test_user_in_db + + +class TestUserRankPosition: + @pytest.mark.asyncio + async def test_get_user_rank_position_should_work( + self, + client: AsyncClient, + test_db: AsyncSession, + ) -> None: + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + auth = await create_user_and_login(client, user_data) + + request_data = {"email": user_data["email"]} + response = await client.post( + "/rank/user", + json=request_data, + headers=auth["headers"], + ) + + assert response.status_code == 200 + data = response.json() + + assert "position" in data + assert data["position"] == 1 + + @pytest.mark.asyncio + async def test_get_user_rank_position_with_multiple_users( + self, + client: AsyncClient, + test_db: AsyncSession, + ) -> None: + await create_test_user_in_db(test_db, "top@example.com", score=1000) + await create_test_user_in_db(test_db, "middle@example.com", score=500) + await create_test_user_in_db(test_db, "bottom@example.com", score=100) + + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + auth = await create_user_and_login(client, user_data) + + trip_data = CreateTripRequest( + email=user_data["email"], + route=RouteIdentifierSchema(bus_line="8000", bus_direction=1), + distance=0, + data=datetime.now(timezone.utc), + ) + await client.post( + "/trips/", json=trip_data.model_dump(mode="json"), headers=auth["headers"] + ) + + request_data = {"email": user_data["email"]} + response = await client.post( + "/rank/user", + json=request_data, + headers=auth["headers"], + ) + + assert response.status_code == 200 + data = response.json() + assert data["position"] == 3 + + @pytest.mark.asyncio + async def test_get_user_rank_position_without_auth_fails( + self, + client: AsyncClient, + ) -> None: + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + request_data = {"email": user_data["email"]} + response = await client.post("/rank/user", json=request_data) + + assert response.status_code == 401 + + +class TestGlobalRanking: + + @pytest.mark.asyncio + async def test_get_global_ranking_should_work( + self, + client: AsyncClient, + test_db: AsyncSession, + ) -> None: + await create_test_user_in_db(test_db, "first@example.com", score=1000) + await create_test_user_in_db(test_db, "second@example.com", score=500) + await create_test_user_in_db(test_db, "third@example.com", score=100) + + response = await client.get("/rank/global") + + assert response.status_code == 200 + data = response.json() + + assert "users" in data + assert len(data["users"]) == 3 + + scores = [user["score"] for user in data["users"]] + assert scores == [1000, 500, 100] + + first_user = data["users"][0] + assert "name" in first_user + assert "email" in first_user + assert "score" in first_user + assert first_user["email"] == "first@example.com" + + @pytest.mark.asyncio + async def test_get_global_ranking_returns_empty_when_no_users( + self, + client: AsyncClient, + test_db: AsyncSession, + ) -> None: + response = await client.get("/rank/global") + + assert response.status_code == 200 + data = response.json() + + assert "users" in data + assert len(data["users"]) == 0 + + @pytest.mark.asyncio + async def test_get_global_ranking_with_single_user( + self, + client: AsyncClient, + test_db: AsyncSession, + ) -> None: + await create_test_user_in_db(test_db, "solo@example.com", score=500) + + response = await client.get("/rank/global") + + assert response.status_code == 200 + data = response.json() + + assert len(data["users"]) == 1 + assert data["users"][0]["email"] == "solo@example.com" + assert data["users"][0]["score"] == 500 diff --git a/tests/integration/test_trip.py b/tests/integration/test_trip.py new file mode 100644 index 0000000..2f2366b --- /dev/null +++ b/tests/integration/test_trip.py @@ -0,0 +1,241 @@ +from datetime import datetime, timezone + +import pytest +from httpx import AsyncClient +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.adapters.database.models import TripDB +from src.web.schemas import CreateTripRequest, RouteIdentifierSchema + +from .conftest import create_user_and_login + + +class TestCreateTrip: + @pytest.mark.asyncio + async def test_create_trip_should_return_successfully_and_save_to_database( + self, + client: AsyncClient, + test_db: AsyncSession, + ) -> None: + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + auth = await create_user_and_login(client, user_data) + + trip_data = CreateTripRequest( + email=user_data["email"], + route=RouteIdentifierSchema( + bus_line="8000", + bus_direction=1, + ), + distance=5000, + data=datetime.now(timezone.utc), + ) + + response = await client.post( + "/trips/", + json=trip_data.model_dump(mode="json"), + headers=auth["headers"], + ) + + assert response.status_code == 201 + data = response.json() + + assert "score" in data + + result = await test_db.execute( + select(TripDB).where(TripDB.email == user_data["email"]) + ) + trip = result.scalar_one_or_none() + + assert trip is not None + assert trip.email == user_data["email"] + assert trip.bus_line == "8000" + assert trip.bus_direction == 1 + assert trip.distance == 5000 + assert trip.score == data["score"] + + @pytest.mark.asyncio + async def test_create_trip_updates_user_score( + self, + client: AsyncClient, + ) -> None: + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + auth = await create_user_and_login(client, user_data) + + trip_data = CreateTripRequest( + email=user_data["email"], + route=RouteIdentifierSchema( + bus_line="8000", + bus_direction=1, + ), + distance=1000, + data=datetime.now(timezone.utc), + ) + second_trip_data = CreateTripRequest( + email=user_data["email"], + route=RouteIdentifierSchema( + bus_line="8000", + bus_direction=1, + ), + distance=2000, + data=datetime.now(timezone.utc), + ) + + resp1 = await client.post( + "/trips/", json=trip_data.model_dump(mode="json"), headers=auth["headers"] + ) + assert resp1.status_code == 201 + data1 = resp1.json() + score1 = data1["score"] + + resp2 = await client.post( + "/trips/", + json=second_trip_data.model_dump(mode="json"), + headers=auth["headers"], + ) + assert resp2.status_code == 201 + data2 = resp2.json() + score2 = data2["score"] + + user_response = await client.get("/users/me", headers=auth["headers"]) + assert user_response.status_code == 200 + user_data_response = user_response.json() + + assert user_data_response["score"] == score1 + score2 + + @pytest.mark.asyncio + async def test_create_trip_without_authentication_fails( + self, + client: AsyncClient, + ) -> None: + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + trip_data = CreateTripRequest( + email=user_data["email"], + route=RouteIdentifierSchema( + bus_line="8000", + bus_direction=1, + ), + distance=1000, + data=datetime.now(timezone.utc), + ) + + response = await client.post("/trips/", json=trip_data.model_dump(mode="json")) + + assert response.status_code == 401 + + @pytest.mark.asyncio + async def test_create_trip_zero_distance( + self, + client: AsyncClient, + ) -> None: + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + auth = await create_user_and_login(client, user_data) + + trip_data = CreateTripRequest( + email=user_data["email"], + route=RouteIdentifierSchema( + bus_line="9000", + bus_direction=2, + ), + distance=0, + data=datetime.now(timezone.utc), + ) + + response = await client.post( + "/trips/", + json=trip_data.model_dump(mode="json"), + headers=auth["headers"], + ) + + assert response.status_code == 201 + assert response.json()["score"] == 0 + + @pytest.mark.asyncio + async def test_create_trip_negative_distance_fails( + self, + client: AsyncClient, + test_db: AsyncSession, + ) -> None: + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + auth = await create_user_and_login(client, user_data) + + trip_data = CreateTripRequest( + email=user_data["email"], + route=RouteIdentifierSchema( + bus_line="8000", + bus_direction=1, + ), + distance=-1000, + data=datetime.now(timezone.utc), + ) + + response = await client.post( + "/trips/", + json=trip_data.model_dump(mode="json"), + headers=auth["headers"], + ) + + assert response.status_code == 422 + + result = await test_db.execute( + select(TripDB).where(TripDB.email == user_data["email"]) + ) + trip = result.scalar_one_or_none() + assert trip is None + + @pytest.mark.asyncio + async def test_create_trip_invalid_route_identifier_fails( + self, + client: AsyncClient, + test_db: AsyncSession, + ) -> None: + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + auth = await create_user_and_login(client, user_data) + + trip_data = CreateTripRequest( + email=user_data["email"], + route=RouteIdentifierSchema( + bus_line="8000", + bus_direction=3, + ), + distance=1000, + data=datetime.now(timezone.utc), + ) + + response = await client.post( + "/trips/", + json=trip_data.model_dump(mode="json"), + headers=auth["headers"], + ) + + assert response.status_code == 422 + + result = await test_db.execute( + select(TripDB).where(TripDB.email == user_data["email"]) + ) + trip = result.scalar_one_or_none() + assert trip is None diff --git a/tests/integration/test_user_history.py b/tests/integration/test_user_history.py new file mode 100644 index 0000000..a048c76 --- /dev/null +++ b/tests/integration/test_user_history.py @@ -0,0 +1,143 @@ +from datetime import datetime, timezone + +import pytest +from httpx import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession + +from src.web.schemas import ( + CreateTripRequest, + HistoryRequest, + HistoryResponse, + RouteIdentifierSchema, +) + +from .conftest import create_user_and_login + + +class TestUserHistory: + @pytest.mark.asyncio + async def test_get_user_history_should_work( + self, + client: AsyncClient, + ) -> None: + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + auth = await create_user_and_login(client, user_data) + + trip_dates: list[datetime] = [ + datetime(2025, 11, 1, 8, 0, 0, tzinfo=timezone.utc), + datetime(2025, 11, 15, 12, 0, 0, tzinfo=timezone.utc), + datetime(2025, 11, 29, 18, 0, 0, tzinfo=timezone.utc), + ] + + scores: list[int] = [] + for i, trip_date in enumerate(trip_dates): + trip_request = CreateTripRequest( + email=user_data["email"], + route=RouteIdentifierSchema(bus_line=f"800{i}", bus_direction=1), + distance=(i + 1) * 1000, + data=trip_date, + ) + response = await client.post( + "/trips/", + json=trip_request.model_dump(mode="json"), + headers=auth["headers"], + ) + scores.append(response.json()["score"]) + + history_request = HistoryRequest(email=user_data["email"]) + response = await client.post( + "/history/", + json=history_request.model_dump(), + headers=auth["headers"], + ) + + assert response.status_code == 200 + history_response = HistoryResponse.model_validate(response.json()) + + assert len(history_response.trips) == 3 + + for trip in history_response.trips: + assert isinstance(trip.date, datetime) + assert isinstance(trip.score, int) + + assert sorted(scores) == sorted([trip.score for trip in history_response.trips]) + + @pytest.mark.asyncio + async def test_get_user_history_returns_empty_when_no_trips( + self, + client: AsyncClient, + ) -> None: + user_data = { + "email": "test@example.com", + "password": "secure_password_123", + } + + auth = await create_user_and_login(client, user_data) + + history_request = HistoryRequest(email=user_data["email"]) + + response = await client.post( + "/history/", + json=history_request.model_dump(), + headers=auth["headers"], + ) + + assert response.status_code == 200 + history_response = HistoryResponse.model_validate(response.json()) + assert isinstance(history_response.trips, list) + assert len(history_response.trips) == 0 + + @pytest.mark.asyncio + async def test_get_user_history_without_authentication_fails( + self, + client: AsyncClient, + ) -> None: + history_request = HistoryRequest(email="test@example.com") + response = await client.post("/history/", json=history_request.model_dump()) + + assert response.status_code == 401 + + @pytest.mark.asyncio + async def test_get_history_includes_correct_dates( + self, + client: AsyncClient, + ) -> None: + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + + auth = await create_user_and_login(client, user_data) + + specific_date: datetime = datetime(2025, 6, 15, 10, 30, 0, tzinfo=timezone.utc) + trip_request = CreateTripRequest( + email=user_data["email"], + route=RouteIdentifierSchema(bus_line="8000", bus_direction=1), + distance=1000, + data=specific_date, + ) + await client.post( + "/trips/", + json=trip_request.model_dump(mode="json"), + headers=auth["headers"], + ) + + history_request = HistoryRequest(email=user_data["email"]) + response = await client.post( + "/history/", + json=history_request.model_dump(), + headers=auth["headers"], + ) + + assert response.status_code == 200 + history_response = HistoryResponse.model_validate(response.json()) + + trip_date = history_response.trips[0].date + assert trip_date.year == 2025 + assert trip_date.month == 6 + assert trip_date.day == 15 \ No newline at end of file From f3a271d23df56961ffd70fe8301c04d7d90fa27e Mon Sep 17 00:00:00 2001 From: Kim Kakeya Date: Mon, 1 Dec 2025 20:23:46 -0300 Subject: [PATCH 03/12] feat: use only pyproject.toml instead of mypy.ini --- mypy.ini | 36 ------------------------------------ pyproject.toml | 21 +++++++++++++++++++++ 2 files changed, 21 insertions(+), 36 deletions(-) delete mode 100644 mypy.ini diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index 6d0f4d3..0000000 --- a/mypy.ini +++ /dev/null @@ -1,36 +0,0 @@ -[mypy] -python_version = 3.13 -strict = True -warn_return_any = True -warn_unused_configs = True -disallow_untyped_defs = True -disallow_any_generics = True -check_untyped_defs = True -no_implicit_optional = True -warn_redundant_casts = True -warn_unused_ignores = True -warn_no_return = True -follow_imports = normal -plugins = pydantic.mypy - -[pydantic-mypy] -init_forbid_extra = True -init_typed = True -warn_required_dynamic_aliases = True - -# Ignore third-party packages without stubs -[mypy-sqlalchemy.*] -ignore_missing_imports = True - -[mypy-aiosqlite.*] -ignore_missing_imports = True - -[mypy-httpx.*] -ignore_missing_imports = True - -[mypy-uvicorn.*] -ignore_missing_imports = True - -# Less strict for tests -[mypy-tests.*] -disallow_untyped_defs = False diff --git a/pyproject.toml b/pyproject.toml index 5dcc114..c5e9a27 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,10 +48,31 @@ follow_imports = "normal" plugins = ["pydantic.mypy"] exclude = ["^\\.venv/", "^build/", "^dist/"] +[tool.pydantic-mypy] +init_forbid_extra = true +init_typed = true +warn_required_dynamic_aliases = true + [[tool.mypy.overrides]] module = "tests.*" disallow_untyped_defs = false +[[tool.mypy.overrides]] +module = "sqlalchemy.*" +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "aiosqlite.*" +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "httpx.*" +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "uvicorn.*" +ignore_missing_imports = true + [tool.ruff] line-length = 100 target-version = "py313" From deccd402f2f1959aa21de9216b0473cf781974f6 Mon Sep 17 00:00:00 2001 From: Kim Kakeya Date: Mon, 1 Dec 2025 20:26:44 -0300 Subject: [PATCH 04/12] feat: make history controller return empty list when on history --- src/web/controllers/history_controller.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/src/web/controllers/history_controller.py b/src/web/controllers/history_controller.py index e5cee3a..c2973df 100644 --- a/src/web/controllers/history_controller.py +++ b/src/web/controllers/history_controller.py @@ -4,11 +4,13 @@ This controller handles queries for user trip history. """ -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends from sqlalchemy.ext.asyncio import AsyncSession from ...adapters.database.connection import get_db -from ...adapters.repositories.history_repository_adapter import UserHistoryRepositoryAdapter +from ...adapters.repositories.history_repository_adapter import ( + UserHistoryRepositoryAdapter, +) from ...core.models.user import User from ...core.services.history_service import HistoryService from ..schemas import HistoryRequest, HistoryResponse, TripHistoryEntry @@ -48,22 +50,14 @@ async def get_user_history( current_user: Authenticated user (from JWT token) Returns: - User's trip history with dates and scores - - Raises: - HTTPException: If user not found or has no history + User's trip history with dates and scores (empty list if no history) """ dates, scores = await history_service.get_user_history_summary(request.email) - if not dates: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="No history found for user", - ) - - # Combine dates and scores into trip entries + # Combine dates and scores into trip entries (returns empty list if no history) trips = [ - TripHistoryEntry(date=date, score=score) for date, score in zip(dates, scores, strict=False) + TripHistoryEntry(date=date, score=score) + for date, score in zip(dates, scores, strict=False) ] return HistoryResponse(trips=trips) From 1a4d3e075d1bc562153c66f27fdadf92f28026b2 Mon Sep 17 00:00:00 2001 From: Kim Kakeya Date: Mon, 1 Dec 2025 20:27:03 -0300 Subject: [PATCH 05/12] feat: improve integration tests --- tests/integration/conftest.py | 20 +- tests/integration/test_bus_position.py | 266 ------------- tests/integration/test_ranking.py | 2 +- tests/integration/test_route.py | 512 +++++++++++++++++++++++++ tests/integration/test_trip.py | 42 +- tests/integration/test_user_history.py | 5 +- 6 files changed, 557 insertions(+), 290 deletions(-) delete mode 100644 tests/integration/test_bus_position.py create mode 100644 tests/integration/test_route.py diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 0762f05..66b84ff 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -68,7 +68,9 @@ async def test_db() -> AsyncGenerator[AsyncSession, None]: @pytest.fixture -async def client(test_db: AsyncSession) -> AsyncGenerator[AsyncClient, None]: +async def client( + test_db: AsyncSession +) -> AsyncGenerator[AsyncClient, None]: """ Create test HTTP client. @@ -82,6 +84,7 @@ async def client(test_db: AsyncSession) -> AsyncGenerator[AsyncClient, None]: async with AsyncClient(transport=transport, base_url="http://test") as ac: yield ac + async def create_user_and_login( client: AsyncClient, user_data: dict[str, str], @@ -116,6 +119,21 @@ async def create_user_and_login( } +@pytest.fixture +def set_sptrans_api_token(monkeypatch: pytest.MonkeyPatch) -> None: + """ + Set a fake SPTrans API token for testing. + + This patches the settings object directly since it's already instantiated + at module load time. Patching the environment variable alone won't work + because settings.sptrans_api_token is already resolved. + """ + monkeypatch.setattr( + "src.config.settings.sptrans_api_token", + "fake-test-token", + ) + + async def create_test_user_in_db( session: AsyncSession, email: str, diff --git a/tests/integration/test_bus_position.py b/tests/integration/test_bus_position.py deleted file mode 100644 index 5cd65ba..0000000 --- a/tests/integration/test_bus_position.py +++ /dev/null @@ -1,266 +0,0 @@ -from datetime import datetime, timezone -from unittest.mock import AsyncMock, patch - -import pytest -from httpx import AsyncClient - -from src.core.models.bus import BusPosition, RouteIdentifier -from src.core.models.coordinate import Coordinate -from src.web.schemas import BusRoutesDetailsRequest, RouteIdentifierSchema - - -class TestBusPositions: - """Test bus position queries.""" - - @pytest.mark.asyncio - async def test_get_bus_position_returns_successfully( - self, - client: AsyncClient, - ) -> None: - """Test that getting bus positions works when the API returns data.""" - - mock_positions = [ - BusPosition( - route=RouteIdentifier(bus_line="8000", bus_direction=1), - position=Coordinate(latitude=-23.550520, longitude=-46.633308), - time_updated=datetime.now(timezone.utc), - ), - BusPosition( - route=RouteIdentifier(bus_line="8000", bus_direction=1), - position=Coordinate(latitude=-23.551234, longitude=-46.634567), - time_updated=datetime.now(timezone.utc), - ), - ] - - # Mock the SpTransAdapter methods - with ( - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", - new_callable=AsyncMock, - return_value=True, - ), - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", - new_callable=AsyncMock, - return_value=mock_positions, - ), - ): - request_data = BusRoutesDetailsRequest( - routes=[ - RouteIdentifierSchema(bus_line="8000", bus_direction=1), - ] - ) - - response = await client.post( - "/routes/positions", json=request_data.model_dump() - ) - - assert response.status_code == 200 - data = response.json() - - assert "buses" in data - assert len(data["buses"]) == 2 - - # Verify first bus position structure - first_bus = data["buses"][0] - assert "route" in first_bus - assert first_bus["route"]["bus_line"] == "8000" - assert first_bus["route"]["bus_direction"] == 1 - assert "position" in first_bus - assert "latitude" in first_bus["position"] - assert "longitude" in first_bus["position"] - assert "time_updated" in first_bus - - @pytest.mark.asyncio - async def test_get_bus_position_returns_404_when_line_not_found( - self, - client: AsyncClient, - ) -> None: - """ - Test that getting bus positions returns error when bus line is not found. - - Note: Since the current implementation catches exceptions and returns 500, - we test that behavior. In a production system, you might want to - distinguish between "line not found" (404) and "API error" (500). - """ - # Mock the adapter to simulate line not found - with ( - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", - new_callable=AsyncMock, - return_value=True, - ), - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", - new_callable=AsyncMock, - side_effect=ValueError("Line INVALID123 not found"), - ), - ): - request_data = BusRoutesDetailsRequest( - routes=[ - RouteIdentifierSchema(bus_line="INVALID123", bus_direction=1), - ] - ) - - response = await client.post( - "/routes/positions", json=request_data.model_dump() - ) - - assert response.status_code == 404 - assert "Failed to retrieve bus positions" in response.json()["detail"] - - @pytest.mark.asyncio - async def test_get_bus_position_returns_empty_when_no_buses_on_line( - self, - client: AsyncClient, - ) -> None: - """ - Test that getting bus positions returns empty list when SPTrans - returns no buses for a valid line (e.g., no buses currently running). - """ - # Mock empty response (valid line but no buses currently) - with ( - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", - new_callable=AsyncMock, - return_value=True, - ), - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", - new_callable=AsyncMock, - return_value=[], - ), - ): - request_data = BusRoutesDetailsRequest( - routes=[ - RouteIdentifierSchema(bus_line="8000", bus_direction=1), - ] - ) - - response = await client.post( - "/routes/positions", json=request_data.model_dump() - ) - - assert response.status_code == 200 - data = response.json() - - assert "buses" in data - assert len(data["buses"]) == 0 - - @pytest.mark.asyncio - async def test_get_bus_position_works_with_multiple_routes( - self, - client: AsyncClient, - ) -> None: - """Test that querying multiple routes returns positions for all of them.""" - mock_positions = [ - BusPosition( - route=RouteIdentifier(bus_line="8000", bus_direction=1), - position=Coordinate(latitude=-23.550520, longitude=-46.633308), - time_updated=datetime.now(timezone.utc), - ), - BusPosition( - route=RouteIdentifier(bus_line="9000", bus_direction=2), - position=Coordinate(latitude=-23.560520, longitude=-46.643308), - time_updated=datetime.now(timezone.utc), - ), - ] - - with ( - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", - new_callable=AsyncMock, - return_value=True, - ), - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", - new_callable=AsyncMock, - return_value=mock_positions, - ), - ): - request_data = BusRoutesDetailsRequest( - routes=[ - RouteIdentifierSchema(bus_line="8000", bus_direction=1), - RouteIdentifierSchema(bus_line="9000", bus_direction=2), - ] - ) - - response = await client.post( - "/routes/positions", json=request_data.model_dump() - ) - - assert response.status_code == 200 - data = response.json() - - assert len(data["buses"]) == 2 - bus_lines = [bus["route"]["bus_line"] for bus in data["buses"]] - assert "8000" in bus_lines - assert "9000" in bus_lines - - @pytest.mark.asyncio - async def test_get_bus_position_returns_500_when_authentication_failure( - self, - client: AsyncClient, - ) -> None: - """Test behavior when SPTrans authentication fails.""" - with patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", - new_callable=AsyncMock, - side_effect=RuntimeError("Authentication failed"), - ): - request_data = BusRoutesDetailsRequest( - routes=[ - RouteIdentifierSchema(bus_line="8000", bus_direction=1), - ] - ) - - response = await client.post( - "/routes/positions", json=request_data.model_dump() - ) - - assert response.status_code == 500 - - @pytest.mark.asyncio - async def test_get_bus_position_returns_422_when_invalid_direction( - self, - client: AsyncClient, - ) -> None: - """Test that invalid bus direction fails validation.""" - request_data = BusRoutesDetailsRequest( - routes=[ - RouteIdentifierSchema(bus_line="8000", bus_direction=3), # Invalid - ] - ) - - response = await client.post( - "/routes/positions", json=request_data.model_dump() - ) - - assert response.status_code == 422 # Validation error - - @pytest.mark.asyncio - async def test_get_bus_position_returns_successfully_with_empty_routes_list( - self, - client: AsyncClient, - ) -> None: - """Test behavior with empty routes list.""" - with ( - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", - new_callable=AsyncMock, - return_value=True, - ), - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", - new_callable=AsyncMock, - return_value=[], - ), - ): - request_data = BusRoutesDetailsRequest(routes=[]) - - response = await client.post( - "/routes/positions", json=request_data.model_dump() - ) - - assert response.status_code == 200 - assert response.json()["buses"] == [] diff --git a/tests/integration/test_ranking.py b/tests/integration/test_ranking.py index b1b6a10..8964472 100644 --- a/tests/integration/test_ranking.py +++ b/tests/integration/test_ranking.py @@ -72,7 +72,7 @@ async def test_get_user_rank_position_with_multiple_users( assert response.status_code == 200 data = response.json() - assert data["position"] == 3 + assert data["position"] == 4 @pytest.mark.asyncio async def test_get_user_rank_position_without_auth_fails( diff --git a/tests/integration/test_route.py b/tests/integration/test_route.py new file mode 100644 index 0000000..bc937c3 --- /dev/null +++ b/tests/integration/test_route.py @@ -0,0 +1,512 @@ +from datetime import datetime, timezone +from unittest.mock import AsyncMock, patch + +import pytest +from httpx import AsyncClient + +from src.core.models.bus import BusPosition, BusRoute, RouteIdentifier +from src.core.models.coordinate import Coordinate +from src.web.schemas import ( + BusPositionsRequest, + BusRouteSchema, + BusRoutesDetailsRequest, + RouteIdentifierSchema, +) + + +class TestRouteDetails: + @pytest.mark.asyncio + async def test_get_route_details_returns_successfully( + self, + client: AsyncClient, + ) -> None: + mock_bus_routes = [ + BusRoute( + route_id=12345, + route=RouteIdentifier(bus_line="8000", bus_direction=1), + ) + ] + + with ( + patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", + new_callable=AsyncMock, + return_value=True, + ), + patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.get_route_details", + new_callable=AsyncMock, + return_value=mock_bus_routes, + ), + ): + request_data = BusRoutesDetailsRequest( + routes=[ + RouteIdentifierSchema(bus_line="8000", bus_direction=1), + ] + ) + + response = await client.post( + "/routes/details", json=request_data.model_dump() + ) + + assert response.status_code == 200 + data = response.json() + + assert "routes" in data + assert len(data["routes"]) == 1 + + # Verify route structure + first_route = data["routes"][0] + assert "route_id" in first_route + assert first_route["route_id"] == 12345 + assert "route" in first_route + assert first_route["route"]["bus_line"] == "8000" + assert first_route["route"]["bus_direction"] == 1 + + @pytest.mark.asyncio + async def test_get_route_details_with_multiple_lines( + self, + client: AsyncClient, + ) -> None: + mock_routes_8000 = [ + BusRoute( + route_id=12345, + route=RouteIdentifier(bus_line="8000", bus_direction=1), + ), + ] + mock_routes_9000 = [ + BusRoute( + route_id=67890, + route=RouteIdentifier(bus_line="9000", bus_direction=1), + ), + ] + + with ( + patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", + new_callable=AsyncMock, + return_value=True, + ), + patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.get_route_details", + new_callable=AsyncMock, + side_effect=[mock_routes_8000, mock_routes_9000], + ), + ): + request_data = BusRoutesDetailsRequest( + routes=[ + RouteIdentifierSchema(bus_line="8000", bus_direction=1), + RouteIdentifierSchema(bus_line="9000", bus_direction=1), + ] + ) + + response = await client.post( + "/routes/details", json=request_data.model_dump() + ) + + assert response.status_code == 200 + data = response.json() + + assert len(data["routes"]) == 2 + route_ids = [r["route_id"] for r in data["routes"]] + assert 12345 in route_ids + assert 67890 in route_ids + + @pytest.mark.asyncio + async def test_get_route_details_returns_empty_for_unknown_line( + self, + client: AsyncClient, + ) -> None: + with ( + patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", + new_callable=AsyncMock, + return_value=True, + ), + patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.get_route_details", + new_callable=AsyncMock, + return_value=[], + ), + ): + request_data = BusRoutesDetailsRequest( + routes=[ + RouteIdentifierSchema(bus_line="UNKNOWN", bus_direction=1), + ] + ) + + response = await client.post( + "/routes/details", json=request_data.model_dump() + ) + + assert response.status_code == 200 + data = response.json() + assert data["routes"] == [] + + @pytest.mark.asyncio + async def test_get_route_details_returns_500_on_api_error( + self, + client: AsyncClient, + ) -> None: + with ( + patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", + new_callable=AsyncMock, + return_value=True, + ), + patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.get_route_details", + new_callable=AsyncMock, + side_effect=RuntimeError("API unavailable"), + ), + ): + request_data = BusRoutesDetailsRequest( + routes=[ + RouteIdentifierSchema(bus_line="8000", bus_direction=1), + ] + ) + + response = await client.post( + "/routes/details", json=request_data.model_dump() + ) + + assert response.status_code == 500 + assert "Failed to retrieve route details" in response.json()["detail"] + + @pytest.mark.asyncio + async def test_get_route_details_with_empty_routes_list( + self, + client: AsyncClient, + ) -> None: + with patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", + new_callable=AsyncMock, + return_value=True, + ): + request_data = BusRoutesDetailsRequest(routes=[]) + + response = await client.post( + "/routes/details", json=request_data.model_dump() + ) + + assert response.status_code == 200 + assert response.json()["routes"] == [] + + +class TestBusPositions: + @pytest.mark.asyncio + async def test_get_bus_position_returns_successfully( + self, + client: AsyncClient, + ) -> None: + mock_positions = [ + BusPosition( + route=RouteIdentifier(bus_line="8000", bus_direction=1), + position=Coordinate(latitude=-23.550520, longitude=-46.633308), + time_updated=datetime.now(timezone.utc), + ), + BusPosition( + route=RouteIdentifier(bus_line="8000", bus_direction=1), + position=Coordinate(latitude=-23.551234, longitude=-46.634567), + time_updated=datetime.now(timezone.utc), + ), + ] + + with ( + patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", + new_callable=AsyncMock, + return_value=True, + ), + patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", + new_callable=AsyncMock, + return_value=mock_positions, + ), + ): + request_data = BusPositionsRequest( + routes=[ + BusRouteSchema( + route_id=12345, + route=RouteIdentifierSchema(bus_line="8000", bus_direction=1), + ), + ] + ) + + response = await client.post( + "/routes/positions", json=request_data.model_dump() + ) + + assert response.status_code == 200 + data = response.json() + + assert "buses" in data + assert len(data["buses"]) == 2 + + first_bus = data["buses"][0] + assert "route" in first_bus + assert first_bus["route"]["bus_line"] == "8000" + assert first_bus["route"]["bus_direction"] == 1 + assert "position" in first_bus + assert "latitude" in first_bus["position"] + assert "longitude" in first_bus["position"] + assert "time_updated" in first_bus + + @pytest.mark.asyncio + async def test_get_bus_position_returns_500_when_error( + self, + client: AsyncClient, + ) -> None: + """Test that getting bus positions returns 500 on API errors.""" + with ( + patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", + new_callable=AsyncMock, + return_value=True, + ), + patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", + new_callable=AsyncMock, + side_effect=ValueError("Error fetching positions"), + ), + ): + request_data = BusPositionsRequest( + routes=[ + BusRouteSchema( + route_id=99999, + route=RouteIdentifierSchema(bus_line="123", bus_direction=1), + ), + ] + ) + + response = await client.post( + "/routes/positions", json=request_data.model_dump() + ) + + assert response.status_code == 500 + assert "Failed to retrieve bus positions" in response.json()["detail"] + + @pytest.mark.asyncio + async def test_get_bus_position_returns_empty_when_no_buses_on_line( + self, + client: AsyncClient, + ) -> None: + """Test that getting bus positions returns empty list when no buses running.""" + with ( + patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", + new_callable=AsyncMock, + return_value=True, + ), + patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", + new_callable=AsyncMock, + return_value=[], + ), + ): + request_data = BusPositionsRequest( + routes=[ + BusRouteSchema( + route_id=12345, + route=RouteIdentifierSchema(bus_line="8000", bus_direction=1), + ), + ] + ) + + response = await client.post( + "/routes/positions", json=request_data.model_dump() + ) + + assert response.status_code == 200 + data = response.json() + + assert "buses" in data + assert len(data["buses"]) == 0 + + @pytest.mark.asyncio + async def test_get_bus_position_works_with_multiple_routes( + self, + client: AsyncClient, + ) -> None: + """Test that querying multiple routes returns positions for all of them.""" + mock_positions_8000 = [ + BusPosition( + route=RouteIdentifier(bus_line="8000", bus_direction=1), + position=Coordinate(latitude=-23.550520, longitude=-46.633308), + time_updated=datetime.now(timezone.utc), + ), + ] + mock_positions_9000 = [ + BusPosition( + route=RouteIdentifier(bus_line="9000", bus_direction=2), + position=Coordinate(latitude=-23.560520, longitude=-46.643308), + time_updated=datetime.now(timezone.utc), + ), + ] + + with ( + patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", + new_callable=AsyncMock, + return_value=True, + ), + patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", + new_callable=AsyncMock, + side_effect=[mock_positions_8000, mock_positions_9000], + ), + ): + request_data = BusPositionsRequest( + routes=[ + BusRouteSchema( + route_id=12345, + route=RouteIdentifierSchema(bus_line="8000", bus_direction=1), + ), + BusRouteSchema( + route_id=67890, + route=RouteIdentifierSchema(bus_line="9000", bus_direction=2), + ), + ] + ) + + response = await client.post( + "/routes/positions", json=request_data.model_dump() + ) + + assert response.status_code == 200 + data = response.json() + + assert len(data["buses"]) == 2 + bus_lines = [bus["route"]["bus_line"] for bus in data["buses"]] + assert "8000" in bus_lines + assert "9000" in bus_lines + + @pytest.mark.asyncio + async def test_get_bus_position_returns_500_when_authentication_failure( + self, + client: AsyncClient, + ) -> None: + """Test behavior when SPTrans authentication fails.""" + with patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", + new_callable=AsyncMock, + side_effect=RuntimeError("Authentication failed"), + ): + request_data = BusPositionsRequest( + routes=[ + BusRouteSchema( + route_id=12345, + route=RouteIdentifierSchema(bus_line="8000", bus_direction=1), + ), + ] + ) + + response = await client.post( + "/routes/positions", json=request_data.model_dump() + ) + + assert response.status_code == 500 + + @pytest.mark.asyncio + async def test_get_bus_position_returns_422_when_invalid_data( + self, + client: AsyncClient, + ) -> None: + """Test that invalid bus direction fails validation.""" + request_data = { + "routes": [ + { + "route_id": 12345, + "route": {"bus_line": "8000", "bus_direction": 3}, + } + ] + } + + response = await client.post("/routes/positions", json=request_data) + + assert response.status_code == 422 # Validation error + + @pytest.mark.asyncio + async def test_get_bus_position_returns_successfully_with_empty_routes_list( + self, + client: AsyncClient, + ) -> None: + """Test behavior with empty routes list.""" + with ( + patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", + new_callable=AsyncMock, + return_value=True, + ), + patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", + new_callable=AsyncMock, + return_value=[], + ), + ): + request_data = BusPositionsRequest(routes=[]) + + response = await client.post( + "/routes/positions", json=request_data.model_dump() + ) + + assert response.status_code == 200 + assert response.json()["buses"] == [] + + +class TestRouteShape: + """Tests for GET /routes/shape/{route_id} endpoint. + + These tests use the actual GTFS database (gtfs.db) since + GTFSRepositoryAdapter is a local database adapter, not an external service. + """ + + @pytest.mark.asyncio + async def test_get_route_shape_returns_successfully( + self, + client: AsyncClient, + ) -> None: + """Test that getting route shape works for an existing route.""" + response = await client.get("/routes/shape/1012-10") + + assert response.status_code == 200 + data = response.json() + + assert data["route_id"] == "1012-10" + assert "shape_id" in data + assert "points" in data + assert len(data["points"]) > 0 + + first_point = data["points"][0] + assert "latitude" in first_point + assert "longitude" in first_point + assert isinstance(first_point["latitude"], float) + assert isinstance(first_point["longitude"], float) + + @pytest.mark.asyncio + async def test_get_route_shape_returns_404_when_not_found( + self, + client: AsyncClient, + ) -> None: + response = await client.get("/routes/shape/NONEXISTENT-ROUTE-12345") + + assert response.status_code == 404 + assert "not found" in response.json()["detail"].lower() + + @pytest.mark.asyncio + async def test_get_route_shape_points_have_valid_coordinates( + self, + client: AsyncClient, + ) -> None: + """Test that all shape points have valid São Paulo coordinates.""" + response = await client.get("/routes/shape/1012-10") + + assert response.status_code == 200 + points = response.json()["points"] + + # São Paulo approximate bounding box + for point in points: + # Latitude should be around -23 to -24 for São Paulo + assert -25 <= point["latitude"] <= -22 + # Longitude should be around -46 to -47 for São Paulo + assert -48 <= point["longitude"] <= -45 diff --git a/tests/integration/test_trip.py b/tests/integration/test_trip.py index 2f2366b..c1285a6 100644 --- a/tests/integration/test_trip.py +++ b/tests/integration/test_trip.py @@ -179,19 +179,20 @@ async def test_create_trip_negative_distance_fails( } auth = await create_user_and_login(client, user_data) - trip_data = CreateTripRequest( - email=user_data["email"], - route=RouteIdentifierSchema( - bus_line="8000", - bus_direction=1, - ), - distance=-1000, - data=datetime.now(timezone.utc), - ) + # Use raw dict to bypass Pydantic validation and test API validation + trip_data = { + "email": user_data["email"], + "route": { + "bus_line": "8000", + "bus_direction": 1, + }, + "distance": -1000, + "data": datetime.now(timezone.utc).isoformat(), + } response = await client.post( "/trips/", - json=trip_data.model_dump(mode="json"), + json=trip_data, headers=auth["headers"], ) @@ -216,19 +217,20 @@ async def test_create_trip_invalid_route_identifier_fails( } auth = await create_user_and_login(client, user_data) - trip_data = CreateTripRequest( - email=user_data["email"], - route=RouteIdentifierSchema( - bus_line="8000", - bus_direction=3, - ), - distance=1000, - data=datetime.now(timezone.utc), - ) + # Use raw dict to bypass Pydantic validation and test API validation + trip_data = { + "email": user_data["email"], + "route": { + "bus_line": "8000", + "bus_direction": 3, + }, + "distance": 1000, + "data": datetime.now(timezone.utc).isoformat(), + } response = await client.post( "/trips/", - json=trip_data.model_dump(mode="json"), + json=trip_data, headers=auth["headers"], ) diff --git a/tests/integration/test_user_history.py b/tests/integration/test_user_history.py index a048c76..7f76d4a 100644 --- a/tests/integration/test_user_history.py +++ b/tests/integration/test_user_history.py @@ -72,6 +72,7 @@ async def test_get_user_history_returns_empty_when_no_trips( client: AsyncClient, ) -> None: user_data = { + "name": "Test User", "email": "test@example.com", "password": "secure_password_123", } @@ -106,7 +107,7 @@ async def test_get_history_includes_correct_dates( self, client: AsyncClient, ) -> None: - user_data = { + user_data = { "name": "Test User", "email": "test@example.com", "password": "securepassword123", @@ -140,4 +141,4 @@ async def test_get_history_includes_correct_dates( trip_date = history_response.trips[0].date assert trip_date.year == 2025 assert trip_date.month == 6 - assert trip_date.day == 15 \ No newline at end of file + assert trip_date.day == 15 From 77b92e2562f7e961ce2be5bd74b8c72e46a096c4 Mon Sep 17 00:00:00 2001 From: Kim Kakeya Date: Mon, 1 Dec 2025 20:28:18 -0300 Subject: [PATCH 06/12] style: apply ruff styles and formating --- src/web/controllers/history_controller.py | 3 +- tests/integration/conftest.py | 4 +- tests/integration/test_login.py | 3 -- tests/integration/test_ranking.py | 7 ++- tests/integration/test_route.py | 54 +++++++---------------- tests/integration/test_trip.py | 28 +++++------- tests/integration/test_user_history.py | 11 +++-- tests/web/test_route_controller.py | 8 +--- 8 files changed, 39 insertions(+), 79 deletions(-) diff --git a/src/web/controllers/history_controller.py b/src/web/controllers/history_controller.py index c2973df..5873289 100644 --- a/src/web/controllers/history_controller.py +++ b/src/web/controllers/history_controller.py @@ -56,8 +56,7 @@ async def get_user_history( # Combine dates and scores into trip entries (returns empty list if no history) trips = [ - TripHistoryEntry(date=date, score=score) - for date, score in zip(dates, scores, strict=False) + TripHistoryEntry(date=date, score=score) for date, score in zip(dates, scores, strict=False) ] return HistoryResponse(trips=trips) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 66b84ff..33ade6a 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -68,9 +68,7 @@ async def test_db() -> AsyncGenerator[AsyncSession, None]: @pytest.fixture -async def client( - test_db: AsyncSession -) -> AsyncGenerator[AsyncClient, None]: +async def client(test_db: AsyncSession) -> AsyncGenerator[AsyncClient, None]: """ Create test HTTP client. diff --git a/tests/integration/test_login.py b/tests/integration/test_login.py index 05f1971..a3d14ee 100644 --- a/tests/integration/test_login.py +++ b/tests/integration/test_login.py @@ -5,7 +5,6 @@ class TestUserRegistration: - @pytest.mark.asyncio async def test_create_account_should_work( self, @@ -73,7 +72,6 @@ async def test_create_account_short_password_fails( class TestUserLogin: - @pytest.mark.asyncio async def test_login_should_work( self, @@ -135,7 +133,6 @@ async def test_login_nonexistent_user_fails( self, client: AsyncClient, ) -> None: - response = await client.post( "/users/login", data={ diff --git a/tests/integration/test_ranking.py b/tests/integration/test_ranking.py index 8964472..26e0fd9 100644 --- a/tests/integration/test_ranking.py +++ b/tests/integration/test_ranking.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime import pytest from httpx import AsyncClient @@ -6,7 +6,7 @@ from src.web.schemas import CreateTripRequest, RouteIdentifierSchema -from .conftest import create_user_and_login, create_test_user_in_db +from .conftest import create_test_user_in_db, create_user_and_login class TestUserRankPosition: @@ -57,7 +57,7 @@ async def test_get_user_rank_position_with_multiple_users( email=user_data["email"], route=RouteIdentifierSchema(bus_line="8000", bus_direction=1), distance=0, - data=datetime.now(timezone.utc), + data=datetime.now(UTC), ) await client.post( "/trips/", json=trip_data.model_dump(mode="json"), headers=auth["headers"] @@ -91,7 +91,6 @@ async def test_get_user_rank_position_without_auth_fails( class TestGlobalRanking: - @pytest.mark.asyncio async def test_get_global_ranking_should_work( self, diff --git a/tests/integration/test_route.py b/tests/integration/test_route.py index bc937c3..8077979 100644 --- a/tests/integration/test_route.py +++ b/tests/integration/test_route.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime from unittest.mock import AsyncMock, patch import pytest @@ -45,9 +45,7 @@ async def test_get_route_details_returns_successfully( ] ) - response = await client.post( - "/routes/details", json=request_data.model_dump() - ) + response = await client.post("/routes/details", json=request_data.model_dump()) assert response.status_code == 200 data = response.json() @@ -100,9 +98,7 @@ async def test_get_route_details_with_multiple_lines( ] ) - response = await client.post( - "/routes/details", json=request_data.model_dump() - ) + response = await client.post("/routes/details", json=request_data.model_dump()) assert response.status_code == 200 data = response.json() @@ -135,9 +131,7 @@ async def test_get_route_details_returns_empty_for_unknown_line( ] ) - response = await client.post( - "/routes/details", json=request_data.model_dump() - ) + response = await client.post("/routes/details", json=request_data.model_dump()) assert response.status_code == 200 data = response.json() @@ -166,9 +160,7 @@ async def test_get_route_details_returns_500_on_api_error( ] ) - response = await client.post( - "/routes/details", json=request_data.model_dump() - ) + response = await client.post("/routes/details", json=request_data.model_dump()) assert response.status_code == 500 assert "Failed to retrieve route details" in response.json()["detail"] @@ -185,9 +177,7 @@ async def test_get_route_details_with_empty_routes_list( ): request_data = BusRoutesDetailsRequest(routes=[]) - response = await client.post( - "/routes/details", json=request_data.model_dump() - ) + response = await client.post("/routes/details", json=request_data.model_dump()) assert response.status_code == 200 assert response.json()["routes"] == [] @@ -203,12 +193,12 @@ async def test_get_bus_position_returns_successfully( BusPosition( route=RouteIdentifier(bus_line="8000", bus_direction=1), position=Coordinate(latitude=-23.550520, longitude=-46.633308), - time_updated=datetime.now(timezone.utc), + time_updated=datetime.now(UTC), ), BusPosition( route=RouteIdentifier(bus_line="8000", bus_direction=1), position=Coordinate(latitude=-23.551234, longitude=-46.634567), - time_updated=datetime.now(timezone.utc), + time_updated=datetime.now(UTC), ), ] @@ -233,9 +223,7 @@ async def test_get_bus_position_returns_successfully( ] ) - response = await client.post( - "/routes/positions", json=request_data.model_dump() - ) + response = await client.post("/routes/positions", json=request_data.model_dump()) assert response.status_code == 200 data = response.json() @@ -279,9 +267,7 @@ async def test_get_bus_position_returns_500_when_error( ] ) - response = await client.post( - "/routes/positions", json=request_data.model_dump() - ) + response = await client.post("/routes/positions", json=request_data.model_dump()) assert response.status_code == 500 assert "Failed to retrieve bus positions" in response.json()["detail"] @@ -313,9 +299,7 @@ async def test_get_bus_position_returns_empty_when_no_buses_on_line( ] ) - response = await client.post( - "/routes/positions", json=request_data.model_dump() - ) + response = await client.post("/routes/positions", json=request_data.model_dump()) assert response.status_code == 200 data = response.json() @@ -333,14 +317,14 @@ async def test_get_bus_position_works_with_multiple_routes( BusPosition( route=RouteIdentifier(bus_line="8000", bus_direction=1), position=Coordinate(latitude=-23.550520, longitude=-46.633308), - time_updated=datetime.now(timezone.utc), + time_updated=datetime.now(UTC), ), ] mock_positions_9000 = [ BusPosition( route=RouteIdentifier(bus_line="9000", bus_direction=2), position=Coordinate(latitude=-23.560520, longitude=-46.643308), - time_updated=datetime.now(timezone.utc), + time_updated=datetime.now(UTC), ), ] @@ -369,9 +353,7 @@ async def test_get_bus_position_works_with_multiple_routes( ] ) - response = await client.post( - "/routes/positions", json=request_data.model_dump() - ) + response = await client.post("/routes/positions", json=request_data.model_dump()) assert response.status_code == 200 data = response.json() @@ -401,9 +383,7 @@ async def test_get_bus_position_returns_500_when_authentication_failure( ] ) - response = await client.post( - "/routes/positions", json=request_data.model_dump() - ) + response = await client.post("/routes/positions", json=request_data.model_dump()) assert response.status_code == 500 @@ -446,9 +426,7 @@ async def test_get_bus_position_returns_successfully_with_empty_routes_list( ): request_data = BusPositionsRequest(routes=[]) - response = await client.post( - "/routes/positions", json=request_data.model_dump() - ) + response = await client.post("/routes/positions", json=request_data.model_dump()) assert response.status_code == 200 assert response.json()["buses"] == [] diff --git a/tests/integration/test_trip.py b/tests/integration/test_trip.py index c1285a6..3ee7c02 100644 --- a/tests/integration/test_trip.py +++ b/tests/integration/test_trip.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime import pytest from httpx import AsyncClient @@ -32,7 +32,7 @@ async def test_create_trip_should_return_successfully_and_save_to_database( bus_direction=1, ), distance=5000, - data=datetime.now(timezone.utc), + data=datetime.now(UTC), ) response = await client.post( @@ -46,9 +46,7 @@ async def test_create_trip_should_return_successfully_and_save_to_database( assert "score" in data - result = await test_db.execute( - select(TripDB).where(TripDB.email == user_data["email"]) - ) + result = await test_db.execute(select(TripDB).where(TripDB.email == user_data["email"])) trip = result.scalar_one_or_none() assert trip is not None @@ -77,7 +75,7 @@ async def test_create_trip_updates_user_score( bus_direction=1, ), distance=1000, - data=datetime.now(timezone.utc), + data=datetime.now(UTC), ) second_trip_data = CreateTripRequest( email=user_data["email"], @@ -86,7 +84,7 @@ async def test_create_trip_updates_user_score( bus_direction=1, ), distance=2000, - data=datetime.now(timezone.utc), + data=datetime.now(UTC), ) resp1 = await client.post( @@ -128,7 +126,7 @@ async def test_create_trip_without_authentication_fails( bus_direction=1, ), distance=1000, - data=datetime.now(timezone.utc), + data=datetime.now(UTC), ) response = await client.post("/trips/", json=trip_data.model_dump(mode="json")) @@ -154,7 +152,7 @@ async def test_create_trip_zero_distance( bus_direction=2, ), distance=0, - data=datetime.now(timezone.utc), + data=datetime.now(UTC), ) response = await client.post( @@ -187,7 +185,7 @@ async def test_create_trip_negative_distance_fails( "bus_direction": 1, }, "distance": -1000, - "data": datetime.now(timezone.utc).isoformat(), + "data": datetime.now(UTC).isoformat(), } response = await client.post( @@ -198,9 +196,7 @@ async def test_create_trip_negative_distance_fails( assert response.status_code == 422 - result = await test_db.execute( - select(TripDB).where(TripDB.email == user_data["email"]) - ) + result = await test_db.execute(select(TripDB).where(TripDB.email == user_data["email"])) trip = result.scalar_one_or_none() assert trip is None @@ -225,7 +221,7 @@ async def test_create_trip_invalid_route_identifier_fails( "bus_direction": 3, }, "distance": 1000, - "data": datetime.now(timezone.utc).isoformat(), + "data": datetime.now(UTC).isoformat(), } response = await client.post( @@ -236,8 +232,6 @@ async def test_create_trip_invalid_route_identifier_fails( assert response.status_code == 422 - result = await test_db.execute( - select(TripDB).where(TripDB.email == user_data["email"]) - ) + result = await test_db.execute(select(TripDB).where(TripDB.email == user_data["email"])) trip = result.scalar_one_or_none() assert trip is None diff --git a/tests/integration/test_user_history.py b/tests/integration/test_user_history.py index 7f76d4a..33b0de5 100644 --- a/tests/integration/test_user_history.py +++ b/tests/integration/test_user_history.py @@ -1,8 +1,7 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime import pytest from httpx import AsyncClient -from sqlalchemy.ext.asyncio import AsyncSession from src.web.schemas import ( CreateTripRequest, @@ -28,9 +27,9 @@ async def test_get_user_history_should_work( auth = await create_user_and_login(client, user_data) trip_dates: list[datetime] = [ - datetime(2025, 11, 1, 8, 0, 0, tzinfo=timezone.utc), - datetime(2025, 11, 15, 12, 0, 0, tzinfo=timezone.utc), - datetime(2025, 11, 29, 18, 0, 0, tzinfo=timezone.utc), + datetime(2025, 11, 1, 8, 0, 0, tzinfo=UTC), + datetime(2025, 11, 15, 12, 0, 0, tzinfo=UTC), + datetime(2025, 11, 29, 18, 0, 0, tzinfo=UTC), ] scores: list[int] = [] @@ -115,7 +114,7 @@ async def test_get_history_includes_correct_dates( auth = await create_user_and_login(client, user_data) - specific_date: datetime = datetime(2025, 6, 15, 10, 30, 0, tzinfo=timezone.utc) + specific_date: datetime = datetime(2025, 6, 15, 10, 30, 0, tzinfo=UTC) trip_request = CreateTripRequest( email=user_data["email"], route=RouteIdentifierSchema(bus_line="8000", bus_direction=1), diff --git a/tests/web/test_route_controller.py b/tests/web/test_route_controller.py index 739dc2b..2cac40a 100644 --- a/tests/web/test_route_controller.py +++ b/tests/web/test_route_controller.py @@ -50,9 +50,7 @@ def override_dependency(mock_service: RouteService) -> Generator[None, None, Non @pytest.mark.asyncio -async def test_details_endpoint_success( - client: TestClient, mock_service: RouteService -) -> None: +async def test_details_endpoint_success(client: TestClient, mock_service: RouteService) -> None: """ Testa o endpoint POST /routes/details garantindo que: - Ele chama RouteService.get_route_details() @@ -127,9 +125,7 @@ async def test_details_endpoint_error_returns_500( @pytest.mark.asyncio -async def test_positions_endpoint_success( - client: TestClient, mock_service: RouteService -) -> None: +async def test_positions_endpoint_success(client: TestClient, mock_service: RouteService) -> None: """ Testa o endpoint POST /routes/positions garantindo que: - Ele chama RouteService.get_bus_positions() From ad51938d6d88808394aec856106513a81ba5cc68 Mon Sep 17 00:00:00 2001 From: Kim Kakeya Date: Mon, 1 Dec 2025 21:03:31 -0300 Subject: [PATCH 07/12] style: clean tests --- tests/integration/test_login.py | 11 ----------- tests/integration/test_ranking.py | 2 -- tests/integration/test_route.py | 10 ---------- 3 files changed, 23 deletions(-) diff --git a/tests/integration/test_login.py b/tests/integration/test_login.py index a3d14ee..fe3c32f 100644 --- a/tests/integration/test_login.py +++ b/tests/integration/test_login.py @@ -23,7 +23,6 @@ async def test_create_account_should_work( assert data["name"] == user_data["name"] assert data["email"] == user_data["email"] assert data["score"] == 0 - # Password should not be returned assert "password" not in data @pytest.mark.asyncio @@ -36,11 +35,9 @@ async def test_create_account_duplicate_email_fails( "email": "test@example.com", "password": "securepassword123", } - # Create first user response1 = await client.post("/users/register", json=user_data) assert response1.status_code == 201 - # Try to create second user with same email response2 = await client.post("/users/register", json=user_data) assert response2.status_code == 400 @@ -70,7 +67,6 @@ async def test_create_account_short_password_fails( response = await client.post("/users/register", json=invalid_data) assert response.status_code == 422 # Validation error - class TestUserLogin: @pytest.mark.asyncio async def test_login_should_work( @@ -145,14 +141,11 @@ async def test_login_nonexistent_user_fails( class TestGetUserInfo: - """Test getting user information.""" - @pytest.mark.asyncio async def test_get_user_info_should_work( self, client: AsyncClient, ) -> None: - """Test that authenticated user can get their own info.""" user_data = { "name": "Test User", "email": "test@example.com", @@ -207,13 +200,10 @@ async def test_get_other_user_info_should_work( "email": "second@example.com", "password": "anotherpassword123", } - # Create first user await client.post("/users/register", json=user_data) - # Create second user await client.post("/users/register", json=second_user_data) - # Get first user's info by email (public endpoint) response = await client.get(f"/users/{user_data['email']}") assert response.status_code == 200 @@ -228,7 +218,6 @@ async def test_get_nonexistent_user_info_fails( self, client: AsyncClient, ) -> None: - """Test that getting non-existent user's info returns 404.""" response = await client.get("/users/nonexistent@example.com") assert response.status_code == 404 diff --git a/tests/integration/test_ranking.py b/tests/integration/test_ranking.py index 26e0fd9..427af94 100644 --- a/tests/integration/test_ranking.py +++ b/tests/integration/test_ranking.py @@ -8,13 +8,11 @@ from .conftest import create_test_user_in_db, create_user_and_login - class TestUserRankPosition: @pytest.mark.asyncio async def test_get_user_rank_position_should_work( self, client: AsyncClient, - test_db: AsyncSession, ) -> None: user_data = { "name": "Test User", diff --git a/tests/integration/test_route.py b/tests/integration/test_route.py index 8077979..3d3c374 100644 --- a/tests/integration/test_route.py +++ b/tests/integration/test_route.py @@ -13,7 +13,6 @@ RouteIdentifierSchema, ) - class TestRouteDetails: @pytest.mark.asyncio async def test_get_route_details_returns_successfully( @@ -53,7 +52,6 @@ async def test_get_route_details_returns_successfully( assert "routes" in data assert len(data["routes"]) == 1 - # Verify route structure first_route = data["routes"][0] assert "route_id" in first_route assert first_route["route_id"] == 12345 @@ -245,7 +243,6 @@ async def test_get_bus_position_returns_500_when_error( self, client: AsyncClient, ) -> None: - """Test that getting bus positions returns 500 on API errors.""" with ( patch( "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", @@ -277,7 +274,6 @@ async def test_get_bus_position_returns_empty_when_no_buses_on_line( self, client: AsyncClient, ) -> None: - """Test that getting bus positions returns empty list when no buses running.""" with ( patch( "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", @@ -312,7 +308,6 @@ async def test_get_bus_position_works_with_multiple_routes( self, client: AsyncClient, ) -> None: - """Test that querying multiple routes returns positions for all of them.""" mock_positions_8000 = [ BusPosition( route=RouteIdentifier(bus_line="8000", bus_direction=1), @@ -368,7 +363,6 @@ async def test_get_bus_position_returns_500_when_authentication_failure( self, client: AsyncClient, ) -> None: - """Test behavior when SPTrans authentication fails.""" with patch( "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", new_callable=AsyncMock, @@ -392,7 +386,6 @@ async def test_get_bus_position_returns_422_when_invalid_data( self, client: AsyncClient, ) -> None: - """Test that invalid bus direction fails validation.""" request_data = { "routes": [ { @@ -411,7 +404,6 @@ async def test_get_bus_position_returns_successfully_with_empty_routes_list( self, client: AsyncClient, ) -> None: - """Test behavior with empty routes list.""" with ( patch( "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", @@ -444,7 +436,6 @@ async def test_get_route_shape_returns_successfully( self, client: AsyncClient, ) -> None: - """Test that getting route shape works for an existing route.""" response = await client.get("/routes/shape/1012-10") assert response.status_code == 200 @@ -476,7 +467,6 @@ async def test_get_route_shape_points_have_valid_coordinates( self, client: AsyncClient, ) -> None: - """Test that all shape points have valid São Paulo coordinates.""" response = await client.get("/routes/shape/1012-10") assert response.status_code == 200 From 52483b91e286514091741ae2fd6f9e087e2b0ac6 Mon Sep 17 00:00:00 2001 From: Kim Kakeya Date: Mon, 1 Dec 2025 21:06:01 -0300 Subject: [PATCH 08/12] fix: ruff styling --- tests/integration/test_ranking.py | 1 + tests/integration/test_route.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/integration/test_ranking.py b/tests/integration/test_ranking.py index 427af94..f179047 100644 --- a/tests/integration/test_ranking.py +++ b/tests/integration/test_ranking.py @@ -8,6 +8,7 @@ from .conftest import create_test_user_in_db, create_user_and_login + class TestUserRankPosition: @pytest.mark.asyncio async def test_get_user_rank_position_should_work( diff --git a/tests/integration/test_route.py b/tests/integration/test_route.py index 3d3c374..f70d7b0 100644 --- a/tests/integration/test_route.py +++ b/tests/integration/test_route.py @@ -13,6 +13,7 @@ RouteIdentifierSchema, ) + class TestRouteDetails: @pytest.mark.asyncio async def test_get_route_details_returns_successfully( From fe9a2a7f7f1ee6c08ddce0f6e3c591e033f40217 Mon Sep 17 00:00:00 2001 From: Kim Kakeya Date: Mon, 1 Dec 2025 21:39:51 -0300 Subject: [PATCH 09/12] feat: fix access to endpoints --- src/web/auth.py | 50 +++++ src/web/controllers/history_controller.py | 21 +- src/web/controllers/rank_controller.py | 26 +-- src/web/controllers/route_controller.py | 12 +- src/web/controllers/trip_controller.py | 10 +- src/web/controllers/user_controller.py | 103 +-------- src/web/schemas.py | 23 +-- tests/integration/test_login.py | 41 +--- tests/integration/test_ranking.py | 61 ++++-- tests/integration/test_route.py | 241 +++++++++++++++++++--- tests/integration/test_trip.py | 26 +-- tests/integration/test_user_history.py | 19 +- tests/web/test_route_controller.py | 25 ++- 13 files changed, 373 insertions(+), 285 deletions(-) create mode 100644 src/web/auth.py diff --git a/src/web/auth.py b/src/web/auth.py new file mode 100644 index 0000000..cc55d75 --- /dev/null +++ b/src/web/auth.py @@ -0,0 +1,50 @@ +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer +from sqlalchemy.ext.asyncio import AsyncSession + +from ..adapters.database.connection import get_db +from ..adapters.repositories.user_repository_adapter import UserRepositoryAdapter +from ..adapters.security.hashing import PasslibPasswordHasher +from ..adapters.security.jwt import verify_token +from ..core.models.user import User +from ..core.ports.password_hasher import PasswordHasherPort +from ..core.services.user_service import UserService + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login") + + +def get_password_hasher() -> PasswordHasherPort: + return PasslibPasswordHasher() + + +def get_user_service( + db: AsyncSession = Depends(get_db), + password_hasher: PasswordHasherPort = Depends(get_password_hasher), +) -> UserService: + user_repo = UserRepositoryAdapter(db) + return UserService(user_repo, password_hasher) + + +async def get_current_user( + token: str = Depends(oauth2_scheme), + user_service: UserService = Depends(get_user_service), +) -> User: + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + + payload = verify_token(token) + if payload is None: + raise credentials_exception + + email: str | None = payload.get("sub") + if email is None: + raise credentials_exception + + user = await user_service.get_user(email) + if user is None: + raise credentials_exception + + return user diff --git a/src/web/controllers/history_controller.py b/src/web/controllers/history_controller.py index 5873289..f541171 100644 --- a/src/web/controllers/history_controller.py +++ b/src/web/controllers/history_controller.py @@ -1,9 +1,3 @@ -""" -History controller - API endpoints for user trip history. - -This controller handles queries for user trip history. -""" - from fastapi import APIRouter, Depends from sqlalchemy.ext.asyncio import AsyncSession @@ -13,13 +7,11 @@ ) from ...core.models.user import User from ...core.services.history_service import HistoryService -from ..schemas import HistoryRequest, HistoryResponse, TripHistoryEntry +from ..auth import get_current_user +from ..schemas import HistoryResponse, TripHistoryEntry router = APIRouter(prefix="/history", tags=["history"]) -# Import get_current_user from user_controller to avoid circular imports -from .user_controller import get_current_user # noqa: E402 - def get_history_service(db: AsyncSession = Depends(get_db)) -> HistoryService: """ @@ -35,9 +27,8 @@ def get_history_service(db: AsyncSession = Depends(get_db)) -> HistoryService: return HistoryService(history_repo) -@router.post("/", response_model=HistoryResponse) +@router.get("/", response_model=HistoryResponse) async def get_user_history( - request: HistoryRequest, history_service: HistoryService = Depends(get_history_service), current_user: User = Depends(get_current_user), ) -> HistoryResponse: @@ -45,18 +36,18 @@ async def get_user_history( Get a user's trip history summary. Args: - request: Request with user email history_service: Injected history service current_user: Authenticated user (from JWT token) Returns: User's trip history with dates and scores (empty list if no history) """ - dates, scores = await history_service.get_user_history_summary(request.email) + dates, scores = await history_service.get_user_history_summary(current_user.email) # Combine dates and scores into trip entries (returns empty list if no history) trips = [ - TripHistoryEntry(date=date, score=score) for date, score in zip(dates, scores, strict=False) + TripHistoryEntry(date=date, score=score) + for date, score in zip(dates, scores, strict=False) ] return HistoryResponse(trips=trips) diff --git a/src/web/controllers/rank_controller.py b/src/web/controllers/rank_controller.py index 92f1831..b0e9bba 100644 --- a/src/web/controllers/rank_controller.py +++ b/src/web/controllers/rank_controller.py @@ -11,14 +11,12 @@ from ...adapters.repositories.user_repository_adapter import UserRepositoryAdapter from ...core.models.user import User from ...core.services.score_service import ScoreService +from ..auth import get_current_user from ..mappers import map_user_domain_list_to_response -from ..schemas import GlobalRankingResponse, UserRankingRequest, UserRankingResponse +from ..schemas import GlobalRankingResponse, UserRankingResponse router = APIRouter(prefix="/rank", tags=["ranking"]) -# Import get_current_user from user_controller to avoid circular imports -from .user_controller import get_current_user # noqa: E402 - def get_score_service(db: AsyncSession = Depends(get_db)) -> ScoreService: """ @@ -34,27 +32,12 @@ def get_score_service(db: AsyncSession = Depends(get_db)) -> ScoreService: return ScoreService(user_repo) -@router.post("/user", response_model=UserRankingResponse) +@router.get("/user", response_model=UserRankingResponse) async def get_user_ranking( - request: UserRankingRequest, score_service: ScoreService = Depends(get_score_service), current_user: User = Depends(get_current_user), ) -> UserRankingResponse: - """ - Get a user's position in the global ranking. - - Args: - request: Request with user email - score_service: Injected score service - current_user: Authenticated user (from JWT token) - - Returns: - User's rank position - - Raises: - HTTPException: If user not found - """ - position = await score_service.get_user_ranking(request.email) + position = await score_service.get_user_ranking(current_user.email) if position is None: raise HTTPException( @@ -68,6 +51,7 @@ async def get_user_ranking( @router.get("/global", response_model=GlobalRankingResponse) async def get_global_ranking( score_service: ScoreService = Depends(get_score_service), + current_user: User = Depends(get_current_user), ) -> GlobalRankingResponse: """ Get the global user ranking. diff --git a/src/web/controllers/route_controller.py b/src/web/controllers/route_controller.py index bf69ff8..6418a01 100644 --- a/src/web/controllers/route_controller.py +++ b/src/web/controllers/route_controller.py @@ -10,7 +10,9 @@ from ...adapters.repositories.gtfs_repository_adapter import GTFSRepositoryAdapter from ...config import settings from ...core.models.bus import BusPosition, BusRoute, RouteIdentifier +from ...core.models.user import User from ...core.services.route_service import RouteService +from ..auth import get_current_user from ..mappers import ( map_bus_position_list_to_schema, map_route_identifier_schema_to_domain, @@ -48,6 +50,7 @@ def get_route_service() -> RouteService: async def get_route_details_endpoint( request: BusRoutesDetailsRequest, route_service: RouteService = Depends(get_route_service), + current_user: User = Depends(get_current_user), ) -> BusRoutesDetailsResponse: """ Resolve, para cada linha solicitada, as rotas concretas do provedor @@ -59,7 +62,8 @@ async def get_route_details_endpoint( try: # Schemas -> domínio (RouteIdentifier) route_identifiers: list[RouteIdentifier] = [ - map_route_identifier_schema_to_domain(route_schema) for route_schema in request.routes + map_route_identifier_schema_to_domain(route_schema) + for route_schema in request.routes ] bus_routes: list[BusRoute] = [] @@ -97,6 +101,7 @@ async def get_route_details_endpoint( async def get_bus_positions( request: BusPositionsRequest, route_service: RouteService = Depends(get_route_service), + current_user: User = Depends(get_current_user), ) -> BusPositionsResponse: """ Recupera as posições dos ônibus para as rotas já resolvidas. @@ -119,7 +124,9 @@ async def get_bus_positions( route=route_identifier, ) - route_positions: list[BusPosition] = await route_service.get_bus_positions(bus_route) + route_positions: list[BusPosition] = await route_service.get_bus_positions( + bus_route + ) all_positions.extend(route_positions) # Domínio -> schemas @@ -138,6 +145,7 @@ async def get_bus_positions( async def get_route_shape( route_id: str, route_service: RouteService = Depends(get_route_service), + current_user: User = Depends(get_current_user), ) -> RouteShapeResponse: """ Get the geographic shape (coordinates) of a route from GTFS data. diff --git a/src/web/controllers/trip_controller.py b/src/web/controllers/trip_controller.py index a11ea43..5e955fb 100644 --- a/src/web/controllers/trip_controller.py +++ b/src/web/controllers/trip_controller.py @@ -12,13 +12,11 @@ from ...adapters.repositories.user_repository_adapter import UserRepositoryAdapter from ...core.models.user import User from ...core.services.trip_service import TripService +from ..auth import get_current_user from ..schemas import CreateTripRequest, CreateTripResponse router = APIRouter(prefix="/trips", tags=["trips"]) -# Import get_current_user from user_controller to avoid circular imports -from .user_controller import get_current_user # noqa: E402 - def get_trip_service(db: AsyncSession = Depends(get_db)) -> TripService: """ @@ -35,7 +33,9 @@ def get_trip_service(db: AsyncSession = Depends(get_db)) -> TripService: return TripService(trip_repo, user_repo) -@router.post("/", response_model=CreateTripResponse, status_code=status.HTTP_201_CREATED) +@router.post( + "/", response_model=CreateTripResponse, status_code=status.HTTP_201_CREATED +) async def create_trip( request: CreateTripRequest, trip_service: TripService = Depends(get_trip_service), @@ -57,7 +57,7 @@ async def create_trip( """ try: trip = await trip_service.create_trip( - email=request.email, + email=current_user.email, bus_line=request.route.bus_line, bus_direction=request.route.bus_direction, distance=request.distance, diff --git a/src/web/controllers/user_controller.py b/src/web/controllers/user_controller.py index d673d97..48815df 100644 --- a/src/web/controllers/user_controller.py +++ b/src/web/controllers/user_controller.py @@ -6,16 +6,12 @@ """ from fastapi import APIRouter, Depends, HTTPException, status -from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm -from sqlalchemy.ext.asyncio import AsyncSession +from fastapi.security import OAuth2PasswordRequestForm -from ...adapters.database.connection import get_db -from ...adapters.repositories.user_repository_adapter import UserRepositoryAdapter -from ...adapters.security.hashing import PasslibPasswordHasher -from ...adapters.security.jwt import create_access_token, verify_token +from ...adapters.security.jwt import create_access_token from ...core.models.user import User -from ...core.ports.password_hasher import PasswordHasherPort from ...core.services.user_service import UserService +from ..auth import get_current_user, get_user_service from ..mappers import map_user_domain_to_response from ..schemas import ( TokenResponse, @@ -25,57 +21,10 @@ router = APIRouter(prefix="/users", tags=["users"]) -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login") - -def get_password_hasher() -> PasswordHasherPort: - """Dependency that provides a password hasher implementation.""" - return PasslibPasswordHasher() - - -def get_user_service( - db: AsyncSession = Depends(get_db), - password_hasher: PasswordHasherPort = Depends(get_password_hasher), -) -> UserService: - """ - Dependency that provides a UserService instance. - - Args: - db: Database session - - Returns: - Configured UserService instance - """ - user_repo = UserRepositoryAdapter(db) - return UserService(user_repo, password_hasher) - - -async def get_current_user( - token: str = Depends(oauth2_scheme), - user_service: UserService = Depends(get_user_service), -) -> User: - credentials_exception = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, - ) - - payload = verify_token(token) - if payload is None: - raise credentials_exception - - email: str | None = payload.get("sub") - if email is None: - raise credentials_exception - - user = await user_service.get_user(email) - if user is None: - raise credentials_exception - - return user - - -@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED) +@router.post( + "/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED +) async def create_user( request: UserCreateAccountRequest, user_service: UserService = Depends(get_user_service), @@ -147,44 +96,4 @@ async def login_user( async def get_current_user_info( current_user: User = Depends(get_current_user), ) -> UserResponse: - """ - Get current authenticated user's information. - - This is a protected route that requires a valid JWT token. - - Args: - current_user: Current authenticated user from token - - Returns: - Current user's information - """ return map_user_domain_to_response(current_user) - - -@router.get("/{email}", response_model=UserResponse) -async def get_user( - email: str, - user_service: UserService = Depends(get_user_service), -) -> UserResponse: - """ - Get user information by email. - - Args: - email: User's email address - user_service: Injected user service - - Returns: - User information - - Raises: - HTTPException: If user not found - """ - user = await user_service.get_user(email) - - if not user: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="User not found", - ) - - return map_user_domain_to_response(user) diff --git a/src/web/schemas.py b/src/web/schemas.py index 452cadf..73a517e 100644 --- a/src/web/schemas.py +++ b/src/web/schemas.py @@ -89,9 +89,6 @@ class TokenResponse(BaseModel): class CreateTripRequest(BaseModel): - """Request schema for creating a new trip.""" - - email: EmailStr = Field(..., description="User's email") route: RouteIdentifierSchema distance: int = Field(..., ge=0, description="Distance traveled in meters") data: datetime = Field(..., description="Trip date and time") @@ -143,21 +140,15 @@ class RouteShapeResponse(BaseModel): route_id: str = Field(..., description="Route identifier") shape_id: str = Field(..., description="GTFS shape identifier") - points: list[CoordinateSchema] = Field(..., description="Ordered list of coordinates") + points: list[CoordinateSchema] = Field( + ..., description="Ordered list of coordinates" + ) # ===== Ranking Schemas ===== -class UserRankingRequest(BaseModel): - """Request schema for getting a user's ranking.""" - - email: EmailStr = Field(..., description="User's email") - - class UserRankingResponse(BaseModel): - """Response schema for user ranking.""" - position: int = Field(..., description="User's rank position") @@ -170,15 +161,7 @@ class GlobalRankingResponse(BaseModel): # ===== History Schemas ===== -class HistoryRequest(BaseModel): - """Request schema for getting user history.""" - - email: EmailStr = Field(..., description="User's email") - - class TripHistoryEntry(BaseModel): - """Schema for a single trip history entry.""" - date: datetime = Field(..., description="Trip date and time") score: int = Field(..., description="Points earned from this trip") diff --git a/tests/integration/test_login.py b/tests/integration/test_login.py index fe3c32f..c901d03 100644 --- a/tests/integration/test_login.py +++ b/tests/integration/test_login.py @@ -67,6 +67,7 @@ async def test_create_account_short_password_fails( response = await client.post("/users/register", json=invalid_data) assert response.status_code == 422 # Validation error + class TestUserLogin: @pytest.mark.asyncio async def test_login_should_work( @@ -182,43 +183,3 @@ async def test_get_user_info_invalid_token_fails( ) assert response.status_code == 401 - - -class TestGetOtherUserInfo: - @pytest.mark.asyncio - async def test_get_other_user_info_should_work( - self, - client: AsyncClient, - ) -> None: - user_data = { - "name": "Test User", - "email": "test@example.com", - "password": "securepassword123", - } - second_user_data = { - "name": "Second User", - "email": "second@example.com", - "password": "anotherpassword123", - } - await client.post("/users/register", json=user_data) - - await client.post("/users/register", json=second_user_data) - - response = await client.get(f"/users/{user_data['email']}") - - assert response.status_code == 200 - data = response.json() - - assert data["name"] == user_data["name"] - assert data["email"] == user_data["email"] - assert data["score"] == 0 - - @pytest.mark.asyncio - async def test_get_nonexistent_user_info_fails( - self, - client: AsyncClient, - ) -> None: - response = await client.get("/users/nonexistent@example.com") - - assert response.status_code == 404 - assert "User not found" in response.json()["detail"] diff --git a/tests/integration/test_ranking.py b/tests/integration/test_ranking.py index f179047..48f683b 100644 --- a/tests/integration/test_ranking.py +++ b/tests/integration/test_ranking.py @@ -22,10 +22,8 @@ async def test_get_user_rank_position_should_work( } auth = await create_user_and_login(client, user_data) - request_data = {"email": user_data["email"]} - response = await client.post( + response = await client.get( "/rank/user", - json=request_data, headers=auth["headers"], ) @@ -53,7 +51,6 @@ async def test_get_user_rank_position_with_multiple_users( auth = await create_user_and_login(client, user_data) trip_data = CreateTripRequest( - email=user_data["email"], route=RouteIdentifierSchema(bus_line="8000", bus_direction=1), distance=0, data=datetime.now(UTC), @@ -62,10 +59,8 @@ async def test_get_user_rank_position_with_multiple_users( "/trips/", json=trip_data.model_dump(mode="json"), headers=auth["headers"] ) - request_data = {"email": user_data["email"]} - response = await client.post( + response = await client.get( "/rank/user", - json=request_data, headers=auth["headers"], ) @@ -78,13 +73,7 @@ async def test_get_user_rank_position_without_auth_fails( self, client: AsyncClient, ) -> None: - user_data = { - "name": "Test User", - "email": "test@example.com", - "password": "securepassword123", - } - request_data = {"email": user_data["email"]} - response = await client.post("/rank/user", json=request_data) + response = await client.get("/rank/user") assert response.status_code == 401 @@ -100,16 +89,23 @@ async def test_get_global_ranking_should_work( await create_test_user_in_db(test_db, "second@example.com", score=500) await create_test_user_in_db(test_db, "third@example.com", score=100) - response = await client.get("/rank/global") + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + auth = await create_user_and_login(client, user_data) + + response = await client.get("/rank/global", headers=auth["headers"]) assert response.status_code == 200 data = response.json() assert "users" in data - assert len(data["users"]) == 3 + assert len(data["users"]) == 4 scores = [user["score"] for user in data["users"]] - assert scores == [1000, 500, 100] + assert scores == [1000, 500, 100, 0] first_user = data["users"][0] assert "name" in first_user @@ -123,13 +119,20 @@ async def test_get_global_ranking_returns_empty_when_no_users( client: AsyncClient, test_db: AsyncSession, ) -> None: - response = await client.get("/rank/global") + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + auth = await create_user_and_login(client, user_data) + + response = await client.get("/rank/global", headers=auth["headers"]) assert response.status_code == 200 data = response.json() assert "users" in data - assert len(data["users"]) == 0 + assert len(data["users"]) == 1 @pytest.mark.asyncio async def test_get_global_ranking_with_single_user( @@ -139,11 +142,27 @@ async def test_get_global_ranking_with_single_user( ) -> None: await create_test_user_in_db(test_db, "solo@example.com", score=500) - response = await client.get("/rank/global") + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + auth = await create_user_and_login(client, user_data) + + response = await client.get("/rank/global", headers=auth["headers"]) assert response.status_code == 200 data = response.json() - assert len(data["users"]) == 1 + assert len(data["users"]) == 2 assert data["users"][0]["email"] == "solo@example.com" assert data["users"][0]["score"] == 500 + + @pytest.mark.asyncio + async def test_get_global_ranking_without_auth_fails( + self, + client: AsyncClient, + ) -> None: + response = await client.get("/rank/global") + + assert response.status_code == 401 diff --git a/tests/integration/test_route.py b/tests/integration/test_route.py index f70d7b0..eae7b3e 100644 --- a/tests/integration/test_route.py +++ b/tests/integration/test_route.py @@ -13,6 +13,8 @@ RouteIdentifierSchema, ) +from .conftest import create_user_and_login + class TestRouteDetails: @pytest.mark.asyncio @@ -20,6 +22,13 @@ async def test_get_route_details_returns_successfully( self, client: AsyncClient, ) -> None: + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + auth = await create_user_and_login(client, user_data) + mock_bus_routes = [ BusRoute( route_id=12345, @@ -45,7 +54,11 @@ async def test_get_route_details_returns_successfully( ] ) - response = await client.post("/routes/details", json=request_data.model_dump()) + response = await client.post( + "/routes/details", + json=request_data.model_dump(), + headers=auth["headers"], + ) assert response.status_code == 200 data = response.json() @@ -65,6 +78,13 @@ async def test_get_route_details_with_multiple_lines( self, client: AsyncClient, ) -> None: + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + auth = await create_user_and_login(client, user_data) + mock_routes_8000 = [ BusRoute( route_id=12345, @@ -97,7 +117,11 @@ async def test_get_route_details_with_multiple_lines( ] ) - response = await client.post("/routes/details", json=request_data.model_dump()) + response = await client.post( + "/routes/details", + json=request_data.model_dump(), + headers=auth["headers"], + ) assert response.status_code == 200 data = response.json() @@ -112,6 +136,13 @@ async def test_get_route_details_returns_empty_for_unknown_line( self, client: AsyncClient, ) -> None: + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + auth = await create_user_and_login(client, user_data) + with ( patch( "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", @@ -130,7 +161,11 @@ async def test_get_route_details_returns_empty_for_unknown_line( ] ) - response = await client.post("/routes/details", json=request_data.model_dump()) + response = await client.post( + "/routes/details", + json=request_data.model_dump(), + headers=auth["headers"], + ) assert response.status_code == 200 data = response.json() @@ -141,6 +176,13 @@ async def test_get_route_details_returns_500_on_api_error( self, client: AsyncClient, ) -> None: + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + auth = await create_user_and_login(client, user_data) + with ( patch( "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", @@ -159,7 +201,11 @@ async def test_get_route_details_returns_500_on_api_error( ] ) - response = await client.post("/routes/details", json=request_data.model_dump()) + response = await client.post( + "/routes/details", + json=request_data.model_dump(), + headers=auth["headers"], + ) assert response.status_code == 500 assert "Failed to retrieve route details" in response.json()["detail"] @@ -169,6 +215,13 @@ async def test_get_route_details_with_empty_routes_list( self, client: AsyncClient, ) -> None: + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + auth = await create_user_and_login(client, user_data) + with patch( "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", new_callable=AsyncMock, @@ -176,11 +229,30 @@ async def test_get_route_details_with_empty_routes_list( ): request_data = BusRoutesDetailsRequest(routes=[]) - response = await client.post("/routes/details", json=request_data.model_dump()) + response = await client.post( + "/routes/details", + json=request_data.model_dump(), + headers=auth["headers"], + ) assert response.status_code == 200 assert response.json()["routes"] == [] + @pytest.mark.asyncio + async def test_get_route_details_without_auth_fails( + self, + client: AsyncClient, + ) -> None: + request_data = BusRoutesDetailsRequest( + routes=[ + RouteIdentifierSchema(bus_line="8000", bus_direction=1), + ] + ) + + response = await client.post("/routes/details", json=request_data.model_dump()) + + assert response.status_code == 401 + class TestBusPositions: @pytest.mark.asyncio @@ -188,6 +260,13 @@ async def test_get_bus_position_returns_successfully( self, client: AsyncClient, ) -> None: + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + auth = await create_user_and_login(client, user_data) + mock_positions = [ BusPosition( route=RouteIdentifier(bus_line="8000", bus_direction=1), @@ -222,7 +301,11 @@ async def test_get_bus_position_returns_successfully( ] ) - response = await client.post("/routes/positions", json=request_data.model_dump()) + response = await client.post( + "/routes/positions", + json=request_data.model_dump(), + headers=auth["headers"], + ) assert response.status_code == 200 data = response.json() @@ -244,6 +327,13 @@ async def test_get_bus_position_returns_500_when_error( self, client: AsyncClient, ) -> None: + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + auth = await create_user_and_login(client, user_data) + with ( patch( "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", @@ -265,7 +355,11 @@ async def test_get_bus_position_returns_500_when_error( ] ) - response = await client.post("/routes/positions", json=request_data.model_dump()) + response = await client.post( + "/routes/positions", + json=request_data.model_dump(), + headers=auth["headers"], + ) assert response.status_code == 500 assert "Failed to retrieve bus positions" in response.json()["detail"] @@ -275,6 +369,13 @@ async def test_get_bus_position_returns_empty_when_no_buses_on_line( self, client: AsyncClient, ) -> None: + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + auth = await create_user_and_login(client, user_data) + with ( patch( "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", @@ -296,7 +397,11 @@ async def test_get_bus_position_returns_empty_when_no_buses_on_line( ] ) - response = await client.post("/routes/positions", json=request_data.model_dump()) + response = await client.post( + "/routes/positions", + json=request_data.model_dump(), + headers=auth["headers"], + ) assert response.status_code == 200 data = response.json() @@ -309,6 +414,13 @@ async def test_get_bus_position_works_with_multiple_routes( self, client: AsyncClient, ) -> None: + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + auth = await create_user_and_login(client, user_data) + mock_positions_8000 = [ BusPosition( route=RouteIdentifier(bus_line="8000", bus_direction=1), @@ -349,7 +461,11 @@ async def test_get_bus_position_works_with_multiple_routes( ] ) - response = await client.post("/routes/positions", json=request_data.model_dump()) + response = await client.post( + "/routes/positions", + json=request_data.model_dump(), + headers=auth["headers"], + ) assert response.status_code == 200 data = response.json() @@ -364,6 +480,13 @@ async def test_get_bus_position_returns_500_when_authentication_failure( self, client: AsyncClient, ) -> None: + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + auth = await create_user_and_login(client, user_data) + with patch( "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", new_callable=AsyncMock, @@ -378,7 +501,11 @@ async def test_get_bus_position_returns_500_when_authentication_failure( ] ) - response = await client.post("/routes/positions", json=request_data.model_dump()) + response = await client.post( + "/routes/positions", + json=request_data.model_dump(), + headers=auth["headers"], + ) assert response.status_code == 500 @@ -387,6 +514,13 @@ async def test_get_bus_position_returns_422_when_invalid_data( self, client: AsyncClient, ) -> None: + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + auth = await create_user_and_login(client, user_data) + request_data = { "routes": [ { @@ -396,15 +530,26 @@ async def test_get_bus_position_returns_422_when_invalid_data( ] } - response = await client.post("/routes/positions", json=request_data) + response = await client.post( + "/routes/positions", + json=request_data, + headers=auth["headers"], + ) - assert response.status_code == 422 # Validation error + assert response.status_code == 422 @pytest.mark.asyncio async def test_get_bus_position_returns_successfully_with_empty_routes_list( self, client: AsyncClient, ) -> None: + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + auth = await create_user_and_login(client, user_data) + with ( patch( "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", @@ -419,25 +564,50 @@ async def test_get_bus_position_returns_successfully_with_empty_routes_list( ): request_data = BusPositionsRequest(routes=[]) - response = await client.post("/routes/positions", json=request_data.model_dump()) + response = await client.post( + "/routes/positions", + json=request_data.model_dump(), + headers=auth["headers"], + ) assert response.status_code == 200 assert response.json()["buses"] == [] + @pytest.mark.asyncio + async def test_get_bus_position_without_auth_fails( + self, + client: AsyncClient, + ) -> None: + request_data = BusPositionsRequest( + routes=[ + BusRouteSchema( + route_id=12345, + route=RouteIdentifierSchema(bus_line="8000", bus_direction=1), + ), + ] + ) + + response = await client.post( + "/routes/positions", json=request_data.model_dump() + ) -class TestRouteShape: - """Tests for GET /routes/shape/{route_id} endpoint. + assert response.status_code == 401 - These tests use the actual GTFS database (gtfs.db) since - GTFSRepositoryAdapter is a local database adapter, not an external service. - """ +class TestRouteShape: @pytest.mark.asyncio async def test_get_route_shape_returns_successfully( self, client: AsyncClient, ) -> None: - response = await client.get("/routes/shape/1012-10") + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + auth = await create_user_and_login(client, user_data) + + response = await client.get("/routes/shape/1012-10", headers=auth["headers"]) assert response.status_code == 200 data = response.json() @@ -458,7 +628,17 @@ async def test_get_route_shape_returns_404_when_not_found( self, client: AsyncClient, ) -> None: - response = await client.get("/routes/shape/NONEXISTENT-ROUTE-12345") + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + auth = await create_user_and_login(client, user_data) + + response = await client.get( + "/routes/shape/NONEXISTENT-ROUTE-12345", + headers=auth["headers"], + ) assert response.status_code == 404 assert "not found" in response.json()["detail"].lower() @@ -468,14 +648,27 @@ async def test_get_route_shape_points_have_valid_coordinates( self, client: AsyncClient, ) -> None: - response = await client.get("/routes/shape/1012-10") + user_data = { + "name": "Test User", + "email": "test@example.com", + "password": "securepassword123", + } + auth = await create_user_and_login(client, user_data) + + response = await client.get("/routes/shape/1012-10", headers=auth["headers"]) assert response.status_code == 200 points = response.json()["points"] - # São Paulo approximate bounding box for point in points: - # Latitude should be around -23 to -24 for São Paulo assert -25 <= point["latitude"] <= -22 - # Longitude should be around -46 to -47 for São Paulo assert -48 <= point["longitude"] <= -45 + + @pytest.mark.asyncio + async def test_get_route_shape_without_auth_fails( + self, + client: AsyncClient, + ) -> None: + response = await client.get("/routes/shape/1012-10") + + assert response.status_code == 401 diff --git a/tests/integration/test_trip.py b/tests/integration/test_trip.py index 3ee7c02..6f30e28 100644 --- a/tests/integration/test_trip.py +++ b/tests/integration/test_trip.py @@ -26,7 +26,6 @@ async def test_create_trip_should_return_successfully_and_save_to_database( auth = await create_user_and_login(client, user_data) trip_data = CreateTripRequest( - email=user_data["email"], route=RouteIdentifierSchema( bus_line="8000", bus_direction=1, @@ -46,7 +45,9 @@ async def test_create_trip_should_return_successfully_and_save_to_database( assert "score" in data - result = await test_db.execute(select(TripDB).where(TripDB.email == user_data["email"])) + result = await test_db.execute( + select(TripDB).where(TripDB.email == user_data["email"]) + ) trip = result.scalar_one_or_none() assert trip is not None @@ -69,7 +70,6 @@ async def test_create_trip_updates_user_score( auth = await create_user_and_login(client, user_data) trip_data = CreateTripRequest( - email=user_data["email"], route=RouteIdentifierSchema( bus_line="8000", bus_direction=1, @@ -78,7 +78,6 @@ async def test_create_trip_updates_user_score( data=datetime.now(UTC), ) second_trip_data = CreateTripRequest( - email=user_data["email"], route=RouteIdentifierSchema( bus_line="8000", bus_direction=1, @@ -114,13 +113,7 @@ async def test_create_trip_without_authentication_fails( self, client: AsyncClient, ) -> None: - user_data = { - "name": "Test User", - "email": "test@example.com", - "password": "securepassword123", - } trip_data = CreateTripRequest( - email=user_data["email"], route=RouteIdentifierSchema( bus_line="8000", bus_direction=1, @@ -146,7 +139,6 @@ async def test_create_trip_zero_distance( auth = await create_user_and_login(client, user_data) trip_data = CreateTripRequest( - email=user_data["email"], route=RouteIdentifierSchema( bus_line="9000", bus_direction=2, @@ -177,9 +169,7 @@ async def test_create_trip_negative_distance_fails( } auth = await create_user_and_login(client, user_data) - # Use raw dict to bypass Pydantic validation and test API validation trip_data = { - "email": user_data["email"], "route": { "bus_line": "8000", "bus_direction": 1, @@ -196,7 +186,9 @@ async def test_create_trip_negative_distance_fails( assert response.status_code == 422 - result = await test_db.execute(select(TripDB).where(TripDB.email == user_data["email"])) + result = await test_db.execute( + select(TripDB).where(TripDB.email == user_data["email"]) + ) trip = result.scalar_one_or_none() assert trip is None @@ -213,9 +205,7 @@ async def test_create_trip_invalid_route_identifier_fails( } auth = await create_user_and_login(client, user_data) - # Use raw dict to bypass Pydantic validation and test API validation trip_data = { - "email": user_data["email"], "route": { "bus_line": "8000", "bus_direction": 3, @@ -232,6 +222,8 @@ async def test_create_trip_invalid_route_identifier_fails( assert response.status_code == 422 - result = await test_db.execute(select(TripDB).where(TripDB.email == user_data["email"])) + result = await test_db.execute( + select(TripDB).where(TripDB.email == user_data["email"]) + ) trip = result.scalar_one_or_none() assert trip is None diff --git a/tests/integration/test_user_history.py b/tests/integration/test_user_history.py index 33b0de5..0c921ad 100644 --- a/tests/integration/test_user_history.py +++ b/tests/integration/test_user_history.py @@ -5,7 +5,6 @@ from src.web.schemas import ( CreateTripRequest, - HistoryRequest, HistoryResponse, RouteIdentifierSchema, ) @@ -35,7 +34,6 @@ async def test_get_user_history_should_work( scores: list[int] = [] for i, trip_date in enumerate(trip_dates): trip_request = CreateTripRequest( - email=user_data["email"], route=RouteIdentifierSchema(bus_line=f"800{i}", bus_direction=1), distance=(i + 1) * 1000, data=trip_date, @@ -47,10 +45,8 @@ async def test_get_user_history_should_work( ) scores.append(response.json()["score"]) - history_request = HistoryRequest(email=user_data["email"]) - response = await client.post( + response = await client.get( "/history/", - json=history_request.model_dump(), headers=auth["headers"], ) @@ -78,11 +74,8 @@ async def test_get_user_history_returns_empty_when_no_trips( auth = await create_user_and_login(client, user_data) - history_request = HistoryRequest(email=user_data["email"]) - - response = await client.post( + response = await client.get( "/history/", - json=history_request.model_dump(), headers=auth["headers"], ) @@ -96,8 +89,7 @@ async def test_get_user_history_without_authentication_fails( self, client: AsyncClient, ) -> None: - history_request = HistoryRequest(email="test@example.com") - response = await client.post("/history/", json=history_request.model_dump()) + response = await client.get("/history/") assert response.status_code == 401 @@ -116,7 +108,6 @@ async def test_get_history_includes_correct_dates( specific_date: datetime = datetime(2025, 6, 15, 10, 30, 0, tzinfo=UTC) trip_request = CreateTripRequest( - email=user_data["email"], route=RouteIdentifierSchema(bus_line="8000", bus_direction=1), distance=1000, data=specific_date, @@ -127,10 +118,8 @@ async def test_get_history_includes_correct_dates( headers=auth["headers"], ) - history_request = HistoryRequest(email=user_data["email"]) - response = await client.post( + response = await client.get( "/history/", - json=history_request.model_dump(), headers=auth["headers"], ) diff --git a/tests/web/test_route_controller.py b/tests/web/test_route_controller.py index 2cac40a..28eca28 100644 --- a/tests/web/test_route_controller.py +++ b/tests/web/test_route_controller.py @@ -9,8 +9,10 @@ from src.core.models.bus import BusPosition, BusRoute, RouteIdentifier from src.core.models.coordinate import Coordinate +from src.core.models.user import User from src.core.services.route_service import RouteService from src.main import app +from src.web.auth import get_current_user from src.web.controllers.route_controller import get_route_service @@ -26,20 +28,23 @@ def mock_service() -> RouteService: mas com métodos assíncronos (AsyncMock). """ service = AsyncMock(spec=RouteService) - # Cast to RouteService to satisfy type checker typed_service: RouteService = service - # Set up return values with proper types - these are AsyncMock instances typed_service.get_route_details = AsyncMock() # type: ignore[method-assign] typed_service.get_bus_positions = AsyncMock() # type: ignore[method-assign] return typed_service +@pytest.fixture +def mock_current_user() -> User: + return User(name="Test User", email="test@example.com", score=0) + + @pytest.fixture(autouse=True) -def override_dependency(mock_service: RouteService) -> Generator[None, None, None]: - """ - Override da dependência get_route_service para usar o mock. - """ +def override_dependency( + mock_service: RouteService, mock_current_user: User +) -> Generator[None, None, None]: app.dependency_overrides[get_route_service] = lambda: mock_service + app.dependency_overrides[get_current_user] = lambda: mock_current_user yield app.dependency_overrides.clear() @@ -50,7 +55,9 @@ def override_dependency(mock_service: RouteService) -> Generator[None, None, Non @pytest.mark.asyncio -async def test_details_endpoint_success(client: TestClient, mock_service: RouteService) -> None: +async def test_details_endpoint_success( + client: TestClient, mock_service: RouteService +) -> None: """ Testa o endpoint POST /routes/details garantindo que: - Ele chama RouteService.get_route_details() @@ -125,7 +132,9 @@ async def test_details_endpoint_error_returns_500( @pytest.mark.asyncio -async def test_positions_endpoint_success(client: TestClient, mock_service: RouteService) -> None: +async def test_positions_endpoint_success( + client: TestClient, mock_service: RouteService +) -> None: """ Testa o endpoint POST /routes/positions garantindo que: - Ele chama RouteService.get_bus_positions() From 9d809280cfa5accd2ba5ac79769c1a27ab304997 Mon Sep 17 00:00:00 2001 From: Kim Kakeya Date: Mon, 1 Dec 2025 21:40:50 -0300 Subject: [PATCH 10/12] style: ruff styling --- src/web/controllers/history_controller.py | 3 +-- src/web/controllers/route_controller.py | 7 ++----- src/web/controllers/trip_controller.py | 4 +--- src/web/controllers/user_controller.py | 4 +--- src/web/schemas.py | 4 +--- tests/integration/test_route.py | 4 +--- tests/integration/test_trip.py | 12 +++--------- tests/web/test_route_controller.py | 8 ++------ 8 files changed, 12 insertions(+), 34 deletions(-) diff --git a/src/web/controllers/history_controller.py b/src/web/controllers/history_controller.py index f541171..3be6af7 100644 --- a/src/web/controllers/history_controller.py +++ b/src/web/controllers/history_controller.py @@ -46,8 +46,7 @@ async def get_user_history( # Combine dates and scores into trip entries (returns empty list if no history) trips = [ - TripHistoryEntry(date=date, score=score) - for date, score in zip(dates, scores, strict=False) + TripHistoryEntry(date=date, score=score) for date, score in zip(dates, scores, strict=False) ] return HistoryResponse(trips=trips) diff --git a/src/web/controllers/route_controller.py b/src/web/controllers/route_controller.py index 6418a01..92de9c5 100644 --- a/src/web/controllers/route_controller.py +++ b/src/web/controllers/route_controller.py @@ -62,8 +62,7 @@ async def get_route_details_endpoint( try: # Schemas -> domínio (RouteIdentifier) route_identifiers: list[RouteIdentifier] = [ - map_route_identifier_schema_to_domain(route_schema) - for route_schema in request.routes + map_route_identifier_schema_to_domain(route_schema) for route_schema in request.routes ] bus_routes: list[BusRoute] = [] @@ -124,9 +123,7 @@ async def get_bus_positions( route=route_identifier, ) - route_positions: list[BusPosition] = await route_service.get_bus_positions( - bus_route - ) + route_positions: list[BusPosition] = await route_service.get_bus_positions(bus_route) all_positions.extend(route_positions) # Domínio -> schemas diff --git a/src/web/controllers/trip_controller.py b/src/web/controllers/trip_controller.py index 5e955fb..b45365a 100644 --- a/src/web/controllers/trip_controller.py +++ b/src/web/controllers/trip_controller.py @@ -33,9 +33,7 @@ def get_trip_service(db: AsyncSession = Depends(get_db)) -> TripService: return TripService(trip_repo, user_repo) -@router.post( - "/", response_model=CreateTripResponse, status_code=status.HTTP_201_CREATED -) +@router.post("/", response_model=CreateTripResponse, status_code=status.HTTP_201_CREATED) async def create_trip( request: CreateTripRequest, trip_service: TripService = Depends(get_trip_service), diff --git a/src/web/controllers/user_controller.py b/src/web/controllers/user_controller.py index 48815df..ba3ed99 100644 --- a/src/web/controllers/user_controller.py +++ b/src/web/controllers/user_controller.py @@ -22,9 +22,7 @@ router = APIRouter(prefix="/users", tags=["users"]) -@router.post( - "/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED -) +@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED) async def create_user( request: UserCreateAccountRequest, user_service: UserService = Depends(get_user_service), diff --git a/src/web/schemas.py b/src/web/schemas.py index 73a517e..5b10573 100644 --- a/src/web/schemas.py +++ b/src/web/schemas.py @@ -140,9 +140,7 @@ class RouteShapeResponse(BaseModel): route_id: str = Field(..., description="Route identifier") shape_id: str = Field(..., description="GTFS shape identifier") - points: list[CoordinateSchema] = Field( - ..., description="Ordered list of coordinates" - ) + points: list[CoordinateSchema] = Field(..., description="Ordered list of coordinates") # ===== Ranking Schemas ===== diff --git a/tests/integration/test_route.py b/tests/integration/test_route.py index eae7b3e..c168893 100644 --- a/tests/integration/test_route.py +++ b/tests/integration/test_route.py @@ -587,9 +587,7 @@ async def test_get_bus_position_without_auth_fails( ] ) - response = await client.post( - "/routes/positions", json=request_data.model_dump() - ) + response = await client.post("/routes/positions", json=request_data.model_dump()) assert response.status_code == 401 diff --git a/tests/integration/test_trip.py b/tests/integration/test_trip.py index 6f30e28..dcc0bf9 100644 --- a/tests/integration/test_trip.py +++ b/tests/integration/test_trip.py @@ -45,9 +45,7 @@ async def test_create_trip_should_return_successfully_and_save_to_database( assert "score" in data - result = await test_db.execute( - select(TripDB).where(TripDB.email == user_data["email"]) - ) + result = await test_db.execute(select(TripDB).where(TripDB.email == user_data["email"])) trip = result.scalar_one_or_none() assert trip is not None @@ -186,9 +184,7 @@ async def test_create_trip_negative_distance_fails( assert response.status_code == 422 - result = await test_db.execute( - select(TripDB).where(TripDB.email == user_data["email"]) - ) + result = await test_db.execute(select(TripDB).where(TripDB.email == user_data["email"])) trip = result.scalar_one_or_none() assert trip is None @@ -222,8 +218,6 @@ async def test_create_trip_invalid_route_identifier_fails( assert response.status_code == 422 - result = await test_db.execute( - select(TripDB).where(TripDB.email == user_data["email"]) - ) + result = await test_db.execute(select(TripDB).where(TripDB.email == user_data["email"])) trip = result.scalar_one_or_none() assert trip is None diff --git a/tests/web/test_route_controller.py b/tests/web/test_route_controller.py index 28eca28..b6a9b52 100644 --- a/tests/web/test_route_controller.py +++ b/tests/web/test_route_controller.py @@ -55,9 +55,7 @@ def override_dependency( @pytest.mark.asyncio -async def test_details_endpoint_success( - client: TestClient, mock_service: RouteService -) -> None: +async def test_details_endpoint_success(client: TestClient, mock_service: RouteService) -> None: """ Testa o endpoint POST /routes/details garantindo que: - Ele chama RouteService.get_route_details() @@ -132,9 +130,7 @@ async def test_details_endpoint_error_returns_500( @pytest.mark.asyncio -async def test_positions_endpoint_success( - client: TestClient, mock_service: RouteService -) -> None: +async def test_positions_endpoint_success(client: TestClient, mock_service: RouteService) -> None: """ Testa o endpoint POST /routes/positions garantindo que: - Ele chama RouteService.get_bus_positions() From 16a63b2cde8594d367199353829d6018259a554c Mon Sep 17 00:00:00 2001 From: Kim Kakeya Date: Wed, 3 Dec 2025 20:04:18 -0300 Subject: [PATCH 11/12] feat: update sptrans_adapter - Automatically apply auth when not authenticated with provider - Better use of pydantic objects - Search endpoint more explicitly used for search --- src/adapters/external/models/LineInfo.py | 11 - .../external/models/SPTransPosResp.py | 14 - src/adapters/external/sptrans_adapter.py | 207 +++++------ src/adapters/external/sptrans_mappers.py | 104 ++++++ src/adapters/external/sptrans_schemas.py | 40 +-- .../repositories/gtfs_repository_adapter.py | 2 +- src/core/models/bus.py | 23 +- src/core/models/coordinate.py | 7 +- src/core/ports/bus_provider_port.py | 33 +- src/core/services/route_service.py | 47 ++- src/web/controllers/route_controller.py | 118 +++--- src/web/mappers.py | 69 +++- src/web/schemas.py | 22 +- tests/adapters/test_sptrans_adapter.py | 115 +++--- tests/core/test_route_service.py | 128 +++---- tests/integration/test_route.py | 256 ++++--------- tests/web/test_route_controller.py | 336 ++++++++++-------- 17 files changed, 766 insertions(+), 766 deletions(-) delete mode 100644 src/adapters/external/models/LineInfo.py delete mode 100644 src/adapters/external/models/SPTransPosResp.py create mode 100644 src/adapters/external/sptrans_mappers.py diff --git a/src/adapters/external/models/LineInfo.py b/src/adapters/external/models/LineInfo.py deleted file mode 100644 index 2bd46fa..0000000 --- a/src/adapters/external/models/LineInfo.py +++ /dev/null @@ -1,11 +0,0 @@ -from typing import TypedDict - - -class LineInfo(TypedDict): - cl: int # código da linha (código interno SPTrans) - lc: bool # sentido preferencial (circular / comum) - lt: str # número da linha (ex: "8000") - sl: int # sentido (0: ida, 1: volta) - tl: int # tipo da linha (10 = urbana, etc.) - tp: str # ponto inicial - ts: str # ponto final diff --git a/src/adapters/external/models/SPTransPosResp.py b/src/adapters/external/models/SPTransPosResp.py deleted file mode 100644 index c787065..0000000 --- a/src/adapters/external/models/SPTransPosResp.py +++ /dev/null @@ -1,14 +0,0 @@ -from typing import TypedDict - - -class Vehicle(TypedDict): - p: str # prefixo do ônibus - a: bool # acessibilidade - ta: str # timestamp ISO - py: float # latitude - px: float # longitude - - -class SPTransPositionsResponse(TypedDict): - hr: str # horário da resposta - vs: list[Vehicle] # lista de veículos diff --git a/src/adapters/external/sptrans_adapter.py b/src/adapters/external/sptrans_adapter.py index b81adfc..12df761 100644 --- a/src/adapters/external/sptrans_adapter.py +++ b/src/adapters/external/sptrans_adapter.py @@ -5,18 +5,18 @@ with the SPTrans API. """ -from datetime import datetime - import httpx from httpx import Response from src.config import settings -from ...adapters.external.models.LineInfo import LineInfo -from ...adapters.external.models.SPTransPosResp import SPTransPositionsResponse, Vehicle -from ...core.models.bus import BusPosition, BusRoute, RouteIdentifier -from ...core.models.coordinate import Coordinate +from ...core.models.bus import BusPosition, BusRoute from ...core.ports.bus_provider_port import BusProviderPort +from .sptrans_mappers import ( + map_positions_response_to_bus_positions, + map_search_response_to_bus_route_list, +) +from .sptrans_schemas import SPTransLineSearchResponse, SPTransPositionsResponse class SpTransAdapter(BusProviderPort): @@ -29,11 +29,12 @@ def __init__( Initialize the SPTrans adapter. Args: - api_token: API authentication token (optional) - base_url: Base URL for the SPTrans API (optional) + api_token: API authentication token (optional, defaults to settings) + base_url: Base URL for the SPTrans API (optional, defaults to settings) + + Raises: + ValueError: If no API token is provided or configured. """ - # se o parâmetro foi passado, usa ele; - # senão, usa o valor do settings carregado do .env self.api_token = api_token or settings.sptrans_api_token self.base_url = base_url or settings.sptrans_base_url @@ -42,15 +43,27 @@ def __init__( "SPTransAdapter: nenhum token fornecido e SPTRANS_API_TOKEN não definido no ambiente/.env." ) - self.session_token: str | None = None + self._authenticated: bool = False self.client = httpx.AsyncClient(base_url=self.base_url, timeout=30.0) - async def authenticate(self) -> bool: + async def _ensure_authenticated(self) -> None: + """ + Ensure the client is authenticated, authenticating if necessary. + + Raises: + RuntimeError: If authentication fails. + """ + if not self._authenticated: + success = await self._authenticate() + if not success: + raise RuntimeError("SPTrans authentication failed") + + async def _authenticate(self) -> bool: """ Authenticate with the SPTrans API. Returns: - True if authentication successful, False otherwise + True if authentication successful, False otherwise. """ try: response: Response = await self.client.post( @@ -59,131 +72,119 @@ async def authenticate(self) -> bool: ) if response.status_code == 200 and response.text == "true": - # marca que autenticou com sucesso - self.session_token = "authenticated" + self._authenticated = True return True return False - except Exception as e: - exc: Exception = e - print(f"Authentication failed: {exc}") + except Exception: return False - async def get_bus_positions(self, bus_route: BusRoute) -> list[BusPosition]: + def _is_unauthorized_response(self, response: Response) -> bool: """ - Get real-time positions for a specific route using SPTrans data. - - Pré-condição: - O método `authenticate()` deve ter sido chamado com sucesso antes de - usar este método. + Check if response indicates authentication is required. Args: - bus_route: BusRoute(route_id=cl, route=RouteIdentifier) + response: HTTP response to check. Returns: - List of BusPosition objects. + True if response indicates unauthorized access. """ - # Checagem defensiva opcional - if getattr(self, "session_token", None) != "authenticated": - raise RuntimeError("SPTrans client not authenticated. Call `authenticate()` first.") - - positions: list[BusPosition] = [] + if response.status_code != 401: + return False try: - line_code: int = bus_route.route_id + data = response.json() + return "Authorization has been denied" in data.get("Message", "") + except Exception: + return False - response: Response = await self.client.get( - "/Posicao/Linha", - params={"codigoLinha": line_code}, - ) + async def _request_with_auth_retry( + self, + method: str, + url: str, + params: dict[str, str | int] | None = None, + ) -> Response: + """ + Make an HTTP request with automatic authentication retry on 401. - if response.status_code != 200: - raise RuntimeError( - f"SPTrans returned status {response.status_code} for line {bus_route}" - ) + Args: + method: HTTP method (GET, POST, etc.) + url: Request URL path. + params: Query parameters for the request. + + Returns: + HTTP response. - response_data: SPTransPositionsResponse = response.json() + Raises: + RuntimeError: If authentication fails after retry. + """ + await self._ensure_authenticated() - vehicles: list[Vehicle] = response_data["vs"] + response = await self.client.request(method, url, params=params) - for vehicle in vehicles: - pos: BusPosition = BusPosition( - route=bus_route.route, - position=Coordinate( - latitude=vehicle["py"] / 1_000_000, - longitude=vehicle["px"] / 1_000_000, - ), - time_updated=datetime.fromisoformat(vehicle["ta"]), - ) - positions.append(pos) + if self._is_unauthorized_response(response): + self._authenticated = False + await self._ensure_authenticated() + response = await self.client.request(method, url, params=params) - except Exception as e: - exc: Exception = e - print(f"Failed to get positions for bus_route {bus_route}: {exc}") + if self._is_unauthorized_response(response): + raise RuntimeError("SPTrans authentication failed after retry") - return positions + return response - async def get_route_details(self, route: RouteIdentifier) -> list[BusRoute]: + async def get_bus_positions(self, routes: list[BusRoute]) -> list[BusPosition]: """ - Resolve a logical bus line (bus_line) into all SPTrans BusRoute entries using - the `/Linha/Buscar` endpoint. + Get real-time positions for specified routes. Args: - route (RouteIdentifier): Logical bus line (ex: "8000") + routes: List of BusRoute objects containing route info. Returns: - list[BusRoute]: Todas as variantes da linha retornadas pela SPTrans. - - Raises: - RuntimeError: Se a requisição falhar, vier vazia ou inválida. + List of BusPosition objects with route identifier and coordinates. """ + positions: list[BusPosition] = [] - # Verifica se está autenticado - if getattr(self, "session_token", None) != "authenticated": - raise RuntimeError("SPTrans client not authenticated. Call `authenticate()` first.") - - try: - response: Response = await self.client.get( - "/Linha/Buscar", - params={"termosBusca": route.bus_line}, + for bus_route in routes: + response = await self._request_with_auth_retry( + "GET", + "/Posicao/Linha", + params={"codigoLinha": bus_route.route_id}, ) if response.status_code != 200: - raise RuntimeError( - f"SPTrans returned status {response.status_code} for line search." - ) + continue - data: list[LineInfo] = response.json() - - if not isinstance(data, list) or len(data) == 0: - raise RuntimeError(f"No SPTrans line found for line={route.bus_line}") - - bus_routes: list[BusRoute] = [] - - for item in data: - # Validate based on TypedDict keys - if "cl" not in item or "lt" not in item: - continue # Skip invalid entries + response_data = SPTransPositionsResponse.model_validate(response.json()) + route_positions = map_positions_response_to_bus_positions( + response_data, + bus_route.route, + ) + positions.extend(route_positions) - line_code = item["cl"] - line_text = item["lt"] - line_dir = item["sl"] + return positions - bus_routes.append( - BusRoute( - route_id=line_code, - route=RouteIdentifier(bus_line=line_text, bus_direction=line_dir), - ) - ) + async def search_routes(self, query: str) -> list[BusRoute]: + """ + Search for bus routes matching a query string. - if not bus_routes: - raise RuntimeError( - f"Invalid SPTrans response for line={route.bus_line}: " - "missing required fields" - ) + Args: + query: Search term (e.g., "809" or "Vila Nova Conceição"). - return bus_routes + Returns: + List of matching BusRoute objects. + """ + response = await self._request_with_auth_retry( + "GET", + "/Linha/Buscar", + params={"termosBusca": query}, + ) + + if response.status_code != 200: + raise RuntimeError( + f"SPTrans returned status {response.status_code} for search." + ) - except Exception as e: - raise RuntimeError(f"Failed to resolve route details for {route}: {e}") from e + response_data = SPTransLineSearchResponse.model_validate(response.json()) + bus_routes: list[BusRoute] = map_search_response_to_bus_route_list(response_data) + return bus_routes diff --git a/src/adapters/external/sptrans_mappers.py b/src/adapters/external/sptrans_mappers.py new file mode 100644 index 0000000..2c66ff3 --- /dev/null +++ b/src/adapters/external/sptrans_mappers.py @@ -0,0 +1,104 @@ +""" +Mappers for converting SPTrans API schemas to domain models. + +These mappers handle the transformation between external API DTOs +and internal domain objects, keeping the adapter layer clean. +""" + +from ...core.models.bus import BusPosition, BusRoute, RouteIdentifier +from ...core.models.coordinate import Coordinate +from .sptrans_schemas import SPTransLineInfo, SPTransLineSearchResponse, SPTransPositionsResponse, SPTransVehicle + +def map_search_response_to_bus_route_list( + data: SPTransLineSearchResponse, +) -> list[BusRoute]: + """ + Convert API search response to list of BusRoute domain objects. + + Args: + data: SPTransLineSearchResponse object. + + Returns: + List of BusRoute domain objects. + """ + bus_route_list: list[BusRoute] = [] + for item in data.results: + bus_route_list.append(map_line_info_to_bus_route(item)) + return bus_route_list + +def map_line_info_to_bus_route(line_info: SPTransLineInfo) -> BusRoute: + """ + Convert SPTrans line info to domain BusRoute. + + Args: + line_info: SPTrans line information. + + Returns: + Domain BusRoute with route_id and identifier. + """ + return BusRoute( + route_id=line_info.cl, + route=map_line_info_to_route_identifier(line_info), + ) + +def map_line_info_to_route_identifier(line_info: SPTransLineInfo) -> RouteIdentifier: + """ + Convert SPTrans line info to domain RouteIdentifier. + + Args: + line_info: SPTrans line information. + + Returns: + Domain RouteIdentifier with formatted bus_line (lt-tl format). + """ + bus_line = f"{line_info.lt}-{line_info.tl}" + return RouteIdentifier( + bus_line=bus_line, + bus_direction=line_info.sl, + ) + + +def map_positions_response_to_bus_positions( + data: SPTransPositionsResponse, + route: RouteIdentifier, +) -> list[BusPosition]: + """ + Convert API positions response to list of BusPosition domain objects. + + Args: + data: SPTransPositionsResponse object. + route: RouteIdentifier for all vehicles in this response. + + Returns: + List of domain BusPosition objects. + """ + positions: list[BusPosition] = [] + + for vehicle in data.vs: + position = map_vehicle_to_bus_position(vehicle, route) + positions.append(position) + + return positions + +def map_vehicle_to_bus_position( + vehicle: SPTransVehicle, + route: RouteIdentifier, +) -> BusPosition: + """ + Convert SPTrans vehicle to domain BusPosition. + + Args: + vehicle: SPTransVehicle position data. + route: RouteIdentifier for this vehicle's route. + + Returns: + Domain BusPosition with coordinates and route info. + """ + return BusPosition( + route=route, + position=Coordinate( + latitude=vehicle.py, + longitude=vehicle.px, + ), + time_updated=vehicle.ta, + ) diff --git a/src/adapters/external/sptrans_schemas.py b/src/adapters/external/sptrans_schemas.py index bd8d7e3..f125ce6 100644 --- a/src/adapters/external/sptrans_schemas.py +++ b/src/adapters/external/sptrans_schemas.py @@ -9,33 +9,33 @@ from pydantic import BaseModel, Field +class SPTransLineInfo(BaseModel): + """Schema for SPTrans line information response.""" -class SPTransRouteResponse(BaseModel): - """Schema for SPTrans route response.""" + cl: int = Field(..., description="Route code (internal SPTrans code)") + lc: bool = Field(..., description="Is circular route") + lt: str = Field(..., description="Line number (e.g., '8000')") + sl: int = Field(..., description="Direction (1 = ida, 2 = volta)") + tl: int = Field(..., description="Line type (10 = urban, etc.)") + tp: str = Field(..., description="Primary terminal") + ts: str = Field(..., description="Secondary terminal") - cl: int = Field(..., description="Route code") - lc: bool = Field(..., description="Is circular") - lt: str = Field(..., description="Main direction") - tl: int = Field(..., description="Type") - sl: int = Field(..., description="Secondary type") - tp: str = Field(..., description="Terminal principal") - ts: str = Field(..., description="Terminal secundário") +class SPTransLineSearchResponse(BaseModel): + """Schema for SPTrans line search response item.""" + results: list[SPTransLineInfo] = Field(..., description="List of line info results") +class SPTransVehicle(BaseModel): + """Schema for SPTrans vehicle position.""" -class SPTransVehicleResponse(BaseModel): - """Schema for SPTrans vehicle position response.""" - - p: int = Field(..., description="Route code") + p: str = Field(..., description="Vehicle prefix") a: bool = Field(..., description="Is accessible") ta: datetime = Field(..., description="Time updated") - px: int = Field(..., description="Longitude * 10^6") - py: int = Field(..., description="Latitude * 10^6") + py: float = Field(..., description="Latitude") + px: float = Field(..., description="Longitude") class SPTransPositionsResponse(BaseModel): - """Schema for SPTrans positions response.""" - - hr: datetime = Field(..., alias="currentTime", description="Current time") - vs: list[SPTransVehicleResponse] = Field(..., alias="vehicles", description="List of vehicles") + """Schema for SPTrans positions API response.""" - model_config = {"populate_by_name": True} + hr: str = Field(..., description="Response time") + vs: list[SPTransVehicle] = Field(default_factory=list, description="List of vehicles") diff --git a/src/adapters/repositories/gtfs_repository_adapter.py b/src/adapters/repositories/gtfs_repository_adapter.py index 17a9ccb..e534279 100644 --- a/src/adapters/repositories/gtfs_repository_adapter.py +++ b/src/adapters/repositories/gtfs_repository_adapter.py @@ -52,7 +52,7 @@ def get_route_shape(self, route_id: str) -> RouteShape | None: (shape_id,), ) - points = [] + points: list[RouteShapePoint] = [] for row in cursor.fetchall(): point = RouteShapePoint( coordinate=Coordinate( diff --git a/src/core/models/bus.py b/src/core/models/bus.py index 3c497dd..2762dd7 100644 --- a/src/core/models/bus.py +++ b/src/core/models/bus.py @@ -1,13 +1,14 @@ """Bus-related domain models.""" -from dataclasses import dataclass from datetime import datetime +from typing import Literal + +from pydantic import BaseModel, Field from .coordinate import Coordinate -@dataclass -class RouteIdentifier: +class RouteIdentifier(BaseModel): """ Identifier for a bus route. @@ -17,11 +18,14 @@ class RouteIdentifier: """ bus_line: str - bus_direction: int + bus_direction: Literal[1, 2] = Field( + ..., description="Direction (1 = ida, 2 = volta)" + ) + + model_config = {"frozen": True} -@dataclass -class BusRoute: +class BusRoute(BaseModel): """ Bus route information. @@ -33,9 +37,10 @@ class BusRoute: route_id: int route: RouteIdentifier + model_config = {"frozen": True} -@dataclass -class BusPosition: + +class BusPosition(BaseModel): """ Real-time position of a bus. @@ -48,3 +53,5 @@ class BusPosition: route: RouteIdentifier position: Coordinate time_updated: datetime + + model_config = {"frozen": True} diff --git a/src/core/models/coordinate.py b/src/core/models/coordinate.py index d2bb2ea..977aec2 100644 --- a/src/core/models/coordinate.py +++ b/src/core/models/coordinate.py @@ -1,10 +1,9 @@ """Coordinate domain model.""" -from dataclasses import dataclass +from pydantic import BaseModel -@dataclass -class Coordinate: +class Coordinate(BaseModel): """ Geographic coordinate representation. @@ -15,3 +14,5 @@ class Coordinate: latitude: float longitude: float + + model_config = {"frozen": True} diff --git a/src/core/ports/bus_provider_port.py b/src/core/ports/bus_provider_port.py index d3199e5..eb95b43 100644 --- a/src/core/ports/bus_provider_port.py +++ b/src/core/ports/bus_provider_port.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod -from ..models.bus import BusPosition, BusRoute, RouteIdentifier +from ..models.bus import BusPosition, BusRoute class BusProviderPort(ABC): @@ -11,46 +11,37 @@ class BusProviderPort(ABC): This port defines the contract for interacting with external bus tracking APIs (e.g., SPTrans, NextBus, etc.). + Authentication is managed internally by implementations. """ @abstractmethod - async def authenticate(self) -> bool: + async def get_bus_positions(self, routes: list[BusRoute]) -> list[BusPosition]: """ - Authenticate with the bus tracking API. - - Returns: - True if authentication successful, False otherwise - """ - pass - - @abstractmethod - async def get_bus_positions(self, bus_route: BusRoute) -> list[BusPosition]: - """ - Get real-time positions for specified bus routes. + Get real-time positions for specified routes. Args: - routes: List of route identifiers to query + routes: List of BusRoute objects containing route info. Returns: - List of current bus positions + List of current bus positions with route identifiers. Raises: - Exception: If API call fails or authentication required + RuntimeError: If API call fails or authentication fails. """ pass @abstractmethod - async def get_route_details(self, route: RouteIdentifier) -> list[BusRoute]: + async def search_routes(self, query: str) -> list[BusRoute]: """ - Get detailed information about a specific route. + Search for bus routes matching a query string. Args: - route: Route identifier to query + query: Search term (e.g., route number or destination name). Returns: - Bus route details + List of matching bus routes. Raises: - Exception: If route not found or API call fails + RuntimeError: If API call fails or authentication fails. """ pass diff --git a/src/core/services/route_service.py b/src/core/services/route_service.py index 6052879..429fcd5 100644 --- a/src/core/services/route_service.py +++ b/src/core/services/route_service.py @@ -1,6 +1,6 @@ """Route service - Business logic for route and bus position queries.""" -from ..models.bus import BusPosition, BusRoute, RouteIdentifier +from ..models.bus import BusPosition, BusRoute from ..models.route_shape import RouteShape from ..ports.bus_provider_port import BusProviderPort from ..ports.gtfs_repository import GTFSRepositoryPort @@ -14,60 +14,57 @@ class RouteService: real-time bus information and GTFS data for route shapes. """ - def __init__(self, bus_provider: BusProviderPort, gtfs_repository: GTFSRepositoryPort): + def __init__( + self, bus_provider: BusProviderPort, gtfs_repository: GTFSRepositoryPort + ): """ Initialize the route service. Args: - bus_provider: Implementation of BusProviderPort - gtfs_repository: Implementation of GTFSRepositoryPort + bus_provider: Implementation of BusProviderPort. + gtfs_repository: Implementation of GTFSRepositoryPort. """ self.bus_provider = bus_provider self.gtfs_repository = gtfs_repository - async def get_bus_positions(self, bus_route: BusRoute) -> list[BusPosition]: + async def get_bus_positions(self, routes: list[BusRoute]) -> list[BusPosition]: """ - Get current positions for specified bus routes. + Get current positions for specified routes. Args: - routes: List of route identifiers to query + routes: List of BusRoute objects containing route info. Returns: - List of current bus positions + List of current bus positions. Raises: - Exception: If API authentication fails or request fails + RuntimeError: If API request fails. """ - # Ensure we're authenticated - await self.bus_provider.authenticate() + return await self.bus_provider.get_bus_positions(routes) - # Get positions - return await self.bus_provider.get_bus_positions(bus_route) - - async def get_route_details(self, route: RouteIdentifier) -> list[BusRoute]: + async def search_routes(self, query: str) -> list[BusRoute]: """ - Get detailed information about a route. + Search for bus routes matching a query string. Args: - route: Route identifier + query: Search term (e.g., "809" or "Vila Nova Conceição"). Returns: - Route details + List of matching bus routes. Raises: - Exception: If route not found or API fails + RuntimeError: If API request fails. """ - await self.bus_provider.authenticate() - return await self.bus_provider.get_route_details(route) + return await self.bus_provider.search_routes(query) - def get_route_shape(self, route_id: str) -> RouteShape | None: + def get_route_shape(self, bus_line: str) -> RouteShape | None: """ Get the geographic shape coordinates of a route from GTFS data. Args: - route_id: Route identifier (e.g., "1012-10") + bus_line: Bus line identifier (e.g., "1012-10"). Returns: - RouteShape with ordered coordinates, or None if route not found + RouteShape with ordered coordinates, or None if route not found. """ - return self.gtfs_repository.get_route_shape(route_id) + return self.gtfs_repository.get_route_shape(bus_line) diff --git a/src/web/controllers/route_controller.py b/src/web/controllers/route_controller.py index 92de9c5..593c610 100644 --- a/src/web/controllers/route_controller.py +++ b/src/web/controllers/route_controller.py @@ -4,27 +4,25 @@ This controller handles queries for real-time bus information and route shapes. """ -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, Query, status from ...adapters.external.sptrans_adapter import SpTransAdapter from ...adapters.repositories.gtfs_repository_adapter import GTFSRepositoryAdapter from ...config import settings -from ...core.models.bus import BusPosition, BusRoute, RouteIdentifier +from ...core.models.bus import BusPosition, BusRoute from ...core.models.user import User from ...core.services.route_service import RouteService from ..auth import get_current_user from ..mappers import ( map_bus_position_list_to_schema, - map_route_identifier_schema_to_domain, + map_bus_route_domain_list_to_schema, + map_bus_route_schema_list_to_domain, map_route_shape_to_response, ) from ..schemas import ( BusPositionsRequest, BusPositionsResponse, - BusRouteSchema, - BusRoutesDetailsRequest, - BusRoutesDetailsResponse, - RouteIdentifierSchema, + RouteSearchResponse, RouteShapeResponse, ) @@ -36,7 +34,7 @@ def get_route_service() -> RouteService: Dependency that provides a RouteService instance. Returns: - Configured RouteService instance + Configured RouteService instance. """ bus_provider = SpTransAdapter( api_token=settings.sptrans_api_token, @@ -46,50 +44,33 @@ def get_route_service() -> RouteService: return RouteService(bus_provider, gtfs_repository) -@router.post("/details", response_model=BusRoutesDetailsResponse) -async def get_route_details_endpoint( - request: BusRoutesDetailsRequest, +@router.get("/search", response_model=RouteSearchResponse) +async def search_routes_endpoint( + query: str = Query(..., min_length=1, description="Search term for routes"), route_service: RouteService = Depends(get_route_service), current_user: User = Depends(get_current_user), -) -> BusRoutesDetailsResponse: +) -> RouteSearchResponse: """ - Resolve, para cada linha solicitada, as rotas concretas do provedor - (por exemplo, diferentes variantes/direções internamente). + Search for bus routes matching a query string. - Entrada: lista de RouteIdentifierSchema (apenas bus_line). - Saída: lista "achatada" de BusRouteSchema (route_id + bus_line). - """ - try: - # Schemas -> domínio (RouteIdentifier) - route_identifiers: list[RouteIdentifier] = [ - map_route_identifier_schema_to_domain(route_schema) for route_schema in request.routes - ] - - bus_routes: list[BusRoute] = [] + Args: + query: Search term (e.g., "809" or "Vila Nova Conceição"). + route_service: Injected route service. + current_user: Authenticated user. - for route_identifier in route_identifiers: - # O service retorna uma lista de BusRoute - resolved_routes: list[BusRoute] = await route_service.get_route_details( - route_identifier - ) - bus_routes.extend(resolved_routes) - - # Domínio -> schemas - route_schemas: list[BusRouteSchema] = [ - BusRouteSchema( - route_id=bus_route.route_id, - route=RouteIdentifierSchema( - bus_line=bus_route.route.bus_line, - bus_direction=bus_route.route.bus_direction, - ), - ) - for bus_route in bus_routes - ] + Returns: + List of matching routes with provider IDs. - return BusRoutesDetailsResponse(routes=route_schemas) + Raises: + HTTPException: If search fails. + """ + try: + bus_routes: list[BusRoute] = await route_service.search_routes(query) + route_schemas = map_bus_route_domain_list_to_schema(bus_routes) + return RouteSearchResponse(routes=route_schemas) - except Exception as e: # caminho de erro genérico - detail: str = f"Failed to retrieve route details: {str(e)}" + except Exception as e: + detail: str = f"Failed to search routes: {str(e)}" raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=detail, @@ -103,34 +84,26 @@ async def get_bus_positions( current_user: User = Depends(get_current_user), ) -> BusPositionsResponse: """ - Recupera as posições dos ônibus para as rotas já resolvidas. + Get real-time bus positions for specified routes. - Entrada: lista de BusRouteSchema (tipicamente saída de /routes/details). - Saída: lista de BusPositionSchema. - """ - try: - all_positions: list[BusPosition] = [] - - for route_schema in request.routes: - # Schema -> domínio (BusRoute) - route_identifier = RouteIdentifier( - bus_line=route_schema.route.bus_line, - bus_direction=1, # direção default; SPTrans /Linha/Buscar não usa mais isso - ) - - bus_route = BusRoute( - route_id=route_schema.route_id, - route=route_identifier, - ) + Args: + request: Request containing list of routes. + route_service: Injected route service. + current_user: Authenticated user. - route_positions: list[BusPosition] = await route_service.get_bus_positions(bus_route) - all_positions.extend(route_positions) + Returns: + List of bus positions with route identifiers. - # Domínio -> schemas - position_schemas = map_bus_position_list_to_schema(all_positions) + Raises: + HTTPException: If fetching positions fails. + """ + try: + bus_routes = map_bus_route_schema_list_to_domain(request.routes) + positions: list[BusPosition] = await route_service.get_bus_positions(bus_routes) + position_schemas = map_bus_position_list_to_schema(positions) return BusPositionsResponse(buses=position_schemas) - except Exception as e: # caminho de erro genérico + except Exception as e: detail: str = f"Failed to retrieve bus positions: {str(e)}" raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -148,14 +121,15 @@ async def get_route_shape( Get the geographic shape (coordinates) of a route from GTFS data. Args: - route_id: Route identifier (e.g., "1012-10") - route_service: Injected route service + route_id: Route identifier (e.g., "1012-10"). + route_service: Injected route service. + current_user: Authenticated user. Returns: - Ordered list of coordinates defining the route shape + Ordered list of coordinates defining the route shape. Raises: - HTTPException: If route not found or database error occurs + HTTPException: If route not found or database error occurs. """ try: shape = route_service.get_route_shape(route_id) diff --git a/src/web/mappers.py b/src/web/mappers.py index 02cad8b..ca10168 100644 --- a/src/web/mappers.py +++ b/src/web/mappers.py @@ -5,12 +5,13 @@ maintaining the separation of concerns. """ -from ..core.models.bus import BusPosition, RouteIdentifier +from ..core.models.bus import BusPosition, BusRoute, RouteIdentifier from ..core.models.coordinate import Coordinate from ..core.models.route_shape import RouteShape from ..core.models.user import User from .schemas import ( BusPositionSchema, + BusRouteSchema, CoordinateSchema, RouteIdentifierSchema, RouteShapeResponse, @@ -89,6 +90,68 @@ def map_route_identifier_domain_to_schema( ) +def map_bus_route_schema_to_domain(schema: BusRouteSchema) -> BusRoute: + """ + Map a BusRouteSchema to a BusRoute domain model. + + Args: + schema: BusRouteSchema from API + + Returns: + BusRoute domain model + """ + return BusRoute( + route_id=schema.route_id, + route=map_route_identifier_schema_to_domain(schema.route), + ) + + +def map_bus_route_schema_list_to_domain( + schemas: list[BusRouteSchema], +) -> list[BusRoute]: + """ + Map a list of BusRouteSchema to BusRoute domain models. + + Args: + schemas: List of BusRouteSchema from API + + Returns: + List of BusRoute domain models + """ + return [map_bus_route_schema_to_domain(schema) for schema in schemas] + + +def map_bus_route_domain_to_schema(bus_route: BusRoute) -> BusRouteSchema: + """ + Map a BusRoute domain model to a BusRouteSchema. + + Args: + bus_route: BusRoute domain model + + Returns: + BusRouteSchema for API + """ + return BusRouteSchema( + route_id=bus_route.route_id, + route=map_route_identifier_domain_to_schema(bus_route.route), + ) + + +def map_bus_route_domain_list_to_schema( + bus_routes: list[BusRoute], +) -> list[BusRouteSchema]: + """ + Map a list of BusRoute domain models to BusRouteSchema list. + + Args: + bus_routes: List of BusRoute domain models + + Returns: + List of BusRouteSchema for API + """ + return [map_bus_route_domain_to_schema(bus_route) for bus_route in bus_routes] + + def map_coordinate_domain_to_schema(coord: Coordinate) -> CoordinateSchema: """ Map a Coordinate domain model to a CoordinateSchema. @@ -153,5 +216,7 @@ def map_route_shape_to_response(shape: RouteShape) -> RouteShapeResponse: return RouteShapeResponse( route_id=shape.route_id, shape_id=shape.shape_id, - points=[map_coordinate_domain_to_schema(point.coordinate) for point in shape.points], + points=[ + map_coordinate_domain_to_schema(point.coordinate) for point in shape.points + ], ) diff --git a/src/web/schemas.py b/src/web/schemas.py index 5b10573..2765905 100644 --- a/src/web/schemas.py +++ b/src/web/schemas.py @@ -109,7 +109,7 @@ class BusPositionsRequest(BaseModel): """Request schema for querying bus positions.""" routes: list[BusRouteSchema] = Field( - ..., description="List of resolved routes (with route_id) to query positions" + ..., description="List of routes to query positions" ) @@ -119,19 +119,21 @@ class BusPositionsResponse(BaseModel): buses: list[BusPositionSchema] = Field(..., description="List of bus positions") -class BusRoutesDetailsRequest(BaseModel): - """Request schema for resolving route details.""" +class RouteSearchRequest(BaseModel): + """Request schema for searching routes.""" - routes: list[RouteIdentifierSchema] = Field( - ..., description="List of routes (line + direction) to resolve" + query: str = Field( + ..., + min_length=1, + description="Search term (e.g., '809' or 'Vila Nova Conceição')", ) -class BusRoutesDetailsResponse(BaseModel): - """Response schema for route details.""" +class RouteSearchResponse(BaseModel): + """Response schema for route search results.""" routes: list[BusRouteSchema] = Field( - ..., description="List of resolved routes with provider IDs" + ..., description="List of matching routes with provider IDs" ) @@ -140,7 +142,9 @@ class RouteShapeResponse(BaseModel): route_id: str = Field(..., description="Route identifier") shape_id: str = Field(..., description="GTFS shape identifier") - points: list[CoordinateSchema] = Field(..., description="Ordered list of coordinates") + points: list[CoordinateSchema] = Field( + ..., description="Ordered list of coordinates" + ) # ===== Ranking Schemas ===== diff --git a/tests/adapters/test_sptrans_adapter.py b/tests/adapters/test_sptrans_adapter.py index 6e2549e..c0173d6 100644 --- a/tests/adapters/test_sptrans_adapter.py +++ b/tests/adapters/test_sptrans_adapter.py @@ -1,82 +1,81 @@ +import os + import pytest from src.adapters.external.sptrans_adapter import SpTransAdapter -from src.core.models.bus import BusPosition, BusRoute, RouteIdentifier +from src.core.models.bus import BusPosition, BusRoute + + +skip_if_no_token = pytest.mark.skipif( + not os.getenv("SPTRANS_API_TOKEN"), + reason="SPTRANS_API_TOKEN not set - skipping integration test", +) +@skip_if_no_token @pytest.mark.asyncio -async def test_authentication() -> None: +async def test_automatic_authentication() -> None: """ - Real authentication test against SPTrans. - Requires SPTRANS_API_TOKEN to be configured. + Test that the adapter authenticates automatically when making requests. """ adapter: SpTransAdapter = SpTransAdapter() - ok: bool = await adapter.authenticate() + routes: list[BusRoute] = await adapter.search_routes("8075") - print("Cookies recebidos:", adapter.client.cookies) - - assert ok is True, "A autenticação real falhou. Verifique seu TOKEN." - assert "apiCredentials" in adapter.client.cookies, "Cookie de credenciais não foi criado." + assert ( + "apiCredentials" in adapter.client.cookies + ), "Cookie de credenciais não foi criado." + assert len(routes) > 0 +@skip_if_no_token @pytest.mark.asyncio -async def test_get_route_details_8075_direction_1() -> None: +async def test_search_routes_number() -> None: """ - Integration test: resolves the internal SPTrans 'codigoLinha' (cl) - for route 8075 using get_route_details(), which now returns a list - of BusRoute entries (diferentes sentidos/variações). + Searches for route number and validates the results. """ adapter: SpTransAdapter = SpTransAdapter() - await adapter.authenticate() - - # RouteIdentifier para linha 8075 (direção ainda existe no domínio) - route: RouteIdentifier = RouteIdentifier(bus_line="8075", bus_direction=1) - - # Agora get_route_details retorna list[BusRoute] - bus_routes: list[BusRoute] = await adapter.get_route_details(route) + bus_routes: list[BusRoute] = await adapter.search_routes("8075") - print("Retrieved BusRoutes:", bus_routes) - - # Validate result assert isinstance(bus_routes, list) assert len(bus_routes) > 0, "Nenhuma rota retornada para 8075" for bus_route in bus_routes: assert isinstance(bus_route.route_id, int), "route_id must be an integer" assert bus_route.route_id > 0, "route_id must be positive" - # a linha deve bater com o que pedimos - assert bus_route.route.bus_line == "8075" + assert "8075" in bus_route.route.bus_line +@skip_if_no_token @pytest.mark.asyncio -async def test_get_bus_positions_8075_direction_1() -> None: +async def test_search_routes_by_destination() -> None: """ - Integration test: fetches real-time positions for one concrete route - of line 8075 (direction 1, se disponível), usando - get_route_details() + get_bus_positions(), sem mocks. + Searches for routes by destination name. """ adapter: SpTransAdapter = SpTransAdapter() - await adapter.authenticate() + bus_routes: list[BusRoute] = await adapter.search_routes("Lapa") - route: RouteIdentifier = RouteIdentifier(bus_line="8075", bus_direction=1) + assert isinstance(bus_routes, list) + assert len(bus_routes) > 0, "Nenhuma rota retornada para Lapa" - # Step 1: resolve SPTrans internal codes (list[BusRoute]) - bus_routes: list[BusRoute] = await adapter.get_route_details(route) - assert len(bus_routes) > 0 - # escolhe uma rota com direction 1, se existir; senão, pega a primeira - chosen_route: BusRoute = next( - (br for br in bus_routes if getattr(br.route, "bus_direction", None) == 1), - bus_routes[0], - ) +@skip_if_no_token +@pytest.mark.asyncio +async def test_get_bus_positions() -> None: + """ + Fetches real-time positions for route 8075. + """ + adapter: SpTransAdapter = SpTransAdapter() + + bus_routes: list[BusRoute] = await adapter.search_routes("8075") + assert len(bus_routes) > 0 + chosen_route: BusRoute = bus_routes[0] assert chosen_route.route_id > 0 - # Step 2: fetch positions usando BusRoute concreto - positions: list[BusPosition] = await adapter.get_bus_positions(chosen_route) + positions: list[BusPosition] = await adapter.get_bus_positions([chosen_route]) assert positions is not None assert isinstance(positions, list) @@ -84,36 +83,22 @@ async def test_get_bus_positions_8075_direction_1() -> None: if positions: pos: BusPosition = positions[0] - # Route info should match what we requested (mesma linha) - assert pos.route.bus_line == "8075" + assert "8075" in pos.route.bus_line + assert pos.route.bus_direction in (1, 2) - # se tiver direction em BusPosition.route, valida também - if hasattr(pos.route, "bus_direction"): - assert pos.route.bus_direction in (1, 2) - - # Coordenadas assert isinstance(pos.position.latitude, float | int) assert isinstance(pos.position.longitude, float | int) +@skip_if_no_token @pytest.mark.asyncio -async def test_get_route_details_without_authentication(): - adapter: SpTransAdapter = SpTransAdapter(api_token="INVALID") - - route = RouteIdentifier(bus_line="8075", bus_direction=1) - - with pytest.raises(RuntimeError, match="not authenticated"): - await adapter.get_route_details(route) - - -@pytest.mark.asyncio -async def test_get_bus_positions_without_authentication(): +async def test_search_routes_returns_empty_for_unknown() -> None: + """ + Test that search returns empty list for unknown routes. + """ adapter: SpTransAdapter = SpTransAdapter() - bus_route = BusRoute( - route_id=8075, - route=RouteIdentifier(bus_line="8075", bus_direction=1), - ) + routes: list[BusRoute] = await adapter.search_routes("XYZNONEXISTENT999") - with pytest.raises(RuntimeError, match="not authenticated"): - await adapter.get_bus_positions(bus_route) + assert isinstance(routes, list) + assert len(routes) == 0 diff --git a/tests/core/test_route_service.py b/tests/core/test_route_service.py index af5f519..8152e46 100644 --- a/tests/core/test_route_service.py +++ b/tests/core/test_route_service.py @@ -1,7 +1,7 @@ """Tests for RouteService. This file contains two groups of tests: -- async tests that exercise the bus provider (authenticate, get_bus_positions, get_route_details) +- async tests that exercise the bus provider (get_bus_positions, search_routes) - sync tests that exercise get_route_shape delegating to a GTFS repository """ @@ -22,24 +22,17 @@ @pytest.mark.asyncio -async def test_get_bus_positions_calls_auth_and_provider() -> None: - # Arrange +async def test_get_bus_positions_calls_provider() -> None: raw_provider: Mock = Mock(spec=BusProviderPort) - raw_provider.authenticate = AsyncMock(return_value=True) raw_provider.get_bus_positions = AsyncMock() bus_provider: BusProviderPort = cast(BusProviderPort, raw_provider) route_identifier: RouteIdentifier = RouteIdentifier( - bus_line="8075", + bus_line="8075-10", bus_direction=1, ) - bus_route: BusRoute = BusRoute( - route_id=1234, - route=route_identifier, - ) - expected_positions: list[BusPosition] = [ BusPosition( route=route_identifier, @@ -48,27 +41,26 @@ async def test_get_bus_positions_calls_auth_and_provider() -> None: ), ] - # configurando retorno tipado do mock raw_provider.get_bus_positions.return_value = expected_positions gtfs_repo = create_autospec(GTFSRepositoryPort, instance=True) - service: RouteService = RouteService(bus_provider=bus_provider, gtfs_repository=gtfs_repo) + service: RouteService = RouteService( + bus_provider=bus_provider, gtfs_repository=gtfs_repo + ) - # Act - result: list[BusPosition] = await service.get_bus_positions(bus_route) + routes = [ + BusRoute(route_id=1234, route=route_identifier), + ] + result: list[BusPosition] = await service.get_bus_positions(routes) - # Assert - raw_provider.authenticate.assert_awaited_once() - raw_provider.get_bus_positions.assert_awaited_once_with(bus_route) + raw_provider.get_bus_positions.assert_awaited_once_with(routes) assert result == expected_positions @pytest.mark.asyncio -async def test_get_route_details_calls_auth_and_provider() -> None: - # Arrange +async def test_search_routes_calls_provider() -> None: raw_provider: Mock = Mock(spec=BusProviderPort) - raw_provider.authenticate = AsyncMock(return_value=True) - raw_provider.get_route_details = AsyncMock() + raw_provider.search_routes = AsyncMock() bus_provider: BusProviderPort = cast(BusProviderPort, raw_provider) @@ -84,85 +76,70 @@ async def test_get_route_details_calls_auth_and_provider() -> None: expected_routes: list[BusRoute] = [expected_bus_route] - # agora o provider também retorna lista - raw_provider.get_route_details.return_value = expected_routes + raw_provider.search_routes.return_value = expected_routes gtfs_repo = create_autospec(GTFSRepositoryPort, instance=True) - service: RouteService = RouteService(bus_provider=bus_provider, gtfs_repository=gtfs_repo) + service: RouteService = RouteService( + bus_provider=bus_provider, gtfs_repository=gtfs_repo + ) - # Act - result: list[BusRoute] = await service.get_route_details(route_identifier) + query = "8075" + result: list[BusRoute] = await service.search_routes(query) - # Assert - raw_provider.authenticate.assert_awaited_once() - raw_provider.get_route_details.assert_awaited_once_with(route_identifier) + raw_provider.search_routes.assert_awaited_once_with(query) assert result == expected_routes @pytest.mark.asyncio async def test_get_bus_positions_propagates_exception_from_provider() -> None: - # Arrange + """Test that exceptions from the provider are propagated.""" raw_provider: Mock = Mock(spec=BusProviderPort) - raw_provider.authenticate = AsyncMock(return_value=True) raw_provider.get_bus_positions = AsyncMock(side_effect=RuntimeError("boom")) bus_provider: BusProviderPort = cast(BusProviderPort, raw_provider) - route_identifier: RouteIdentifier = RouteIdentifier( - bus_line="8075", - bus_direction=1, - ) - - bus_route: BusRoute = BusRoute( - route_id=1234, - route=route_identifier, + gtfs_repo = create_autospec(GTFSRepositoryPort, instance=True) + service: RouteService = RouteService( + bus_provider=bus_provider, gtfs_repository=gtfs_repo ) - gtfs_repo = create_autospec(GTFSRepositoryPort, instance=True) - service: RouteService = RouteService(bus_provider=bus_provider, gtfs_repository=gtfs_repo) + routes = [ + BusRoute( + route_id=1234, + route=RouteIdentifier(bus_line="8075-10", bus_direction=1), + ), + ] - # Act / Assert with pytest.raises(RuntimeError, match="boom"): - await service.get_bus_positions(bus_route) + await service.get_bus_positions(routes) - raw_provider.authenticate.assert_awaited_once() - raw_provider.get_bus_positions.assert_awaited_once_with(bus_route) + raw_provider.get_bus_positions.assert_awaited_once_with(routes) @pytest.mark.asyncio -async def test_get_route_details_propagates_exception_from_authenticate() -> None: - # Arrange +async def test_search_routes_propagates_exception_from_provider() -> None: + """Test that exceptions from search_routes are propagated.""" raw_provider: Mock = Mock(spec=BusProviderPort) - raw_provider.authenticate = AsyncMock( - side_effect=RuntimeError("auth failed"), - ) - raw_provider.get_route_details = AsyncMock() + raw_provider.search_routes = AsyncMock(side_effect=RuntimeError("search failed")) bus_provider: BusProviderPort = cast(BusProviderPort, raw_provider) - route_identifier: RouteIdentifier = RouteIdentifier( - bus_line="8075", - bus_direction=1, - ) - gtfs_repo = create_autospec(GTFSRepositoryPort, instance=True) - service: RouteService = RouteService(bus_provider=bus_provider, gtfs_repository=gtfs_repo) + service: RouteService = RouteService( + bus_provider=bus_provider, gtfs_repository=gtfs_repo + ) - # Act / Assert - with pytest.raises(RuntimeError, match="auth failed"): - await service.get_route_details(route_identifier) + with pytest.raises(RuntimeError, match="search failed"): + await service.search_routes("8075") - raw_provider.authenticate.assert_awaited_once() - raw_provider.get_route_details.assert_not_awaited() + raw_provider.search_routes.assert_awaited_once_with("8075") def test_get_route_shape_found() -> None: """Test getting a route shape when it exists.""" - # Arrange bus_provider = create_autospec(BusProviderPort, instance=True) gtfs_repo = create_autospec(GTFSRepositoryPort, instance=True) - # Create a mock route shape mock_shape = RouteShape( route_id="1012-10", shape_id="84609", @@ -184,10 +161,8 @@ def test_get_route_shape_found() -> None: service = RouteService(bus_provider, gtfs_repo) - # Act result = service.get_route_shape("1012-10") - # Assert assert result is not None assert result.route_id == "1012-10" assert result.shape_id == "84609" @@ -197,7 +172,6 @@ def test_get_route_shape_found() -> None: def test_get_route_shape_not_found() -> None: """Test getting a route shape when it doesn't exist.""" - # Arrange bus_provider = create_autospec(BusProviderPort, instance=True) gtfs_repo = create_autospec(GTFSRepositoryPort, instance=True) @@ -205,24 +179,22 @@ def test_get_route_shape_not_found() -> None: service = RouteService(bus_provider, gtfs_repo) - # Act result = service.get_route_shape("nonexistent-route") - # Assert assert result is None gtfs_repo.get_route_shape.assert_called_once_with("nonexistent-route") def test_get_route_shape_with_many_points() -> None: """Test getting a route shape with many coordinate points.""" - # Arrange bus_provider = create_autospec(BusProviderPort, instance=True) gtfs_repo = create_autospec(GTFSRepositoryPort, instance=True) - # Create a shape with many points points = [ RouteShapePoint( - coordinate=Coordinate(latitude=-23.5505 + i * 0.001, longitude=-46.6333 + i * 0.001), + coordinate=Coordinate( + latitude=-23.5505 + i * 0.001, longitude=-46.6333 + i * 0.001 + ), sequence=i + 1, distance_traveled=float(i * 10), ) @@ -235,10 +207,8 @@ def test_get_route_shape_with_many_points() -> None: service = RouteService(bus_provider, gtfs_repo) - # Act result = service.get_route_shape("long-route") - # Assert assert result is not None assert len(result.points) == 100 assert result.points[0].sequence == 1 @@ -248,7 +218,6 @@ def test_get_route_shape_with_many_points() -> None: def test_get_route_shape_with_special_characters() -> None: """Test getting a route shape with special characters in route ID.""" - # Arrange bus_provider = create_autospec(BusProviderPort, instance=True) gtfs_repo = create_autospec(GTFSRepositoryPort, instance=True) @@ -268,10 +237,8 @@ def test_get_route_shape_with_special_characters() -> None: service = RouteService(bus_provider, gtfs_repo) - # Act result = service.get_route_shape("route-with-special_chars@123") - # Assert assert result is not None assert result.route_id == "route-with-special_chars@123" gtfs_repo.get_route_shape.assert_called_once_with("route-with-special_chars@123") @@ -279,7 +246,6 @@ def test_get_route_shape_with_special_characters() -> None: def test_get_route_shape_independent_of_bus_provider() -> None: """Test that get_route_shape doesn't interact with bus provider.""" - # Arrange bus_provider = create_autospec(BusProviderPort, instance=True) gtfs_repo = create_autospec(GTFSRepositoryPort, instance=True) @@ -299,12 +265,8 @@ def test_get_route_shape_independent_of_bus_provider() -> None: service = RouteService(bus_provider, gtfs_repo) - # Act result = service.get_route_shape("test-route") - # Assert assert result is not None - # Verify bus_provider was not called at all - bus_provider.authenticate.assert_not_called() bus_provider.get_bus_positions.assert_not_called() - bus_provider.get_route_details.assert_not_called() + bus_provider.search_routes.assert_not_called() diff --git a/tests/integration/test_route.py b/tests/integration/test_route.py index c168893..3a07b0e 100644 --- a/tests/integration/test_route.py +++ b/tests/integration/test_route.py @@ -9,16 +9,15 @@ from src.web.schemas import ( BusPositionsRequest, BusRouteSchema, - BusRoutesDetailsRequest, RouteIdentifierSchema, ) from .conftest import create_user_and_login -class TestRouteDetails: +class TestRouteSearch: @pytest.mark.asyncio - async def test_get_route_details_returns_successfully( + async def test_search_routes_returns_successfully( self, client: AsyncClient, ) -> None: @@ -36,27 +35,14 @@ async def test_get_route_details_returns_successfully( ) ] - with ( - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", - new_callable=AsyncMock, - return_value=True, - ), - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.get_route_details", - new_callable=AsyncMock, - return_value=mock_bus_routes, - ), + with patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.search_routes", + new_callable=AsyncMock, + return_value=mock_bus_routes, ): - request_data = BusRoutesDetailsRequest( - routes=[ - RouteIdentifierSchema(bus_line="8000", bus_direction=1), - ] - ) - - response = await client.post( - "/routes/details", - json=request_data.model_dump(), + response = await client.get( + "/routes/search", + params={"query": "8000"}, headers=auth["headers"], ) @@ -74,7 +60,7 @@ async def test_get_route_details_returns_successfully( assert first_route["route"]["bus_direction"] == 1 @pytest.mark.asyncio - async def test_get_route_details_with_multiple_lines( + async def test_search_routes_returns_multiple_results( self, client: AsyncClient, ) -> None: @@ -85,41 +71,25 @@ async def test_get_route_details_with_multiple_lines( } auth = await create_user_and_login(client, user_data) - mock_routes_8000 = [ + mock_bus_routes = [ BusRoute( route_id=12345, route=RouteIdentifier(bus_line="8000", bus_direction=1), ), - ] - mock_routes_9000 = [ BusRoute( - route_id=67890, - route=RouteIdentifier(bus_line="9000", bus_direction=1), + route_id=12346, + route=RouteIdentifier(bus_line="8000", bus_direction=2), ), ] - with ( - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", - new_callable=AsyncMock, - return_value=True, - ), - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.get_route_details", - new_callable=AsyncMock, - side_effect=[mock_routes_8000, mock_routes_9000], - ), + with patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.search_routes", + new_callable=AsyncMock, + return_value=mock_bus_routes, ): - request_data = BusRoutesDetailsRequest( - routes=[ - RouteIdentifierSchema(bus_line="8000", bus_direction=1), - RouteIdentifierSchema(bus_line="9000", bus_direction=1), - ] - ) - - response = await client.post( - "/routes/details", - json=request_data.model_dump(), + response = await client.get( + "/routes/search", + params={"query": "8000"}, headers=auth["headers"], ) @@ -129,10 +99,10 @@ async def test_get_route_details_with_multiple_lines( assert len(data["routes"]) == 2 route_ids = [r["route_id"] for r in data["routes"]] assert 12345 in route_ids - assert 67890 in route_ids + assert 12346 in route_ids @pytest.mark.asyncio - async def test_get_route_details_returns_empty_for_unknown_line( + async def test_search_routes_returns_empty_for_unknown_query( self, client: AsyncClient, ) -> None: @@ -143,27 +113,14 @@ async def test_get_route_details_returns_empty_for_unknown_line( } auth = await create_user_and_login(client, user_data) - with ( - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", - new_callable=AsyncMock, - return_value=True, - ), - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.get_route_details", - new_callable=AsyncMock, - return_value=[], - ), + with patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.search_routes", + new_callable=AsyncMock, + return_value=[], ): - request_data = BusRoutesDetailsRequest( - routes=[ - RouteIdentifierSchema(bus_line="UNKNOWN", bus_direction=1), - ] - ) - - response = await client.post( - "/routes/details", - json=request_data.model_dump(), + response = await client.get( + "/routes/search", + params={"query": "UNKNOWN"}, headers=auth["headers"], ) @@ -172,7 +129,7 @@ async def test_get_route_details_returns_empty_for_unknown_line( assert data["routes"] == [] @pytest.mark.asyncio - async def test_get_route_details_returns_500_on_api_error( + async def test_search_routes_returns_500_on_api_error( self, client: AsyncClient, ) -> None: @@ -183,35 +140,22 @@ async def test_get_route_details_returns_500_on_api_error( } auth = await create_user_and_login(client, user_data) - with ( - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", - new_callable=AsyncMock, - return_value=True, - ), - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.get_route_details", - new_callable=AsyncMock, - side_effect=RuntimeError("API unavailable"), - ), + with patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.search_routes", + new_callable=AsyncMock, + side_effect=RuntimeError("API unavailable"), ): - request_data = BusRoutesDetailsRequest( - routes=[ - RouteIdentifierSchema(bus_line="8000", bus_direction=1), - ] - ) - - response = await client.post( - "/routes/details", - json=request_data.model_dump(), + response = await client.get( + "/routes/search", + params={"query": "8000"}, headers=auth["headers"], ) assert response.status_code == 500 - assert "Failed to retrieve route details" in response.json()["detail"] + assert "Failed to search routes" in response.json()["detail"] @pytest.mark.asyncio - async def test_get_route_details_with_empty_routes_list( + async def test_search_routes_returns_422_when_query_missing( self, client: AsyncClient, ) -> None: @@ -222,34 +166,19 @@ async def test_get_route_details_with_empty_routes_list( } auth = await create_user_and_login(client, user_data) - with patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", - new_callable=AsyncMock, - return_value=True, - ): - request_data = BusRoutesDetailsRequest(routes=[]) - - response = await client.post( - "/routes/details", - json=request_data.model_dump(), - headers=auth["headers"], - ) + response = await client.get( + "/routes/search", + headers=auth["headers"], + ) - assert response.status_code == 200 - assert response.json()["routes"] == [] + assert response.status_code == 422 @pytest.mark.asyncio - async def test_get_route_details_without_auth_fails( + async def test_search_routes_without_auth_fails( self, client: AsyncClient, ) -> None: - request_data = BusRoutesDetailsRequest( - routes=[ - RouteIdentifierSchema(bus_line="8000", bus_direction=1), - ] - ) - - response = await client.post("/routes/details", json=request_data.model_dump()) + response = await client.get("/routes/search", params={"query": "8000"}) assert response.status_code == 401 @@ -280,17 +209,10 @@ async def test_get_bus_position_returns_successfully( ), ] - with ( - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", - new_callable=AsyncMock, - return_value=True, - ), - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", - new_callable=AsyncMock, - return_value=mock_positions, - ), + with patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", + new_callable=AsyncMock, + return_value=mock_positions, ): request_data = BusPositionsRequest( routes=[ @@ -334,17 +256,10 @@ async def test_get_bus_position_returns_500_when_error( } auth = await create_user_and_login(client, user_data) - with ( - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", - new_callable=AsyncMock, - return_value=True, - ), - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", - new_callable=AsyncMock, - side_effect=ValueError("Error fetching positions"), - ), + with patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", + new_callable=AsyncMock, + side_effect=ValueError("Error fetching positions"), ): request_data = BusPositionsRequest( routes=[ @@ -376,17 +291,10 @@ async def test_get_bus_position_returns_empty_when_no_buses_on_line( } auth = await create_user_and_login(client, user_data) - with ( - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", - new_callable=AsyncMock, - return_value=True, - ), - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", - new_callable=AsyncMock, - return_value=[], - ), + with patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", + new_callable=AsyncMock, + return_value=[], ): request_data = BusPositionsRequest( routes=[ @@ -421,14 +329,12 @@ async def test_get_bus_position_works_with_multiple_routes( } auth = await create_user_and_login(client, user_data) - mock_positions_8000 = [ + mock_positions = [ BusPosition( route=RouteIdentifier(bus_line="8000", bus_direction=1), position=Coordinate(latitude=-23.550520, longitude=-46.633308), time_updated=datetime.now(UTC), ), - ] - mock_positions_9000 = [ BusPosition( route=RouteIdentifier(bus_line="9000", bus_direction=2), position=Coordinate(latitude=-23.560520, longitude=-46.643308), @@ -436,17 +342,10 @@ async def test_get_bus_position_works_with_multiple_routes( ), ] - with ( - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", - new_callable=AsyncMock, - return_value=True, - ), - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", - new_callable=AsyncMock, - side_effect=[mock_positions_8000, mock_positions_9000], - ), + with patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", + new_callable=AsyncMock, + return_value=mock_positions, ): request_data = BusPositionsRequest( routes=[ @@ -476,7 +375,7 @@ async def test_get_bus_position_works_with_multiple_routes( assert "9000" in bus_lines @pytest.mark.asyncio - async def test_get_bus_position_returns_500_when_authentication_failure( + async def test_get_bus_position_returns_500_when_api_error( self, client: AsyncClient, ) -> None: @@ -488,9 +387,9 @@ async def test_get_bus_position_returns_500_when_authentication_failure( auth = await create_user_and_login(client, user_data) with patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", + "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", new_callable=AsyncMock, - side_effect=RuntimeError("Authentication failed"), + side_effect=RuntimeError("API error"), ): request_data = BusPositionsRequest( routes=[ @@ -521,7 +420,7 @@ async def test_get_bus_position_returns_422_when_invalid_data( } auth = await create_user_and_login(client, user_data) - request_data = { + invalid_request_data: dict[str, list[dict[str, int | dict[str, str | int]]]] = { "routes": [ { "route_id": 12345, @@ -532,7 +431,7 @@ async def test_get_bus_position_returns_422_when_invalid_data( response = await client.post( "/routes/positions", - json=request_data, + json=invalid_request_data, headers=auth["headers"], ) @@ -550,17 +449,10 @@ async def test_get_bus_position_returns_successfully_with_empty_routes_list( } auth = await create_user_and_login(client, user_data) - with ( - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.authenticate", - new_callable=AsyncMock, - return_value=True, - ), - patch( - "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", - new_callable=AsyncMock, - return_value=[], - ), + with patch( + "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", + new_callable=AsyncMock, + return_value=[], ): request_data = BusPositionsRequest(routes=[]) @@ -587,7 +479,9 @@ async def test_get_bus_position_without_auth_fails( ] ) - response = await client.post("/routes/positions", json=request_data.model_dump()) + response = await client.post( + "/routes/positions", json=request_data.model_dump() + ) assert response.status_code == 401 diff --git a/tests/web/test_route_controller.py b/tests/web/test_route_controller.py index b6a9b52..03ab1f8 100644 --- a/tests/web/test_route_controller.py +++ b/tests/web/test_route_controller.py @@ -23,13 +23,9 @@ def client() -> TestClient: @pytest.fixture def mock_service() -> RouteService: - """ - Cria um mock fortemente tipado do RouteService, - mas com métodos assíncronos (AsyncMock). - """ service = AsyncMock(spec=RouteService) typed_service: RouteService = service - typed_service.get_route_details = AsyncMock() # type: ignore[method-assign] + typed_service.search_routes = AsyncMock() # type: ignore[method-assign] typed_service.get_bus_positions = AsyncMock() # type: ignore[method-assign] return typed_service @@ -49,162 +45,206 @@ def override_dependency( app.dependency_overrides.clear() -# ========================= -# /routes/details -# ========================= +class TestSearchRoutes: + @pytest.mark.asyncio + async def test_search_endpoint_success( + self, client: TestClient, mock_service: RouteService + ) -> None: + bus_route_1 = BusRoute(route_id=2044, route=RouteIdentifier(bus_line="8075", bus_direction=1)) + bus_route_2 = BusRoute(route_id=34812, route=RouteIdentifier(bus_line="8075", bus_direction=2)) -@pytest.mark.asyncio -async def test_details_endpoint_success(client: TestClient, mock_service: RouteService) -> None: - """ - Testa o endpoint POST /routes/details garantindo que: - - Ele chama RouteService.get_route_details() - - Ele retorna uma lista achatada de rotas - """ + mock_service.search_routes.return_value = [bus_route_1, bus_route_2] # type: ignore[attr-defined] - # ----- Arrange ----- - # domínio - route_identifier = RouteIdentifier(bus_line="8075", bus_direction=1) - bus_route_1 = BusRoute(route_id=2044, route=route_identifier) - bus_route_2 = BusRoute(route_id=34812, route=route_identifier) + response = client.get("/routes/search", params={"query": "8075"}) - # get_route_details agora retorna list[BusRoute] - mock_service.get_route_details.return_value = [bus_route_1, bus_route_2] # type: ignore[attr-defined] + assert response.status_code == 200 + data = response.json() - payload = { - "routes": [ - {"bus_line": "8075"}, - ] - } + assert "routes" in data + assert len(data["routes"]) == 2 - # ----- Act ----- - response = client.post("/routes/details", json=payload) + routes = data["routes"] - # ----- Assert ----- - assert response.status_code == 200 - data = response.json() + assert routes[0]["route_id"] == 2044 + assert routes[0]["route"]["bus_line"] == "8075" - assert "routes" in data - assert len(data["routes"]) == 2 + assert routes[1]["route_id"] == 34812 + assert routes[1]["route"]["bus_line"] == "8075" - routes = data["routes"] + mock_service.search_routes.assert_awaited_once() # type: ignore[attr-defined] + called_arg = mock_service.search_routes.await_args.args[0] # type: ignore[attr-defined] + assert called_arg == "8075" - assert routes[0]["route_id"] == 2044 - assert routes[0]["route"]["bus_line"] == "8075" + @pytest.mark.asyncio + async def test_search_endpoint_with_destination_name( + self, client: TestClient, mock_service: RouteService + ) -> None: + """Test search with destination name query.""" + route_identifier = RouteIdentifier(bus_line="809", bus_direction=1) + bus_route = BusRoute(route_id=1234, route=route_identifier) - assert routes[1]["route_id"] == 34812 - assert routes[1]["route"]["bus_line"] == "8075" + mock_service.search_routes.return_value = [bus_route] # type: ignore[attr-defined] - # garante que o service foi chamado uma vez - mock_service.get_route_details.assert_awaited_once() # type: ignore[attr-defined] - called_arg = mock_service.get_route_details.await_args.args[0] # type: ignore[attr-defined] - assert isinstance(called_arg, RouteIdentifier) - assert called_arg.bus_line == "8075" - # direção padrão que estamos usando - assert called_arg.bus_direction == 1 + response = client.get("/routes/search", params={"query": "Vila Nova Conceição"}) + assert response.status_code == 200 + data = response.json() -@pytest.mark.asyncio -async def test_details_endpoint_error_returns_500( - client: TestClient, mock_service: RouteService -) -> None: - """ - Testa se o controller retorna 500 caso o service levante exception - em /routes/details. - """ + assert "routes" in data + assert len(data["routes"]) == 1 - mock_service.get_route_details.side_effect = RuntimeError("boom") # type: ignore[attr-defined] + mock_service.search_routes.assert_awaited_once() # type: ignore[attr-defined] + called_arg = mock_service.search_routes.await_args.args[0] # type: ignore[attr-defined] + assert called_arg == "Vila Nova Conceição" - payload = {"routes": [{"bus_line": "8075"}]} - - response = client.post("/routes/details", json=payload) - - assert response.status_code == 500 - body = response.json() - assert "Failed to retrieve route details" in body["detail"] - - -# ========================= -# /routes/positions -# ========================= - - -@pytest.mark.asyncio -async def test_positions_endpoint_success(client: TestClient, mock_service: RouteService) -> None: - """ - Testa o endpoint POST /routes/positions garantindo que: - - Ele chama RouteService.get_bus_positions() - - Ele retorna os dados de posição corretamente - """ - - # ----- Arrange ----- - route_identifier = RouteIdentifier(bus_line="8075", bus_direction=1) - - position = BusPosition( - route=route_identifier, - position=Coordinate(latitude=-23.5, longitude=-46.6), - time_updated=datetime.now(UTC), - ) - - mock_service.get_bus_positions.return_value = [position] # type: ignore[attr-defined] - - payload = { - "routes": [ - { - "route_id": 2044, - "route": {"bus_line": "8075"}, - } - ] - } - - # ----- Act ----- - response = client.post("/routes/positions", json=payload) - - # ----- Assert ----- - assert response.status_code == 200 - data = response.json() - - assert "buses" in data - assert len(data["buses"]) == 1 - - bus = data["buses"][0] - - assert bus["route"]["bus_line"] == "8075" - assert "position" in bus - assert "latitude" in bus["position"] - assert "longitude" in bus["position"] - assert "time_updated" in bus - - mock_service.get_bus_positions.assert_awaited_once() # type: ignore[attr-defined] - called_arg = mock_service.get_bus_positions.await_args.args[0] # type: ignore[attr-defined] - assert isinstance(called_arg, BusRoute) - assert called_arg.route.bus_line == "8075" - assert called_arg.route.bus_direction == 1 - assert called_arg.route_id == 2044 - - -@pytest.mark.asyncio -async def test_positions_endpoint_error_returns_500( - client: TestClient, mock_service: RouteService -) -> None: - """ - Testa se o controller retorna 500 caso o service levante exception - em /routes/positions. - """ - - mock_service.get_bus_positions.side_effect = RuntimeError("boom") # type: ignore[attr-defined] - - payload = { - "routes": [ - { - "route_id": 2044, - "route": {"bus_line": "8075"}, - } - ] - } - - response = client.post("/routes/positions", json=payload) - - assert response.status_code == 500 - body = response.json() - assert "Failed to retrieve bus positions" in body["detail"] + @pytest.mark.asyncio + async def test_search_endpoint_empty_results( + self, client: TestClient, mock_service: RouteService + ) -> None: + """Test search returning empty results.""" + mock_service.search_routes.return_value = [] # type: ignore[attr-defined] + + response = client.get("/routes/search", params={"query": "UNKNOWN"}) + + assert response.status_code == 200 + data = response.json() + assert data["routes"] == [] + + @pytest.mark.asyncio + async def test_search_endpoint_error_returns_500( + self, client: TestClient, mock_service: RouteService + ) -> None: + """Test that service exception returns 500 error.""" + mock_service.search_routes.side_effect = RuntimeError("boom") # type: ignore[attr-defined] + + response = client.get("/routes/search", params={"query": "8075"}) + + assert response.status_code == 500 + body = response.json() + assert "Failed to search routes" in body["detail"] + + +class TestBusPositions: + """Tests for the /routes/positions endpoint.""" + + @pytest.mark.asyncio + async def test_positions_endpoint_success( + self, client: TestClient, mock_service: RouteService + ) -> None: + """Test successful positions retrieval.""" + route_identifier = RouteIdentifier(bus_line="8075-10", bus_direction=1) + + position = BusPosition( + route=route_identifier, + position=Coordinate(latitude=-23.5, longitude=-46.6), + time_updated=datetime.now(UTC), + ) + + mock_service.get_bus_positions.return_value = [position] # type: ignore[attr-defined] + + payload = { + "routes": [ + { + "route_id": 2044, + "route": {"bus_line": "8075-10", "bus_direction": 1}, + } + ] + } + + response = client.post("/routes/positions", json=payload) + + assert response.status_code == 200 + data = response.json() + + assert "buses" in data + assert len(data["buses"]) == 1 + + bus = data["buses"][0] + + assert bus["route"]["bus_line"] == "8075-10" + assert "position" in bus + assert "latitude" in bus["position"] + assert "longitude" in bus["position"] + assert "time_updated" in bus + + mock_service.get_bus_positions.assert_awaited_once() # type: ignore[attr-defined] + called_arg = mock_service.get_bus_positions.await_args.args[0] # type: ignore[attr-defined] + assert len(called_arg) == 1 + assert called_arg[0].route_id == 2044 + assert called_arg[0].route.bus_line == "8075-10" + + @pytest.mark.asyncio + async def test_positions_endpoint_multiple_routes( + self, client: TestClient, mock_service: RouteService + ) -> None: + """Test positions for multiple routes.""" + position1 = BusPosition( + route=RouteIdentifier(bus_line="8075-10", bus_direction=1), + position=Coordinate(latitude=-23.5, longitude=-46.6), + time_updated=datetime.now(UTC), + ) + position2 = BusPosition( + route=RouteIdentifier(bus_line="809-10", bus_direction=1), + position=Coordinate(latitude=-23.6, longitude=-46.7), + time_updated=datetime.now(UTC), + ) + + mock_service.get_bus_positions.return_value = [position1, position2] # type: ignore[attr-defined] + + payload = { + "routes": [ + { + "route_id": 2044, + "route": {"bus_line": "8075-10", "bus_direction": 1}, + }, + { + "route_id": 5678, + "route": {"bus_line": "809-10", "bus_direction": 1}, + }, + ] + } + + response = client.post("/routes/positions", json=payload) + + assert response.status_code == 200 + data = response.json() + + assert len(data["buses"]) == 2 + mock_service.get_bus_positions.assert_awaited_once() # type: ignore[attr-defined] + + @pytest.mark.asyncio + async def test_positions_endpoint_error_returns_500( + self, client: TestClient, mock_service: RouteService + ) -> None: + """Test that service exception returns 500 error.""" + mock_service.get_bus_positions.side_effect = RuntimeError("boom") # type: ignore[attr-defined] + + payload = { + "routes": [ + { + "route_id": 2044, + "route": {"bus_line": "8075-10", "bus_direction": 1}, + } + ] + } + + response = client.post("/routes/positions", json=payload) + + assert response.status_code == 500 + body = response.json() + assert "Failed to retrieve bus positions" in body["detail"] + + @pytest.mark.asyncio + async def test_positions_endpoint_empty_routes( + self, client: TestClient, mock_service: RouteService + ) -> None: + """Test positions with empty routes list.""" + mock_service.get_bus_positions.return_value = [] # type: ignore[attr-defined] + + payload = {"routes": []} + + response = client.post("/routes/positions", json=payload) + + assert response.status_code == 200 + assert response.json()["buses"] == [] From 7e4b905163c4dc38d20384e5241382760012e432 Mon Sep 17 00:00:00 2001 From: Kim Kakeya Date: Wed, 3 Dec 2025 23:37:33 -0300 Subject: [PATCH 12/12] feat: update sptrans objects and how we get multiple positions --- requirements.txt | 1 + src/adapters/external/sptrans_adapter.py | 38 ++++----- src/adapters/external/sptrans_mappers.py | 53 ++++++++---- src/adapters/external/sptrans_schemas.py | 49 ++++++----- src/core/models/bus.py | 17 ++-- src/core/ports/bus_provider_port.py | 9 +- src/core/services/route_service.py | 17 ++-- src/web/controllers/route_controller.py | 12 +-- src/web/mappers.py | 60 ++++++------- src/web/schemas.py | 33 +++++-- tests/adapters/test_sptrans_adapter.py | 10 +-- tests/core/test_route_service.py | 47 +++------- tests/integration/test_route.py | 104 +++++++++++------------ tests/web/test_route_controller.py | 79 ++++++++--------- 14 files changed, 275 insertions(+), 254 deletions(-) diff --git a/requirements.txt b/requirements.txt index bfc3d74..b5d8ccb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,5 +18,6 @@ pytest==8.3.0 pytest-asyncio==0.24.0 pytest-mock==3.14.0 pytest-cov==7.0.0 +pytest-dotenv==0.5.2 mypy==1.11.0 ruff==0.6.0 diff --git a/src/adapters/external/sptrans_adapter.py b/src/adapters/external/sptrans_adapter.py index 12df761..f9036fe 100644 --- a/src/adapters/external/sptrans_adapter.py +++ b/src/adapters/external/sptrans_adapter.py @@ -133,36 +133,30 @@ async def _request_with_auth_retry( return response - async def get_bus_positions(self, routes: list[BusRoute]) -> list[BusPosition]: + async def get_bus_positions( + self, + route_id: int, + ) -> list[BusPosition]: """ Get real-time positions for specified routes. Args: - routes: List of BusRoute objects containing route info. + route_ids: List of provider-specific route IDs. Returns: - List of BusPosition objects with route identifier and coordinates. + List of BusPosition objects with route_id and coordinates. """ - positions: list[BusPosition] = [] - - for bus_route in routes: - response = await self._request_with_auth_retry( - "GET", - "/Posicao/Linha", - params={"codigoLinha": bus_route.route_id}, - ) - if response.status_code != 200: - continue + response = await self._request_with_auth_retry( + "GET", + "/Posicao/Linha", + params={"codigoLinha": route_id}, + ) - response_data = SPTransPositionsResponse.model_validate(response.json()) - route_positions = map_positions_response_to_bus_positions( - response_data, - bus_route.route, - ) - positions.extend(route_positions) + response_data = SPTransPositionsResponse.model_validate(response.json()) + route_positions = map_positions_response_to_bus_positions(response_data, route_id) - return positions + return route_positions async def search_routes(self, query: str) -> list[BusRoute]: """ @@ -181,9 +175,7 @@ async def search_routes(self, query: str) -> list[BusRoute]: ) if response.status_code != 200: - raise RuntimeError( - f"SPTrans returned status {response.status_code} for search." - ) + raise RuntimeError(f"SPTrans returned status {response.status_code} for search.") response_data = SPTransLineSearchResponse.model_validate(response.json()) bus_routes: list[BusRoute] = map_search_response_to_bus_route_list(response_data) diff --git a/src/adapters/external/sptrans_mappers.py b/src/adapters/external/sptrans_mappers.py index 2c66ff3..8148f5f 100644 --- a/src/adapters/external/sptrans_mappers.py +++ b/src/adapters/external/sptrans_mappers.py @@ -5,9 +5,17 @@ and internal domain objects, keeping the adapter layer clean. """ -from ...core.models.bus import BusPosition, BusRoute, RouteIdentifier +from typing import cast + +from ...core.models.bus import BusDirection, BusPosition, BusRoute, RouteIdentifier from ...core.models.coordinate import Coordinate -from .sptrans_schemas import SPTransLineInfo, SPTransLineSearchResponse, SPTransPositionsResponse, SPTransVehicle +from .sptrans_schemas import ( + SPTransLineInfo, + SPTransLineSearchResponse, + SPTransPositionsResponse, + SPTransVehicle, +) + def map_search_response_to_bus_route_list( data: SPTransLineSearchResponse, @@ -22,10 +30,11 @@ def map_search_response_to_bus_route_list( List of BusRoute domain objects. """ bus_route_list: list[BusRoute] = [] - for item in data.results: + for item in data.root: bus_route_list.append(map_line_info_to_bus_route(item)) return bus_route_list + def map_line_info_to_bus_route(line_info: SPTransLineInfo) -> BusRoute: """ Convert SPTrans line info to domain BusRoute. @@ -36,11 +45,18 @@ def map_line_info_to_bus_route(line_info: SPTransLineInfo) -> BusRoute: Returns: Domain BusRoute with route_id and identifier. """ + # Terminal name: primary if direction=1 (ida), secondary if direction=2 (volta) + terminal_name = ( + line_info.primary_terminal if line_info.direction == 1 else line_info.secondary_terminal + ) return BusRoute( - route_id=line_info.cl, + route_id=line_info.route_id, route=map_line_info_to_route_identifier(line_info), + is_circular=line_info.is_circular, + terminal_name=terminal_name, ) + def map_line_info_to_route_identifier(line_info: SPTransLineInfo) -> RouteIdentifier: """ Convert SPTrans line info to domain RouteIdentifier. @@ -49,56 +65,57 @@ def map_line_info_to_route_identifier(line_info: SPTransLineInfo) -> RouteIdenti line_info: SPTrans line information. Returns: - Domain RouteIdentifier with formatted bus_line (lt-tl format). + Domain RouteIdentifier with formatted bus_line (line_number-line_sufix format). """ - bus_line = f"{line_info.lt}-{line_info.tl}" + bus_line = f"{line_info.line_number}-{line_info.line_sufix}" return RouteIdentifier( bus_line=bus_line, - bus_direction=line_info.sl, + bus_direction=cast(BusDirection, line_info.direction), ) def map_positions_response_to_bus_positions( data: SPTransPositionsResponse, - route: RouteIdentifier, + route_id: int, ) -> list[BusPosition]: """ Convert API positions response to list of BusPosition domain objects. Args: data: SPTransPositionsResponse object. - route: RouteIdentifier for all vehicles in this response. + route_id: Provider-specific route identifier. Returns: List of domain BusPosition objects. """ positions: list[BusPosition] = [] - for vehicle in data.vs: - position = map_vehicle_to_bus_position(vehicle, route) + for vehicle in data.vehicles: + position = map_vehicle_to_bus_position(vehicle, route_id) positions.append(position) return positions + def map_vehicle_to_bus_position( vehicle: SPTransVehicle, - route: RouteIdentifier, + route_id: int, ) -> BusPosition: """ Convert SPTrans vehicle to domain BusPosition. Args: vehicle: SPTransVehicle position data. - route: RouteIdentifier for this vehicle's route. + route_id: Provider-specific route identifier. Returns: - Domain BusPosition with coordinates and route info. + Domain BusPosition with coordinates and route_id. """ return BusPosition( - route=route, + route_id=route_id, position=Coordinate( - latitude=vehicle.py, - longitude=vehicle.px, + latitude=vehicle.latitude, + longitude=vehicle.longitude, ), - time_updated=vehicle.ta, + time_updated=vehicle.time_updated, ) diff --git a/src/adapters/external/sptrans_schemas.py b/src/adapters/external/sptrans_schemas.py index f125ce6..05cb8a6 100644 --- a/src/adapters/external/sptrans_schemas.py +++ b/src/adapters/external/sptrans_schemas.py @@ -1,5 +1,4 @@ -""" -SPTrans API-specific schemas. +"""SPTrans API-specific schemas. These DTOs are used exclusively for communication with the SPTrans API and should not leak into the domain or web layers. @@ -7,35 +6,47 @@ from datetime import datetime -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, RootModel + class SPTransLineInfo(BaseModel): """Schema for SPTrans line information response.""" - cl: int = Field(..., description="Route code (internal SPTrans code)") - lc: bool = Field(..., description="Is circular route") - lt: str = Field(..., description="Line number (e.g., '8000')") - sl: int = Field(..., description="Direction (1 = ida, 2 = volta)") - tl: int = Field(..., description="Line type (10 = urban, etc.)") - tp: str = Field(..., description="Primary terminal") - ts: str = Field(..., description="Secondary terminal") + route_id: int = Field(..., alias="cl", description="Route code (internal SPTrans code)") + is_circular: bool = Field(..., alias="lc", description="Is circular route") + line_number: str = Field(..., alias="lt", description="Line number (e.g., '8000')") + direction: int = Field(..., alias="sl", description="Direction (1 = ida, 2 = volta)") + line_sufix: int = Field(..., alias="tl", description="Line type (10 = urban, etc.)") + primary_terminal: str = Field(..., alias="tp", description="Primary terminal") + secondary_terminal: str = Field(..., alias="ts", description="Secondary terminal") + + model_config = {"populate_by_name": True} + -class SPTransLineSearchResponse(BaseModel): +class SPTransLineSearchResponse(RootModel[list[SPTransLineInfo]]): """Schema for SPTrans line search response item.""" - results: list[SPTransLineInfo] = Field(..., description="List of line info results") + + root: list[SPTransLineInfo] = Field(..., description="List of line info results") + class SPTransVehicle(BaseModel): """Schema for SPTrans vehicle position.""" - p: str = Field(..., description="Vehicle prefix") - a: bool = Field(..., description="Is accessible") - ta: datetime = Field(..., description="Time updated") - py: float = Field(..., description="Latitude") - px: float = Field(..., description="Longitude") + vehicle_prefix: str = Field(..., alias="p", description="Vehicle prefix") + is_accessible: bool = Field(..., alias="a", description="Is accessible") + time_updated: datetime = Field(..., alias="ta", description="Time updated") + latitude: float = Field(..., alias="py", description="Latitude") + longitude: float = Field(..., alias="px", description="Longitude") + + model_config = {"populate_by_name": True} class SPTransPositionsResponse(BaseModel): """Schema for SPTrans positions API response.""" - hr: str = Field(..., description="Response time") - vs: list[SPTransVehicle] = Field(default_factory=list, description="List of vehicles") + response_time: str = Field(..., alias="hr", description="Response time") + vehicles: list[SPTransVehicle] = Field( + default_factory=list, alias="vs", description="List of vehicles" + ) + + model_config = {"populate_by_name": True} diff --git a/src/core/models/bus.py b/src/core/models/bus.py index 2762dd7..19e3ebe 100644 --- a/src/core/models/bus.py +++ b/src/core/models/bus.py @@ -7,6 +7,8 @@ from .coordinate import Coordinate +BusDirection = Literal[1, 2] + class RouteIdentifier(BaseModel): """ @@ -18,9 +20,7 @@ class RouteIdentifier(BaseModel): """ bus_line: str - bus_direction: Literal[1, 2] = Field( - ..., description="Direction (1 = ida, 2 = volta)" - ) + bus_direction: BusDirection = Field(..., description="Direction (1 = ida, 2 = volta)") model_config = {"frozen": True} @@ -32,10 +32,17 @@ class BusRoute(BaseModel): Attributes: route_id: Unique identifier for this route route: Route identifier containing line and direction + is_circular: Whether the route is circular + terminal_name: Terminal name (primary if direction=1, secondary if direction=2) """ route_id: int route: RouteIdentifier + is_circular: bool = Field(..., description="Whether the route is circular") + terminal_name: str = Field( + ..., + description="Terminal name (primary if direction=1, secondary if direction=2)", + ) model_config = {"frozen": True} @@ -45,12 +52,12 @@ class BusPosition(BaseModel): Real-time position of a bus. Attributes: - route: Route identifier for this bus + route_id: Provider-specific route identifier position: Geographic coordinate of the bus time_updated: Last update timestamp """ - route: RouteIdentifier + route_id: int position: Coordinate time_updated: datetime diff --git a/src/core/ports/bus_provider_port.py b/src/core/ports/bus_provider_port.py index eb95b43..d7fff0d 100644 --- a/src/core/ports/bus_provider_port.py +++ b/src/core/ports/bus_provider_port.py @@ -15,15 +15,18 @@ class BusProviderPort(ABC): """ @abstractmethod - async def get_bus_positions(self, routes: list[BusRoute]) -> list[BusPosition]: + async def get_bus_positions( + self, + route_id: int, + ) -> list[BusPosition]: """ Get real-time positions for specified routes. Args: - routes: List of BusRoute objects containing route info. + route_ids: List of provider-specific route IDs. Returns: - List of current bus positions with route identifiers. + List of current bus positions with route_id. Raises: RuntimeError: If API call fails or authentication fails. diff --git a/src/core/services/route_service.py b/src/core/services/route_service.py index 429fcd5..e1eb7b1 100644 --- a/src/core/services/route_service.py +++ b/src/core/services/route_service.py @@ -14,9 +14,7 @@ class RouteService: real-time bus information and GTFS data for route shapes. """ - def __init__( - self, bus_provider: BusProviderPort, gtfs_repository: GTFSRepositoryPort - ): + def __init__(self, bus_provider: BusProviderPort, gtfs_repository: GTFSRepositoryPort): """ Initialize the route service. @@ -27,12 +25,15 @@ def __init__( self.bus_provider = bus_provider self.gtfs_repository = gtfs_repository - async def get_bus_positions(self, routes: list[BusRoute]) -> list[BusPosition]: + async def get_bus_positions( + self, + route_ids: list[int], + ) -> list[BusPosition]: """ Get current positions for specified routes. Args: - routes: List of BusRoute objects containing route info. + route_ids: List of provider-specific route IDs. Returns: List of current bus positions. @@ -40,7 +41,11 @@ async def get_bus_positions(self, routes: list[BusRoute]) -> list[BusPosition]: Raises: RuntimeError: If API request fails. """ - return await self.bus_provider.get_bus_positions(routes) + positions: list[BusPosition] = [] + for route_id in route_ids: + route_positions: list[BusPosition] = await self.bus_provider.get_bus_positions(route_id) + positions.extend(route_positions) + return positions async def search_routes(self, query: str) -> list[BusRoute]: """ diff --git a/src/web/controllers/route_controller.py b/src/web/controllers/route_controller.py index 593c610..96ce612 100644 --- a/src/web/controllers/route_controller.py +++ b/src/web/controllers/route_controller.py @@ -16,7 +16,7 @@ from ..mappers import ( map_bus_position_list_to_schema, map_bus_route_domain_list_to_schema, - map_bus_route_schema_list_to_domain, + map_bus_route_request_list, map_route_shape_to_response, ) from ..schemas import ( @@ -87,19 +87,21 @@ async def get_bus_positions( Get real-time bus positions for specified routes. Args: - request: Request containing list of routes. + request: Request containing list of route_ids. route_service: Injected route service. current_user: Authenticated user. Returns: - List of bus positions with route identifiers. + List of bus positions with route_id. Raises: HTTPException: If fetching positions fails. """ try: - bus_routes = map_bus_route_schema_list_to_domain(request.routes) - positions: list[BusPosition] = await route_service.get_bus_positions(bus_routes) + # Extract route_ids from request + route_ids = map_bus_route_request_list(request.routes) + + positions: list[BusPosition] = await route_service.get_bus_positions(route_ids) position_schemas = map_bus_position_list_to_schema(positions) return BusPositionsResponse(buses=position_schemas) diff --git a/src/web/mappers.py b/src/web/mappers.py index ca10168..b20fead 100644 --- a/src/web/mappers.py +++ b/src/web/mappers.py @@ -5,13 +5,16 @@ maintaining the separation of concerns. """ -from ..core.models.bus import BusPosition, BusRoute, RouteIdentifier +from typing import cast + +from ..core.models.bus import BusDirection, BusPosition, BusRoute, RouteIdentifier from ..core.models.coordinate import Coordinate from ..core.models.route_shape import RouteShape from ..core.models.user import User from .schemas import ( BusPositionSchema, - BusRouteSchema, + BusRouteRequestSchema, + BusRouteResponseSchema, CoordinateSchema, RouteIdentifierSchema, RouteShapeResponse, @@ -68,7 +71,7 @@ def map_route_identifier_schema_to_domain( """ return RouteIdentifier( bus_line=schema.bus_line, - bus_direction=schema.bus_direction, + bus_direction=cast(BusDirection, schema.bus_direction), ) @@ -90,64 +93,65 @@ def map_route_identifier_domain_to_schema( ) -def map_bus_route_schema_to_domain(schema: BusRouteSchema) -> BusRoute: +def map_bus_route_request_to_route_id( + schema: BusRouteRequestSchema, +) -> int: """ - Map a BusRouteSchema to a BusRoute domain model. + Extract route_id from a BusRouteRequestSchema. Args: - schema: BusRouteSchema from API + schema: BusRouteRequestSchema from API request Returns: - BusRoute domain model + route_id (int) """ - return BusRoute( - route_id=schema.route_id, - route=map_route_identifier_schema_to_domain(schema.route), - ) + return schema.route_id -def map_bus_route_schema_list_to_domain( - schemas: list[BusRouteSchema], -) -> list[BusRoute]: +def map_bus_route_request_list( + schemas: list[BusRouteRequestSchema], +) -> list[int]: """ - Map a list of BusRouteSchema to BusRoute domain models. + Map a list of BusRouteRequestSchema to a list of route_ids. Args: - schemas: List of BusRouteSchema from API + schemas: List of BusRouteRequestSchema from API request Returns: - List of BusRoute domain models + List of route_ids """ - return [map_bus_route_schema_to_domain(schema) for schema in schemas] + return [schema.route_id for schema in schemas] -def map_bus_route_domain_to_schema(bus_route: BusRoute) -> BusRouteSchema: +def map_bus_route_domain_to_schema(bus_route: BusRoute) -> BusRouteResponseSchema: """ - Map a BusRoute domain model to a BusRouteSchema. + Map a BusRoute domain model to a BusRouteResponseSchema. Args: bus_route: BusRoute domain model Returns: - BusRouteSchema for API + BusRouteResponseSchema for API response """ - return BusRouteSchema( + return BusRouteResponseSchema( route_id=bus_route.route_id, route=map_route_identifier_domain_to_schema(bus_route.route), + is_circular=bus_route.is_circular, + terminal_name=bus_route.terminal_name, ) def map_bus_route_domain_list_to_schema( bus_routes: list[BusRoute], -) -> list[BusRouteSchema]: +) -> list[BusRouteResponseSchema]: """ - Map a list of BusRoute domain models to BusRouteSchema list. + Map a list of BusRoute domain models to BusRouteResponseSchema list. Args: bus_routes: List of BusRoute domain models Returns: - List of BusRouteSchema for API + List of BusRouteResponseSchema for API response """ return [map_bus_route_domain_to_schema(bus_route) for bus_route in bus_routes] @@ -179,7 +183,7 @@ def map_bus_position_domain_to_schema(position: BusPosition) -> BusPositionSchem BusPositionSchema for API """ return BusPositionSchema( - route=map_route_identifier_domain_to_schema(position.route), + route_id=position.route_id, position=map_coordinate_domain_to_schema(position.position), time_updated=position.time_updated, ) @@ -216,7 +220,5 @@ def map_route_shape_to_response(shape: RouteShape) -> RouteShapeResponse: return RouteShapeResponse( route_id=shape.route_id, shape_id=shape.shape_id, - points=[ - map_coordinate_domain_to_schema(point.coordinate) for point in shape.points - ], + points=[map_coordinate_domain_to_schema(point.coordinate) for point in shape.points], ) diff --git a/src/web/schemas.py b/src/web/schemas.py index 2765905..07c43d5 100644 --- a/src/web/schemas.py +++ b/src/web/schemas.py @@ -36,18 +36,35 @@ class CoordinateSchema(BaseModel): class BusPositionSchema(BaseModel): """Schema for bus position information.""" - route: RouteIdentifierSchema + route_id: int = Field(..., description="Provider-specific route identifier") position: CoordinateSchema time_updated: datetime = Field(..., description="Last update timestamp") model_config = {"populate_by_name": True} -class BusRouteSchema(BaseModel): - """Schema for a resolved bus route (provider-specific ID + identifier).""" +class BusRouteRequestSchema(BaseModel): + """Schema for bus route in requests (input from user). + + Only requires route_id for querying positions. + """ + + route_id: int = Field(..., description="Provider-specific route identifier") + + +class BusRouteResponseSchema(BaseModel): + """Schema for bus route in responses (output to user). + + Includes full route information with metadata. + """ route_id: int = Field(..., description="Provider-specific route identifier") route: RouteIdentifierSchema + is_circular: bool = Field(..., description="Whether the route is circular") + terminal_name: str = Field( + ..., + description="Terminal name (primary if direction=1, secondary if direction=2)", + ) # ===== User Management Schemas ===== @@ -108,8 +125,8 @@ class CreateTripResponse(BaseModel): class BusPositionsRequest(BaseModel): """Request schema for querying bus positions.""" - routes: list[BusRouteSchema] = Field( - ..., description="List of routes to query positions" + routes: list[BusRouteRequestSchema] = Field( + ..., description="List of routes to query positions for" ) @@ -132,7 +149,7 @@ class RouteSearchRequest(BaseModel): class RouteSearchResponse(BaseModel): """Response schema for route search results.""" - routes: list[BusRouteSchema] = Field( + routes: list[BusRouteResponseSchema] = Field( ..., description="List of matching routes with provider IDs" ) @@ -142,9 +159,7 @@ class RouteShapeResponse(BaseModel): route_id: str = Field(..., description="Route identifier") shape_id: str = Field(..., description="GTFS shape identifier") - points: list[CoordinateSchema] = Field( - ..., description="Ordered list of coordinates" - ) + points: list[CoordinateSchema] = Field(..., description="Ordered list of coordinates") # ===== Ranking Schemas ===== diff --git a/tests/adapters/test_sptrans_adapter.py b/tests/adapters/test_sptrans_adapter.py index c0173d6..55a9e3f 100644 --- a/tests/adapters/test_sptrans_adapter.py +++ b/tests/adapters/test_sptrans_adapter.py @@ -5,7 +5,6 @@ from src.adapters.external.sptrans_adapter import SpTransAdapter from src.core.models.bus import BusPosition, BusRoute - skip_if_no_token = pytest.mark.skipif( not os.getenv("SPTRANS_API_TOKEN"), reason="SPTRANS_API_TOKEN not set - skipping integration test", @@ -22,9 +21,7 @@ async def test_automatic_authentication() -> None: routes: list[BusRoute] = await adapter.search_routes("8075") - assert ( - "apiCredentials" in adapter.client.cookies - ), "Cookie de credenciais não foi criado." + assert "apiCredentials" in adapter.client.cookies, "Cookie de credenciais não foi criado." assert len(routes) > 0 @@ -75,7 +72,7 @@ async def test_get_bus_positions() -> None: chosen_route: BusRoute = bus_routes[0] assert chosen_route.route_id > 0 - positions: list[BusPosition] = await adapter.get_bus_positions([chosen_route]) + positions: list[BusPosition] = await adapter.get_bus_positions(chosen_route.route_id) assert positions is not None assert isinstance(positions, list) @@ -83,9 +80,6 @@ async def test_get_bus_positions() -> None: if positions: pos: BusPosition = positions[0] - assert "8075" in pos.route.bus_line - assert pos.route.bus_direction in (1, 2) - assert isinstance(pos.position.latitude, float | int) assert isinstance(pos.position.longitude, float | int) diff --git a/tests/core/test_route_service.py b/tests/core/test_route_service.py index 8152e46..df8558e 100644 --- a/tests/core/test_route_service.py +++ b/tests/core/test_route_service.py @@ -28,14 +28,9 @@ async def test_get_bus_positions_calls_provider() -> None: bus_provider: BusProviderPort = cast(BusProviderPort, raw_provider) - route_identifier: RouteIdentifier = RouteIdentifier( - bus_line="8075-10", - bus_direction=1, - ) - expected_positions: list[BusPosition] = [ BusPosition( - route=route_identifier, + route_id=1234, position=Coordinate(latitude=-23.0, longitude=-46.0), time_updated=datetime.now(UTC), ), @@ -44,16 +39,11 @@ async def test_get_bus_positions_calls_provider() -> None: raw_provider.get_bus_positions.return_value = expected_positions gtfs_repo = create_autospec(GTFSRepositoryPort, instance=True) - service: RouteService = RouteService( - bus_provider=bus_provider, gtfs_repository=gtfs_repo - ) + service: RouteService = RouteService(bus_provider=bus_provider, gtfs_repository=gtfs_repo) - routes = [ - BusRoute(route_id=1234, route=route_identifier), - ] - result: list[BusPosition] = await service.get_bus_positions(routes) + result: list[BusPosition] = await service.get_bus_positions([1234]) - raw_provider.get_bus_positions.assert_awaited_once_with(routes) + raw_provider.get_bus_positions.assert_awaited_once_with(1234) assert result == expected_positions @@ -72,6 +62,8 @@ async def test_search_routes_calls_provider() -> None: expected_bus_route: BusRoute = BusRoute( route_id=1234, route=route_identifier, + is_circular=False, + terminal_name="Terminal A", ) expected_routes: list[BusRoute] = [expected_bus_route] @@ -79,9 +71,7 @@ async def test_search_routes_calls_provider() -> None: raw_provider.search_routes.return_value = expected_routes gtfs_repo = create_autospec(GTFSRepositoryPort, instance=True) - service: RouteService = RouteService( - bus_provider=bus_provider, gtfs_repository=gtfs_repo - ) + service: RouteService = RouteService(bus_provider=bus_provider, gtfs_repository=gtfs_repo) query = "8075" result: list[BusRoute] = await service.search_routes(query) @@ -99,21 +89,12 @@ async def test_get_bus_positions_propagates_exception_from_provider() -> None: bus_provider: BusProviderPort = cast(BusProviderPort, raw_provider) gtfs_repo = create_autospec(GTFSRepositoryPort, instance=True) - service: RouteService = RouteService( - bus_provider=bus_provider, gtfs_repository=gtfs_repo - ) - - routes = [ - BusRoute( - route_id=1234, - route=RouteIdentifier(bus_line="8075-10", bus_direction=1), - ), - ] + service: RouteService = RouteService(bus_provider=bus_provider, gtfs_repository=gtfs_repo) with pytest.raises(RuntimeError, match="boom"): - await service.get_bus_positions(routes) + await service.get_bus_positions([1234]) - raw_provider.get_bus_positions.assert_awaited_once_with(routes) + raw_provider.get_bus_positions.assert_awaited_once_with(1234) @pytest.mark.asyncio @@ -125,9 +106,7 @@ async def test_search_routes_propagates_exception_from_provider() -> None: bus_provider: BusProviderPort = cast(BusProviderPort, raw_provider) gtfs_repo = create_autospec(GTFSRepositoryPort, instance=True) - service: RouteService = RouteService( - bus_provider=bus_provider, gtfs_repository=gtfs_repo - ) + service: RouteService = RouteService(bus_provider=bus_provider, gtfs_repository=gtfs_repo) with pytest.raises(RuntimeError, match="search failed"): await service.search_routes("8075") @@ -192,9 +171,7 @@ def test_get_route_shape_with_many_points() -> None: points = [ RouteShapePoint( - coordinate=Coordinate( - latitude=-23.5505 + i * 0.001, longitude=-46.6333 + i * 0.001 - ), + coordinate=Coordinate(latitude=-23.5505 + i * 0.001, longitude=-46.6333 + i * 0.001), sequence=i + 1, distance_traveled=float(i * 10), ) diff --git a/tests/integration/test_route.py b/tests/integration/test_route.py index 3a07b0e..9a72b6c 100644 --- a/tests/integration/test_route.py +++ b/tests/integration/test_route.py @@ -8,8 +8,7 @@ from src.core.models.coordinate import Coordinate from src.web.schemas import ( BusPositionsRequest, - BusRouteSchema, - RouteIdentifierSchema, + BusRouteRequestSchema, ) from .conftest import create_user_and_login @@ -31,7 +30,12 @@ async def test_search_routes_returns_successfully( mock_bus_routes = [ BusRoute( route_id=12345, - route=RouteIdentifier(bus_line="8000", bus_direction=1), + route=RouteIdentifier( + bus_line="8000", + bus_direction=1, + ), + is_circular=False, + terminal_name="Terminal A", ) ] @@ -74,11 +78,21 @@ async def test_search_routes_returns_multiple_results( mock_bus_routes = [ BusRoute( route_id=12345, - route=RouteIdentifier(bus_line="8000", bus_direction=1), + route=RouteIdentifier( + bus_line="8000", + bus_direction=1, + ), + is_circular=False, + terminal_name="Terminal A", ), BusRoute( route_id=12346, - route=RouteIdentifier(bus_line="8000", bus_direction=2), + route=RouteIdentifier( + bus_line="8000", + bus_direction=2, + ), + is_circular=False, + terminal_name="Terminal B", ), ] @@ -198,12 +212,12 @@ async def test_get_bus_position_returns_successfully( mock_positions = [ BusPosition( - route=RouteIdentifier(bus_line="8000", bus_direction=1), + route_id=12345, position=Coordinate(latitude=-23.550520, longitude=-46.633308), time_updated=datetime.now(UTC), ), BusPosition( - route=RouteIdentifier(bus_line="8000", bus_direction=1), + route_id=12345, position=Coordinate(latitude=-23.551234, longitude=-46.634567), time_updated=datetime.now(UTC), ), @@ -216,10 +230,7 @@ async def test_get_bus_position_returns_successfully( ): request_data = BusPositionsRequest( routes=[ - BusRouteSchema( - route_id=12345, - route=RouteIdentifierSchema(bus_line="8000", bus_direction=1), - ), + BusRouteRequestSchema(route_id=12345), ] ) @@ -236,9 +247,8 @@ async def test_get_bus_position_returns_successfully( assert len(data["buses"]) == 2 first_bus = data["buses"][0] - assert "route" in first_bus - assert first_bus["route"]["bus_line"] == "8000" - assert first_bus["route"]["bus_direction"] == 1 + assert "route_id" in first_bus + assert first_bus["route_id"] == 12345 assert "position" in first_bus assert "latitude" in first_bus["position"] assert "longitude" in first_bus["position"] @@ -263,10 +273,7 @@ async def test_get_bus_position_returns_500_when_error( ): request_data = BusPositionsRequest( routes=[ - BusRouteSchema( - route_id=99999, - route=RouteIdentifierSchema(bus_line="123", bus_direction=1), - ), + BusRouteRequestSchema(route_id=99999), ] ) @@ -298,10 +305,7 @@ async def test_get_bus_position_returns_empty_when_no_buses_on_line( ): request_data = BusPositionsRequest( routes=[ - BusRouteSchema( - route_id=12345, - route=RouteIdentifierSchema(bus_line="8000", bus_direction=1), - ), + BusRouteRequestSchema(route_id=12345), ] ) @@ -329,34 +333,37 @@ async def test_get_bus_position_works_with_multiple_routes( } auth = await create_user_and_login(client, user_data) - mock_positions = [ + mock_position_12345 = [ BusPosition( - route=RouteIdentifier(bus_line="8000", bus_direction=1), + route_id=12345, position=Coordinate(latitude=-23.550520, longitude=-46.633308), time_updated=datetime.now(UTC), ), + ] + mock_position_67890 = [ BusPosition( - route=RouteIdentifier(bus_line="9000", bus_direction=2), + route_id=67890, position=Coordinate(latitude=-23.560520, longitude=-46.643308), time_updated=datetime.now(UTC), ), ] + def mock_get_positions(route_id: int) -> list[BusPosition]: + if route_id == 12345: + return mock_position_12345 + elif route_id == 67890: + return mock_position_67890 + return [] + with patch( "src.adapters.external.sptrans_adapter.SpTransAdapter.get_bus_positions", new_callable=AsyncMock, - return_value=mock_positions, + side_effect=mock_get_positions, ): request_data = BusPositionsRequest( routes=[ - BusRouteSchema( - route_id=12345, - route=RouteIdentifierSchema(bus_line="8000", bus_direction=1), - ), - BusRouteSchema( - route_id=67890, - route=RouteIdentifierSchema(bus_line="9000", bus_direction=2), - ), + BusRouteRequestSchema(route_id=12345), + BusRouteRequestSchema(route_id=67890), ] ) @@ -370,9 +377,9 @@ async def test_get_bus_position_works_with_multiple_routes( data = response.json() assert len(data["buses"]) == 2 - bus_lines = [bus["route"]["bus_line"] for bus in data["buses"]] - assert "8000" in bus_lines - assert "9000" in bus_lines + route_ids = [bus["route_id"] for bus in data["buses"]] + assert 12345 in route_ids + assert 67890 in route_ids @pytest.mark.asyncio async def test_get_bus_position_returns_500_when_api_error( @@ -393,10 +400,7 @@ async def test_get_bus_position_returns_500_when_api_error( ): request_data = BusPositionsRequest( routes=[ - BusRouteSchema( - route_id=12345, - route=RouteIdentifierSchema(bus_line="8000", bus_direction=1), - ), + BusRouteRequestSchema(route_id=12345), ] ) @@ -420,13 +424,8 @@ async def test_get_bus_position_returns_422_when_invalid_data( } auth = await create_user_and_login(client, user_data) - invalid_request_data: dict[str, list[dict[str, int | dict[str, str | int]]]] = { - "routes": [ - { - "route_id": 12345, - "route": {"bus_line": "8000", "bus_direction": 3}, - } - ] + invalid_request_data: dict[str, list[dict[str, str]]] = { + "routes": [{"route_id": "not_an_int"}] } response = await client.post( @@ -472,16 +471,11 @@ async def test_get_bus_position_without_auth_fails( ) -> None: request_data = BusPositionsRequest( routes=[ - BusRouteSchema( - route_id=12345, - route=RouteIdentifierSchema(bus_line="8000", bus_direction=1), - ), + BusRouteRequestSchema(route_id=12345), ] ) - response = await client.post( - "/routes/positions", json=request_data.model_dump() - ) + response = await client.post("/routes/positions", json=request_data.model_dump()) assert response.status_code == 401 diff --git a/tests/web/test_route_controller.py b/tests/web/test_route_controller.py index 03ab1f8..a963d7e 100644 --- a/tests/web/test_route_controller.py +++ b/tests/web/test_route_controller.py @@ -46,13 +46,28 @@ def override_dependency( class TestSearchRoutes: - @pytest.mark.asyncio async def test_search_endpoint_success( self, client: TestClient, mock_service: RouteService ) -> None: - bus_route_1 = BusRoute(route_id=2044, route=RouteIdentifier(bus_line="8075", bus_direction=1)) - bus_route_2 = BusRoute(route_id=34812, route=RouteIdentifier(bus_line="8075", bus_direction=2)) + bus_route_1 = BusRoute( + route_id=2044, + route=RouteIdentifier( + bus_line="8075", + bus_direction=1, + ), + is_circular=False, + terminal_name="Terminal A", + ) + bus_route_2 = BusRoute( + route_id=34812, + route=RouteIdentifier( + bus_line="8075", + bus_direction=2, + ), + is_circular=False, + terminal_name="Terminal B", + ) mock_service.search_routes.return_value = [bus_route_1, bus_route_2] # type: ignore[attr-defined] @@ -81,8 +96,16 @@ async def test_search_endpoint_with_destination_name( self, client: TestClient, mock_service: RouteService ) -> None: """Test search with destination name query.""" - route_identifier = RouteIdentifier(bus_line="809", bus_direction=1) - bus_route = BusRoute(route_id=1234, route=route_identifier) + route_identifier = RouteIdentifier( + bus_line="809", + bus_direction=1, + ) + bus_route = BusRoute( + route_id=1234, + route=route_identifier, + is_circular=False, + terminal_name="Vila Nova Conceição", + ) mock_service.search_routes.return_value = [bus_route] # type: ignore[attr-defined] @@ -133,24 +156,15 @@ async def test_positions_endpoint_success( self, client: TestClient, mock_service: RouteService ) -> None: """Test successful positions retrieval.""" - route_identifier = RouteIdentifier(bus_line="8075-10", bus_direction=1) - position = BusPosition( - route=route_identifier, + route_id=2044, position=Coordinate(latitude=-23.5, longitude=-46.6), time_updated=datetime.now(UTC), ) mock_service.get_bus_positions.return_value = [position] # type: ignore[attr-defined] - payload = { - "routes": [ - { - "route_id": 2044, - "route": {"bus_line": "8075-10", "bus_direction": 1}, - } - ] - } + payload = {"routes": [{"route_id": 2044}]} response = client.post("/routes/positions", json=payload) @@ -162,17 +176,17 @@ async def test_positions_endpoint_success( bus = data["buses"][0] - assert bus["route"]["bus_line"] == "8075-10" + assert bus["route_id"] == 2044 assert "position" in bus assert "latitude" in bus["position"] assert "longitude" in bus["position"] assert "time_updated" in bus mock_service.get_bus_positions.assert_awaited_once() # type: ignore[attr-defined] - called_arg = mock_service.get_bus_positions.await_args.args[0] # type: ignore[attr-defined] - assert len(called_arg) == 1 - assert called_arg[0].route_id == 2044 - assert called_arg[0].route.bus_line == "8075-10" + # Service now receives route_ids: list[int] + called_args = mock_service.get_bus_positions.await_args.args # type: ignore[attr-defined] + route_ids = called_args[0] + assert route_ids == [2044] @pytest.mark.asyncio async def test_positions_endpoint_multiple_routes( @@ -180,12 +194,12 @@ async def test_positions_endpoint_multiple_routes( ) -> None: """Test positions for multiple routes.""" position1 = BusPosition( - route=RouteIdentifier(bus_line="8075-10", bus_direction=1), + route_id=2044, position=Coordinate(latitude=-23.5, longitude=-46.6), time_updated=datetime.now(UTC), ) position2 = BusPosition( - route=RouteIdentifier(bus_line="809-10", bus_direction=1), + route_id=5678, position=Coordinate(latitude=-23.6, longitude=-46.7), time_updated=datetime.now(UTC), ) @@ -194,14 +208,8 @@ async def test_positions_endpoint_multiple_routes( payload = { "routes": [ - { - "route_id": 2044, - "route": {"bus_line": "8075-10", "bus_direction": 1}, - }, - { - "route_id": 5678, - "route": {"bus_line": "809-10", "bus_direction": 1}, - }, + {"route_id": 2044}, + {"route_id": 5678}, ] } @@ -220,14 +228,7 @@ async def test_positions_endpoint_error_returns_500( """Test that service exception returns 500 error.""" mock_service.get_bus_positions.side_effect = RuntimeError("boom") # type: ignore[attr-defined] - payload = { - "routes": [ - { - "route_id": 2044, - "route": {"bus_line": "8075-10", "bus_direction": 1}, - } - ] - } + payload = {"routes": [{"route_id": 2044}]} response = client.post("/routes/positions", json=payload)