Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions mesa_llm/tools/inbuilt_tools.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -203,6 +204,19 @@ def speak_to(
listener_agents_unique_ids: The unique ids of the agents receiving the message
message: The message to send
"""
if isinstance(listener_agents_unique_ids, str):
try:
listener_agents_unique_ids = json.loads(listener_agents_unique_ids)
except (json.JSONDecodeError, ValueError):
listener_agents_unique_ids = [
int(x.strip())
for x in listener_agents_unique_ids.strip("[]").split(",")
if x.strip()
]
listener_agents_unique_ids = [
int(uid) for uid in (listener_agents_unique_ids or [])
]

listener_agents = [
listener_agent
for listener_agent in agent.model.agents
Expand Down
20 changes: 16 additions & 4 deletions mesa_llm/tools/tool_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import concurrent.futures
import contextlib
import inspect
import json
import logging
Expand Down Expand Up @@ -125,10 +126,21 @@ async def _process_tool_call(
sig = inspect.signature(function_to_call)
expects_agent = "agent" in sig.parameters

# Filter arguments to only those accepted
filtered_args = {
k: v for k, v in function_args.items() if k in sig.parameters
}
# Filter arguments to only those accepted by the function, with type coercion based on annotations
hints = getattr(function_to_call, "__annotations__", {})

coerce: dict[type, type] = {float: float, int: int}
filtered_args = {}
for k, v in function_args.items():
if k not in sig.parameters:
continue
expected = hints.get(k)
coerce_fn = coerce.get(expected)
new_value = v
if coerce_fn is not None and not isinstance(v, expected):
with contextlib.suppress(ValueError, TypeError):
new_value = coerce_fn(v)
filtered_args[k] = new_value

if expects_agent:
filtered_args["agent"] = agent
Expand Down
57 changes: 57 additions & 0 deletions tests/test_tools/test_inbuilt_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,63 @@ def test_speak_to_records_on_recipients(mocker):
assert ret == "sent message 'Hello there' to [11, 12]"


def test_speak_to_parses_json_string_ids(mocker):
model = DummyModel()

sender = DummyAgent(unique_id=1, model=model)
r1 = DummyAgent(unique_id=2, model=model)
r2 = DummyAgent(unique_id=3, model=model)

r1.memory = SimpleNamespace(add_to_memory=mocker.Mock())
r2.memory = SimpleNamespace(add_to_memory=mocker.Mock())

model.agents = [sender, r1, r2]

ret = speak_to(sender, "[2, 3]", "ping")

r1.memory.add_to_memory.assert_called_once()
r2.memory.add_to_memory.assert_called_once()
assert "ping" in ret and "[2, 3]" in ret


def test_speak_to_parses_bracketed_string_ids(mocker):
model = DummyModel()

sender = DummyAgent(unique_id=4, model=model)
r1 = DummyAgent(unique_id=5, model=model)
r2 = DummyAgent(unique_id=6, model=model)

r1.memory = SimpleNamespace(add_to_memory=mocker.Mock())
r2.memory = SimpleNamespace(add_to_memory=mocker.Mock())

model.agents = [sender, r1, r2]

ret = speak_to(sender, "[5, 6]", "hello")

r1.memory.add_to_memory.assert_called_once()
r2.memory.add_to_memory.assert_called_once()
assert "hello" in ret and "[5, 6]" in ret


def test_speak_to_parses_non_json_string_ids(mocker):
model = DummyModel()

sender = DummyAgent(unique_id=7, model=model)
r1 = DummyAgent(unique_id=8, model=model)
r2 = DummyAgent(unique_id=9, model=model)

r1.memory = SimpleNamespace(add_to_memory=mocker.Mock())
r2.memory = SimpleNamespace(add_to_memory=mocker.Mock())

model.agents = [sender, r1, r2]

ret = speak_to(sender, "8,9", "note")

r1.memory.add_to_memory.assert_called_once()
r2.memory.add_to_memory.assert_called_once()
assert "note" in ret and "[8, 9]" in ret


def test_move_one_step_invalid_direction():
model = DummyModel()
model.grid = MultiGrid(width=4, height=4, torus=False)
Expand Down
62 changes: 62 additions & 0 deletions tests/test_tools/test_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,68 @@ def simple_tool(required_param: str) -> str:
assert result[0]["tool_call_id"] == "call_123"
assert "Simple: test" in result[0]["response"]

def test_call_tools_type_coercion_float(self):
"""Test coercion of float arguments passed as JSON strings."""
manager = ToolManager()

@tool
def float_tool(agent, amount: float) -> str:
"""Float tool.
Args:
agent: The agent making the request
amount: Amount to format.
Returns:
Formatted amount.
"""
return f"{amount:.2f}"

mock_agent = Mock()

mock_tool_call = Mock()
mock_tool_call.id = "call_float"
mock_tool_call.function.name = "float_tool"
mock_tool_call.function.arguments = '{"amount": "35.0"}'

mock_response = Mock()
mock_response.tool_calls = [mock_tool_call]

result = manager.call_tools(mock_agent, mock_response)

assert len(result) == 1
assert result[0]["tool_call_id"] == "call_float"
assert result[0]["response"] == "35.00"

def test_call_tools_type_coercion_int(self):
"""Test coercion of int arguments passed as JSON strings."""
manager = ToolManager()

@tool
def int_tool(agent, count: int) -> int:
"""Int tool.
Args:
agent: The agent making the request
count: Count to increment.
Returns:
Incremented count.
"""
return count + 1

mock_agent = Mock()

mock_tool_call = Mock()
mock_tool_call.id = "call_int"
mock_tool_call.function.name = "int_tool"
mock_tool_call.function.arguments = '{"count": "5"}'

mock_response = Mock()
mock_response.tool_calls = [mock_tool_call]

result = manager.call_tools(mock_agent, mock_response)

assert len(result) == 1
assert result[0]["tool_call_id"] == "call_int"
assert result[0]["response"] == "6"

def test_call_tools_no_response(self):
"""Test call_tools when tool returns None."""
manager = ToolManager()
Expand Down
Loading