diff --git a/.env.example b/.env.example index 6c254ba11..1a1f0ca60 100644 --- a/.env.example +++ b/.env.example @@ -27,6 +27,8 @@ UVICORN_PORT = 8000 ## External config to import into v2ray format subscription # EXTERNAL_CONFIG = "config://..." +## WireGuard: enable peer allocation and subscription output +# WIREGUARD_ENABLED = True ## WireGuard: IPv4 range for auto-allocated user peer addresses (and validation of manual peer_ips) # WIREGUARD_GLOBAL_POOL = "10.0.0.0/8" ## Comma-separated IPv4 CIDR subnets never assigned from the pool (e.g. network + gateway on a point-to-point link) diff --git a/app/operation/subscription.py b/app/operation/subscription.py index c8c59e3b9..b7b8452c0 100644 --- a/app/operation/subscription.py +++ b/app/operation/subscription.py @@ -16,7 +16,7 @@ from app.settings import subscription_settings from app.subscription.share import encode_title, generate_subscription, setup_format_variables from app.templates import render_template -from config import template_settings +from config import template_settings, wireguard_settings from . import BaseOperation from .user import UserOperation @@ -299,6 +299,8 @@ async def user_subscription( client_type = matched_rule.target if matched_rule else None if client_type == ConfigFormat.block or not client_type: await self.raise_error(message="Client not supported", code=406) + if client_type == ConfigFormat.wireguard and not wireguard_settings.enabled: + await self.raise_error(message="Client not supported", code=406) # Update user subscription info await user_sub_update(db, db_user.id, user_agent) @@ -351,6 +353,9 @@ async def user_subscription_with_client_type( """Provides a subscription link based on the specified client type (e.g., Clash, V2Ray).""" sub_settings: SubSettings = await subscription_settings() + if client_type == ConfigFormat.wireguard and not wireguard_settings.enabled: + await self.raise_error(message="Client not supported", code=406) + if client_type == ConfigFormat.block or not getattr(sub_settings.manual_sub_request, client_type): await self.raise_error(message="Client not supported", code=406) db_user = await self.get_validated_sub(db, token=token) @@ -374,6 +379,9 @@ async def user_subscription_by_user( client_type: ConfigFormat, request_url: str = "", ): + if client_type == ConfigFormat.wireguard and not wireguard_settings.enabled: + await self.raise_error(message="Client not supported", code=406) + if client_type == ConfigFormat.block: await self.raise_error(message="Client not supported", code=406) diff --git a/app/settings/__init__.py b/app/settings/__init__.py index 88f284cd5..3bae33ad3 100644 --- a/app/settings/__init__.py +++ b/app/settings/__init__.py @@ -59,6 +59,15 @@ async def subscription_settings() -> settings.Subscription: return validated_settings +@cached() +async def general_settings() -> settings.General: + async with GetDB() as db: + db_settings = await get_settings(db) + + validated_settings = settings.General.model_validate(db_settings.general) + return validated_settings + + async def refresh_caches() -> None: await telegram_settings.cache.clear() await discord_settings.cache.clear() @@ -66,6 +75,7 @@ async def refresh_caches() -> None: await notification_settings.cache.clear() await notification_enable.cache.clear() await subscription_settings.cache.clear() + await general_settings.cache.clear() async def handle_settings_message(_: dict): diff --git a/app/subscription/share.py b/app/subscription/share.py index 96a88c381..de77a953a 100644 --- a/app/subscription/share.py +++ b/app/subscription/share.py @@ -13,6 +13,7 @@ from app.models.user import UsersResponseWithInbounds from app.subscription.client_templates import subscription_client_templates, subscription_xray_templates from app.utils.system import get_public_ip, get_public_ipv6, readable_size +from config import wireguard_settings from . import ( ClashConfiguration, @@ -365,6 +366,9 @@ def _resolve_host_xray_template_content(inbound: SubscriptionInboundData) -> str return xray_template_overrides.get(template_id) for host_data in hosts: + if host_data.protocol == "wireguard" and not wireguard_settings.enabled: + continue + result = await process_host(host_data, format_variables, user.inbounds, proxy_settings) if not result: continue diff --git a/app/utils/wireguard.py b/app/utils/wireguard.py index 51815252e..3e5b186a4 100644 --- a/app/utils/wireguard.py +++ b/app/utils/wireguard.py @@ -19,6 +19,7 @@ collect_used_peer_networks_from_proxy_settings_rows, peer_ips_outside_global_pool, ) +from config import wireguard_settings def _normalized_peer_networks(peer_ips: Iterable[str]) -> list[str]: @@ -125,6 +126,9 @@ async def prepare_wireguard_proxy_settings( elif not proxy_settings.wireguard.public_key: proxy_settings.wireguard.public_key = get_wireguard_public_key(proxy_settings.wireguard.private_key) + if not wireguard_settings.enabled: + return proxy_settings + peer_ips = list(proxy_settings.wireguard.peer_ips or []) # Use merged allocate+validate function to avoid double DB scan @@ -189,6 +193,16 @@ async def bulk_reallocate_wireguard_peer_ips( ``target_users`` should be the users allowed by bulk scope (group/admin/user filters). """ + if not wireguard_settings.enabled: + return { + "wireguard_inbound_tags": 0, + "candidates": 0, + "updated": 0, + "dry_run": dry_run, + "sample_usernames": [], + "affected_users": 0, + } + wg_tags = await get_wireguard_inbound_tags_from_db(db) if not wg_tags: return { diff --git a/config.py b/config.py index b34b35e8a..f6dc6177f 100644 --- a/config.py +++ b/config.py @@ -188,6 +188,7 @@ class FeatureSettings(EnvSettings): class WireGuardSettings(EnvSettings): + enabled: bool = Field(default=True, validation_alias="WIREGUARD_ENABLED") global_pool: str = Field(default="10.0.0.0/8", validation_alias="WIREGUARD_GLOBAL_POOL") reserved: str = Field(default="10.0.0.0/31", validation_alias="WIREGUARD_RESERVED") diff --git a/dashboard/src/components/dialogs/user-modal.tsx b/dashboard/src/components/dialogs/user-modal.tsx index 1219aab74..feb1c14ce 100644 --- a/dashboard/src/components/dialogs/user-modal.tsx +++ b/dashboard/src/components/dialogs/user-modal.tsx @@ -22,8 +22,6 @@ import useDirDetection from '@/hooks/use-dir-detection' import useDynamicErrorHandler from '@/hooks/use-dynamic-errors.ts' import { cn } from '@/lib/utils' import { - getGeneralSettings, - getGetGeneralSettingsQueryKey, getGetGroupsSimpleQueryKey, useCreateUser, useCreateUserFromTemplate, @@ -41,7 +39,7 @@ import { parseDateInput } from '@/utils/dateTimeParsing' import { bytesToFormGigabytes, formatBytes, gbToBytes } from '@/utils/formatByte' import { invalidateUserMetricsQueries, upsertUserInUsersCache } from '@/utils/usersCache' import { generateWireGuardKeyPair, getWireGuardPublicKey } from '@/utils/wireguard' -import { useQuery, useQueryClient } from '@tanstack/react-query' +import { useQueryClient } from '@tanstack/react-query' import { CalendarClock, CalendarPlus, ChevronDown, EllipsisVertical, Info, Layers, Link2Off, ListStart, Lock, Network, PieChart, RefreshCcw, Group, Users, Pencil, UserRoundPlus } from 'lucide-react' import React, { useEffect, useState } from 'react' import { UseFormReturn } from 'react-hook-form' @@ -494,13 +492,6 @@ function UserModal({ isDialogOpen, onOpenChange, form, editingUser, editingUserI }, ) - const { data: generalSettings } = useQuery({ - queryKey: getGetGeneralSettingsQueryKey(), - queryFn: () => getGeneralSettings(), - enabled: isDialogOpen, - refetchOnMount: true, - }) - const syncUserCacheFromApiResponse = (user: UserResponse, options?: { allowInsert?: boolean; notifySuccessCallback?: boolean }) => { upsertUserInUsersCache(queryClient, user, { allowInsert: options?.allowInsert ?? false }) invalidateUserMetricsQueries(queryClient) diff --git a/dashboard/src/pages/_dashboard.settings.general.tsx b/dashboard/src/pages/_dashboard.settings.general.tsx index 312ab4082..a56af289f 100644 --- a/dashboard/src/pages/_dashboard.settings.general.tsx +++ b/dashboard/src/pages/_dashboard.settings.general.tsx @@ -24,6 +24,7 @@ const generalSettingsSchema = z.object({ }) type GeneralSettingsFormInput = z.input +type GeneralSettingsStringField = 'default_flow' | 'default_method' export default function General() { const { t } = useTranslation() @@ -151,7 +152,7 @@ export default function General() { ) } - const clearField = (field: keyof GeneralSettingsFormInput) => { + const clearField = (field: GeneralSettingsStringField) => { return (e: React.MouseEvent) => { e.preventDefault() e.stopPropagation() @@ -230,6 +231,7 @@ export default function General() { )} /> + diff --git a/tests/api/test_user.py b/tests/api/test_user.py index 0782f5cbd..1d25a98a2 100644 --- a/tests/api/test_user.py +++ b/tests/api/test_user.py @@ -477,6 +477,59 @@ def test_wireguard_subscription_outputs_are_consistent(access_token): delete_core(access_token, core["id"]) +def test_wireguard_disabled_skips_peer_ip_allocation_and_subscription_outputs(access_token, monkeypatch): + monkeypatch.setattr("config.wireguard_settings.enabled", False) + + interface_private_key, _ = generate_wireguard_keypair() + interface_name = unique_name("wg_disabled") + endpoint = "198.51.100.20" + + core = create_core( + access_token, + name=unique_name("wireguard_disabled_core"), + config={ + "interface_name": interface_name, + "private_key": interface_private_key, + "listen_port": 51820, + "address": ["10.40.0.1/24"], + }, + type="wg", + fallbacks=[], + ) + host_response = client.post( + "/api/host", + headers=auth_headers(access_token), + json={ + "remark": "Disabled WG {USERNAME}", + "address": [endpoint], + "port": 51820, + "inbound_tag": interface_name, + "priority": 1, + }, + ) + assert host_response.status_code == status.HTTP_201_CREATED + host_id = host_response.json()["id"] + group = create_group(access_token, name=unique_name("wg_disabled_group"), inbound_tags=[interface_name]) + user = create_user(access_token, group_ids=[group["id"]], payload={"username": unique_name("wg_disabled_user")}) + + try: + assert user["proxy_settings"]["wireguard"]["private_key"] + assert user["proxy_settings"]["wireguard"]["public_key"] + assert user["proxy_settings"]["wireguard"]["peer_ips"] == [] + + links_response = client.get(f"{user['subscription_url']}/links") + wireguard_response = client.get(f"{user['subscription_url']}/wireguard") + + assert links_response.status_code == status.HTTP_200_OK + assert "wireguard://" not in links_response.text + assert wireguard_response.status_code == status.HTTP_406_NOT_ACCEPTABLE + finally: + delete_user(access_token, user["username"]) + delete_group(access_token, group["id"]) + client.delete(f"/api/host/{host_id}", headers=auth_headers(access_token)) + delete_core(access_token, core["id"]) + + def test_xray_subscription_includes_wireguard_outbound(access_token): interface_private_key, _ = generate_wireguard_keypair() interface_public_key = get_wireguard_public_key(interface_private_key)