From 0095cba872a66bf1bb8444789a0f0c9e636a0ea4 Mon Sep 17 00:00:00 2001 From: Nikith Shetty Date: Fri, 8 Aug 2025 17:33:14 +0530 Subject: [PATCH 1/2] added FleetInfo pydantic class --- amnex-live-data-server.py | 40 ++++++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/amnex-live-data-server.py b/amnex-live-data-server.py index 6aa9b12..61139a4 100644 --- a/amnex-live-data-server.py +++ b/amnex-live-data-server.py @@ -14,6 +14,7 @@ from sqlalchemy.pool import QueuePool from typing import Optional, List from concurrent.futures import ThreadPoolExecutor +from pydantic import BaseModel import math import traceback from geopy.distance import geodesic @@ -774,7 +775,13 @@ def set(self, key: str, value, ttl: Optional[int] = None): # Create single cache instance cache = SimpleCache() -def get_fleet_info(device_id: str, current_lat: float = None, current_lon: float = None, timestamp: int = None, provider: str = None) -> dict: +class FleetInfo(BaseModel): + """Pydantic model for fleet information returned by get_fleet_info function""" + vehicle_no: str + device_id: str + route_id: str + +def get_fleet_info(device_id: str, current_lat: float = None, current_lon: float = None, timestamp: int = None, provider: str = None) -> List[FleetInfo]: """Get both fleet number and route ID for a device""" cache_key = f"fleetInfo:{device_id}" cache_key_saved = cache_key + ":saved" @@ -784,10 +791,11 @@ def get_fleet_info(device_id: str, current_lat: float = None, current_lon: float # Check cache first fleet_info_str = redis_client.get(cache_key) if fleet_info_str is not None: - fleet_infos = json.loads(fleet_info_str) + fleet_infos_data = json.loads(fleet_info_str) + fleet_infos = [FleetInfo(**fleet_info) for fleet_info in fleet_infos_data] for fleet_info in fleet_infos: if current_lat is not None and current_lon is not None: - store_vehicle_location_history(fleet_info['vehicle_no'], current_lat, current_lon, timestamp) + store_vehicle_location_history(fleet_info.vehicle_no, current_lat, current_lon, timestamp) return fleet_infos try: @@ -798,11 +806,11 @@ def get_fleet_info(device_id: str, current_lat: float = None, current_lon: float # Get route for fleet route_ids = get_route_ids_from_waybills(vehicle_no, current_lat, current_lon, timestamp, provider) for route_id in route_ids: - val = { - 'vehicle_no': vehicle_no, - 'device_id': device_id, - 'route_id': route_id - } + fleet_info = FleetInfo( + vehicle_no=vehicle_no, + device_id=device_id, + route_id=route_id + ) try: fleet_info_saved = redis_client.get(cache_key_saved) if fleet_info_saved is not None: @@ -815,10 +823,12 @@ def get_fleet_info(device_id: str, current_lat: float = None, current_lon: float clean_redis_key_for_route_info(fleet_info_saved['route_id'], route_key) except Exception as e: logger.error(f"Error cleaning redis key for route info: {e}") - fleet_mapping_values.append(val) + fleet_mapping_values.append(fleet_info) if len(route_ids) > 0: - redis_client.setex(cache_key_saved, BUS_LOCATION_MAX_AGE + BUS_CLEANUP_INTERVAL, json.dumps(fleet_mapping_values)) # hack for cleanup if route changes - redis_client.setex(cache_key, BUS_CLEANUP_INTERVAL, json.dumps(fleet_mapping_values)) + # Convert FleetInfo objects to dicts for JSON serialization to Redis + fleet_mapping_dicts = [fleet_info.model_dump() for fleet_info in fleet_mapping_values] + redis_client.setex(cache_key_saved, BUS_LOCATION_MAX_AGE + BUS_CLEANUP_INTERVAL, json.dumps(fleet_mapping_dicts)) # hack for cleanup if route changes + redis_client.setex(cache_key, BUS_CLEANUP_INTERVAL, json.dumps(fleet_mapping_dicts)) return fleet_mapping_values except Exception as e: print(f"Error querying fleet info for device {device_id}: {e}") @@ -1678,9 +1688,9 @@ def handle_client_data(payload, client_ip, serverTime, isNYGpsDevice = False, se if not fleet_infos: push_to_kafka(entity) for fleet_info in fleet_infos: - entity['routeNumber'] = fleet_info.get('route_id') - if fleet_info and 'route_id' in fleet_info and fleet_info["route_id"] != None: - route_id = fleet_info['route_id'] + entity['routeNumber'] = fleet_info.route_id + if fleet_info and fleet_info.route_id is not None: + route_id = fleet_info.route_id stopsInfo = stop_tracker.get_route_stops(route_id) @@ -1698,7 +1708,7 @@ def handle_client_data(payload, client_ip, serverTime, isNYGpsDevice = False, se serverTime, vehicle_id=deviceId, visited_stops=visited_stops, - vehicle_no=fleet_info.get('vehicle_no', deviceId) + vehicle_no=fleet_info.vehicle_no ) if len(visited_stops) > len(before_curr_point_visited_stops): entity['stopId'] = visited_stops[-1] From daf8bc4f8a931937c223385e46093c67065aaec2 Mon Sep 17 00:00:00 2001 From: Nikith Shetty Date: Fri, 8 Aug 2025 20:40:31 +0530 Subject: [PATCH 2/2] refactored amnex-live-data-server.py into sub-modules --- amnex-live-data-server.py | 1107 ++----------------------------------- src/__init__.py | 0 src/cache_utils.py | 116 ++++ src/fleet_management.py | 235 ++++++++ src/geometry_utils.py | 199 +++++++ src/models.py | 65 +++ src/route_matching.py | 139 +++++ src/stop_tracker.py | 429 ++++++++++++++ 8 files changed, 1235 insertions(+), 1055 deletions(-) create mode 100644 src/__init__.py create mode 100644 src/cache_utils.py create mode 100644 src/fleet_management.py create mode 100644 src/geometry_utils.py create mode 100644 src/models.py create mode 100644 src/route_matching.py create mode 100644 src/stop_tracker.py diff --git a/amnex-live-data-server.py b/amnex-live-data-server.py index 61139a4..c9a0044 100644 --- a/amnex-live-data-server.py +++ b/amnex-live-data-server.py @@ -1,27 +1,33 @@ -import socket -import polyline as gpolyline -from confluent_kafka import Producer, KafkaError, KafkaException -import os +import atexit import json -from datetime import datetime, date, timedelta +import logging +import math +import os +import socket import threading -from rediscluster import RedisCluster -import redis import time -from sqlalchemy import create_engine, Column, Integer, String, DateTime, Boolean, BigInteger, Text, select, func -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker -from sqlalchemy.pool import QueuePool -from typing import Optional, List -from concurrent.futures import ThreadPoolExecutor -from pydantic import BaseModel -import math import traceback -from geopy.distance import geodesic -import logging -import atexit +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime, date, timedelta +from typing import Optional, List + import paho.mqtt.client as mqtt +import redis import requests +from confluent_kafka import Producer +from rediscluster import RedisCluster +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import QueuePool + +# Import refactored modules from src package +from src.cache_utils import SimpleCache, get_vehicle_location_history +from src.fleet_management import ( + get_fleet_info, clean_redis_key_for_route_info, load_device_vehicle_mappings, + clean_outdated_vehicle_mappings, start_vehicle_cleanup_thread +) +from src.models import Base, WaybillsBase +from src.stop_tracker import StopTracker # Configure logging logging.basicConfig( @@ -114,7 +120,6 @@ } ) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) -Base = declarative_base() # SQLAlchemy setup for waybills database WAYBILLS_DATABASE_URL = f"postgresql://{WAYBILLS_DB_USER}:{WAYBILLS_DB_PASS}@{WAYBILLS_DB_HOST}:{WAYBILLS_DB_PORT}/{WAYBILLS_DB_NAME}" @@ -128,111 +133,10 @@ pool_recycle=1800 ) WaybillsSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=waybills_engine) -WaybillsBase = declarative_base() - -# Update the SQLAlchemy models -class DeviceVehicleMapping(Base): - __tablename__ = "device_vehicle_mapping" - __table_args__ = {'schema': 'atlas_app'} - vehicle_no = Column(Text, index=True) - device_id = Column(Text, index=True, primary_key=True) - -class RoutePolyline(Base): - __tablename__ = "route_polylines" - __table_args__ = {'schema': 'atlas_app'} - - route_id = Column(BigInteger, primary_key=True) - polyline = Column(Text) - merchant_operating_city_id = Column(Text, primary_key=True) - -# Waybills database models -class Waybill(Base): - __tablename__ = "waybills" - waybill_id = Column(BigInteger, primary_key=True) - schedule_id = Column(BigInteger) - schedule_trip_id = Column(BigInteger) - deleted = Column(Boolean, nullable=False, default=False) - schedule_no = Column(Text) - schedule_trip_name = Column(Text) - schedule_type = Column(Text) - service_type = Column(Text) - updated_at = Column(DateTime) - status = Column(Text) - vehicle_no = Column(Text) - -class BusSchedule(Base): - __tablename__ = "bus_schedule" - - schedule_id = Column(BigInteger, primary_key=True) - deleted = Column(Boolean, nullable=False, default=False) - route_code = Column(Text) - status = Column(Text) - route_id = Column(BigInteger, nullable=False) - -class BusScheduleTripDetail(Base): - __tablename__ = "bus_schedule_trip_detail" - - schedule_trip_detail_id = Column(BigInteger, primary_key=True) - schedule_trip_id = Column(BigInteger) - deleted = Column(Boolean, nullable=False, default=False) - route_number_id = Column(BigInteger, nullable=False) -def get_route_ids_from_waybills(vehicle_no: str, current_lat: float = None, current_lon: float = None, timestamp: int = None, provider: str = None) -> Optional[str]: - """Get the route_id from waybills database for a given vehicle number""" - try: - with WaybillsSessionLocal() as db: - # First get the active waybill for the vehicle - waybill = db.query(Waybill)\ - .filter( - Waybill.vehicle_no == vehicle_no, - Waybill.deleted == False, - Waybill.status == 'Online' - )\ - .order_by(Waybill.updated_at.desc())\ - .first() - - if not waybill: - return None - - if current_lat is not None and current_lon is not None: - store_vehicle_location_history(vehicle_no, current_lat, current_lon, timestamp) - # Add current location to history if provided - location_history = get_vehicle_location_history(vehicle_no) - if len(location_history) < 5: - return None - - # Then get all possible routes from bus_schedule - schedules = db.query(BusScheduleTripDetail)\ - .filter( - BusScheduleTripDetail.schedule_trip_id == waybill.schedule_trip_id, - BusScheduleTripDetail.deleted == False - )\ - .all() # Execute the query to get results - - if len(schedules) == 0: - return None - print(f"Route ID: Bus scheudle len {len(schedules)}") - - best_route_ids = [] - routes_match_score = {} - for schedule in schedules: - if schedule.route_number_id not in routes_match_score: - route_stops = stop_tracker.get_route_stops(str(schedule.route_number_id)) - # Calculate match score using location history - score = calculate_route_match_score(schedule.route_number_id, vehicle_no, route_stops, location_history) - # Ensure score is not None - if score is None: - score = 0.0 - print(f"Route ID: Bus score {vehicle_no} Score for route {schedule.route_number_id}: {score} (Provider: {provider})") - if score > 0.8: - best_route_ids.append(schedule.route_number_id) - routes_match_score[schedule.route_number_id] = score - return best_route_ids - - except Exception as e: - error_details = traceback.format_exc() - logger.error(f"Error querying waybills database for vehicle {vehicle_no} (Provider: {provider}): {e}\nTraceback: {error_details}") - return [] +# Models are now imported from models.py + +# get_route_ids_from_waybills function moved to route_matching.py # Don't create tables since we're using existing tables # Base.metadata.create_all(bind=engine) @@ -249,590 +153,28 @@ def get_route_ids_from_waybills(vehicle_no: str, current_lat: float = None, curr ENABLE_TIMESTAMP_VALIDATION = os.getenv('ENABLE_TIMESTAMP_VALIDATION', 'false').lower() == 'true' # Feature flag for timestamp validation FUTURE_TIMESTAMP_TOLERANCE = int(os.getenv('FUTURE_TIMESTAMP_TOLERANCE', '300')) # 5 minutes tolerance for future timestamps -class StopTracker: - def __init__(self, db_engine, redis_client, use_osrm=USE_OSRM, - osrm_url=OSRM_URL, google_api_key=GOOGLE_API_KEY, - cache_ttl=ROUTE_CACHE_TTL): - self.db_engine = db_engine - self.redis_client = redis_client - self.use_osrm = use_osrm - self.osrm_url = osrm_url - self.google_api_key = google_api_key - self.cache_ttl = cache_ttl - self.stop_visit_radius = float(os.getenv('STOP_VISIT_RADIUS', '0.05')) # 50 meters in km - print(f"StopTracker initialized with {'OSRM' if use_osrm else 'Google Maps'}") - - def get_route_stops(self, route_id): - """Get all stops for a route ordered by sequence, including the route polyline if available""" - cache_key = f"route_stops_info:{route_id}" - - # Check cache - cached = cache.get(cache_key) - if cached: - return cached - - try: - # Get stops for the route from API - stops_api_url = f"{ROUTE_STOP_MAPPING_API_URL}/route-stop-mapping/{GTFS_ID}/route/{route_id}" - response = requests.get(stops_api_url) - response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx) - stops_data = response.json() - - # Get the route polyline from DB - route_polyline = None - with SessionLocal() as db: - polyline_info = db.query(RoutePolyline)\ - .filter(RoutePolyline.route_id == str(route_id), RoutePolyline.merchant_operating_city_id == MERCHANT_OPERATING_CITY_ID)\ - .first() - if polyline_info and polyline_info.polyline: - route_polyline = polyline_info.polyline - - if not stops_data: - return { - 'stops': [], - 'polyline': None - } - - # Format results - resultStops = [ - { - 'stop_id': stop['stopCode'], - 'sequence': stop['sequenceNum'], - 'name': stop['stopName'], - 'stop_lat': float(stop['stopPoint']['lat']), - 'stop_lon': float(stop['stopPoint']['lon']) - } - for stop in stops_data - ] - result = { - 'stops': resultStops, - 'polyline': route_polyline - } - # Cache result - cache.set(cache_key, result, 3600) - return result - except requests.exceptions.RequestException as e: - print(f"Error fetching route stops or polyline from API for route {route_id}: {e}") - return { - 'stops': [], - 'polyline': None - } - except json.JSONDecodeError as e: - print(f"Error decoding JSON response for route {route_id}: {e}") - return { - 'stops': [], - 'polyline': None - } - except Exception as e: - print(f"An unexpected error occurred getting stops for route {route_id}: {e}") - return { - 'stops': [], - 'polyline': None - } - - def get_visited_stops(self, route_id, vehicle_id): - """Get list of stops already visited by this vehicle on this route""" - visit_key = f"visited_stops:{route_id}:{vehicle_id}" - try: - visited_stops = self.redis_client.get(visit_key) - if visited_stops: - return json.loads(visited_stops) - return [] - except Exception as e: - logger.error(f"Error getting visited stops: {e}") - return [] - - def update_visited_stops(self, route_id, vehicle_id, stop_id): - """Add a stop to the visited stops list""" - visit_key = f"visited_stops:{route_id}:{vehicle_id}" - try: - visited_stops = self.get_visited_stops(route_id, vehicle_id) - if stop_id not in visited_stops: - visited_stops.append(stop_id) - self.redis_client.setex( - visit_key, - 7200, # 2 hour TTL - json.dumps(visited_stops) - ) - return visited_stops - except Exception as e: - logger.error(f"Error updating visited stops: {e}") - return [] - - def reset_visited_stops(self, route_id, vehicle_id, vehicle_no): - """Reset the visited stops list for a vehicle""" - visit_key = f"visited_stops:{route_id}:{vehicle_id}" - history_key = f"vehicle_history:{vehicle_no}" - try: - self.redis_client.delete(visit_key) - self.redis_client.delete(history_key) - logger.info(f"Reset visited stops for vehicle {vehicle_id} on route {route_id}") - return True - except Exception as e: - logger.error(f"Error resetting visited stops: {e}") - return False - - def check_if_at_stop(self, stop, vehicle_lat, vehicle_lon): - """Check if vehicle is within radius of a stop""" - # Calculate distance using haversine formula - lat1, lon1 = math.radians(vehicle_lat), math.radians(vehicle_lon) - lat2, lon2 = math.radians(float(stop['stop_lat'])), math.radians(float(stop['stop_lon'])) - - # Haversine formula - dlon = lon2 - lon1 - dlat = lat2 - lat1 - a = math.sin(dlat/2)**2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon/2)**2 - c = 2 * math.asin(math.sqrt(a)) - distance = 6371 * c # Radius of earth in kilometers - - return distance <= self.stop_visit_radius, distance - - def find_next_stop(self, stops, visited_stops, vehicle_lat, vehicle_lon): - """Find the next stop in sequence after the last visited stop""" - if not visited_stops: - # If no stops visited yet, find the nearest stop as the next stop - nearest_stop = None - min_distance = float('inf') - for stop in stops: - distance, _ = self.check_if_at_stop(stop, vehicle_lat, vehicle_lon) - if distance < min_distance: - min_distance = distance - nearest_stop = stop - return (nearest_stop, min_distance) - - # Get the last visited stop ID - last_visited_id = visited_stops[-1] - - # Find its index in the stops list - last_index = -1 - for i, stop in enumerate(stops): - if stop['stop_id'] == last_visited_id: - last_index = i - break - - # If we found the last stop and it's not the last in the route - if last_index >= 0 and last_index < len(stops) - 1: - return (stops[last_index + 1], None) - elif last_index == len(stops) - 1: - # We're at the last stop of the route - return (None, None) - - # If we couldn't find the last visited stop in the list - # (this shouldn't happen but just in case) - return (stops[0] if stops else None ,None) - - def find_closest_stop(self, stops, vehicle_lat, vehicle_lon): - """Find the closest stop to the given coordinates""" - if not stops: - return None, float('inf') - - closest_stop = None - min_distance = float('inf') - - for stop in stops: - # Calculate distance using haversine formula - lat1, lon1 = math.radians(vehicle_lat), math.radians(vehicle_lon) - lat2, lon2 = math.radians(float(stop['stop_lat'])), math.radians(float(stop['stop_lon'])) - - # Haversine formula - dlon = lon2 - lon1 - dlat = lat2 - lat1 - a = math.sin(dlat/2)**2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon/2)**2 - c = 2 * math.asin(math.sqrt(a)) - distance = 6371 * c # Radius of earth in kilometers - - if distance < min_distance: - min_distance = distance - closest_stop = stop - - return closest_stop, min_distance - - def get_travel_duration(self, origin_id, dest_id, origin_lat, origin_lon, dest_lat, dest_lon): - """Get travel duration between two stops with caching""" - # Try to get from cache - cache_key = f"route_segment:{origin_id}:{dest_id}" - try: - if origin_id != 0: - cached = self.redis_client.get(cache_key) - if cached: - data = json.loads(cached) - return data.get('duration') - except Exception as e: - print(f"Redis error: {e}") - - # Not in cache, calculate using routing API - try: - duration = None - # Fallback to simple estimation (30 km/h) - # Calculate distance using haversine - lat1, lon1 = math.radians(origin_lat), math.radians(origin_lon) - lat2, lon2 = math.radians(dest_lat), math.radians(dest_lon) - - dlon = lon2 - lon1 - dlat = lat2 - lat1 - a = math.sin(dlat/2)**2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon/2)**2 - c = 2 * math.asin(math.sqrt(a)) - distance = 6371000 * c # Radius of earth in meters - - # Estimate duration: distance / speed (30 km/h = 8.33 m/s) - duration = distance / 8.33 - - # Cache the fallback estimation - cache_data = { - 'duration': duration, - 'timestamp': datetime.now().isoformat(), - 'estimated': True - } - if origin_id != 0: - self.redis_client.setex(cache_key, self.cache_ttl, json.dumps(cache_data)) - - return duration - except Exception as e: - print(f"Error calculating travel duration: {e}") - return None - +# StopTracker class moved to stop_tracker.py - def check_if_crossed_stop(self, prev_location, current_location, stop_location, threshold_meters=20): - """ - Check if a vehicle has crossed a stop between its previous and current location. - - This function determines if a stop was passed by checking if the stop is near - the path between the vehicle's previous and current positions. - - Args: - prev_location (tuple): (lat, lon) of previous vehicle location - current_location (tuple): (lat, lon) of current vehicle location - stop_location (tuple): (lat, lon) of the stop - threshold_meters (float): Maximum distance in meters from the path to consider the stop crossed - - Returns: - bool: True if the stop was crossed, False otherwise - """ - # If any of the locations are None, return False - if any(loc is None for loc in [prev_location, current_location, stop_location]): - return False - # 1. First check: Is the stop close enough to either the current or previous position? - # This handles the case where the vehicle might have temporarily stopped at the bus stop - dist_to_prev = geodesic(prev_location, stop_location).meters - dist_to_curr = geodesic(current_location, stop_location).meters - - if dist_to_prev < threshold_meters or dist_to_curr < threshold_meters: - return True - - path_distance = geodesic(prev_location, current_location).meters - - if path_distance < 5: # 5 meters threshold for significant movement - return False - - # Calculate distances from prev to stop and from stop to current - dist_prev_to_stop = geodesic(prev_location, stop_location).meters - dist_stop_to_curr = geodesic(stop_location, current_location).meters - - # Check if the stop is roughly on the path (within reasonable error margin) - # due to GPS inaccuracy and road curvature - is_on_path = abs(dist_prev_to_stop + dist_stop_to_curr - path_distance) < threshold_meters - - # 3. Third check: Direction verification - # We need to verify the vehicle is moving toward the stop and then away from it - - # Calculate bearings - def calculate_bearing(point1, point2): - """Calculate the bearing between two points.""" - lat1, lon1 = math.radians(point1[0]), math.radians(point1[1]) - lat2, lon2 = math.radians(point2[0]), math.radians(point2[1]) - - dlon = lon2 - lon1 - - y = math.sin(dlon) * math.cos(lat2) - x = math.cos(lat1) * math.sin(lat2) - math.sin(lat1) * math.cos(lat2) * math.cos(dlon) - - bearing = math.atan2(y, x) - # Convert to degrees - bearing = math.degrees(bearing) - # Normalize to 0-360 - bearing = (bearing + 360) % 360 - - - return bearing - - # Get bearings - bearing_prev_to_curr = calculate_bearing(prev_location, current_location) - bearing_prev_to_stop = calculate_bearing(prev_location, stop_location) - bearing_stop_to_curr = calculate_bearing(stop_location, current_location) - - # Check if the bearings are roughly aligned - def angle_diff(a, b): - """Calculate the absolute difference between two angles in degrees.""" - return min(abs(a - b), 360 - abs(a - b)) - - alignment_prev_to_stop = angle_diff(bearing_prev_to_curr, bearing_prev_to_stop) < 60 - alignment_stop_to_curr = angle_diff(bearing_prev_to_curr, bearing_stop_to_curr) < 60 - - # 4. Combine all checks: - # - The stop should be roughly on the path - # - The bearings should be aligned - # - The distance from prev to stop and then to curr should be in increasing order of sequence - return (is_on_path and - alignment_prev_to_stop and - alignment_stop_to_curr and - dist_prev_to_stop < path_distance and - dist_stop_to_curr < path_distance) - - def calculate_eta(self, stopsInfo, route_id, vehicle_lat, vehicle_lon, current_time, vehicle_id, visited_stops=[], vehicle_no=None): - """Calculate ETA for all upcoming stops from current position""" - # Get all stops for the route - stops = stopsInfo.get('stops') - if not stops: - return None - - next_stop = None - closest_stop = None - distance = float('inf') - calculation_method = "realtime" - - # Check if the vehicle is at a stop now - for stop in stops: - # Check if vehicle is at the stop based on current position - is_at_stop, _ = self.check_if_at_stop(stop, vehicle_lat, vehicle_lon) - - # Get the vehicle's previous location from history - # Check if we crossed the stop between last position and current position - if not is_at_stop: - location_history = get_vehicle_location_history(vehicle_no) - if len(location_history) > 0: - last_point = location_history[-1] # Most recent point in history - # Check if the stop is between the last point and current point - crossed_stop = self.check_if_crossed_stop( - (last_point['lat'], last_point['lon']), - (vehicle_lat, vehicle_lon), - (float(stop['stop_lat']), float(stop['stop_lon'])) - ) - if crossed_stop: - is_at_stop = True - if is_at_stop: - # Vehicle is at this stop - if stop['stop_id'] not in visited_stops: - # Add to visited stops if not already there - self.update_visited_stops(route_id, vehicle_id, stop['stop_id']) - visited_stops.append(stop['stop_id']) - calculation_method = "visited_stops" - break - - # Find next stop based on visited stops - (next_stop, distance) = self.find_next_stop(stops, visited_stops, vehicle_lat, vehicle_lon) - if next_stop: - if not distance: - _, distance = self.check_if_at_stop(next_stop, vehicle_lat, vehicle_lon) - closest_stop = next_stop - calculation_method = "sequence_based" - else: - # We're at the end of the route, reset visited stops - self.reset_visited_stops(route_id, vehicle_id, vehicle_no) - # Fall back to closest stop method - closest_stop, distance = self.find_closest_stop(stops, vehicle_lat, vehicle_lon) - calculation_method = "distance_based_fallback" - - if not closest_stop: - return None - - # Find the index of the closest/next stop in the route - closest_index = -1 - for i, stop in enumerate(stops): - if stop['stop_id'] == closest_stop['stop_id']: - closest_index = i - break - - if closest_index == -1: - # Something went wrong, stop not found in the list - return None - - # Calculate ETAs for the closest stop and all upcoming stops - eta_list = [] - cumulative_time = 0 - current_lat, current_lon = vehicle_lat, vehicle_lon - - # First, calculate ETA for the closest/next stop - if distance <= 0.01: # 10 meters in km - we're practically at the stop - arrival_time = current_time - calculation_method = "immediate" - else: - # Calculate time to reach the stop - duration = self.get_travel_duration( - 0, closest_stop['stop_id'], - current_lat, current_lon, - closest_stop['stop_lat'], closest_stop['stop_lon'] - ) - - if duration: - arrival_time = current_time + timedelta(seconds=duration) - cumulative_time = duration - calculation_method = "estimated" - else: - # Fallback estimation - duration = distance / 8.33 # distance / (30 km/h in m/s) - arrival_time = current_time + timedelta(seconds=duration) - cumulative_time = duration - calculation_method = "estimated" - - # Add closest/next stop to the ETA list - eta_list.append({ - 'stop_id': closest_stop['stop_id'], - 'stop_seq': closest_stop['sequence'], - 'stop_name': closest_stop['name'], - 'stop_lat': closest_stop['stop_lat'], - 'stop_lon': closest_stop['stop_lon'], - 'arrival_time': int(arrival_time.timestamp()), - 'calculation_method': calculation_method - }) - - # Then calculate ETAs for all remaining stops (everything after closest_index) - for i in range(closest_index + 1, len(stops)): - prev_stop = stops[i-1] - current_stop = stops[i] - - # Calculate duration between stops - duration = self.get_travel_duration( - prev_stop['stop_id'], current_stop['stop_id'], - prev_stop['stop_lat'], prev_stop['stop_lon'], - current_stop['stop_lat'], current_stop['stop_lon'] - ) - - if duration: - cumulative_time += duration - arrival_time = current_time + timedelta(seconds=cumulative_time) - - calculation_method = "estimated" - - eta_list.append({ - 'stop_id': current_stop['stop_id'], - 'stop_seq': current_stop['sequence'], - 'stop_name': current_stop['name'], - 'stop_lat': current_stop['stop_lat'], - 'stop_lon': current_stop['stop_lon'], - 'arrival_time': int(arrival_time.timestamp()), - 'calculation_method': calculation_method - }) - else: - # If we couldn't calculate duration, use estimated method - calculation_method = "estimated" - - return { - 'route_id': route_id, - 'current_time': int(current_time.timestamp()), - 'closest_stop': { - 'stop_id': closest_stop['stop_id'], - 'stop_name': closest_stop['name'], - 'distance': distance - }, - 'calculation_method': calculation_method, - 'eta': eta_list - } - -# Create instance -stop_tracker = StopTracker(engine, redis_client) - -class SimpleCache: - def __init__(self): - self.cache = {} +# Create instance with updated parameters +stop_tracker = StopTracker( + db_engine=engine, + redis_client=redis_client, + use_osrm=USE_OSRM, + osrm_url=OSRM_URL, + google_api_key=GOOGLE_API_KEY, + cache_ttl=ROUTE_CACHE_TTL, + route_stop_mapping_api_url=ROUTE_STOP_MAPPING_API_URL, + gtfs_id=GTFS_ID, + merchant_operating_city_id=MERCHANT_OPERATING_CITY_ID +) - def get(self, key: str): - res = self.cache.get(key) - if res: - value, expiry_timestamp = res - if expiry_timestamp is not None and expiry_timestamp < time.time(): - del self.cache[key] # Expired - res = None - else: - return value - - if res == None: - res_from_redis = redis_client.get(f"simpleCache:{key}") - if res_from_redis: - parsed_res = json.loads(res_from_redis) - # When loading from Redis, get the TTL from Redis and apply it to the in-memory cache - redis_ttl = redis_client.ttl(f"simpleCache:{key}") - in_memory_expiry_timestamp = None - if redis_ttl is not None and redis_ttl > -1: # -1 means no expire, -2 means key doesn't exist - in_memory_expiry_timestamp = time.time() + redis_ttl - - self.cache[key] = (parsed_res, in_memory_expiry_timestamp) - return parsed_res - else: - return None - return res - - def set(self, key: str, value, ttl: Optional[int] = None): - expiry_timestamp = None - if ttl is not None: - expiry_timestamp = time.time() + ttl - self.cache[key] = (value, expiry_timestamp) - - if ttl is None: - redis_client.set(f"simpleCache:{key}", json.dumps(value)) - else: - redis_client.setex(f"simpleCache:{key}", ttl, json.dumps(value)) # Create single cache instance -cache = SimpleCache() +cache = SimpleCache(redis_client) -class FleetInfo(BaseModel): - """Pydantic model for fleet information returned by get_fleet_info function""" - vehicle_no: str - device_id: str - route_id: str +# FleetInfo model moved to models.py -def get_fleet_info(device_id: str, current_lat: float = None, current_lon: float = None, timestamp: int = None, provider: str = None) -> List[FleetInfo]: - """Get both fleet number and route ID for a device""" - cache_key = f"fleetInfo:{device_id}" - cache_key_saved = cache_key + ":saved" - - fleet_mapping_values = [] # response values - - # Check cache first - fleet_info_str = redis_client.get(cache_key) - if fleet_info_str is not None: - fleet_infos_data = json.loads(fleet_info_str) - fleet_infos = [FleetInfo(**fleet_info) for fleet_info in fleet_infos_data] - for fleet_info in fleet_infos: - if current_lat is not None and current_lon is not None: - store_vehicle_location_history(fleet_info.vehicle_no, current_lat, current_lon, timestamp) - return fleet_infos - - try: - vehicle_no = device_vehicle_map.get(device_id) - if not vehicle_no: - return [] - - # Get route for fleet - route_ids = get_route_ids_from_waybills(vehicle_no, current_lat, current_lon, timestamp, provider) - for route_id in route_ids: - fleet_info = FleetInfo( - vehicle_no=vehicle_no, - device_id=device_id, - route_id=route_id - ) - try: - fleet_info_saved = redis_client.get(cache_key_saved) - if fleet_info_saved is not None: - fleet_info_saved = json.loads(fleet_info_saved) - print("going to delete route info") - if ('route_id' in fleet_info_saved and - fleet_info_saved['route_id'] is not None and - route_id != fleet_info_saved['route_id']): - route_key = "route:" + fleet_info_saved['route_id'] - clean_redis_key_for_route_info(fleet_info_saved['route_id'], route_key) - except Exception as e: - logger.error(f"Error cleaning redis key for route info: {e}") - fleet_mapping_values.append(fleet_info) - if len(route_ids) > 0: - # Convert FleetInfo objects to dicts for JSON serialization to Redis - fleet_mapping_dicts = [fleet_info.model_dump() for fleet_info in fleet_mapping_values] - redis_client.setex(cache_key_saved, BUS_LOCATION_MAX_AGE + BUS_CLEANUP_INTERVAL, json.dumps(fleet_mapping_dicts)) # hack for cleanup if route changes - redis_client.setex(cache_key, BUS_CLEANUP_INTERVAL, json.dumps(fleet_mapping_dicts)) - return fleet_mapping_values - except Exception as e: - print(f"Error querying fleet info for device {device_id}: {e}") - return fleet_mapping_values +# get_fleet_info function moved to fleet_management.py def date_to_unix(d: date) -> int: return int(d.timestamp()) @@ -1178,348 +520,13 @@ def forward_to_tcp(data_str): tcp_client.queue_message(data_str) return True - # Use the library's implementation when available -def decode_polyline(polyline_str): - """Wrapper for polyline library's decoder""" - if not polyline_str: - return [] - try: - return gpolyline.decode(polyline_str) - except Exception as e: - print(f"Error decoding polyline: {e}") - return [] +# Geometry utilities and vehicle history functions moved to geometry_utils.py and cache_utils.py -def calculate_distance(lat1, lon1, lat2, lon2): - """ - Calculate the great circle distance between two points - using the haversine formula - """ - # Convert decimal degrees to radians - lat1, lon1, lat2, lon2 = map(math.radians, [lat1, lon1, lat2, lon2]) - - # Haversine formula - dlon = lon2 - lon1 - dlat = lat2 - lat1 - a = math.sin(dlat/2)**2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon/2)**2 - c = 2 * math.asin(math.sqrt(a)) - - # Radius of earth in kilometers - r = 6371 - - return c * r +# clean_redis_key_for_route_info function moved to fleet_management.py -def is_point_near_polyline(point_lat, point_lon, polyline_points, max_distance_meter=50): - """ - Simpler function to check if a point is within max_distance_meter of any - segment of the polyline. - """ - if not polyline_points or len(polyline_points) < 2: - return False, float('inf'), None - - min_distance = float('inf') - - min_segment = None - - # Check each segment of the polyline - for i in range(len(polyline_points) - 1): - # Start and end points of current segment - p1_lat, p1_lon = polyline_points[i] - p2_lat, p2_lon = polyline_points[i + 1] - - # Calculate distance to this segment using a simple approximation - # For short segments, this is reasonable and much simpler - - # Calculate distances to segment endpoints - d1 = calculate_distance(point_lat, point_lon, p1_lat, p1_lon) - d2 = calculate_distance(point_lat, point_lon, p2_lat, p2_lon) - - # Calculate length of segment - segment_length = calculate_distance(p1_lat, p1_lon, p2_lat, p2_lon) - - # Use the simplified distance formula (works well for short segments) - if segment_length > 0: - # Projection calculation - # Vector from p1 to p2 - v1x = p2_lon - p1_lon - v1y = p2_lat - p1_lat - - # Vector from p1 to point - v2x = point_lon - p1_lon - v2y = point_lat - p1_lat - - # Dot product - dot = v1x * v2x + v1y * v2y - - # Squared length of segment - len_sq = v1x * v1x + v1y * v1y - - # Projection parameter (t) - t = max(0, min(1, dot / len_sq)) - - # Projected point - proj_x = p1_lon + t * v1x - proj_y = p1_lat + t * v1y - - # Distance to projection - distance = calculate_distance(point_lat, point_lon, proj_y, proj_x) - else: - # If segment is very short, just use distance to p1 - distance = d1 - - # Update minimum distance - if distance < min_distance: - min_segment = i - min_distance = distance - - # Check if within threshold (convert meters to kilometers) - max_distance_km = max_distance_meter / 1000 - return min_distance <= max_distance_km, min_distance, min_segment +# Vehicle cleanup functions moved to fleet_management.py -def store_vehicle_location_history(device_id: str, lat: float, lon: float, timestamp: int, max_points: int = 25): - """Store vehicle location history in Redis with TTL""" - history = None - try: - history_key = f"vehicle_history:{device_id}" - point = { - "lat": lat, - "lon": lon, - "timestamp": int(timestamp if timestamp else time.time()) - } - - # Get existing history - history = redis_client.get(history_key) - if history: - points = json.loads(history) or [] - else: - points = [] - if len(points) > 0: - lastPoint = points[-1] - if calculate_distance(lastPoint['lat'], lastPoint['lon'], point['lat'], point['lon']) < 0.002: - return - - # Add new point - points.append(point) - - # Keep only last max_points - if len(points) > max_points: - points = points[-max_points:] - - points.sort(key=lambda x: x['timestamp']) - # Store updated history with 1 hour TTL - redis_client.setex(history_key, 3600, json.dumps(points)) - - except Exception as e: - error_details = traceback.format_exc() - logger.error(f"Error storing vehicle history for {device_id}: {e}\nHistory value: {history}\nTraceback: {error_details}") - -def get_vehicle_location_history(device_id: str) -> List[dict]: - """Get vehicle location history from Redis""" - try: - history_key = f"vehicle_history:{device_id}" - history = redis_client.get(history_key) - if history: - value = json.loads(history) - if value: - return value - return [] - except Exception as e: - logger.error(f"Error getting vehicle history for {device_id}: {e}") - return [] - -def clean_redis_key_for_route_info(route_id, redis_key): - current_time = int(time.time()) - prod_vehicle_data = prod_redis_client.hgetall(redis_key) - vehicle_data = redis_client.hgetall(redis_key) - # Merge prod_vehicle_data and vehicle_data so that all vehicles from both are considered. - # If a vehicle_id exists in both, prefer the one from prod_vehicle_data. - merged_vehicle_data = dict(vehicle_data) if vehicle_data else {} - if prod_vehicle_data: - merged_vehicle_data.update(prod_vehicle_data) - vehicle_data = merged_vehicle_data - if not vehicle_data: - return - - vehicles_to_remove = [] - removed_count = 0 - - # Check each vehicle's timestamp - for vehicle_id, data_json in merged_vehicle_data.items(): - try: - data = json.loads(data_json) - # First check serverTime if available - if 'serverTime' in data: - timestamp = data.get('serverTime') - # Otherwise use timestamp - else: - timestamp = data.get('timestamp') - - # If no valid timestamp, skip - if not timestamp: - continue - - age = current_time - int(timestamp) - print("Error age", vehicle_id, route_id,age, current_time, int(timestamp), current_time - int(timestamp)) - - # If older than threshold, mark for removal - if age > BUS_LOCATION_MAX_AGE: - vehicles_to_remove.append(vehicle_id) - logger.debug(f"Vehicle {vehicle_id} on route {route_id} outdated by {age}s, marking for removal") - except (json.JSONDecodeError, KeyError, TypeError, ValueError) as e: - logger.error(f"Error parsing data for vehicle {vehicle_id}: {e}") - # Mark invalid entries for removal - vehicles_to_remove.append(vehicle_id) - - # Remove outdated vehicles - if vehicles_to_remove: - redis_client.hdel(redis_key, *vehicles_to_remove) - prod_redis_client.hdel(redis_key, *vehicles_to_remove) - removed_count = len(vehicles_to_remove) - logger.info(f"Removed {removed_count} outdated vehicles from route {route_id}") - - return removed_count - -def clean_outdated_vehicle_mappings(): - """ - Remove outdated vehicle mappings from Redis for all routes. - Uses Redis lock to ensure only one instance runs cleanup at a time. - """ - # Try to acquire lock - lock_key = "vehicle_mappings_cleanup_lock" - lock_acquired = redis_client.set(lock_key, "locked", nx=True, ex=CLEANUP_LOCK_TTL) - - if not lock_acquired: - logger.debug("Vehicle mappings cleanup already running in another pod/process") - return - - try: - logger.info("Starting vehicle mappings cleanup") - # Get all route keys - # Use a more robust approach to get all keys matching the pattern - route_keys = [] - cursor = 0 - prod_cursor = 0 - max_iterations = 100 - iteration_count = 0 - start = True - - while iteration_count < max_iterations: - if (start and cursor == 0) or (not start and cursor != 0): - cursor, keys = redis_client.scan(cursor, match="route:*", count=1000) - route_keys.extend(keys) - if (start and prod_cursor == 0) or (not start and prod_cursor != 0): - prod_cursor, prod_keys = prod_redis_client.scan(cursor, match="route:*", count=1000) - route_keys.extend(prod_keys) - start = False - iteration_count += 1 - if cursor == 0 and prod_cursor == 0: - break - route_keys = list(set(route_keys)) - logger.debug(f"Found {len(route_keys)} route keys for cleanup after {iteration_count} iterations") - if not route_keys: - logger.debug("No route data found for cleanup") - return - - total_routes = len(route_keys) - total_vehicles_removed = 0 - - for redis_key in route_keys: - try: - # Extract route_id from key - route_id = redis_key.split(":", 1)[1] if ":" in redis_key else "unknown" - # Get all vehicles for this route - removed = clean_redis_key_for_route_info(route_id, redis_key) - if removed: - total_vehicles_removed += removed - - except Exception as e: - logger.error(f"Error cleaning route {redis_key}: {e}") - - logger.info(f"Completed vehicle mappings cleanup: processed {total_routes} routes, removed {total_vehicles_removed} vehicles") - - except Exception as e: - logger.error(f"Error during vehicle mappings cleanup: {e}") - finally: - # Release the lock - try: - redis_client.delete(lock_key) - except: - pass - -def start_vehicle_cleanup_thread(): - """Start a background thread for vehicle mapping cleanup""" - def cleanup_worker(): - logger.info(f"Vehicle mappings cleanup thread started (interval: {BUS_CLEANUP_INTERVAL}s, max age: {BUS_LOCATION_MAX_AGE}s)") - - while True: - try: - clean_outdated_vehicle_mappings() - except Exception as e: - logger.error(f"Error in vehicle cleanup worker: {e}") - - time.sleep(BUS_CLEANUP_INTERVAL) - - cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True) - cleanup_thread.start() - - return cleanup_thread - -def calculate_route_match_score(route_id, vehicle_no, stops: dict, vehicle_points: List[dict], max_distance_meter: float = 100) -> float: - """ - Calculate how well a route matches a series of vehicle_points, considering direction. - Uses polyline for more accurate route matching when available. - Returns a score between 0 and 1, where 1 is a perfect match. - """ - try: - # Check if stops is a dict with polyline and stops keys - if isinstance(stops, dict) and 'stops' in stops and 'polyline' in stops: - route_polyline = stops.get('polyline') - polyline_points = decode_polyline(route_polyline) - min_points_required = 4 - else: - route_polyline = "" - stopsInfo = stops.get('stops') - polyline_points = list(map(lambda x: (x['stop_lat'], x['stop_lon']), stopsInfo)) - min_points_required = 10 - - if not vehicle_points or len(vehicle_points) < min_points_required: - return 0.0 - - # Sort vehicle_points by timestamp to ensure they're in chronological order - vehicle_points = sorted(vehicle_points, key=lambda x: x.get('timestamp', 0)) - if polyline_points: - # Count how many vehicle_points are near the polyline - near_points = [] - total_distance = 0.0 - - min_segments_list = [] - for point in vehicle_points: - try: - is_near, distance, min_segment_start = is_point_near_polyline( - point['lat'], point['lon'], polyline_points, max_distance_meter - ) - if is_near: - if min_segment_start is not None: - min_segments_list.append(min_segment_start) - near_points.append(point) - total_distance += distance - except (KeyError, ValueError, TypeError) as e: - logger.debug(f"Error checking if point is near polyline: {e}, point: {point}") - continue - - # Calculate proximity score (0-1) - proximity_ratio = len(near_points) / len(vehicle_points) if len(vehicle_points) > 0 else 0 - - # Only proceed if enough vehicle_points are near the polyline - if proximity_ratio >= 0.3: - # Convert set to list and sort to check direction - if len(min_segments_list) >= 2 and min(min_segments_list) == min_segments_list[0]: - print(f"Route ID: {vehicle_no} {len(near_points)}/{len(vehicle_points)}, Score: {proximity_ratio:.2f}") - return proximity_ratio - return 0.0 - except Exception as e: - error_details = traceback.format_exc() - logger.error(f"Error calculating route match score: {stops} {e}\nTraceback: {error_details}") - return 0.0 +# calculate_route_match_score function moved to route_matching.py def push_to_kafka(entity): max_retries = 3 @@ -1626,16 +633,7 @@ def validate_and_update_timestamp(entity: dict, vehicle_number: str) -> bool: logger.error(f"Error validating timestamp for vehicle {vehicle_number}: {e}") return True # Allow on error to avoid blocking valid data -def load_device_vehicle_mappings(): - global device_vehicle_map - try: - with SessionLocal() as db: - mappings = db.query(DeviceVehicleMapping).all() - for mapping in mappings: - device_vehicle_map[mapping.device_id] = mapping.vehicle_no - logger.info(f"Loaded {len(device_vehicle_map)} device-vehicle mappings at startup.") - except Exception as e: - logger.error(f"Error loading device-vehicle mappings at startup: {e}") +# load_device_vehicle_mappings function moved to fleet_management.py def handle_client_data(payload, client_ip, serverTime, isNYGpsDevice = False, session=None): """Handle client data and send it to Kafka""" @@ -1684,7 +682,7 @@ def handle_client_data(payload, client_ip, serverTime, isNYGpsDevice = False, se return # Get route information for this vehicle - fleet_infos = get_fleet_info(deviceId, vehicle_lat, vehicle_lon, entity.get('timestamp'), entity.get('provider')) + fleet_infos = get_fleet_info(redis_client, device_vehicle_map, WaybillsSessionLocal, deviceId, vehicle_lat, vehicle_lon, entity.get('timestamp'), entity.get('provider'), stop_tracker, BUS_LOCATION_MAX_AGE, BUS_CLEANUP_INTERVAL) if not fleet_infos: push_to_kafka(entity) for fleet_info in fleet_infos: @@ -1996,7 +994,6 @@ def main_server(): # Start MQTT client, no separate thread required # as we already called loop_start() and we already registered a shutdown function mqtt_client_obj = mqtt_client() - load_device_vehicle_mappings() - start_vehicle_cleanup_thread() - main_server() - + device_vehicle_map = load_device_vehicle_mappings(SessionLocal) + start_vehicle_cleanup_thread(redis_client, prod_redis_client, CLEANUP_LOCK_TTL, BUS_CLEANUP_INTERVAL, BUS_LOCATION_MAX_AGE) + main_server() \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/cache_utils.py b/src/cache_utils.py new file mode 100644 index 0000000..c030701 --- /dev/null +++ b/src/cache_utils.py @@ -0,0 +1,116 @@ +""" +Cache utilities and Redis operations for the GPS tracking system. +Contains the SimpleCache class and vehicle location history management. +""" + +import json +import logging +import time +from typing import Optional, List, Dict, Any + +logger = logging.getLogger('amnex-data-server') + + +class SimpleCache: + """Simple in-memory cache with Redis fallback""" + + def __init__(self, redis_client): + self.cache = {} + self.redis_client = redis_client + + def get(self, key: str) -> Optional[Any]: + """Get value from cache, checking in-memory first, then Redis""" + res = self.cache.get(key) + if res: + value, expiry_timestamp = res + if expiry_timestamp is not None and expiry_timestamp < time.time(): + del self.cache[key] # Expired + res = None + else: + return value + + if res is None: + res_from_redis = self.redis_client.get(f"simpleCache:{key}") + if res_from_redis: + parsed_res = json.loads(res_from_redis) + # When loading from Redis, get the TTL from Redis and apply it to the in-memory cache + redis_ttl = self.redis_client.ttl(f"simpleCache:{key}") + in_memory_expiry_timestamp = None + if redis_ttl is not None and redis_ttl > -1: # -1 means no expire, -2 means key doesn't exist + in_memory_expiry_timestamp = time.time() + redis_ttl + + self.cache[key] = (parsed_res, in_memory_expiry_timestamp) + return parsed_res + else: + return None + return res + + def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None: + """Set value in both in-memory cache and Redis""" + expiry_timestamp = None + if ttl is not None: + expiry_timestamp = time.time() + ttl + self.cache[key] = (value, expiry_timestamp) + + if ttl is None: + self.redis_client.set(f"simpleCache:{key}", json.dumps(value)) + else: + self.redis_client.setex(f"simpleCache:{key}", ttl, json.dumps(value)) + + +def store_vehicle_location_history(redis_client, device_id: str, lat: float, lon: float, + timestamp: int, max_points: int = 25) -> None: + """Store vehicle location history in Redis with TTL""" + from .geometry_utils import calculate_distance + + history = None + try: + history_key = f"vehicle_history:{device_id}" + point = { + "lat": lat, + "lon": lon, + "timestamp": int(timestamp if timestamp else time.time()) + } + + # Get existing history + history = redis_client.get(history_key) + if history: + points = json.loads(history) or [] + else: + points = [] + + if len(points) > 0: + last_point = points[-1] + if calculate_distance(last_point['lat'], last_point['lon'], point['lat'], point['lon']) < 0.002: + return + + # Add new point + points.append(point) + + # Keep only last max_points + if len(points) > max_points: + points = points[-max_points:] + + points.sort(key=lambda x: x['timestamp']) + # Store updated history with 1 hour TTL + redis_client.setex(history_key, 3600, json.dumps(points)) + + except Exception as e: + import traceback + error_details = traceback.format_exc() + logger.error(f"Error storing vehicle history for {device_id}: {e}\nHistory value: {history}\nTraceback: {error_details}") + + +def get_vehicle_location_history(redis_client, device_id: str) -> List[Dict[str, Any]]: + """Get vehicle location history from Redis""" + try: + history_key = f"vehicle_history:{device_id}" + history = redis_client.get(history_key) + if history: + value = json.loads(history) + if value: + return value + return [] + except Exception as e: + logger.error(f"Error getting vehicle history for {device_id}: {e}") + return [] \ No newline at end of file diff --git a/src/fleet_management.py b/src/fleet_management.py new file mode 100644 index 0000000..f35851f --- /dev/null +++ b/src/fleet_management.py @@ -0,0 +1,235 @@ +""" +Fleet management and device mapping functionality. +Contains fleet information retrieval and device-vehicle mapping logic. +""" + +import json +import logging +import time +import threading +from typing import List, Optional + +from .cache_utils import store_vehicle_location_history +from .models import FleetInfo +from .route_matching import get_route_ids_from_waybills + +logger = logging.getLogger('amnex-data-server') + + +def get_fleet_info(redis_client, device_vehicle_map: dict, waybills_session_local, + device_id: str, current_lat: float = None, current_lon: float = None, + timestamp: int = None, provider: str = None, stop_tracker=None, + bus_location_max_age: int = 120, bus_cleanup_interval: int = 180) -> List[FleetInfo]: + """Get both fleet number and route ID for a device""" + cache_key = f"fleetInfo:{device_id}" + cache_key_saved = cache_key + ":saved" + + fleet_mapping_values = [] # response values + + # Check cache first + fleet_info_str = redis_client.get(cache_key) + if fleet_info_str is not None: + fleet_infos_data = json.loads(fleet_info_str) + fleet_infos = [FleetInfo(**fleet_info) for fleet_info in fleet_infos_data] + for fleet_info in fleet_infos: + if current_lat is not None and current_lon is not None: + store_vehicle_location_history(redis_client, fleet_info.vehicle_no, current_lat, current_lon, timestamp) + return fleet_infos + + try: + vehicle_no = device_vehicle_map.get(device_id) + if not vehicle_no: + return [] + + # Get route for fleet + route_ids = get_route_ids_from_waybills(waybills_session_local, vehicle_no, current_lat, current_lon, timestamp, provider, stop_tracker) + for route_id in route_ids: + fleet_info = FleetInfo( + vehicle_no=vehicle_no, + device_id=device_id, + route_id=route_id + ) + try: + fleet_info_saved = redis_client.get(cache_key_saved) + if fleet_info_saved is not None: + fleet_info_saved = json.loads(fleet_info_saved) + print("going to delete route info") + if ('route_id' in fleet_info_saved and + fleet_info_saved['route_id'] is not None and + route_id != fleet_info_saved['route_id']): + route_key = "route:" + fleet_info_saved['route_id'] + clean_redis_key_for_route_info(redis_client, fleet_info_saved['route_id'], route_key, bus_location_max_age) + except Exception as e: + logger.error(f"Error cleaning redis key for route info: {e}") + fleet_mapping_values.append(fleet_info) + + if len(route_ids) > 0: + # Convert FleetInfo objects to dicts for JSON serialization to Redis + fleet_mapping_dicts = [fleet_info.model_dump() for fleet_info in fleet_mapping_values] + redis_client.setex(cache_key_saved, bus_location_max_age + bus_cleanup_interval, json.dumps(fleet_mapping_dicts)) # hack for cleanup if route changes + redis_client.setex(cache_key, bus_cleanup_interval, json.dumps(fleet_mapping_dicts)) + return fleet_mapping_values + except Exception as e: + print(f"Error querying fleet info for device {device_id}: {e}") + return fleet_mapping_values + + +def clean_redis_key_for_route_info(redis_client, prod_redis_client, route_id: str, redis_key: str, + bus_location_max_age: int) -> int: + """Clean outdated vehicle data from a specific route key""" + import time + + current_time = int(time.time()) + prod_vehicle_data = prod_redis_client.hgetall(redis_key) + vehicle_data = redis_client.hgetall(redis_key) + + # Merge prod_vehicle_data and vehicle_data so that all vehicles from both are considered. + # If a vehicle_id exists in both, prefer the one from prod_vehicle_data. + merged_vehicle_data = dict(vehicle_data) if vehicle_data else {} + if prod_vehicle_data: + merged_vehicle_data.update(prod_vehicle_data) + vehicle_data = merged_vehicle_data + + if not vehicle_data: + return 0 + + vehicles_to_remove = [] + removed_count = 0 + + # Check each vehicle's timestamp + for vehicle_id, data_json in merged_vehicle_data.items(): + try: + data = json.loads(data_json) + # First check serverTime if available + if 'serverTime' in data: + timestamp = data.get('serverTime') + # Otherwise use timestamp + else: + timestamp = data.get('timestamp') + + # If no valid timestamp, skip + if not timestamp: + continue + + age = current_time - int(timestamp) + print("Error age", vehicle_id, route_id, age, current_time, int(timestamp), current_time - int(timestamp)) + + # If older than threshold, mark for removal + if age > bus_location_max_age: + vehicles_to_remove.append(vehicle_id) + logger.debug(f"Vehicle {vehicle_id} on route {route_id} outdated by {age}s, marking for removal") + except (json.JSONDecodeError, KeyError, TypeError, ValueError) as e: + logger.error(f"Error parsing data for vehicle {vehicle_id}: {e}") + # Mark invalid entries for removal + vehicles_to_remove.append(vehicle_id) + + # Remove outdated vehicles + if vehicles_to_remove: + redis_client.hdel(redis_key, *vehicles_to_remove) + prod_redis_client.hdel(redis_key, *vehicles_to_remove) + removed_count = len(vehicles_to_remove) + logger.info(f"Removed {removed_count} outdated vehicles from route {route_id}") + + return removed_count + + +def load_device_vehicle_mappings(session_local) -> dict: + """Load device to vehicle mappings from database""" + from .models import DeviceVehicleMapping + + device_vehicle_map = {} + try: + with session_local() as db: + mappings = db.query(DeviceVehicleMapping).all() + for mapping in mappings: + device_vehicle_map[mapping.device_id] = mapping.vehicle_no + logger.info(f"Loaded {len(device_vehicle_map)} device-vehicle mappings at startup.") + except Exception as e: + logger.error(f"Error loading device-vehicle mappings at startup: {e}") + + return device_vehicle_map + +def clean_outdated_vehicle_mappings(redis_client, prod_redis_client, CLEANUP_LOCK_TTL): + """ + Remove outdated vehicle mappings from Redis for all routes. + Uses Redis lock to ensure only one instance runs cleanup at a time. + """ + # Try to acquire lock + lock_key = "vehicle_mappings_cleanup_lock" + lock_acquired = redis_client.set(lock_key, "locked", nx=True, ex=CLEANUP_LOCK_TTL) + + if not lock_acquired: + logger.debug("Vehicle mappings cleanup already running in another pod/process") + return + + try: + logger.info("Starting vehicle mappings cleanup") + # Get all route keys + # Use a more robust approach to get all keys matching the pattern + route_keys = [] + cursor = 0 + prod_cursor = 0 + max_iterations = 100 + iteration_count = 0 + start = True + + while iteration_count < max_iterations: + if (start and cursor == 0) or (not start and cursor != 0): + cursor, keys = redis_client.scan(cursor, match="route:*", count=1000) + route_keys.extend(keys) + if (start and prod_cursor == 0) or (not start and prod_cursor != 0): + prod_cursor, prod_keys = prod_redis_client.scan(cursor, match="route:*", count=1000) + route_keys.extend(prod_keys) + start = False + iteration_count += 1 + if cursor == 0 and prod_cursor == 0: + break + route_keys = list(set(route_keys)) + logger.debug(f"Found {len(route_keys)} route keys for cleanup after {iteration_count} iterations") + if not route_keys: + logger.debug("No route data found for cleanup") + return + + total_routes = len(route_keys) + total_vehicles_removed = 0 + + for redis_key in route_keys: + try: + # Extract route_id from key + route_id = redis_key.split(":", 1)[1] if ":" in redis_key else "unknown" + # Get all vehicles for this route + removed = clean_redis_key_for_route_info(route_id, redis_key) + if removed: + total_vehicles_removed += removed + + except Exception as e: + logger.error(f"Error cleaning route {redis_key}: {e}") + + logger.info(f"Completed vehicle mappings cleanup: processed {total_routes} routes, removed {total_vehicles_removed} vehicles") + + except Exception as e: + logger.error(f"Error during vehicle mappings cleanup: {e}") + finally: + # Release the lock + try: + redis_client.delete(lock_key) + except: + pass + +def start_vehicle_cleanup_thread(redis_client, prod_redis_client, CLEANUP_LOCK_TTL, BUS_CLEANUP_INTERVAL, BUS_LOCATION_MAX_AGE): + """Start a background thread for vehicle mapping cleanup""" + def cleanup_worker(): + logger.info(f"Vehicle mappings cleanup thread started (interval: {BUS_CLEANUP_INTERVAL}s, max age: {BUS_LOCATION_MAX_AGE}s)") + + while True: + try: + clean_outdated_vehicle_mappings(redis_client, prod_redis_client, CLEANUP_LOCK_TTL) + except Exception as e: + logger.error(f"Error in vehicle cleanup worker: {e}") + + time.sleep(BUS_CLEANUP_INTERVAL) + + cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True) + cleanup_thread.start() + + return cleanup_thread \ No newline at end of file diff --git a/src/geometry_utils.py b/src/geometry_utils.py new file mode 100644 index 0000000..8a1f8d6 --- /dev/null +++ b/src/geometry_utils.py @@ -0,0 +1,199 @@ +""" +Geometry and distance calculation utilities for GPS tracking. +Contains polyline decoding, distance calculations, and spatial analysis functions. +""" + +import math +import polyline as gpolyline +from geopy.distance import geodesic +from typing import List, Tuple, Optional + + +def decode_polyline(polyline_str: str) -> List[Tuple[float, float]]: + """Wrapper for polyline library's decoder""" + if not polyline_str: + return [] + try: + return gpolyline.decode(polyline_str) + except Exception as e: + print(f"Error decoding polyline: {e}") + return [] + + +def calculate_distance(lat1: float, lon1: float, lat2: float, lon2: float) -> float: + """ + Calculate the great circle distance between two points + using the haversine formula + """ + # Convert decimal degrees to radians + lat1, lon1, lat2, lon2 = map(math.radians, [lat1, lon1, lat2, lon2]) + + # Haversine formula + dlon = lon2 - lon1 + dlat = lat2 - lat1 + a = math.sin(dlat/2)**2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon/2)**2 + c = 2 * math.asin(math.sqrt(a)) + + # Radius of earth in kilometers + r = 6371 + + return c * r + + +def is_point_near_polyline(point_lat: float, point_lon: float, polyline_points: List[Tuple[float, float]], + max_distance_meter: float = 50) -> Tuple[bool, float, Optional[int]]: + """ + Simpler function to check if a point is within max_distance_meter of any + segment of the polyline. + + Returns: + Tuple of (is_near, min_distance, min_segment_index) + """ + if not polyline_points or len(polyline_points) < 2: + return False, float('inf'), None + + min_distance = float('inf') + min_segment = None + + # Check each segment of the polyline + for i in range(len(polyline_points) - 1): + # Start and end points of current segment + p1_lat, p1_lon = polyline_points[i] + p2_lat, p2_lon = polyline_points[i + 1] + + # Calculate distance to this segment using a simple approximation + # For short segments, this is reasonable and much simpler + + # Calculate distances to segment endpoints + d1 = calculate_distance(point_lat, point_lon, p1_lat, p1_lon) + d2 = calculate_distance(point_lat, point_lon, p2_lat, p2_lon) + + # Calculate length of segment + segment_length = calculate_distance(p1_lat, p1_lon, p2_lat, p2_lon) + + # Use the simplified distance formula (works well for short segments) + if segment_length > 0: + # Projection calculation + # Vector from p1 to p2 + v1x = p2_lon - p1_lon + v1y = p2_lat - p1_lat + + # Vector from p1 to point + v2x = point_lon - p1_lon + v2y = point_lat - p1_lat + + # Dot product + dot = v1x * v2x + v1y * v2y + + # Squared length of segment + len_sq = v1x * v1x + v1y * v1y + + # Projection parameter (t) + t = max(0, min(1, dot / len_sq)) + + # Projected point + proj_x = p1_lon + t * v1x + proj_y = p1_lat + t * v1y + + # Distance to projection + distance = calculate_distance(point_lat, point_lon, proj_y, proj_x) + else: + # If segment is very short, just use distance to p1 + distance = d1 + + # Update minimum distance + if distance < min_distance: + min_segment = i + min_distance = distance + + # Check if within threshold (convert meters to kilometers) + max_distance_km = max_distance_meter / 1000 + return min_distance <= max_distance_km, min_distance, min_segment + + +def check_if_crossed_stop(prev_location: Tuple[float, float], current_location: Tuple[float, float], + stop_location: Tuple[float, float], threshold_meters: float = 20) -> bool: + """ + Check if a vehicle has crossed a stop between its previous and current location. + + This function determines if a stop was passed by checking if the stop is near + the path between the vehicle's previous and current positions. + + Args: + prev_location: (lat, lon) of previous vehicle location + current_location: (lat, lon) of current vehicle location + stop_location: (lat, lon) of the stop + threshold_meters: Maximum distance in meters from the path to consider the stop crossed + + Returns: + bool: True if the stop was crossed, False otherwise + """ + # If any of the locations are None, return False + if any(loc is None for loc in [prev_location, current_location, stop_location]): + return False + + # 1. First check: Is the stop close enough to either the current or previous position? + # This handles the case where the vehicle might have temporarily stopped at the bus stop + dist_to_prev = geodesic(prev_location, stop_location).meters + dist_to_curr = geodesic(current_location, stop_location).meters + + if dist_to_prev < threshold_meters or dist_to_curr < threshold_meters: + return True + + path_distance = geodesic(prev_location, current_location).meters + + if path_distance < 5: # 5 meters threshold for significant movement + return False + + # Calculate distances from prev to stop and from stop to current + dist_prev_to_stop = geodesic(prev_location, stop_location).meters + dist_stop_to_curr = geodesic(stop_location, current_location).meters + + # Check if the stop is roughly on the path (within reasonable error margin) + # due to GPS inaccuracy and road curvature + is_on_path = abs(dist_prev_to_stop + dist_stop_to_curr - path_distance) < threshold_meters + + # 3. Third check: Direction verification + # We need to verify the vehicle is moving toward the stop and then away from it + + # Calculate bearings + def calculate_bearing(point1: Tuple[float, float], point2: Tuple[float, float]) -> float: + """Calculate the bearing between two points.""" + lat1, lon1 = math.radians(point1[0]), math.radians(point1[1]) + lat2, lon2 = math.radians(point2[0]), math.radians(point2[1]) + + dlon = lon2 - lon1 + + y = math.sin(dlon) * math.cos(lat2) + x = math.cos(lat1) * math.sin(lat2) - math.sin(lat1) * math.cos(lat2) * math.cos(dlon) + + bearing = math.atan2(y, x) + # Convert to degrees + bearing = math.degrees(bearing) + # Normalize to 0-360 + bearing = (bearing + 360) % 360 + + return bearing + + # Get bearings + bearing_prev_to_curr = calculate_bearing(prev_location, current_location) + bearing_prev_to_stop = calculate_bearing(prev_location, stop_location) + bearing_stop_to_curr = calculate_bearing(stop_location, current_location) + + # Check if the bearings are roughly aligned + def angle_diff(a: float, b: float) -> float: + """Calculate the absolute difference between two angles in degrees.""" + return min(abs(a - b), 360 - abs(a - b)) + + alignment_prev_to_stop = angle_diff(bearing_prev_to_curr, bearing_prev_to_stop) < 60 + alignment_stop_to_curr = angle_diff(bearing_prev_to_curr, bearing_stop_to_curr) < 60 + + # 4. Combine all checks: + # - The stop should be roughly on the path + # - The bearings should be aligned + # - The distance from prev to stop and then to curr should be in increasing order of sequence + return (is_on_path and + alignment_prev_to_stop and + alignment_stop_to_curr and + dist_prev_to_stop < path_distance and + dist_stop_to_curr < path_distance) \ No newline at end of file diff --git a/src/models.py b/src/models.py new file mode 100644 index 0000000..bd53c05 --- /dev/null +++ b/src/models.py @@ -0,0 +1,65 @@ +""" +Data models for the GPS tracking system. +Contains SQLAlchemy models and Pydantic models. +""" + +from pydantic import BaseModel +from sqlalchemy import Column, String, DateTime, Boolean, BigInteger, Text +from sqlalchemy.ext.declarative import declarative_base + +Base = declarative_base() +WaybillsBase = declarative_base() + +# SQLAlchemy Models +class DeviceVehicleMapping(Base): + __tablename__ = "device_vehicle_mapping" + __table_args__ = {'schema': 'atlas_app'} + vehicle_no = Column(Text, index=True) + device_id = Column(Text, index=True, primary_key=True) + +class RoutePolyline(Base): + __tablename__ = "route_polylines" + __table_args__ = {'schema': 'atlas_app'} + + route_id = Column(BigInteger, primary_key=True) + polyline = Column(Text) + merchant_operating_city_id = Column(Text, primary_key=True) + +# Waybills database models +class Waybill(Base): + __tablename__ = "waybills" + waybill_id = Column(BigInteger, primary_key=True) + schedule_id = Column(BigInteger) + schedule_trip_id = Column(BigInteger) + deleted = Column(Boolean, nullable=False, default=False) + schedule_no = Column(Text) + schedule_trip_name = Column(Text) + schedule_type = Column(Text) + service_type = Column(Text) + updated_at = Column(DateTime) + status = Column(Text) + vehicle_no = Column(Text) + +class BusSchedule(Base): + __tablename__ = "bus_schedule" + + schedule_id = Column(BigInteger, primary_key=True) + deleted = Column(Boolean, nullable=False, default=False) + route_code = Column(Text) + status = Column(Text) + route_id = Column(BigInteger, nullable=False) + +class BusScheduleTripDetail(Base): + __tablename__ = "bus_schedule_trip_detail" + + schedule_trip_detail_id = Column(BigInteger, primary_key=True) + schedule_trip_id = Column(BigInteger) + deleted = Column(Boolean, nullable=False, default=False) + route_number_id = Column(BigInteger, nullable=False) + +# Pydantic Models +class FleetInfo(BaseModel): + """Pydantic model for fleet information returned by get_fleet_info function""" + vehicle_no: str + device_id: str + route_id: str \ No newline at end of file diff --git a/src/route_matching.py b/src/route_matching.py new file mode 100644 index 0000000..6ed179b --- /dev/null +++ b/src/route_matching.py @@ -0,0 +1,139 @@ +""" +Route matching algorithm and related components. +Contains the main route matching logic and helper functions. +""" + +import json +import logging +import traceback +from typing import List, Dict, Any, Optional + +from .geometry_utils import decode_polyline, is_point_near_polyline + +logger = logging.getLogger('amnex-data-server') + + +def calculate_route_match_score(route_id: str, vehicle_no: str, stops: dict, + vehicle_points: List[dict], max_distance_meter: float = 100) -> float: + """ + Calculate how well a route matches a series of vehicle_points, considering direction. + Uses polyline for more accurate route matching when available. + Returns a score between 0 and 1, where 1 is a perfect match. + """ + try: + # Check if stops is a dict with polyline and stops keys + if isinstance(stops, dict) and 'stops' in stops and 'polyline' in stops: + route_polyline = stops.get('polyline') + polyline_points = decode_polyline(route_polyline) + min_points_required = 4 + else: + route_polyline = "" + stops_info = stops.get('stops') + polyline_points = list(map(lambda x: (x['stop_lat'], x['stop_lon']), stops_info)) + min_points_required = 10 + + if not vehicle_points or len(vehicle_points) < min_points_required: + return 0.0 + + # Sort vehicle_points by timestamp to ensure they're in chronological order + vehicle_points = sorted(vehicle_points, key=lambda x: x.get('timestamp', 0)) + + if polyline_points: + # Count how many vehicle_points are near the polyline + near_points = [] + total_distance = 0.0 + + min_segments_list = [] + for point in vehicle_points: + try: + is_near, distance, min_segment_start = is_point_near_polyline( + point['lat'], point['lon'], polyline_points, max_distance_meter + ) + if is_near: + if min_segment_start is not None: + min_segments_list.append(min_segment_start) + near_points.append(point) + total_distance += distance + except (KeyError, ValueError, TypeError) as e: + logger.debug(f"Error checking if point is near polyline: {e}, point: {point}") + continue + + # Calculate proximity score (0-1) + proximity_ratio = len(near_points) / len(vehicle_points) if len(vehicle_points) > 0 else 0 + + # Only proceed if enough vehicle_points are near the polyline + if proximity_ratio >= 0.3: + # Convert set to list and sort to check direction + if len(min_segments_list) >= 2 and min(min_segments_list) == min_segments_list[0]: + print(f"Route ID: {vehicle_no} {len(near_points)}/{len(vehicle_points)}, Score: {proximity_ratio:.2f}") + return proximity_ratio + return 0.0 + except Exception as e: + error_details = traceback.format_exc() + logger.error(f"Error calculating route match score: {stops} {e}\nTraceback: {error_details}") + return 0.0 + + +def get_route_ids_from_waybills(waybills_session_local, vehicle_no: str, current_lat: float = None, + current_lon: float = None, timestamp: int = None, + provider: str = None, stop_tracker=None) -> List[str]: + """Get the route_id from waybills database for a given vehicle number""" + from .models import Waybill, BusScheduleTripDetail + from .cache_utils import store_vehicle_location_history, get_vehicle_location_history + + try: + with waybills_session_local() as db: + # First get the active waybill for the vehicle + waybill = db.query(Waybill)\ + .filter( + Waybill.vehicle_no == vehicle_no, + Waybill.deleted == False, + Waybill.status == 'Online' + )\ + .order_by(Waybill.updated_at.desc())\ + .first() + + if not waybill: + return [] + + if current_lat is not None and current_lon is not None and stop_tracker: + store_vehicle_location_history(stop_tracker.redis_client, vehicle_no, current_lat, current_lon, timestamp) + + # Add current location to history if provided + location_history = get_vehicle_location_history(stop_tracker.redis_client, vehicle_no) if stop_tracker else [] + if len(location_history) < 5: + return [] + + # Then get all possible routes from bus_schedule + schedules = db.query(BusScheduleTripDetail)\ + .filter( + BusScheduleTripDetail.schedule_trip_id == waybill.schedule_trip_id, + BusScheduleTripDetail.deleted == False + )\ + .all() # Execute the query to get results + + if len(schedules) == 0: + return [] + + print(f"Route ID: Bus schedule len {len(schedules)}") + + best_route_ids = [] + routes_match_score = {} + for schedule in schedules: + if schedule.route_number_id not in routes_match_score: + route_stops = stop_tracker.get_route_stops(str(schedule.route_number_id)) + # Calculate match score using location history + score = calculate_route_match_score(schedule.route_number_id, vehicle_no, route_stops, location_history) + # Ensure score is not None + if score is None: + score = 0.0 + print(f"Route ID: Bus score {vehicle_no} Score for route {schedule.route_number_id}: {score} (Provider: {provider})") + if score > 0.8: + best_route_ids.append(schedule.route_number_id) + routes_match_score[schedule.route_number_id] = score + return best_route_ids + + except Exception as e: + error_details = traceback.format_exc() + logger.error(f"Error querying waybills database for vehicle {vehicle_no} (Provider: {provider}): {e}\nTraceback: {error_details}") + return [] \ No newline at end of file diff --git a/src/stop_tracker.py b/src/stop_tracker.py new file mode 100644 index 0000000..e532f4a --- /dev/null +++ b/src/stop_tracker.py @@ -0,0 +1,429 @@ +""" +Stop tracking and ETA calculation functionality. +Contains the StopTracker class that handles route stops and ETA calculations. +""" + +import json +import logging +import math +import requests +from datetime import datetime, timedelta +from typing import List, Dict, Any, Optional, Tuple + +from .cache_utils import get_vehicle_location_history +from .geometry_utils import calculate_distance, check_if_crossed_stop + +logger = logging.getLogger('amnex-data-server') + + +class StopTracker: + """Handles stop tracking, route information, and ETA calculations""" + + def __init__(self, db_engine, redis_client, use_osrm=True, + osrm_url='http://router.project-osrm.org', google_api_key='', + cache_ttl=3600, route_stop_mapping_api_url='', gtfs_id='', + merchant_operating_city_id=''): + self.db_engine = db_engine + self.redis_client = redis_client + self.use_osrm = use_osrm + self.osrm_url = osrm_url + self.google_api_key = google_api_key + self.cache_ttl = cache_ttl + self.stop_visit_radius = 0.05 # 50 meters in km + self.route_stop_mapping_api_url = route_stop_mapping_api_url + self.gtfs_id = gtfs_id + self.merchant_operating_city_id = merchant_operating_city_id + print(f"StopTracker initialized with {'OSRM' if use_osrm else 'Google Maps'}") + + def get_route_stops(self, route_id: str) -> Dict[str, Any]: + """Get all stops for a route ordered by sequence, including the route polyline if available""" + from .models import RoutePolyline + from sqlalchemy.orm import sessionmaker + + cache_key = f"route_stops_info:{route_id}" + + # Check cache + cached = self.redis_client.get(f"simpleCache:{cache_key}") + if cached: + return json.loads(cached) + + try: + # Get stops for the route from API + stops_api_url = f"{self.route_stop_mapping_api_url}/route-stop-mapping/{self.gtfs_id}/route/{route_id}" + response = requests.get(stops_api_url) + response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx) + stops_data = response.json() + + # Get the route polyline from DB + route_polyline = None + SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.db_engine) + with SessionLocal() as db: + polyline_info = db.query(RoutePolyline)\ + .filter(RoutePolyline.route_id == str(route_id), + RoutePolyline.merchant_operating_city_id == self.merchant_operating_city_id)\ + .first() + if polyline_info and polyline_info.polyline: + route_polyline = polyline_info.polyline + + if not stops_data: + return { + 'stops': [], + 'polyline': None + } + + # Format results + result_stops = [ + { + 'stop_id': stop['stopCode'], + 'sequence': stop['sequenceNum'], + 'name': stop['stopName'], + 'stop_lat': float(stop['stopPoint']['lat']), + 'stop_lon': float(stop['stopPoint']['lon']) + } + for stop in stops_data + ] + result = { + 'stops': result_stops, + 'polyline': route_polyline + } + # Cache result + self.redis_client.setex(f"simpleCache:{cache_key}", 3600, json.dumps(result)) + return result + except requests.exceptions.RequestException as e: + print(f"Error fetching route stops or polyline from API for route {route_id}: {e}") + return { + 'stops': [], + 'polyline': None + } + except json.JSONDecodeError as e: + print(f"Error decoding JSON response for route {route_id}: {e}") + return { + 'stops': [], + 'polyline': None + } + except Exception as e: + print(f"An unexpected error occurred getting stops for route {route_id}: {e}") + return { + 'stops': [], + 'polyline': None + } + + def get_visited_stops(self, route_id: str, vehicle_id: str) -> List[str]: + """Get list of stops already visited by this vehicle on this route""" + visit_key = f"visited_stops:{route_id}:{vehicle_id}" + try: + visited_stops = self.redis_client.get(visit_key) + if visited_stops: + return json.loads(visited_stops) + return [] + except Exception as e: + logger.error(f"Error getting visited stops: {e}") + return [] + + def update_visited_stops(self, route_id: str, vehicle_id: str, stop_id: str) -> List[str]: + """Add a stop to the visited stops list""" + visit_key = f"visited_stops:{route_id}:{vehicle_id}" + try: + visited_stops = self.get_visited_stops(route_id, vehicle_id) + if stop_id not in visited_stops: + visited_stops.append(stop_id) + self.redis_client.setex( + visit_key, + 7200, # 2 hour TTL + json.dumps(visited_stops) + ) + return visited_stops + except Exception as e: + logger.error(f"Error updating visited stops: {e}") + return [] + + def reset_visited_stops(self, route_id: str, vehicle_id: str, vehicle_no: str) -> bool: + """Reset the visited stops list for a vehicle""" + visit_key = f"visited_stops:{route_id}:{vehicle_id}" + history_key = f"vehicle_history:{vehicle_no}" + try: + self.redis_client.delete(visit_key) + self.redis_client.delete(history_key) + logger.info(f"Reset visited stops for vehicle {vehicle_id} on route {route_id}") + return True + except Exception as e: + logger.error(f"Error resetting visited stops: {e}") + return False + + def check_if_at_stop(self, stop: Dict[str, Any], vehicle_lat: float, vehicle_lon: float) -> Tuple[bool, float]: + """Check if vehicle is within radius of a stop""" + # Calculate distance using haversine formula + lat1, lon1 = math.radians(vehicle_lat), math.radians(vehicle_lon) + lat2, lon2 = math.radians(float(stop['stop_lat'])), math.radians(float(stop['stop_lon'])) + + # Haversine formula + dlon = lon2 - lon1 + dlat = lat2 - lat1 + a = math.sin(dlat/2)**2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon/2)**2 + c = 2 * math.asin(math.sqrt(a)) + distance = 6371 * c # Radius of earth in kilometers + + return distance <= self.stop_visit_radius, distance + + def find_next_stop(self, stops: List[Dict[str, Any]], visited_stops: List[str], + vehicle_lat: float, vehicle_lon: float) -> Tuple[Optional[Dict[str, Any]], Optional[float]]: + """Find the next stop in sequence after the last visited stop""" + if not visited_stops: + # If no stops visited yet, find the nearest stop as the next stop + nearest_stop = None + min_distance = float('inf') + for stop in stops: + is_at_stop, distance = self.check_if_at_stop(stop, vehicle_lat, vehicle_lon) + if distance < min_distance: + min_distance = distance + nearest_stop = stop + return (nearest_stop, min_distance) + + # Get the last visited stop ID + last_visited_id = visited_stops[-1] + + # Find its index in the stops list + last_index = -1 + for i, stop in enumerate(stops): + if stop['stop_id'] == last_visited_id: + last_index = i + break + + # If we found the last stop and it's not the last in the route + if last_index >= 0 and last_index < len(stops) - 1: + return (stops[last_index + 1], None) + elif last_index == len(stops) - 1: + # We're at the last stop of the route + return (None, None) + + # If we couldn't find the last visited stop in the list + # (this shouldn't happen but just in case) + return (stops[0] if stops else None, None) + + def find_closest_stop(self, stops: List[Dict[str, Any]], vehicle_lat: float, + vehicle_lon: float) -> Tuple[Optional[Dict[str, Any]], float]: + """Find the closest stop to the given coordinates""" + if not stops: + return None, float('inf') + + closest_stop = None + min_distance = float('inf') + + for stop in stops: + # Calculate distance using haversine formula + lat1, lon1 = math.radians(vehicle_lat), math.radians(vehicle_lon) + lat2, lon2 = math.radians(float(stop['stop_lat'])), math.radians(float(stop['stop_lon'])) + + # Haversine formula + dlon = lon2 - lon1 + dlat = lat2 - lat1 + a = math.sin(dlat/2)**2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon/2)**2 + c = 2 * math.asin(math.sqrt(a)) + distance = 6371 * c # Radius of earth in kilometers + + if distance < min_distance: + min_distance = distance + closest_stop = stop + + return closest_stop, min_distance + + def get_travel_duration(self, origin_id: str, dest_id: str, origin_lat: float, origin_lon: float, + dest_lat: float, dest_lon: float) -> Optional[float]: + """Get travel duration between two stops with caching""" + # Try to get from cache + cache_key = f"route_segment:{origin_id}:{dest_id}" + try: + if origin_id != 0: + cached = self.redis_client.get(cache_key) + if cached: + data = json.loads(cached) + return data.get('duration') + except Exception as e: + print(f"Redis error: {e}") + + # Not in cache, calculate using routing API + try: + duration = None + # Fallback to simple estimation (30 km/h) + # Calculate distance using haversine + lat1, lon1 = math.radians(origin_lat), math.radians(origin_lon) + lat2, lon2 = math.radians(dest_lat), math.radians(dest_lon) + + dlon = lon2 - lon1 + dlat = lat2 - lat1 + a = math.sin(dlat/2)**2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon/2)**2 + c = 2 * math.asin(math.sqrt(a)) + distance = 6371000 * c # Radius of earth in meters + + # Estimate duration: distance / speed (30 km/h = 8.33 m/s) + duration = distance / 8.33 + + # Cache the fallback estimation + cache_data = { + 'duration': duration, + 'timestamp': datetime.now().isoformat(), + 'estimated': True + } + if origin_id != 0: + self.redis_client.setex(cache_key, self.cache_ttl, json.dumps(cache_data)) + + return duration + except Exception as e: + print(f"Error calculating travel duration: {e}") + return None + + def calculate_eta(self, stops_info: Dict[str, Any], route_id: str, vehicle_lat: float, + vehicle_lon: float, current_time: datetime, vehicle_id: str, + visited_stops: List[str] = None, vehicle_no: str = None) -> Optional[Dict[str, Any]]: + """Calculate ETA for all upcoming stops from current position""" + if visited_stops is None: + visited_stops = [] + + # Get all stops for the route + stops = stops_info.get('stops') + if not stops: + return None + + next_stop = None + closest_stop = None + distance = float('inf') + calculation_method = "realtime" + + # Check if the vehicle is at a stop now + for stop in stops: + # Check if vehicle is at the stop based on current position + is_at_stop, _ = self.check_if_at_stop(stop, vehicle_lat, vehicle_lon) + + # Get the vehicle's previous location from history + # Check if we crossed the stop between last position and current position + if not is_at_stop: + location_history = get_vehicle_location_history(self.redis_client, vehicle_no) + if len(location_history) > 0: + last_point = location_history[-1] # Most recent point in history + # Check if the stop is between the last point and current point + crossed_stop = check_if_crossed_stop( + (last_point['lat'], last_point['lon']), + (vehicle_lat, vehicle_lon), + (float(stop['stop_lat']), float(stop['stop_lon'])) + ) + if crossed_stop: + is_at_stop = True + + if is_at_stop: + # Vehicle is at this stop + if stop['stop_id'] not in visited_stops: + # Add to visited stops if not already there + self.update_visited_stops(route_id, vehicle_id, stop['stop_id']) + visited_stops.append(stop['stop_id']) + calculation_method = "visited_stops" + break + + # Find next stop based on visited stops + (next_stop, distance) = self.find_next_stop(stops, visited_stops, vehicle_lat, vehicle_lon) + if next_stop: + if not distance: + _, distance = self.check_if_at_stop(next_stop, vehicle_lat, vehicle_lon) + closest_stop = next_stop + calculation_method = "sequence_based" + else: + # We're at the end of the route, reset visited stops + self.reset_visited_stops(route_id, vehicle_id, vehicle_no) + # Fall back to closest stop method + closest_stop, distance = self.find_closest_stop(stops, vehicle_lat, vehicle_lon) + calculation_method = "distance_based_fallback" + + if not closest_stop: + return None + + # Find the index of the closest/next stop in the route + closest_index = -1 + for i, stop in enumerate(stops): + if stop['stop_id'] == closest_stop['stop_id']: + closest_index = i + break + + if closest_index == -1: + # Something went wrong, stop not found in the list + return None + + # Calculate ETAs for the closest stop and all upcoming stops + eta_list = [] + cumulative_time = 0 + current_lat, current_lon = vehicle_lat, vehicle_lon + + # First, calculate ETA for the closest/next stop + if distance <= 0.01: # 10 meters in km - we're practically at the stop + arrival_time = current_time + calculation_method = "immediate" + else: + # Calculate time to reach the stop + duration = self.get_travel_duration( + 0, closest_stop['stop_id'], + current_lat, current_lon, + closest_stop['stop_lat'], closest_stop['stop_lon'] + ) + + if duration: + arrival_time = current_time + timedelta(seconds=duration) + cumulative_time = duration + calculation_method = "estimated" + else: + # Fallback estimation + duration = distance / 8.33 # distance / (30 km/h in m/s) + arrival_time = current_time + timedelta(seconds=duration) + cumulative_time = duration + calculation_method = "estimated" + + # Add closest/next stop to the ETA list + eta_list.append({ + 'stop_id': closest_stop['stop_id'], + 'stop_seq': closest_stop['sequence'], + 'stop_name': closest_stop['name'], + 'stop_lat': closest_stop['stop_lat'], + 'stop_lon': closest_stop['stop_lon'], + 'arrival_time': int(arrival_time.timestamp()), + 'calculation_method': calculation_method + }) + + # Then calculate ETAs for all remaining stops (everything after closest_index) + for i in range(closest_index + 1, len(stops)): + prev_stop = stops[i-1] + current_stop = stops[i] + + # Calculate duration between stops + duration = self.get_travel_duration( + prev_stop['stop_id'], current_stop['stop_id'], + prev_stop['stop_lat'], prev_stop['stop_lon'], + current_stop['stop_lat'], current_stop['stop_lon'] + ) + + if duration: + cumulative_time += duration + arrival_time = current_time + timedelta(seconds=cumulative_time) + + calculation_method = "estimated" + + eta_list.append({ + 'stop_id': current_stop['stop_id'], + 'stop_seq': current_stop['sequence'], + 'stop_name': current_stop['name'], + 'stop_lat': current_stop['stop_lat'], + 'stop_lon': current_stop['stop_lon'], + 'arrival_time': int(arrival_time.timestamp()), + 'calculation_method': calculation_method + }) + else: + # If we couldn't calculate duration, use estimated method + calculation_method = "estimated" + + return { + 'route_id': route_id, + 'current_time': int(current_time.timestamp()), + 'closest_stop': { + 'stop_id': closest_stop['stop_id'], + 'stop_name': closest_stop['name'], + 'distance': distance + }, + 'calculation_method': calculation_method, + 'eta': eta_list + } \ No newline at end of file