Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import json
from collections.abc import Iterator
from typing import Any

Expand Down Expand Up @@ -74,13 +75,35 @@ def get_channel_pairs(activity: Activity) -> Iterator[tuple[str, Any]]:
sub_channel = None

if channel_id is not None:
if isinstance(channel_id, str):
# Direct string value
channel_name = channel_id
elif hasattr(channel_id, "channel"):
# Check for ChannelId object first
if hasattr(channel_id, "channel"):
# ChannelId object
channel_name = channel_id.channel
sub_channel = channel_id.sub_channel
elif isinstance(channel_id, str):
# Direct string value
channel_name = channel_id

# Try to get sub_channel from productContext in channel_data if sub_channel is not set
if not sub_channel and activity.channel_data:
try:
# Convert channel_data to dict if it's a string
if isinstance(activity.channel_data, str):
channel_data_dict = json.loads(activity.channel_data)
elif isinstance(activity.channel_data, dict):
channel_data_dict = activity.channel_data
Comment thread
gwharris7 marked this conversation as resolved.
else:
# Try to convert to dict if it has __dict__
channel_data_dict = getattr(activity.channel_data, "__dict__", {})

# Extract productContext if available
if isinstance(channel_data_dict, dict) and "productContext" in channel_data_dict:
product_context = channel_data_dict["productContext"]
if isinstance(product_context, str):
sub_channel = product_context
except (json.JSONDecodeError, AttributeError, TypeError):
Comment thread
gwharris7 marked this conversation as resolved.
# Silently ignore any parsing errors
pass

# Yield channel name as source name
yield CHANNEL_NAME_KEY, channel_name
Expand Down
114 changes: 114 additions & 0 deletions tests/observability/hosting/middleware/test_baggage_middleware.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import json
from unittest.mock import MagicMock

import pytest
Expand All @@ -9,10 +10,12 @@
ActivityEventNames,
ActivityTypes,
ChannelAccount,
ChannelId,
ConversationAccount,
)
from microsoft_agents.hosting.core import TurnContext
from microsoft_agents_a365.observability.core.constants import (
CHANNEL_LINK_KEY,
TENANT_ID_KEY,
USER_ID_KEY,
)
Expand Down Expand Up @@ -53,6 +56,31 @@ def _make_turn_context(
return TurnContext(adapter, activity)


def _make_channel_data_turn_context(
channel_id: ChannelId | str = "msteams",
channel_data: object | None = None,
) -> TurnContext:
Comment thread
gwharris7 marked this conversation as resolved.
"""Create a TurnContext with channel_data for testing."""
activity = Activity(
type="message",
text="Hello",
from_property=ChannelAccount(
aad_object_id="caller-id",
name="Caller",
),
recipient=ChannelAccount(
tenant_id="tenant-123",
name="Agent",
),
conversation=ConversationAccount(id="conv-id"),
service_url="https://example.com",
channel_id=channel_id,
channel_data=channel_data,
)
adapter = MagicMock()
return TurnContext(adapter, activity)


@pytest.mark.asyncio
async def test_baggage_middleware_propagates_baggage():
"""BaggageMiddleware should set baggage context for the downstream logic."""
Expand Down Expand Up @@ -95,3 +123,89 @@ async def logic():
assert logic_called is True
# Baggage should NOT be set because the middleware skipped it
assert captured_caller_id is None


@pytest.mark.asyncio
async def test_baggage_middleware_extracts_product_context_from_channel_data():
"""BaggageMiddleware should extract productContext from channel_data when sub_channel is not set."""

Comment thread
gwharris7 marked this conversation as resolved.
middleware = BaggageMiddleware()
ctx = _make_channel_data_turn_context(
channel_id=ChannelId(channel="msteams"), # No sub_channel
channel_data={"productContext": "COPILOT"},
)

captured_channel_link = None

async def logic():
nonlocal captured_channel_link
captured_channel_link = baggage.get_baggage(CHANNEL_LINK_KEY)

await middleware.on_turn(ctx, logic)

assert captured_channel_link == "COPILOT"


@pytest.mark.asyncio
async def test_baggage_middleware_sub_channel_takes_precedence_over_product_context():
"""BaggageMiddleware should use sub_channel when both sub_channel and productContext are present."""

middleware = BaggageMiddleware()
ctx = _make_channel_data_turn_context(
channel_id=ChannelId(channel="msteams", sub_channel="teams-subchannel"),
channel_data={"productContext": "COPILOT"}, # Should be ignored
)

captured_channel_link = None

async def logic():
nonlocal captured_channel_link
captured_channel_link = baggage.get_baggage(CHANNEL_LINK_KEY)

await middleware.on_turn(ctx, logic)

# sub_channel should take precedence, productContext should be ignored
assert captured_channel_link == "teams-subchannel"


@pytest.mark.asyncio
async def test_baggage_middleware_extracts_product_context_from_json_string_channel_data():
"""BaggageMiddleware should extract productContext from channel_data when it's a JSON string."""

middleware = BaggageMiddleware()
ctx = _make_channel_data_turn_context(
channel_id=ChannelId(channel="msteams"), # No sub_channel
channel_data=json.dumps({"productContext": "COPILOT"}), # JSON string
)

captured_channel_link = None

async def logic():
nonlocal captured_channel_link
captured_channel_link = baggage.get_baggage(CHANNEL_LINK_KEY)

await middleware.on_turn(ctx, logic)

assert captured_channel_link == "COPILOT"


@pytest.mark.asyncio
async def test_baggage_middleware_handles_invalid_json_channel_data_gracefully():
"""BaggageMiddleware should handle invalid JSON in channel_data gracefully without setting baggage."""

middleware = BaggageMiddleware()
ctx = _make_channel_data_turn_context(
channel_id=ChannelId(channel="msteams"), # No sub_channel
channel_data="not valid json", # Non-JSON string
)

captured_channel_link = None

async def logic():
nonlocal captured_channel_link
captured_channel_link = baggage.get_baggage(CHANNEL_LINK_KEY)

await middleware.on_turn(ctx, logic)

# Should not set ChannelLink, should fail gracefully
assert captured_channel_link is None
Loading