diff --git a/src/microsoft/opentelemetry/a365/core/exporters/utils.py b/src/microsoft/opentelemetry/a365/core/exporters/utils.py index 282b5f21..5e3b7ecf 100644 --- a/src/microsoft/opentelemetry/a365/core/exporters/utils.py +++ b/src/microsoft/opentelemetry/a365/core/exporters/utils.py @@ -16,7 +16,7 @@ import os import threading import time -from collections.abc import Callable, Sequence +from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass, field from typing import Any, List, Optional, TypeVar from urllib.parse import urlparse @@ -259,7 +259,7 @@ def build_export_url(endpoint: str, agent_id: str, tenant_id: str, use_s2s_endpo return f"https://{endpoint}{endpoint_path}?api-version=1" -def parse_retry_after(headers: dict[str, str]) -> float | None: +def parse_retry_after(headers: Mapping[str, str]) -> float | None: """Parse the ``Retry-After`` header value. Only numeric (seconds) values are supported. HTTP-date values are ignored. diff --git a/src/microsoft/opentelemetry/a365/hosting/middleware/output_logging_middleware.py b/src/microsoft/opentelemetry/a365/hosting/middleware/output_logging_middleware.py index 34411d9a..b5381848 100644 --- a/src/microsoft/opentelemetry/a365/hosting/middleware/output_logging_middleware.py +++ b/src/microsoft/opentelemetry/a365/hosting/middleware/output_logging_middleware.py @@ -15,6 +15,7 @@ CHANNEL_LINK_KEY, CHANNEL_NAME_KEY, ) +from microsoft.opentelemetry.a365.hosting.scope_helpers.utils import resolve_sub_channel from microsoft.opentelemetry.a365.core.models.response import Response from microsoft.opentelemetry.a365.core.models.user_details import UserDetails from microsoft.opentelemetry.a365.core.request import Request @@ -76,13 +77,13 @@ def _derive_channel( """Derive channel (name and link) from TurnContext.""" channel_id = getattr(context.activity, "channel_id", None) channel_name: str | None = None - sub_channel: str | None = None if channel_id is not None: - if isinstance(channel_id, str): - channel_name = channel_id - elif hasattr(channel_id, "channel"): + if hasattr(channel_id, "channel"): channel_name = channel_id.channel - sub_channel = channel_id.sub_channel + elif isinstance(channel_id, str): + channel_name = channel_id + + sub_channel = resolve_sub_channel(context.activity) if context.activity else None return {"name": channel_name, "link": sub_channel} diff --git a/src/microsoft/opentelemetry/a365/hosting/scope_helpers/utils.py b/src/microsoft/opentelemetry/a365/hosting/scope_helpers/utils.py index 9476f5ed..7554f1e6 100644 --- a/src/microsoft/opentelemetry/a365/hosting/scope_helpers/utils.py +++ b/src/microsoft/opentelemetry/a365/hosting/scope_helpers/utils.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import json from collections.abc import Iterator from typing import Any @@ -24,6 +25,31 @@ AGENT_ROLE = "agenticUser" +def resolve_sub_channel(activity: Activity) -> str | None: + """Resolve sub_channel from ChannelId, falling back to productContext in channel_data.""" + channel_id = activity.channel_id + sub_channel = None + + if channel_id is not None and hasattr(channel_id, "channel"): + sub_channel = channel_id.sub_channel + + if not sub_channel and activity.channel_data: + try: + channel_data = activity.channel_data + if isinstance(channel_data, str): + channel_data = json.loads(channel_data) + elif hasattr(channel_data, "__dict__"): + channel_data = channel_data.__dict__ + + product_context = channel_data.get("productContext") if isinstance(channel_data, dict) else None + if product_context: + sub_channel = product_context + except (json.JSONDecodeError, AttributeError, TypeError): + pass + + return sub_channel + + def _is_agentic(entity: Any) -> bool: if not entity: return False @@ -72,16 +98,16 @@ def get_channel_pairs(activity: Activity) -> Iterator[tuple[str, Any]]: # Extract channel name from either string or ChannelId object channel_name = None - sub_channel = None if channel_id is not None: - if isinstance(channel_id, str): - # Direct string value - channel_name = channel_id - elif hasattr(channel_id, "channel"): + if hasattr(channel_id, "channel"): # ChannelId object channel_name = channel_id.channel - sub_channel = channel_id.sub_channel + elif isinstance(channel_id, str): + # Direct string value + channel_name = channel_id + + sub_channel = resolve_sub_channel(activity) # Yield channel name as source name yield CHANNEL_NAME_KEY, channel_name diff --git a/tests/a365/hosting/middleware/test_baggage_middleware.py b/tests/a365/hosting/middleware/test_baggage_middleware.py index 04d38864..cd8dca1c 100644 --- a/tests/a365/hosting/middleware/test_baggage_middleware.py +++ b/tests/a365/hosting/middleware/test_baggage_middleware.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import json +from typing import Any from unittest.mock import MagicMock import pytest @@ -9,10 +11,12 @@ ActivityEventNames, ActivityTypes, ChannelAccount, + ChannelId, ConversationAccount, ) from microsoft_agents.hosting.core import TurnContext from microsoft.opentelemetry.a365.core.constants import ( + CHANNEL_LINK_KEY, TENANT_ID_KEY, USER_ID_KEY, ) @@ -95,3 +99,108 @@ async def logic(): assert logic_called is True # Baggage should NOT be set because the middleware skipped it assert captured_caller_id is None + + +def _make_channel_data_turn_context( + channel_id: str | ChannelId = "test-channel", + channel_data: Any = None, +) -> TurnContext: + """Create a TurnContext with channel_data for productContext tests.""" + activity = Activity( + type="message", + text="Hello", + from_property=ChannelAccount( + aad_object_id="caller-id", + name="Caller", + agentic_user_id="caller-upn", + tenant_id="tenant-id", + ), + recipient=ChannelAccount( + tenant_id="tenant-123", + role="user", + name="Agent", + ), + conversation=ConversationAccount(id="conv-id"), + service_url="https://example.com", + channel_id=channel_id, + channel_data=channel_data, + ) + adapter = MagicMock() + return TurnContext(adapter, activity) + + +@pytest.mark.asyncio +async def test_baggage_middleware_extracts_product_context_from_dict_channel_data(): + """BaggageMiddleware should extract productContext from dict channel_data as sub_channel.""" + middleware = BaggageMiddleware() + ctx = _make_channel_data_turn_context( + channel_data={"productContext": "word"}, + ) + + captured_channel_link = None + + async def logic(): + nonlocal captured_channel_link + captured_channel_link = baggage.get_baggage(CHANNEL_LINK_KEY) + + await middleware.on_turn(ctx, logic) + + assert captured_channel_link == "word" + + +@pytest.mark.asyncio +async def test_baggage_middleware_sub_channel_takes_precedence_over_product_context(): + """When ChannelId has sub_channel set, it should take precedence over productContext.""" + middleware = BaggageMiddleware() + ctx = _make_channel_data_turn_context( + channel_id=ChannelId(channel="teams", sub_channel="from-channel-id"), + channel_data={"productContext": "from-product-context"}, + ) + + captured_channel_link = None + + async def logic(): + nonlocal captured_channel_link + captured_channel_link = baggage.get_baggage(CHANNEL_LINK_KEY) + + await middleware.on_turn(ctx, logic) + + assert captured_channel_link == "from-channel-id" + + +@pytest.mark.asyncio +async def test_baggage_middleware_extracts_product_context_from_json_string(): + """BaggageMiddleware should extract productContext from JSON string channel_data.""" + middleware = BaggageMiddleware() + ctx = _make_channel_data_turn_context( + channel_data=json.dumps({"productContext": "excel"}), + ) + + captured_channel_link = None + + async def logic(): + nonlocal captured_channel_link + captured_channel_link = baggage.get_baggage(CHANNEL_LINK_KEY) + + await middleware.on_turn(ctx, logic) + + assert captured_channel_link == "excel" + + +@pytest.mark.asyncio +async def test_baggage_middleware_handles_invalid_json_channel_data_gracefully(): + """BaggageMiddleware should not raise on invalid JSON in channel_data.""" + middleware = BaggageMiddleware() + ctx = _make_channel_data_turn_context( + channel_data="not-valid-json{{{", + ) + + captured_channel_link = None + + async def logic(): + nonlocal captured_channel_link + captured_channel_link = baggage.get_baggage(CHANNEL_LINK_KEY) + + await middleware.on_turn(ctx, logic) + + assert captured_channel_link is None diff --git a/tests/a365/hosting/middleware/test_output_logging_middleware.py b/tests/a365/hosting/middleware/test_output_logging_middleware.py index 58805a94..a17ac1ef 100644 --- a/tests/a365/hosting/middleware/test_output_logging_middleware.py +++ b/tests/a365/hosting/middleware/test_output_logging_middleware.py @@ -7,12 +7,14 @@ from microsoft_agents.activity import ( Activity, ChannelAccount, + ChannelId, ConversationAccount, ) from microsoft_agents.hosting.core import TurnContext from microsoft.opentelemetry.a365.hosting.middleware.output_logging_middleware import ( A365_PARENT_TRACEPARENT_KEY, OutputLoggingMiddleware, + _derive_channel, ) @@ -217,3 +219,56 @@ async def test_send_handler_rethrows_errors(): mock_scope.record_error.assert_called_once_with(send_error) mock_scope.dispose.assert_called_once() + + +def _make_channel_turn_context( + channel_id: str | ChannelId = "test-channel", + channel_data=None, +) -> TurnContext: + """Create a TurnContext with specific channel_id and channel_data.""" + activity = Activity( + type="message", + text="Hello", + from_property=ChannelAccount( + aad_object_id="caller-id", + name="Caller", + agentic_user_id="caller-upn", + tenant_id="caller-tenant-id", + ), + recipient=ChannelAccount( + tenant_id="tenant-123", + role="agenticAppInstance", + name="Agent One", + agentic_app_id="agent-app-id", + aad_object_id="agent-auid", + agentic_user_id="agent-upn", + ), + conversation=ConversationAccount(id="conv-id"), + service_url="https://example.com", + channel_id=channel_id, + channel_data=channel_data, + ) + adapter = MagicMock() + return TurnContext(adapter, activity) + + +def test_derive_channel_uses_product_context_from_channel_data(): + """_derive_channel uses productContext from channel_data when sub_channel is not set.""" + ctx = _make_channel_turn_context( + channel_id="test-channel", + channel_data={"productContext": "word"}, + ) + result = _derive_channel(ctx) + assert result["name"] == "test-channel" + assert result["link"] == "word" + + +def test_derive_channel_sub_channel_takes_precedence_over_product_context(): + """_derive_channel uses sub_channel when both sub_channel and productContext are present.""" + ctx = _make_channel_turn_context( + channel_id=ChannelId(channel="teams", sub_channel="from-channel-id"), + channel_data={"productContext": "from-product-context"}, + ) + result = _derive_channel(ctx) + assert result["name"] == "teams" + assert result["link"] == "from-channel-id"