Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions app/operation/subscription.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
from collections import Counter
from datetime import datetime as dt
from json import dumps as json_dumps
from typing import Any
Expand All @@ -10,6 +11,7 @@
from app.db.crud.user import get_user_usages, user_sub_update
from app.db.models import User
from app.models.admin import AdminDetails
from app.models.node import UserIPList
from app.models.settings import Application, ConfigFormat, SubRule, Subscription as SubSettings
from app.models.stats import Period, UserUsageStatsList
from app.models.user import SubscriptionUserResponse, UsersResponseWithInbounds
Expand All @@ -18,9 +20,12 @@
from app.templates import render_template
from config import template_settings

from . import BaseOperation
from . import BaseOperation, OperatorType
from .node import NodeOperation
from .user import UserOperation

node_operator = NodeOperation(operator_type=OperatorType.API)

client_config = {
ConfigFormat.clash_meta: {
"config_format": "clash_meta",
Expand Down Expand Up @@ -115,7 +120,7 @@ def _format_profile_title(

try:
return profile_title.format_map(format_variables)
except (ValueError, KeyError):
except ValueError, KeyError:
# Invalid format string, return original title
return profile_title

Expand All @@ -127,7 +132,7 @@ def _format_announce(sub_settings: SubSettings, format_variables: dict) -> str:

try:
return sub_settings.announce.format_map(format_variables)
except (ValueError, KeyError):
except ValueError, KeyError:
return sub_settings.announce

@staticmethod
Expand Down Expand Up @@ -206,7 +211,7 @@ def _stringify_rule_header_value(value: Any, format_variables: dict[str, str | i
return ""
try:
return header_value.format_map(format_variables)
except (ValueError, KeyError):
except ValueError, KeyError:
return header_value

if isinstance(value, (dict, list, tuple, bool, int, float)):
Expand Down Expand Up @@ -421,6 +426,20 @@ async def user_subscription_apps(self, db: AsyncSession, token: str) -> list[App
format_variables = await self.get_format_variables(user)
return self._make_apps_import_urls(sub_settings.applications, format_variables)

async def user_subscription_online_ips(self, db: AsyncSession, token: str) -> UserIPList:
"""
Get online IP addresses for the subscription user across all available nodes.
"""
db_user = await self.get_validated_sub(db, token=token)
all_nodes_ips = await node_operator.get_user_ip_list_all_nodes(db=db, username=db_user.username)

ips: Counter[str] = Counter()
for node_ips in all_nodes_ips.nodes.values():
if node_ips:
ips.update(node_ips.ips)

return UserIPList(ips=dict(ips))

def _make_apps_import_urls(self, applications: list[Application], format_variables: dict) -> list[Application]:
apps_with_updated_urls = []
for app in applications:
Expand Down
7 changes: 7 additions & 0 deletions app/routers/subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from fastapi.responses import JSONResponse

from app.db import AsyncSession, get_db
from app.models.node import UserIPList
from app.models.settings import Application, ConfigFormat
from app.models.stats import Period, UserUsageStatsList
from app.models.user import SubscriptionUserResponse
Expand Down Expand Up @@ -60,6 +61,12 @@ async def get_sub_user_usage(
return await subscription_operator.get_user_usage(db, token=token, start=start, end=end, period=period)


@router.get("/{token}/online_ips", response_model=UserIPList)
async def user_subscription_online_ips(token: str, db: AsyncSession = Depends(get_db)):
"""Retrieves online IP addresses for the user's subscription."""
return await subscription_operator.user_subscription_online_ips(db, token=token)


@router.get("/{token}/{client_type}")
async def user_subscription_with_client_type(
request: Request,
Expand Down
33 changes: 33 additions & 0 deletions tests/api/test_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
import asyncio
import time
from urllib.parse import parse_qs, unquote, urlsplit
from unittest.mock import AsyncMock, MagicMock

from fastapi import status

from app.models.node import UserIPList, UserIPListAll
from app.models.settings import ConfigFormat, SubRule, Subscription
from app.operation.subscription import SubscriptionOperation
from app.utils import jwt as jwt_utils
Expand Down Expand Up @@ -241,6 +243,37 @@ def test_user_subscriptions(access_token):
cleanup_groups(access_token, core, groups)


def test_user_subscription_online_ips_uses_subscription_token(access_token, monkeypatch):
core, groups = setup_groups(access_token, 1)
user = create_user(
access_token,
group_ids=[group["id"] for group in groups],
payload={"username": unique_name("sub_online_ips")},
)
node_operator_mock = MagicMock()
node_operator_mock.get_user_ip_list_all_nodes = AsyncMock(
return_value=UserIPListAll(
nodes={
7: UserIPList(ips={"198.51.100.10": 2, "203.0.113.20": 1}),
9: UserIPList(ips={"198.51.100.10": 3}),
}
)
)
monkeypatch.setattr("app.operation.subscription.node_operator", node_operator_mock)

try:
response = client.get(f"{user['subscription_url']}/online_ips")

assert response.status_code == status.HTTP_200_OK
assert response.json() == {"ips": {"198.51.100.10": 5, "203.0.113.20": 1}}
awaited_kwargs = node_operator_mock.get_user_ip_list_all_nodes.await_args.kwargs
assert awaited_kwargs["db"] is not None
assert awaited_kwargs["username"] == user["username"]
finally:
delete_user(access_token, user["username"])
cleanup_groups(access_token, core, groups)


def test_user_routes_by_id_and_by_username(access_token):
core, groups = setup_groups(access_token, 1)
user = create_user(access_token, group_ids=[groups[0]["id"]], payload={"username": unique_name("id_routes_user")})
Expand Down
Loading