Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/microsoft/opentelemetry/a365/core/exporters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
import threading
import time
from collections.abc import Callable, Sequence
from collections.abc import Callable, Mapping, Sequence
from dataclasses import dataclass, field
from typing import Any, List, Optional, TypeVar
from urllib.parse import urlparse
Expand Down Expand Up @@ -259,7 +259,7 @@ def build_export_url(endpoint: str, agent_id: str, tenant_id: str, use_s2s_endpo
return f"https://{endpoint}{endpoint_path}?api-version=1"


def parse_retry_after(headers: dict[str, str]) -> float | None:
def parse_retry_after(headers: Mapping[str, str]) -> float | None:
"""Parse the ``Retry-After`` header value.

Only numeric (seconds) values are supported. HTTP-date values are ignored.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
CHANNEL_LINK_KEY,
CHANNEL_NAME_KEY,
)
from microsoft.opentelemetry.a365.hosting.scope_helpers.utils import resolve_sub_channel
from microsoft.opentelemetry.a365.core.models.response import Response
from microsoft.opentelemetry.a365.core.models.user_details import UserDetails
from microsoft.opentelemetry.a365.core.request import Request
Expand Down Expand Up @@ -76,13 +77,13 @@ def _derive_channel(
"""Derive channel (name and link) from TurnContext."""
channel_id = getattr(context.activity, "channel_id", None)
channel_name: str | None = None
sub_channel: str | None = None
if channel_id is not None:
if isinstance(channel_id, str):
channel_name = channel_id
elif hasattr(channel_id, "channel"):
if hasattr(channel_id, "channel"):
channel_name = channel_id.channel
sub_channel = channel_id.sub_channel
elif isinstance(channel_id, str):
channel_name = channel_id

sub_channel = resolve_sub_channel(context.activity) if context.activity else None
return {"name": channel_name, "link": sub_channel}


Expand Down
38 changes: 32 additions & 6 deletions src/microsoft/opentelemetry/a365/hosting/scope_helpers/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import json
from collections.abc import Iterator
from typing import Any

Expand All @@ -24,6 +25,31 @@
AGENT_ROLE = "agenticUser"


def resolve_sub_channel(activity: Activity) -> str | None:
"""Resolve sub_channel from ChannelId, falling back to productContext in channel_data."""
channel_id = activity.channel_id
sub_channel = None

if channel_id is not None and hasattr(channel_id, "channel"):
sub_channel = channel_id.sub_channel

if not sub_channel and activity.channel_data:
try:
channel_data = activity.channel_data
if isinstance(channel_data, str):
channel_data = json.loads(channel_data)
elif hasattr(channel_data, "__dict__"):
channel_data = channel_data.__dict__

product_context = channel_data.get("productContext") if isinstance(channel_data, dict) else None
if product_context:
sub_channel = product_context
except (json.JSONDecodeError, AttributeError, TypeError):
pass

return sub_channel


def _is_agentic(entity: Any) -> bool:
if not entity:
return False
Expand Down Expand Up @@ -72,16 +98,16 @@ def get_channel_pairs(activity: Activity) -> Iterator[tuple[str, Any]]:

# Extract channel name from either string or ChannelId object
channel_name = None
sub_channel = None

if channel_id is not None:
if isinstance(channel_id, str):
# Direct string value
channel_name = channel_id
elif hasattr(channel_id, "channel"):
if hasattr(channel_id, "channel"):
# ChannelId object
channel_name = channel_id.channel
sub_channel = channel_id.sub_channel
elif isinstance(channel_id, str):
# Direct string value
channel_name = channel_id

sub_channel = resolve_sub_channel(activity)

# Yield channel name as source name
yield CHANNEL_NAME_KEY, channel_name
Expand Down
109 changes: 109 additions & 0 deletions tests/a365/hosting/middleware/test_baggage_middleware.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import json
from typing import Any
from unittest.mock import MagicMock

import pytest
Expand All @@ -9,10 +11,12 @@
ActivityEventNames,
ActivityTypes,
ChannelAccount,
ChannelId,
ConversationAccount,
)
from microsoft_agents.hosting.core import TurnContext
from microsoft.opentelemetry.a365.core.constants import (
CHANNEL_LINK_KEY,
TENANT_ID_KEY,
USER_ID_KEY,
)
Expand Down Expand Up @@ -95,3 +99,108 @@ async def logic():
assert logic_called is True
# Baggage should NOT be set because the middleware skipped it
assert captured_caller_id is None


def _make_channel_data_turn_context(
channel_id: str | ChannelId = "test-channel",
channel_data: Any = None,
) -> TurnContext:
"""Create a TurnContext with channel_data for productContext tests."""
activity = Activity(
type="message",
text="Hello",
from_property=ChannelAccount(
aad_object_id="caller-id",
name="Caller",
agentic_user_id="caller-upn",
tenant_id="tenant-id",
),
recipient=ChannelAccount(
tenant_id="tenant-123",
role="user",
name="Agent",
),
conversation=ConversationAccount(id="conv-id"),
service_url="https://example.com",
channel_id=channel_id,
channel_data=channel_data,
)
adapter = MagicMock()
return TurnContext(adapter, activity)


@pytest.mark.asyncio
async def test_baggage_middleware_extracts_product_context_from_dict_channel_data():
"""BaggageMiddleware should extract productContext from dict channel_data as sub_channel."""
middleware = BaggageMiddleware()
ctx = _make_channel_data_turn_context(
channel_data={"productContext": "word"},
)

captured_channel_link = None

async def logic():
nonlocal captured_channel_link
captured_channel_link = baggage.get_baggage(CHANNEL_LINK_KEY)

await middleware.on_turn(ctx, logic)

assert captured_channel_link == "word"


@pytest.mark.asyncio
async def test_baggage_middleware_sub_channel_takes_precedence_over_product_context():
"""When ChannelId has sub_channel set, it should take precedence over productContext."""
middleware = BaggageMiddleware()
ctx = _make_channel_data_turn_context(
channel_id=ChannelId(channel="teams", sub_channel="from-channel-id"),
channel_data={"productContext": "from-product-context"},
)

captured_channel_link = None

async def logic():
nonlocal captured_channel_link
captured_channel_link = baggage.get_baggage(CHANNEL_LINK_KEY)

await middleware.on_turn(ctx, logic)

assert captured_channel_link == "from-channel-id"


@pytest.mark.asyncio
async def test_baggage_middleware_extracts_product_context_from_json_string():
"""BaggageMiddleware should extract productContext from JSON string channel_data."""
middleware = BaggageMiddleware()
ctx = _make_channel_data_turn_context(
channel_data=json.dumps({"productContext": "excel"}),
)

captured_channel_link = None

async def logic():
nonlocal captured_channel_link
captured_channel_link = baggage.get_baggage(CHANNEL_LINK_KEY)

await middleware.on_turn(ctx, logic)

assert captured_channel_link == "excel"


@pytest.mark.asyncio
async def test_baggage_middleware_handles_invalid_json_channel_data_gracefully():
"""BaggageMiddleware should not raise on invalid JSON in channel_data."""
middleware = BaggageMiddleware()
ctx = _make_channel_data_turn_context(
channel_data="not-valid-json{{{",
)

captured_channel_link = None

async def logic():
nonlocal captured_channel_link
captured_channel_link = baggage.get_baggage(CHANNEL_LINK_KEY)

await middleware.on_turn(ctx, logic)

assert captured_channel_link is None
55 changes: 55 additions & 0 deletions tests/a365/hosting/middleware/test_output_logging_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
from microsoft_agents.activity import (
Activity,
ChannelAccount,
ChannelId,
ConversationAccount,
)
from microsoft_agents.hosting.core import TurnContext
from microsoft.opentelemetry.a365.hosting.middleware.output_logging_middleware import (
A365_PARENT_TRACEPARENT_KEY,
OutputLoggingMiddleware,
_derive_channel,
)


Expand Down Expand Up @@ -217,3 +219,56 @@ async def test_send_handler_rethrows_errors():

mock_scope.record_error.assert_called_once_with(send_error)
mock_scope.dispose.assert_called_once()


def _make_channel_turn_context(
channel_id: str | ChannelId = "test-channel",
channel_data=None,
) -> TurnContext:
"""Create a TurnContext with specific channel_id and channel_data."""
activity = Activity(
type="message",
text="Hello",
from_property=ChannelAccount(
aad_object_id="caller-id",
name="Caller",
agentic_user_id="caller-upn",
tenant_id="caller-tenant-id",
),
recipient=ChannelAccount(
tenant_id="tenant-123",
role="agenticAppInstance",
name="Agent One",
agentic_app_id="agent-app-id",
aad_object_id="agent-auid",
agentic_user_id="agent-upn",
),
conversation=ConversationAccount(id="conv-id"),
service_url="https://example.com",
channel_id=channel_id,
channel_data=channel_data,
)
adapter = MagicMock()
return TurnContext(adapter, activity)


def test_derive_channel_uses_product_context_from_channel_data():
"""_derive_channel uses productContext from channel_data when sub_channel is not set."""
ctx = _make_channel_turn_context(
channel_id="test-channel",
channel_data={"productContext": "word"},
)
result = _derive_channel(ctx)
assert result["name"] == "test-channel"
assert result["link"] == "word"


def test_derive_channel_sub_channel_takes_precedence_over_product_context():
"""_derive_channel uses sub_channel when both sub_channel and productContext are present."""
ctx = _make_channel_turn_context(
channel_id=ChannelId(channel="teams", sub_channel="from-channel-id"),
channel_data={"productContext": "from-product-context"},
)
result = _derive_channel(ctx)
assert result["name"] == "teams"
assert result["link"] == "from-channel-id"
Loading