diff --git a/wavefront/client/src/config/authenticators.ts b/wavefront/client/src/config/authenticators.ts index 1f8ef2dc..080787f0 100644 --- a/wavefront/client/src/config/authenticators.ts +++ b/wavefront/client/src/config/authenticators.ts @@ -188,6 +188,122 @@ export const AUTHENTICATOR_PROVIDERS_CONFIG: Record = ({ return ( - + Create New Authenticator Configure a new authentication provider for {selectedApp?.app_name} diff --git a/wavefront/client/src/pages/apps/layout.tsx b/wavefront/client/src/pages/apps/layout.tsx index 323402fa..d9564c13 100644 --- a/wavefront/client/src/pages/apps/layout.tsx +++ b/wavefront/client/src/pages/apps/layout.tsx @@ -4,6 +4,7 @@ import { DatasourcesIcon, ModelInferenceIcon, ModelRepositoryIcon, + PermissionIcon, PhoneIcon, RagIcon, WorkflowIcon, @@ -21,12 +22,12 @@ const navItems = [ link: `/apps/:appId/agents`, description: 'Manage and configure agents for this application', }, - // { - // name: 'Authenticators', - // icon: PermissionIcon, - // link: `/apps/:appId/authenticators`, - // description: 'Manage authentication provider configurations', - // }, + { + name: 'Authenticators', + icon: PermissionIcon, + link: `/apps/:appId/authenticators`, + description: 'Manage authentication provider configurations', + }, { id: 'datasources', name: 'Datasources', diff --git a/wavefront/client/src/types/authenticator.ts b/wavefront/client/src/types/authenticator.ts index 6fd54886..857e30db 100644 --- a/wavefront/client/src/types/authenticator.ts +++ b/wavefront/client/src/types/authenticator.ts @@ -1,7 +1,7 @@ import { IApiResponse } from '@app/lib/axios'; // Authenticator type union matching API auth_type field -export type AuthenticatorType = 'google_oauth' | 'microsoft_oauth' | 'email_password'; +export type AuthenticatorType = 'google_oauth' | 'microsoft_oauth' | 'microsoft_adfs' | 'email_password'; // Main Authenticator entity interface export interface Authenticator { @@ -76,6 +76,25 @@ export interface MicrosoftOAuthConfig { response_mode?: string; } +// Microsoft ADFS (OIDC) specific config interface +export interface MicrosoftADFSConfig { + client_id: string; + client_secret: string; + authority: string; + redirect_uri: string; + client_redirect_success_url: string; + client_redirect_failure_url: string; + scopes: string[]; + response_type?: string; + response_mode?: string; + authorize_path?: string; + token_path?: string; + jwks_path?: string; + expected_issuer?: string; + clock_skew_seconds?: number; + verify_ssl?: boolean; +} + // Email/Password specific config interface export interface EmailPasswordConfig { password_policy: { diff --git a/wavefront/server/modules/db_repo_module/db_repo_module/alembic/versions/2026_06_11_2008-add_username_to_user_table.py b/wavefront/server/modules/db_repo_module/db_repo_module/alembic/versions/2026_06_11_2008-add_username_to_user_table.py new file mode 100644 index 00000000..4a279170 --- /dev/null +++ b/wavefront/server/modules/db_repo_module/db_repo_module/alembic/versions/2026_06_11_2008-add_username_to_user_table.py @@ -0,0 +1,32 @@ +"""add username to user table + +Revision ID: a1b2c3d4e5f8 +Revises: 3b5b1bf90e6c +Create Date: 2026-06-11 20:08:00.000000 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'a1b2c3d4e5f8' +down_revision: Union[str, None] = '3b5b1bf90e6c' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + 'user', + sa.Column('username', sa.String(length=150), nullable=True), + ) + op.create_unique_constraint('uq_user_username', 'user', ['username']) + + +def downgrade() -> None: + op.drop_constraint('uq_user_username', 'user', type_='unique') + op.drop_column('user', 'username') diff --git a/wavefront/server/modules/db_repo_module/db_repo_module/models/user.py b/wavefront/server/modules/db_repo_module/db_repo_module/models/user.py index ac13925c..3c41afe4 100644 --- a/wavefront/server/modules/db_repo_module/db_repo_module/models/user.py +++ b/wavefront/server/modules/db_repo_module/db_repo_module/models/user.py @@ -18,6 +18,7 @@ class User(Base): primary_key=True, default=uuid.uuid4, index=True ) email: Mapped[str] = mapped_column(nullable=False, unique=True) + username: Mapped[Optional[str]] = mapped_column(nullable=True, unique=True) password: Mapped[str] = mapped_column(nullable=False) first_name: Mapped[str] = mapped_column(nullable=False) last_name: Mapped[str] = mapped_column(nullable=False) @@ -46,6 +47,7 @@ def to_dict(self): return { 'id': str(self.id), 'email': self.email, + 'username': self.username, 'first_name': self.first_name, 'last_name': self.last_name, } diff --git a/wavefront/server/modules/plugins_module/plugins_module/utils/authenticator_helper.py b/wavefront/server/modules/plugins_module/plugins_module/utils/authenticator_helper.py index 93e31927..1a93a647 100644 --- a/wavefront/server/modules/plugins_module/plugins_module/utils/authenticator_helper.py +++ b/wavefront/server/modules/plugins_module/plugins_module/utils/authenticator_helper.py @@ -12,7 +12,13 @@ def validate_google_oauth_config(config: Dict[str, Any]) -> List[str]: """Validate Google OAuth configuration and return list of errors.""" errors = [] - required_fields = ['client_id', 'client_secret', 'redirect_uri'] + required_fields = [ + 'client_id', + 'client_secret', + 'redirect_uri', + 'client_redirect_success_url', + 'client_redirect_failure_url', + ] for field in required_fields: if not config.get(field): errors.append(f'Missing required field: {field}') @@ -24,6 +30,11 @@ def validate_google_oauth_config(config: Dict[str, Any]) -> List[str]: ): errors.append('redirect_uri must be a valid HTTP/HTTPS URL') + for field in ('client_redirect_success_url', 'client_redirect_failure_url'): + value = config.get(field) + if value and not (value.startswith('http://') or value.startswith('https://')): + errors.append(f'{field} must be a valid HTTP/HTTPS URL') + # Validate scopes scopes = config.get('scopes', []) if not isinstance(scopes, list) or len(scopes) == 0: @@ -61,6 +72,32 @@ def validate_microsoft_oauth_config(config: Dict[str, Any]) -> List[str]: return errors +def validate_microsoft_adfs_config(config: Dict[str, Any]) -> List[str]: + """Validate Microsoft ADFS configuration and return list of errors.""" + errors = [] + + required_fields = ['client_id', 'client_secret', 'authority', 'redirect_uri'] + for field in required_fields: + if not config.get(field): + errors.append(f'Missing required field: {field}') + + authority = config.get('authority', '') + if authority and not authority.startswith('https://'): + errors.append('authority must be a valid HTTPS URL') + + redirect_uri = config.get('redirect_uri') + if redirect_uri and not ( + redirect_uri.startswith('http://') or redirect_uri.startswith('https://') + ): + errors.append('redirect_uri must be a valid HTTP/HTTPS URL') + + scopes = config.get('scopes', []) + if not isinstance(scopes, list) or len(scopes) == 0: + errors.append('scopes must be a non-empty list') + + return errors + + def validate_email_password_config(config: Dict[str, Any]) -> List[str]: """Validate email/password configuration and return list of errors.""" errors = [] @@ -130,6 +167,23 @@ def get_config_template(auth_type: str) -> Dict[str, Any]: 'response_type': 'code', 'response_mode': 'query', }, + 'microsoft_adfs': { + 'client_id': 'YOUR_ADFS_CLIENT_ID', + 'client_secret': 'YOUR_ADFS_CLIENT_SECRET', + 'authority': 'https://fs.your-domain.com', + 'redirect_uri': 'https://your-domain.com/v1/oauth/adfs/callback', + 'client_redirect_success_url': 'https://your-domain.com/login/success', + 'client_redirect_failure_url': 'https://your-domain.com/login/failed', + 'scopes': ['openid', 'profile', 'email'], + 'response_type': 'code', + 'response_mode': 'query', + 'authorize_path': '/adfs/oauth2/authorize', + 'token_path': '/adfs/oauth2/token', + 'jwks_path': '/adfs/discovery/keys', + 'expected_issuer': 'https://fs.your-domain.com/adfs', + 'clock_skew_seconds': 60, + 'verify_ssl': True, + }, } return templates.get(auth_type, {}) diff --git a/wavefront/server/modules/user_management_module/user_management_module/authorization/require_auth.py b/wavefront/server/modules/user_management_module/user_management_module/authorization/require_auth.py index ff7629d3..8c73d3e4 100644 --- a/wavefront/server/modules/user_management_module/user_management_module/authorization/require_auth.py +++ b/wavefront/server/modules/user_management_module/user_management_module/authorization/require_auth.py @@ -45,6 +45,7 @@ '/floware/v1/plugin-auth/authenticate', '/floware/v1/oauth/google/callback', '/floware/v1/oauth/microsoft/callback', + '/floware/v1/oauth/adfs/callback', '/floware/v1/plugin-auth/oauth/init', '/floware/v1/settings/config', ] diff --git a/wavefront/server/modules/user_management_module/user_management_module/controllers/auth_plugin_controller.py b/wavefront/server/modules/user_management_module/user_management_module/controllers/auth_plugin_controller.py index f634ff00..449b8132 100644 --- a/wavefront/server/modules/user_management_module/user_management_module/controllers/auth_plugin_controller.py +++ b/wavefront/server/modules/user_management_module/user_management_module/controllers/auth_plugin_controller.py @@ -1,4 +1,6 @@ import json +import logging +import secrets from uuid import uuid4 from db_repo_module.models.resource import ResourceScope from dependency_injector.wiring import inject, Provide @@ -36,6 +38,53 @@ auth_plugin_router = APIRouter() +logger = logging.getLogger(__name__) + +# Per-flow OAuth state lives in Redis with a short TTL. Each authorize -> +# callback round-trip is a one-shot: stored on init, consumed (deleted) on +# successful callback. Replays after consumption miss the cache and are +# rejected. +OAUTH_FLOW_TTL_SECONDS = 600 +_OAUTH_FLOW_KEY_PREFIX = 'oauth:flow:' + + +def _store_oauth_flow(cache_manager: CacheManager, auth_id: str) -> tuple[str, str]: + """Mint and persist an opaque state+nonce pair bound to auth_id. + + `nx=True` prevents accidental overwrite if (astronomically) the same + 32-byte token is minted twice. + """ + state = secrets.token_urlsafe(32) + nonce = secrets.token_urlsafe(32) + cache_manager.add( + f'{_OAUTH_FLOW_KEY_PREFIX}{state}', + json.dumps({'auth_id': str(auth_id), 'nonce': nonce}), + expiry=OAUTH_FLOW_TTL_SECONDS, + nx=True, + ) + return state, nonce + + +def _consume_oauth_flow( + cache_manager: CacheManager, state: Optional[str] +) -> Optional[Dict[str, str]]: + """Single-use lookup of the flow record. Returns None on miss/parse error.""" + if not state: + return None + key = f'{_OAUTH_FLOW_KEY_PREFIX}{state}' + raw = cache_manager.get_str(key) + if not raw: + return None + try: + flow = json.loads(raw) + except (TypeError, ValueError): + cache_manager.remove(key) + return None + cache_manager.remove(key) + if not isinstance(flow, dict) or 'auth_id' not in flow: + return None + return flow + class UnifiedAuthRequest(BaseModel): auth_id: str @@ -204,10 +253,12 @@ async def init_oauth_flow( authenticator_repository: SQLAlchemyRepository[Authenticator] = Depends( Provide[PluginsContainer.authenticator_repository] ), + cache_manager: CacheManager = Depends(Provide[UserContainer.cache_manager]), ): """Initialize OAuth flow and return authorization URL.""" try: + logger.debug('OAuth init requested for auth_id=%s', oauth_request.auth_id) # Get authenticator instance by ID auth_id = UUID(oauth_request.auth_id) authenticator = await get_authenticator_instance( @@ -215,6 +266,10 @@ async def init_oauth_flow( ) if not authenticator: + logger.debug( + 'OAuth init: no enabled authenticator for auth_id=%s', + oauth_request.auth_id, + ) return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=response_formatter.buildErrorResponse( @@ -222,9 +277,17 @@ async def init_oauth_flow( ), ) - # Generate state and get authorization URL - state = json.dumps({'auth_id': oauth_request.auth_id}) - auth_url = authenticator.get_authorization_url(state) + # Mint opaque CSRF state + OIDC nonce, persist server-side, and pass + # both into the provider so they end up in the authorize URL. + state, nonce = _store_oauth_flow(cache_manager, oauth_request.auth_id) + logger.debug( + 'OAuth flow stored: auth_id=%s state=%s nonce=%s ttl=%ss', + oauth_request.auth_id, + state, + nonce, + OAUTH_FLOW_TTL_SECONDS, + ) + auth_url = authenticator.get_authorization_url(state, nonce=nonce) if not auth_url: return JSONResponse( @@ -274,11 +337,20 @@ async def google_oauth_callback( token_service: TokenService = Depends(Provide[AuthContainer.token_service]), ): """Handle Google OAuth callback.""" - state_obj = json.loads(state) - auth_id = state_obj['auth_id'] + logger.debug( + 'Google OAuth callback: has_code=%s has_state=%s has_error=%s state=%s', + bool(code), + bool(state), + bool(error), + state, + ) + flow = _consume_oauth_flow(cache_manager, state) + if flow is None: + logger.warning('Google OAuth callback received unknown/expired state') + return RedirectResponse(url='about:blank') return await _handle_oauth_callback( - auth_id, + flow['auth_id'], {'authorization_code': code, 'state': state, 'error': error}, request, response_formatter, @@ -288,6 +360,7 @@ async def google_oauth_callback( session_repository, cache_manager, token_service, + expected_nonce=flow.get('nonce'), ) @@ -317,11 +390,71 @@ async def microsoft_oauth_callback( token_service: TokenService = Depends(Provide[AuthContainer.token_service]), ): """Handle Microsoft OAuth callback.""" - state_obj = json.loads(state) - auth_id = state_obj['auth_id'] + logger.debug( + 'Microsoft OAuth callback: has_code=%s has_state=%s has_error=%s state=%s', + bool(code), + bool(state), + bool(error), + state, + ) + flow = _consume_oauth_flow(cache_manager, state) + if flow is None: + logger.warning('Microsoft OAuth callback received unknown/expired state') + return RedirectResponse(url='about:blank') + + return await _handle_oauth_callback( + flow['auth_id'], + {'authorization_code': code, 'state': state, 'error': error}, + request, + response_formatter, + authenticator_repository, + user_service, + user_repository, + session_repository, + cache_manager, + token_service, + expected_nonce=flow.get('nonce'), + ) + + +@auth_plugin_router.get('/v1/oauth/adfs/callback') +@inject +async def microsoft_adfs_oauth_callback( + request: Request, + state: str = Query(...), + code: Optional[str] = Query(None), + error: Optional[str] = Query(None), + response_formatter: ResponseFormatter = Depends( + Provide[CommonContainer.response_formatter] + ), + authenticator_repository: SQLAlchemyRepository[Authenticator] = Depends( + Provide[PluginsContainer.authenticator_repository] + ), + user_repository: SQLAlchemyRepository[User] = Depends( + Provide[UserContainer.user_repository] + ), + user_service: UserService = Depends(Provide[UserContainer.user_service]), + session_repository: SQLAlchemyRepository[Session] = Depends( + Provide[UserContainer.session_repository] + ), + cache_manager: CacheManager = Depends(Provide[UserContainer.cache_manager]), + token_service: TokenService = Depends(Provide[AuthContainer.token_service]), +): + """Handle Microsoft ADFS OAuth callback.""" + logger.debug( + 'Microsoft ADFS callback: has_code=%s has_state=%s has_error=%s state=%s', + bool(code), + bool(state), + bool(error), + state, + ) + flow = _consume_oauth_flow(cache_manager, state) + if flow is None: + logger.warning('Microsoft ADFS callback received unknown/expired state') + return RedirectResponse(url='about:blank') return await _handle_oauth_callback( - auth_id, + flow['auth_id'], {'authorization_code': code, 'state': state, 'error': error}, request, response_formatter, @@ -331,6 +464,7 @@ async def microsoft_oauth_callback( session_repository, cache_manager, token_service, + expected_nonce=flow.get('nonce'), ) @@ -345,10 +479,19 @@ async def _handle_oauth_callback( session_repository: SQLAlchemyRepository[Session], cache_manager: CacheManager, token_service: TokenService, + expected_nonce: Optional[str] = None, ) -> RedirectResponse: """Common OAuth callback handler.""" try: + logger.debug( + '_handle_oauth_callback: auth_id=%s has_code=%s has_error=%s ' + 'expected_nonce_set=%s', + auth_id, + bool(callback_data.get('authorization_code')), + bool(callback_data.get('error')), + expected_nonce is not None, + ) # Get authenticator instance and config auth_uuid = UUID(auth_id) authenticator, config_data = await get_authenticator_with_config( @@ -379,6 +522,12 @@ def get_failure_redirect(error_msg: str) -> RedirectResponse: provider = config_data.get('auth_type') success_url = config_data.get('config', {}).get('client_redirect_success_url') failure_url = config_data.get('config', {}).get('client_redirect_failure_url') + logger.debug( + '_handle_oauth_callback: provider=%s success_url=%s failure_url=%s', + provider, + success_url, + failure_url, + ) # Handle OAuth error from provider if callback_data.get('error'): @@ -393,7 +542,19 @@ def get_failure_redirect(error_msg: str) -> RedirectResponse: return RedirectResponse(url='about:blank') # Handle OAuth callback - auth_result = authenticator.handle_callback(callback_data) + auth_result = authenticator.handle_callback( + callback_data, expected_nonce=expected_nonce + ) + ui = auth_result.user_info + logger.debug( + '_handle_oauth_callback: provider auth_result success=%s error_code=%s ' + 'email=%s upn=%s unique_name=%s', + auth_result.success, + auth_result.error_code, + ui.email if ui else None, + ui.upn if ui else None, + ui.unique_name if ui else None, + ) if not auth_result.success: if failure_url: @@ -406,8 +567,18 @@ def get_failure_redirect(error_msg: str) -> RedirectResponse: return RedirectResponse(url=f'{failure_url}?{params}') return RedirectResponse(url='about:blank') - # Create session from auth result - user = await user_repository.find_one(email=auth_result.user_info.email) + if ui is None: + return get_failure_redirect('OAuth authentication returned no user info') + + user = await user_repository.find_one(email=ui.email) + if user is None and ui.email: + user = await user_repository.find_one(username=ui.email.lower()) + logger.debug( + '_handle_oauth_callback: user lookup by identifier=%s found=%s deleted=%s', + ui.email, + user is not None, + user.deleted if user else None, + ) if user is None: if failure_url: params = urlencode( @@ -448,6 +619,11 @@ def get_failure_redirect(error_msg: str) -> RedirectResponse: role_id = await user_service.get_user_role_for_scope( user_id=str(user.id), scope=ResourceScope.CONSOLE ) + logger.debug( + '_handle_oauth_callback: console role lookup user_id=%s role_id=%s', + str(user.id), + role_id, + ) if not role_id: if failure_url: @@ -468,12 +644,24 @@ def get_failure_redirect(error_msg: str) -> RedirectResponse: # Success: redirect to success URL with access token if success_url: + logger.debug( + '_handle_oauth_callback: success redirect provider=%s user_id=%s ' + 'session_id=%s -> %s', + provider, + str(user.id), + str(session.id), + success_url, + ) params = urlencode({'provider': provider, 'access_token': token}) return RedirectResponse(url=f'{success_url}?{params}') + logger.debug( + '_handle_oauth_callback: no success_url configured, redirecting to about:blank' + ) return RedirectResponse(url='about:blank') except Exception as e: + logger.debug('_handle_oauth_callback raised: %s', e) # Try to get config for failure URL try: auth_uuid = UUID(auth_id) diff --git a/wavefront/server/modules/user_management_module/user_management_module/controllers/user_controller.py b/wavefront/server/modules/user_management_module/user_management_module/controllers/user_controller.py index b6b30387..25d28b09 100644 --- a/wavefront/server/modules/user_management_module/user_management_module/controllers/user_controller.py +++ b/wavefront/server/modules/user_management_module/user_management_module/controllers/user_controller.py @@ -95,6 +95,18 @@ async def create_user( ), ) + if new_user.username: + existing_by_username = await user_repository.find_one( + username=new_user.username + ) + if existing_by_username and not existing_by_username.deleted: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=response_formatter.buildErrorResponse( + 'User with the same username already exists' + ), + ) + async with user_repository.session() as session: try: get_console_resources_query = ( @@ -120,6 +132,7 @@ async def create_user( hashed_password = hash_password(new_user.password) user = User( email=new_user.email, + username=new_user.username, password=hashed_password, first_name=new_user.first_name, last_name=new_user.last_name, @@ -301,6 +314,7 @@ async def get_all_user( User.first_name, User.last_name, User.email, + User.username, func.array_agg( func.json_build_object( 'id', @@ -326,6 +340,7 @@ async def get_all_user( if len(name) > 1 and name[1]: filters.append(User.last_name.ilike(f'%{name[1]}%')) filters.append(User.email.ilike(f'%{search}%')) + filters.append(User.username.ilike(f'%{search}%')) query = query.where(or_(*filters)) # Add role filter diff --git a/wavefront/server/modules/user_management_module/user_management_module/models/user_schema.py b/wavefront/server/modules/user_management_module/user_management_module/models/user_schema.py index 170041c1..28dd29ab 100644 --- a/wavefront/server/modules/user_management_module/user_management_module/models/user_schema.py +++ b/wavefront/server/modules/user_management_module/user_management_module/models/user_schema.py @@ -11,6 +11,7 @@ class NewUser(BaseModel): email: EmailStr = Field(..., max_length=254) # RFC 5321 standard max length + username: Optional[str] = Field(None, min_length=3, max_length=50) password: str = Field(..., min_length=8) first_name: Optional[str] = Field(None, min_length=1, max_length=50) last_name: Optional[str] = Field(None, max_length=50) @@ -44,6 +45,17 @@ def validate_email_format(cls, v): return v.lower() # Normalize email to lowercase + @field_validator('username') + @classmethod + def validate_username_format(cls, v): + if v is None: + return v + if not re.match(r'^[a-zA-Z0-9._@+-]+$', v): + raise ValueError( + 'Username may only contain letters, digits, and the characters . _ @ + -' + ) + return v.lower() + @field_validator('password') @classmethod def validate_password_strength(cls, v): diff --git a/wavefront/server/plugins/authenticator/authenticator/__init__.py b/wavefront/server/plugins/authenticator/authenticator/__init__.py index d87f4e2d..62b9324b 100644 --- a/wavefront/server/plugins/authenticator/authenticator/__init__.py +++ b/wavefront/server/plugins/authenticator/authenticator/__init__.py @@ -11,6 +11,7 @@ from .email_password.config import EmailPasswordConfig from .google_oauth.config import GoogleOAuthConfig from .microsoft_oauth.config import MicrosoftOAuthConfig +from .microsoft_adfs.config import MicrosoftADFSConfig __all__ = [ 'AuthenticatorFactory', @@ -24,4 +25,5 @@ 'EmailPasswordConfig', 'GoogleOAuthConfig', 'MicrosoftOAuthConfig', + 'MicrosoftADFSConfig', ] diff --git a/wavefront/server/plugins/authenticator/authenticator/email_password/authenticator.py b/wavefront/server/plugins/authenticator/authenticator/email_password/authenticator.py index 76c77fbd..71c4a1d5 100644 --- a/wavefront/server/plugins/authenticator/authenticator/email_password/authenticator.py +++ b/wavefront/server/plugins/authenticator/authenticator/email_password/authenticator.py @@ -116,11 +116,15 @@ def validate_config(self) -> bool: except Exception: return False - def get_authorization_url(self, state: Optional[str] = None) -> Optional[str]: + def get_authorization_url( + self, state: Optional[str] = None, nonce: Optional[str] = None + ) -> Optional[str]: """Email/password doesn't need authorization URL.""" return None - def handle_callback(self, callback_data: Dict[str, Any]) -> AuthResult: + def handle_callback( + self, callback_data: Dict[str, Any], expected_nonce: Optional[str] = None + ) -> AuthResult: """Email/password doesn't use OAuth callbacks.""" return AuthResult( success=False, diff --git a/wavefront/server/plugins/authenticator/authenticator/factory.py b/wavefront/server/plugins/authenticator/authenticator/factory.py index b33bd3c3..d6c46c1b 100644 --- a/wavefront/server/plugins/authenticator/authenticator/factory.py +++ b/wavefront/server/plugins/authenticator/authenticator/factory.py @@ -5,9 +5,11 @@ from .email_password import EmailPasswordAuthenticator from .google_oauth import GoogleOAuthAuthenticator from .microsoft_oauth import MicrosoftOAuthAuthenticator +from .microsoft_adfs import MicrosoftADFSAuthenticator from .email_password.config import EmailPasswordConfig from .google_oauth.config import GoogleOAuthConfig from .microsoft_oauth.config import MicrosoftOAuthConfig +from .microsoft_adfs.config import MicrosoftADFSConfig class AuthenticatorFactory: @@ -28,6 +30,7 @@ def __init__(self): if not hasattr(self, '_initialized'): self._google_instances: Dict[str, GoogleOAuthAuthenticator] = {} self._microsoft_instances: Dict[str, MicrosoftOAuthAuthenticator] = {} + self._adfs_instances: Dict[str, MicrosoftADFSAuthenticator] = {} self._email_instances: Dict[str, EmailPasswordAuthenticator] = {} self._instances_lock = threading.Lock() self._initialized = True @@ -85,6 +88,8 @@ def validate_config( return GoogleOAuthAuthenticator.validate_config_static(config) elif auth_type == AuthenticatorType.MICROSOFT_OAUTH: return MicrosoftOAuthAuthenticator.validate_config_static(config) + elif auth_type == AuthenticatorType.MICROSOFT_ADFS: + return MicrosoftADFSAuthenticator.validate_config_static(config) else: raise ValueError(f'Unsupported authenticator type: {auth_type}') @@ -157,6 +162,7 @@ def get_cached_instance_count( return ( len(self._google_instances) + len(self._microsoft_instances) + + len(self._adfs_instances) + len(self._email_instances) ) @@ -165,6 +171,7 @@ def clear_all_instances(self) -> None: with self._instances_lock: self._google_instances.clear() self._microsoft_instances.clear() + self._adfs_instances.clear() self._email_instances.clear() def _get_cache_for_type( @@ -175,6 +182,8 @@ def _get_cache_for_type( return self._google_instances elif auth_type == AuthenticatorType.MICROSOFT_OAUTH: return self._microsoft_instances + elif auth_type == AuthenticatorType.MICROSOFT_ADFS: + return self._adfs_instances elif auth_type == AuthenticatorType.EMAIL_PASSWORD: return self._email_instances else: @@ -196,6 +205,10 @@ def _create_authenticator( typed_config = MicrosoftOAuthConfig(**config) return MicrosoftOAuthAuthenticator(typed_config) + elif auth_type == AuthenticatorType.MICROSOFT_ADFS: + typed_config = MicrosoftADFSConfig(**config) + return MicrosoftADFSAuthenticator(typed_config) + else: raise ValueError(f'Unsupported authenticator type: {auth_type}') diff --git a/wavefront/server/plugins/authenticator/authenticator/google_oauth/authenticator.py b/wavefront/server/plugins/authenticator/authenticator/google_oauth/authenticator.py index 5c4d5b6d..6d53a20e 100644 --- a/wavefront/server/plugins/authenticator/authenticator/google_oauth/authenticator.py +++ b/wavefront/server/plugins/authenticator/authenticator/google_oauth/authenticator.py @@ -151,16 +151,13 @@ def validate_config(self) -> bool: except Exception: return False - def get_authorization_url(self, state: Optional[str] = None) -> Optional[str]: + def get_authorization_url( + self, state: Optional[str] = None, nonce: Optional[str] = None + ) -> Optional[str]: """Get the Google OAuth authorization URL.""" if not state: raise ValueError("State doesn't exist Google Oauth") - state_obj = json.loads(state) - - if state_obj['auth_id'] is None: - raise ValueError("Auth Id doesn't exist in Google Oauth state") - params = { 'response_type': 'code', 'client_id': self.config.client_id, @@ -171,13 +168,24 @@ def get_authorization_url(self, state: Optional[str] = None) -> Optional[str]: 'prompt': self.config.prompt, } + if nonce: + params['nonce'] = nonce + if self.config.hosted_domain: params['hd'] = self.config.hosted_domain return f'{self.auth_url}?{urlencode(params)}' - def handle_callback(self, callback_data: Dict[str, Any]) -> AuthResult: - """Handle Google OAuth callback.""" + def handle_callback( + self, callback_data: Dict[str, Any], expected_nonce: Optional[str] = None + ) -> AuthResult: + """Handle Google OAuth callback. + + Identity comes from Google's userinfo endpoint (not an id_token), so + `expected_nonce` is accepted for ABC compatibility but not enforced. + State CSRF protection is performed by the controller before this is + called. + """ return self.authenticate(callback_data) def refresh_token(self, refresh_token: str) -> TokenResult: diff --git a/wavefront/server/plugins/authenticator/authenticator/helper.py b/wavefront/server/plugins/authenticator/authenticator/helper.py index cdc96e02..e4f517d9 100644 --- a/wavefront/server/plugins/authenticator/authenticator/helper.py +++ b/wavefront/server/plugins/authenticator/authenticator/helper.py @@ -102,6 +102,7 @@ def get_authenticator_display_name(auth_type: AuthenticatorType) -> str: AuthenticatorType.EMAIL_PASSWORD: 'Email & Password', AuthenticatorType.GOOGLE_OAUTH: 'Google OAuth', AuthenticatorType.MICROSOFT_OAUTH: 'Microsoft OAuth', + AuthenticatorType.MICROSOFT_ADFS: 'Microsoft ADFS', AuthenticatorType.SAML: 'SAML', AuthenticatorType.LDAP: 'LDAP', } @@ -110,7 +111,11 @@ def get_authenticator_display_name(auth_type: AuthenticatorType) -> str: def is_oauth_provider(auth_type: AuthenticatorType) -> bool: """Check if authenticator type is an OAuth provider.""" - oauth_types = {AuthenticatorType.GOOGLE_OAUTH, AuthenticatorType.MICROSOFT_OAUTH} + oauth_types = { + AuthenticatorType.GOOGLE_OAUTH, + AuthenticatorType.MICROSOFT_OAUTH, + AuthenticatorType.MICROSOFT_ADFS, + } return auth_type in oauth_types diff --git a/wavefront/server/plugins/authenticator/authenticator/microsoft_adfs/__init__.py b/wavefront/server/plugins/authenticator/authenticator/microsoft_adfs/__init__.py new file mode 100644 index 00000000..c1c8889f --- /dev/null +++ b/wavefront/server/plugins/authenticator/authenticator/microsoft_adfs/__init__.py @@ -0,0 +1,4 @@ +from .authenticator import MicrosoftADFSAuthenticator +from .config import MicrosoftADFSConfig + +__all__ = ['MicrosoftADFSAuthenticator', 'MicrosoftADFSConfig'] diff --git a/wavefront/server/plugins/authenticator/authenticator/microsoft_adfs/authenticator.py b/wavefront/server/plugins/authenticator/authenticator/microsoft_adfs/authenticator.py new file mode 100644 index 00000000..67ea03d7 --- /dev/null +++ b/wavefront/server/plugins/authenticator/authenticator/microsoft_adfs/authenticator.py @@ -0,0 +1,500 @@ +import json +import logging +import re +import ssl +import jwt +import requests +from datetime import datetime +from jwt import PyJWKClient +from typing import Dict, Any, Optional +from urllib.parse import urlencode, urlparse + +from ..types import AuthenticatorABC, AuthResult, TokenResult, HealthStatus, UserInfo +from .config import MicrosoftADFSConfig + +logger = logging.getLogger(__name__) + +_ALLOWED_ID_TOKEN_ALGS = ['RS256', 'RS384', 'RS512', 'ES256', 'ES384', 'ES512'] + + +class MicrosoftADFSAuthenticator(AuthenticatorABC): + """Microsoft ADFS (OIDC) authenticator. + + Identity is sourced from the `id_token` returned in the token response + rather than a userinfo / Graph call, since on-prem ADFS does not always + expose `/adfs/userinfo` and Microsoft Graph is unreachable. + """ + + def __init__(self, config: MicrosoftADFSConfig): + self.config = config + base = config.authority.rstrip('/') + self.auth_url = f'{base}{config.authorize_path}' + self.token_url = f'{base}{config.token_path}' + self.jwks_url = f'{base}{config.jwks_path}' + + ssl_ctx = ssl.create_default_context() + if not config.verify_ssl: + ssl_ctx.check_hostname = False + ssl_ctx.verify_mode = ssl.CERT_NONE + # PyJWKClient caches keys in-process; safe to construct once per instance. + self._jwks_client = PyJWKClient(self.jwks_url, ssl_context=ssl_ctx) + + @staticmethod + def validate_config_static(config: Dict[str, Any]) -> bool: + required_fields = [ + 'client_id', + 'client_secret', + 'authority', + 'redirect_uri', + 'client_redirect_success_url', + 'client_redirect_failure_url', + 'scopes', + ] + for field_name in required_fields: + if not config.get(field_name): + raise ValueError(f'{field_name} is required') + + authority = config['authority'] + if not authority.startswith('https://'): + raise ValueError('authority must be a valid HTTPS URL') + + parsed_uri = urlparse(config['redirect_uri']) + if not parsed_uri.scheme or not parsed_uri.netloc: + raise ValueError('redirect_uri must be a valid URL with scheme and netloc') + + for url_field in ['client_redirect_success_url', 'client_redirect_failure_url']: + parsed_url = urlparse(config[url_field]) + if not parsed_url.scheme or not parsed_url.netloc: + raise ValueError( + f'{url_field} must be a valid URL with scheme and netloc' + ) + + scopes = config.get('scopes', []) + if not scopes or len(scopes) == 0: + raise ValueError('scopes array cannot be empty') + + return True + + def authenticate( + self, + credentials: Dict[str, Any], + expected_nonce: Optional[str] = None, + ) -> AuthResult: + authorization_code = credentials.get('authorization_code') + + if not authorization_code: + return AuthResult( + success=False, + error='Authorization code is required', + error_code='MISSING_AUTH_CODE', + ) + + token_result, id_token = self._exchange_code_for_token(authorization_code) + + if not token_result.success: + return AuthResult( + success=False, + error=token_result.error, + error_code='TOKEN_EXCHANGE_FAILED', + ) + + if not id_token: + return AuthResult( + success=False, + error='ADFS response missing id_token', + error_code='ID_TOKEN_MISSING', + ) + + user_info = self._get_user_info_from_id_token(id_token, expected_nonce) + + if not user_info: + return AuthResult( + success=False, + error='Failed to extract user information from id_token', + error_code='USER_INFO_FAILED', + ) + + return AuthResult( + success=True, + user_info=user_info, + access_token=token_result.access_token, + refresh_token=token_result.refresh_token, + ) + + def validate_config(self) -> bool: + try: + required_fields = [ + 'client_id', + 'client_secret', + 'authority', + 'redirect_uri', + 'client_redirect_success_url', + 'client_redirect_failure_url', + 'scopes', + ] + for field_name in required_fields: + if not getattr(self.config, field_name, None): + return False + + if not self.config.authority.startswith('https://'): + return False + + for url in ( + self.config.redirect_uri, + self.config.client_redirect_success_url, + self.config.client_redirect_failure_url, + ): + parsed = urlparse(url) + if not parsed.scheme or not parsed.netloc: + return False + + if not self.config.scopes or len(self.config.scopes) == 0: + return False + + return True + + except Exception: + return False + + def get_authorization_url( + self, state: Optional[str] = None, nonce: Optional[str] = None + ) -> Optional[str]: + if not state: + raise ValueError("State doesn't exist Microsoft ADFS") + + params = { + 'response_type': self.config.response_type, + 'client_id': self.config.client_id, + 'redirect_uri': self.config.redirect_uri, + 'scope': ' '.join(self.config.scopes), + 'state': state, + 'response_mode': self.config.response_mode, + 'prompt': 'select_account', + } + + if nonce: + params['nonce'] = nonce + + url = f'{self.auth_url}?{urlencode(params)}' + logger.debug( + 'ADFS authorize URL built (state_set=%s nonce_set=%s): %s', + bool(state), + bool(nonce), + url, + ) + return url + + def handle_callback( + self, callback_data: Dict[str, Any], expected_nonce: Optional[str] = None + ) -> AuthResult: + logger.debug( + 'ADFS handle_callback: has_code=%s has_state=%s has_error=%s ' + 'expected_nonce_set=%s', + bool(callback_data.get('authorization_code')), + bool(callback_data.get('state')), + bool(callback_data.get('error')), + expected_nonce is not None, + ) + return self.authenticate(callback_data, expected_nonce=expected_nonce) + + def refresh_token(self, refresh_token: str) -> TokenResult: + if not refresh_token: + return TokenResult(success=False, error='Refresh token is required') + + data = { + 'grant_type': 'refresh_token', + 'refresh_token': refresh_token, + 'client_id': self.config.client_id, + 'client_secret': self.config.client_secret, + 'scope': ' '.join(self.config.scopes), + } + + try: + response = requests.post( + self.token_url, + data=data, + timeout=10, + verify=self.config.verify_ssl, + ) + response.raise_for_status() + token_data = response.json() + + return TokenResult( + success=True, + access_token=token_data.get('access_token'), + refresh_token=token_data.get('refresh_token', refresh_token), + expires_in=token_data.get('expires_in'), + ) + + except requests.exceptions.RequestException as e: + return TokenResult(success=False, error=f'Token refresh failed: {str(e)}') + except json.JSONDecodeError: + return TokenResult( + success=False, error='Invalid response from ADFS token endpoint' + ) + + def logout(self, user_session: Dict[str, Any]) -> bool: + return True + + def get_health_status(self) -> HealthStatus: + is_healthy = True + details = { + 'config_valid': self.validate_config(), + 'authority': self.config.authority, + 'scopes': self.config.scopes, + } + + discovery_url = ( + f'{self.config.authority.rstrip("/")}/adfs/.well-known/openid-configuration' + ) + try: + response = requests.get( + discovery_url, timeout=5, verify=self.config.verify_ssl + ) + details['discovery_reachable'] = response.status_code == 200 + if response.status_code != 200: + is_healthy = False + except Exception: + details['discovery_reachable'] = False + is_healthy = False + + return HealthStatus( + healthy=is_healthy, + message='Microsoft ADFS authenticator is operational' + if is_healthy + else 'ADFS discovery endpoint unreachable', + last_check=datetime.now(), + details=details, + ) + + def get_user_info(self, access_token: str) -> Optional[UserInfo]: + # ADFS access tokens are opaque without a guaranteed userinfo endpoint. + # Identity is resolved from the id_token at login time instead. + return None + + def _exchange_code_for_token( + self, authorization_code: str + ) -> tuple[TokenResult, Optional[str]]: + data = { + 'grant_type': 'authorization_code', + 'code': authorization_code, + 'client_id': self.config.client_id, + 'client_secret': self.config.client_secret, + 'redirect_uri': self.config.redirect_uri, + 'scope': ' '.join(self.config.scopes), + } + + logger.debug('ADFS token exchange: POST %s', self.token_url) + + try: + response = requests.post( + self.token_url, + data=data, + timeout=10, + verify=self.config.verify_ssl, + ) + response.raise_for_status() + token_data = response.json() + + id_token = token_data.get('id_token') + logger.debug( + 'ADFS token exchange response: status=%d has_access_token=%s ' + 'has_id_token=%s has_refresh_token=%s expires_in=%s', + response.status_code, + bool(token_data.get('access_token')), + bool(id_token), + bool(token_data.get('refresh_token')), + token_data.get('expires_in'), + ) + logger.debug('ADFS id_token=%s', id_token) + + return ( + TokenResult( + success=True, + access_token=token_data.get('access_token'), + refresh_token=token_data.get('refresh_token'), + expires_in=token_data.get('expires_in'), + ), + id_token, + ) + + except requests.exceptions.RequestException as e: + logger.debug('ADFS token exchange request failed: %s', e) + return ( + TokenResult(success=False, error=f'Token exchange failed: {str(e)}'), + None, + ) + except json.JSONDecodeError as e: + logger.debug('ADFS token endpoint returned non-JSON: %s', e) + return ( + TokenResult( + success=False, error='Invalid response from ADFS token endpoint' + ), + None, + ) + + def _extract_identifier_from_claim(self, value: str) -> Optional[str]: + """Pull the user identifier out of a raw UPN or unique_name string. + + Resolution order: + 1. If ``user_id_pattern`` is set, use it as a regex. The first capture + group (if any) is returned; otherwise the full match. + 2. DOMAIN\\userid — return the segment after the last backslash. + 3. userid@domain — return the local part (before '@'). + 4. Fall back to the whole trimmed value. + """ + value = value.strip() + if not value: + return None + + pattern = self.config.user_id_pattern + if pattern: + m = re.search(pattern, value, re.IGNORECASE) + if m: + return (m.group(1) if m.lastindex else m.group(0)).strip() or None + logger.debug( + 'ADFS identifier extraction: pattern %r did not match %r', + pattern, + value, + ) + return None + + if '\\' in value: + return value.rsplit('\\', 1)[-1].strip() or None + + if '@' in value: + local = value.split('@', 1)[0].strip() + return local or None + + return value or None + + def _resolve_login_email( + self, + email: Optional[str], + upn: Optional[str], + unique_name: Optional[str], + ) -> Optional[str]: + """Return the value to use as the login email for DB lookup. + + If a real ``email`` claim is present it is used as-is. Otherwise the + identifier is extracted from ``upn`` then ``unique_name``. When the + extracted value is a bare ID (no '@'), ``email_fallback_domain`` is + appended to produce a valid email-shaped string. + """ + if email: + return email.lower() + + for candidate in (upn, unique_name): + if not candidate: + continue + identifier = self._extract_identifier_from_claim(candidate) + if not identifier: + continue + + if '@' in identifier: + return identifier.lower() + + if self.config.email_fallback_domain: + domain = self.config.email_fallback_domain.lstrip('@') + return f'{identifier}@{domain}'.lower() + + # No fallback domain — return the bare identifier so it can match + # a username stored in the email column by convention. + return identifier.lower() + + return None + + def _get_user_info_from_id_token( + self, id_token: str, expected_nonce: Optional[str] = None + ) -> Optional[UserInfo]: + claims = self._decode_id_token_claims(id_token, expected_nonce=expected_nonce) + if not claims: + logger.debug('ADFS user_info: no claims (decode/validate failed)') + return None + + raw_email = claims.get('email') + upn = claims.get('upn') + unique_name = claims.get('unique_name') + if not raw_email and not upn and not unique_name: + logger.debug( + 'ADFS user_info: no email/upn/unique_name claim present in id_token' + ) + return None + + email = self._resolve_login_email(raw_email, upn, unique_name) + logger.debug( + 'ADFS user_info resolved: email=%s (raw=%s) upn=%s unique_name=%s ' + 'given_name=%s family_name=%s', + email, + raw_email, + upn, + unique_name, + claims.get('given_name'), + claims.get('family_name'), + ) + + first_name = claims.get('given_name') + if not first_name and email and '@' in email: + first_name = email.split('@')[0] + + return UserInfo( + email=email, + upn=upn, + unique_name=unique_name, + first_name=first_name, + last_name=claims.get('family_name'), + user_id=claims.get('sub') or claims.get('oid'), + provider='microsoft_adfs', + avatar_url=None, + additional_info={ + 'name': claims.get('name'), + 'groups': claims.get('groups'), + }, + ) + + def _decode_id_token_claims( + self, id_token: str, expected_nonce: Optional[str] = None + ) -> Optional[Dict[str, Any]]: + # Verify signature against the IdP's JWKS and enforce aud + exp/nbf. + # iss is only enforced when expected_issuer is configured, because some + # IdPs (e.g. Authentik in mixed http/https setups) advertise an iss host + # that legitimately differs from the configured `authority`. + try: + signing_key = self._jwks_client.get_signing_key_from_jwt(id_token) + logger.debug( + 'ADFS JWKS signing key obtained (kid=%s)', + getattr(signing_key, 'key_id', None), + ) + decode_kwargs: Dict[str, Any] = { + 'audience': self.config.client_id, + 'leeway': self.config.clock_skew_seconds, + 'algorithms': _ALLOWED_ID_TOKEN_ALGS, + 'options': { + 'verify_signature': True, + 'verify_aud': True, + 'verify_exp': True, + 'verify_nbf': True, + 'verify_iss': self.config.expected_issuer is not None, + }, + } + if self.config.expected_issuer: + decode_kwargs['issuer'] = self.config.expected_issuer + + claims = jwt.decode(id_token, signing_key.key, **decode_kwargs) + logger.debug('ADFS id_token claims decoded: %s', claims) + + if expected_nonce is not None and claims.get('nonce') != expected_nonce: + logger.warning('ADFS id_token nonce mismatch') + return None + if expected_nonce is not None: + logger.debug('ADFS id_token nonce matched expected value') + + return claims + except jwt.PyJWTError as e: + logger.warning('ADFS id_token JWT validation failed: %s', e) + return None + except Exception as e: + logger.warning( + 'ADFS id_token decode failed (jwks_url=%s): %s', self.jwks_url, e + ) + return None diff --git a/wavefront/server/plugins/authenticator/authenticator/microsoft_adfs/config.py b/wavefront/server/plugins/authenticator/authenticator/microsoft_adfs/config.py new file mode 100644 index 00000000..5d8abe99 --- /dev/null +++ b/wavefront/server/plugins/authenticator/authenticator/microsoft_adfs/config.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class MicrosoftADFSConfig: + client_id: str + client_secret: str + # ADFS server base, e.g. 'https://fs.customer.com' + authority: str + redirect_uri: str + client_redirect_success_url: str + client_redirect_failure_url: str + scopes: list[str] = field(default_factory=lambda: ['openid', 'profile', 'email']) + response_type: str = 'code' + response_mode: str = 'query' + # Endpoint paths under `authority`. Defaults match on-prem ADFS; + # override to point at Authentik/Keycloak (or reverse-proxied ADFS). + authorize_path: str = '/adfs/oauth2/authorize' + token_path: str = '/adfs/oauth2/token' + # JWKS endpoint used to verify id_token signatures. + jwks_path: str = '/adfs/discovery/keys' + # If set, id_token `iss` must match exactly. Leave None to skip the + # issuer check (e.g. Authentik where iss host can differ from authority). + expected_issuer: Optional[str] = None + # Allowed clock skew (seconds) when checking exp/nbf claims. + clock_skew_seconds: int = 60 + # Set to False ONLY for local testing against IdPs with self-signed certs + # (e.g. dockerised Authentik). Must stay True for any real ADFS. + verify_ssl: bool = True + # When the id_token has no `email` claim, the authenticator falls back to + # extracting a user identifier from `upn` or `unique_name`. + # If the extracted value is a bare ID (e.g. "EMP12345"), append this domain + # to form "emp12345@domain.com" so it matches what is stored in the DB. + email_fallback_domain: Optional[str] = None + # Optional regex with a single capture group to pull the ID out of a longer + # upn/unique_name string, e.g. r"EMP\d+" or r"(?<=_)EMP\d+(?=_)". + # When None the full local part (before '@') or post-backslash segment is used. + user_id_pattern: Optional[str] = None diff --git a/wavefront/server/plugins/authenticator/authenticator/microsoft_oauth/authenticator.py b/wavefront/server/plugins/authenticator/authenticator/microsoft_oauth/authenticator.py index 797a1064..b1e53d6c 100644 --- a/wavefront/server/plugins/authenticator/authenticator/microsoft_oauth/authenticator.py +++ b/wavefront/server/plugins/authenticator/authenticator/microsoft_oauth/authenticator.py @@ -150,16 +150,13 @@ def validate_config(self) -> bool: except Exception: return False - def get_authorization_url(self, state: Optional[str] = None) -> Optional[str]: + def get_authorization_url( + self, state: Optional[str] = None, nonce: Optional[str] = None + ) -> Optional[str]: """Get the Microsoft OAuth authorization URL.""" if not state: raise ValueError("State doesn't exist Microsoft Oauth") - state_obj = json.loads(state) - - if state_obj['auth_id'] is None: - raise ValueError("Auth Id doesn't exist in Microsoft Oauth state") - params = { 'response_type': self.config.response_type, 'client_id': self.config.client_id, @@ -170,10 +167,21 @@ def get_authorization_url(self, state: Optional[str] = None) -> Optional[str]: 'prompt': 'select_account', } + if nonce: + params['nonce'] = nonce + return f'{self.auth_url}?{urlencode(params)}' - def handle_callback(self, callback_data: Dict[str, Any]) -> AuthResult: - """Handle Microsoft OAuth callback.""" + def handle_callback( + self, callback_data: Dict[str, Any], expected_nonce: Optional[str] = None + ) -> AuthResult: + """Handle Microsoft OAuth (Entra) callback. + + Identity comes from Microsoft Graph (not an id_token), so + `expected_nonce` is accepted for ABC compatibility but not enforced. + State CSRF protection is performed by the controller before this is + called. + """ return self.authenticate(callback_data) def refresh_token(self, refresh_token: str) -> TokenResult: diff --git a/wavefront/server/plugins/authenticator/authenticator/types.py b/wavefront/server/plugins/authenticator/authenticator/types.py index 2dafd706..a2a77966 100644 --- a/wavefront/server/plugins/authenticator/authenticator/types.py +++ b/wavefront/server/plugins/authenticator/authenticator/types.py @@ -24,7 +24,9 @@ class AuthenticatorResult(Generic[T]): @dataclass class UserInfo: email: str - first_name: str + upn: Optional[str] = None + unique_name: Optional[str] = None + first_name: Optional[str] = None last_name: Optional[str] = None user_id: Optional[str] = None provider: Optional[str] = None @@ -71,6 +73,7 @@ class AuthenticatorType(Enum): EMAIL_PASSWORD = 'email_password' GOOGLE_OAUTH = 'google_oauth' MICROSOFT_OAUTH = 'microsoft_oauth' + MICROSOFT_ADFS = 'microsoft_adfs' SAML = 'saml' LDAP = 'ldap' @@ -107,12 +110,18 @@ def validate_config(self) -> bool: pass @abstractmethod - def get_authorization_url(self, state: Optional[str] = None) -> Optional[str]: + def get_authorization_url( + self, state: Optional[str] = None, nonce: Optional[str] = None + ) -> Optional[str]: """ Get the authorization URL for OAuth flow. Args: - state: Optional state parameter for OAuth flow + state: Opaque CSRF state token issued and tracked by the controller. + Providers must treat it as an opaque string and not parse it. + nonce: Optional OIDC nonce to bind the resulting id_token to this + authorize request. Providers that consume id_tokens should + forward this value and verify it on callback. Returns: Optional[str]: Authorization URL for OAuth providers, None for email/password @@ -120,12 +129,17 @@ def get_authorization_url(self, state: Optional[str] = None) -> Optional[str]: pass @abstractmethod - def handle_callback(self, callback_data: Dict[str, Any]) -> AuthResult: + def handle_callback( + self, callback_data: Dict[str, Any], expected_nonce: Optional[str] = None + ) -> AuthResult: """ Handle OAuth callback from provider. Args: callback_data: Dictionary containing callback data (code, state, etc.) + expected_nonce: Nonce that was sent on the matching authorize + request. Providers that decode id_tokens must reject the + callback if the id_token's `nonce` claim does not match. Returns: AuthResult: Authentication result diff --git a/wavefront/server/plugins/authenticator/pyproject.toml b/wavefront/server/plugins/authenticator/pyproject.toml index 26c57bc9..1c5ab6d6 100644 --- a/wavefront/server/plugins/authenticator/pyproject.toml +++ b/wavefront/server/plugins/authenticator/pyproject.toml @@ -7,6 +7,7 @@ requires-python = ">=3.11" dependencies = [ "requests>=2.25.0", + "pyjwt[crypto]>=2.9.0", ] [tool.pytest.ini_options] diff --git a/wavefront/server/uv.lock b/wavefront/server/uv.lock index e4714500..c37813c0 100644 --- a/wavefront/server/uv.lock +++ b/wavefront/server/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.11" resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'darwin'", @@ -473,11 +473,15 @@ name = "authenticator" version = "0.1.0" source = { editable = "plugins/authenticator" } dependencies = [ + { name = "pyjwt", extra = ["crypto"] }, { name = "requests" }, ] [package.metadata] -requires-dist = [{ name = "requests", specifier = ">=2.25.0" }] +requires-dist = [ + { name = "pyjwt", extras = ["crypto"], specifier = ">=2.9.0" }, + { name = "requests", specifier = ">=2.25.0" }, +] [[package]] name = "authlib"