diff --git a/app/operation/subscription.py b/app/operation/subscription.py index c8c59e3b..cdc5330c 100644 --- a/app/operation/subscription.py +++ b/app/operation/subscription.py @@ -1,4 +1,5 @@ import re +from collections import Counter from datetime import datetime as dt from json import dumps as json_dumps from typing import Any @@ -10,6 +11,7 @@ from app.db.crud.user import get_user_usages, user_sub_update from app.db.models import User from app.models.admin import AdminDetails +from app.models.node import UserIPList from app.models.settings import Application, ConfigFormat, SubRule, Subscription as SubSettings from app.models.stats import Period, UserUsageStatsList from app.models.user import SubscriptionUserResponse, UsersResponseWithInbounds @@ -18,9 +20,12 @@ from app.templates import render_template from config import template_settings -from . import BaseOperation +from . import BaseOperation, OperatorType +from .node import NodeOperation from .user import UserOperation +node_operator = NodeOperation(operator_type=OperatorType.API) + client_config = { ConfigFormat.clash_meta: { "config_format": "clash_meta", @@ -115,7 +120,7 @@ def _format_profile_title( try: return profile_title.format_map(format_variables) - except (ValueError, KeyError): + except ValueError, KeyError: # Invalid format string, return original title return profile_title @@ -127,7 +132,7 @@ def _format_announce(sub_settings: SubSettings, format_variables: dict) -> str: try: return sub_settings.announce.format_map(format_variables) - except (ValueError, KeyError): + except ValueError, KeyError: return sub_settings.announce @staticmethod @@ -206,7 +211,7 @@ def _stringify_rule_header_value(value: Any, format_variables: dict[str, str | i return "" try: return header_value.format_map(format_variables) - except (ValueError, KeyError): + except ValueError, KeyError: return header_value if isinstance(value, (dict, list, tuple, bool, int, float)): @@ -421,6 +426,20 @@ async def user_subscription_apps(self, db: AsyncSession, token: str) -> list[App format_variables = await self.get_format_variables(user) return self._make_apps_import_urls(sub_settings.applications, format_variables) + async def user_subscription_online_ips(self, db: AsyncSession, token: str) -> UserIPList: + """ + Get online IP addresses for the subscription user across all available nodes. + """ + db_user = await self.get_validated_sub(db, token=token) + all_nodes_ips = await node_operator.get_user_ip_list_all_nodes(db=db, username=db_user.username) + + ips: Counter[str] = Counter() + for node_ips in all_nodes_ips.nodes.values(): + if node_ips: + ips.update(node_ips.ips) + + return UserIPList(ips=dict(ips)) + def _make_apps_import_urls(self, applications: list[Application], format_variables: dict) -> list[Application]: apps_with_updated_urls = [] for app in applications: diff --git a/app/routers/subscription.py b/app/routers/subscription.py index 1a0b6f66..484746cb 100644 --- a/app/routers/subscription.py +++ b/app/routers/subscription.py @@ -4,6 +4,7 @@ from fastapi.responses import JSONResponse from app.db import AsyncSession, get_db +from app.models.node import UserIPList from app.models.settings import Application, ConfigFormat from app.models.stats import Period, UserUsageStatsList from app.models.user import SubscriptionUserResponse @@ -60,6 +61,12 @@ async def get_sub_user_usage( return await subscription_operator.get_user_usage(db, token=token, start=start, end=end, period=period) +@router.get("/{token}/online_ips", response_model=UserIPList) +async def user_subscription_online_ips(token: str, db: AsyncSession = Depends(get_db)): + """Retrieves online IP addresses for the user's subscription.""" + return await subscription_operator.user_subscription_online_ips(db, token=token) + + @router.get("/{token}/{client_type}") async def user_subscription_with_client_type( request: Request, diff --git a/tests/api/test_user.py b/tests/api/test_user.py index 0782f5cb..a8e6215a 100644 --- a/tests/api/test_user.py +++ b/tests/api/test_user.py @@ -9,9 +9,11 @@ import asyncio import time from urllib.parse import parse_qs, unquote, urlsplit +from unittest.mock import AsyncMock, MagicMock from fastapi import status +from app.models.node import UserIPList, UserIPListAll from app.models.settings import ConfigFormat, SubRule, Subscription from app.operation.subscription import SubscriptionOperation from app.utils import jwt as jwt_utils @@ -241,6 +243,37 @@ def test_user_subscriptions(access_token): cleanup_groups(access_token, core, groups) +def test_user_subscription_online_ips_uses_subscription_token(access_token, monkeypatch): + core, groups = setup_groups(access_token, 1) + user = create_user( + access_token, + group_ids=[group["id"] for group in groups], + payload={"username": unique_name("sub_online_ips")}, + ) + node_operator_mock = MagicMock() + node_operator_mock.get_user_ip_list_all_nodes = AsyncMock( + return_value=UserIPListAll( + nodes={ + 7: UserIPList(ips={"198.51.100.10": 2, "203.0.113.20": 1}), + 9: UserIPList(ips={"198.51.100.10": 3}), + } + ) + ) + monkeypatch.setattr("app.operation.subscription.node_operator", node_operator_mock) + + try: + response = client.get(f"{user['subscription_url']}/online_ips") + + assert response.status_code == status.HTTP_200_OK + assert response.json() == {"ips": {"198.51.100.10": 5, "203.0.113.20": 1}} + awaited_kwargs = node_operator_mock.get_user_ip_list_all_nodes.await_args.kwargs + assert awaited_kwargs["db"] is not None + assert awaited_kwargs["username"] == user["username"] + finally: + delete_user(access_token, user["username"]) + cleanup_groups(access_token, core, groups) + + def test_user_routes_by_id_and_by_username(access_token): core, groups = setup_groups(access_token, 1) user = create_user(access_token, group_ids=[groups[0]["id"]], payload={"username": unique_name("id_routes_user")})