diff --git a/libraries/microsoft-agents-a365-observability-hosting/microsoft_agents_a365/observability/hosting/scope_helpers/utils.py b/libraries/microsoft-agents-a365-observability-hosting/microsoft_agents_a365/observability/hosting/scope_helpers/utils.py index 298294c1..02aa5e78 100644 --- a/libraries/microsoft-agents-a365-observability-hosting/microsoft_agents_a365/observability/hosting/scope_helpers/utils.py +++ b/libraries/microsoft-agents-a365-observability-hosting/microsoft_agents_a365/observability/hosting/scope_helpers/utils.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import json from collections.abc import Iterator from typing import Any @@ -74,13 +75,35 @@ def get_channel_pairs(activity: Activity) -> Iterator[tuple[str, Any]]: sub_channel = None if channel_id is not None: - if isinstance(channel_id, str): - # Direct string value - channel_name = channel_id - elif hasattr(channel_id, "channel"): + # Check for ChannelId object first + if hasattr(channel_id, "channel"): # ChannelId object channel_name = channel_id.channel sub_channel = channel_id.sub_channel + elif isinstance(channel_id, str): + # Direct string value + channel_name = channel_id + + # Try to get sub_channel from productContext in channel_data if sub_channel is not set + if not sub_channel and activity.channel_data: + try: + # Convert channel_data to dict if it's a string + if isinstance(activity.channel_data, str): + channel_data_dict = json.loads(activity.channel_data) + elif isinstance(activity.channel_data, dict): + channel_data_dict = activity.channel_data + else: + # Try to convert to dict if it has __dict__ + channel_data_dict = getattr(activity.channel_data, "__dict__", {}) + + # Extract productContext if available + if isinstance(channel_data_dict, dict) and "productContext" in channel_data_dict: + product_context = channel_data_dict["productContext"] + if isinstance(product_context, str): + sub_channel = product_context + except (json.JSONDecodeError, AttributeError, TypeError): + # Silently ignore any parsing errors + pass # Yield channel name as source name yield CHANNEL_NAME_KEY, channel_name diff --git a/tests/observability/hosting/middleware/test_baggage_middleware.py b/tests/observability/hosting/middleware/test_baggage_middleware.py index 4263e898..4c3e2574 100644 --- a/tests/observability/hosting/middleware/test_baggage_middleware.py +++ b/tests/observability/hosting/middleware/test_baggage_middleware.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import json from unittest.mock import MagicMock import pytest @@ -9,10 +10,12 @@ ActivityEventNames, ActivityTypes, ChannelAccount, + ChannelId, ConversationAccount, ) from microsoft_agents.hosting.core import TurnContext from microsoft_agents_a365.observability.core.constants import ( + CHANNEL_LINK_KEY, TENANT_ID_KEY, USER_ID_KEY, ) @@ -53,6 +56,31 @@ def _make_turn_context( return TurnContext(adapter, activity) +def _make_channel_data_turn_context( + channel_id: ChannelId | str = "msteams", + channel_data: object | None = None, +) -> TurnContext: + """Create a TurnContext with channel_data for testing.""" + activity = Activity( + type="message", + text="Hello", + from_property=ChannelAccount( + aad_object_id="caller-id", + name="Caller", + ), + recipient=ChannelAccount( + tenant_id="tenant-123", + name="Agent", + ), + conversation=ConversationAccount(id="conv-id"), + service_url="https://example.com", + channel_id=channel_id, + channel_data=channel_data, + ) + adapter = MagicMock() + return TurnContext(adapter, activity) + + @pytest.mark.asyncio async def test_baggage_middleware_propagates_baggage(): """BaggageMiddleware should set baggage context for the downstream logic.""" @@ -95,3 +123,89 @@ async def logic(): assert logic_called is True # Baggage should NOT be set because the middleware skipped it assert captured_caller_id is None + + +@pytest.mark.asyncio +async def test_baggage_middleware_extracts_product_context_from_channel_data(): + """BaggageMiddleware should extract productContext from channel_data when sub_channel is not set.""" + + middleware = BaggageMiddleware() + ctx = _make_channel_data_turn_context( + channel_id=ChannelId(channel="msteams"), # No sub_channel + channel_data={"productContext": "COPILOT"}, + ) + + captured_channel_link = None + + async def logic(): + nonlocal captured_channel_link + captured_channel_link = baggage.get_baggage(CHANNEL_LINK_KEY) + + await middleware.on_turn(ctx, logic) + + assert captured_channel_link == "COPILOT" + + +@pytest.mark.asyncio +async def test_baggage_middleware_sub_channel_takes_precedence_over_product_context(): + """BaggageMiddleware should use sub_channel when both sub_channel and productContext are present.""" + + middleware = BaggageMiddleware() + ctx = _make_channel_data_turn_context( + channel_id=ChannelId(channel="msteams", sub_channel="teams-subchannel"), + channel_data={"productContext": "COPILOT"}, # Should be ignored + ) + + captured_channel_link = None + + async def logic(): + nonlocal captured_channel_link + captured_channel_link = baggage.get_baggage(CHANNEL_LINK_KEY) + + await middleware.on_turn(ctx, logic) + + # sub_channel should take precedence, productContext should be ignored + assert captured_channel_link == "teams-subchannel" + + +@pytest.mark.asyncio +async def test_baggage_middleware_extracts_product_context_from_json_string_channel_data(): + """BaggageMiddleware should extract productContext from channel_data when it's a JSON string.""" + + middleware = BaggageMiddleware() + ctx = _make_channel_data_turn_context( + channel_id=ChannelId(channel="msteams"), # No sub_channel + channel_data=json.dumps({"productContext": "COPILOT"}), # JSON string + ) + + captured_channel_link = None + + async def logic(): + nonlocal captured_channel_link + captured_channel_link = baggage.get_baggage(CHANNEL_LINK_KEY) + + await middleware.on_turn(ctx, logic) + + assert captured_channel_link == "COPILOT" + + +@pytest.mark.asyncio +async def test_baggage_middleware_handles_invalid_json_channel_data_gracefully(): + """BaggageMiddleware should handle invalid JSON in channel_data gracefully without setting baggage.""" + + middleware = BaggageMiddleware() + ctx = _make_channel_data_turn_context( + channel_id=ChannelId(channel="msteams"), # No sub_channel + channel_data="not valid json", # Non-JSON string + ) + + captured_channel_link = None + + async def logic(): + nonlocal captured_channel_link + captured_channel_link = baggage.get_baggage(CHANNEL_LINK_KEY) + + await middleware.on_turn(ctx, logic) + + # Should not set ChannelLink, should fail gracefully + assert captured_channel_link is None