Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 77 additions & 17 deletions mesa_llm/llm_agent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import asyncio
import time

from mesa.agent import Agent
from mesa.discrete_space import (
OrthogonalMooreGrid,
OrthogonalVonNeumannGrid,
)
from mesa.model import Model
from mesa.space import (
ContinuousSpace,
Expand All @@ -20,6 +19,76 @@
from mesa_llm.tools.tool_manager import ToolManager


class OptimizedMessageBus:
"""
Optimized message bus for O(n) agent communication instead of O(n²).
"""

def __init__(self):
self.message_queue = asyncio.Queue()
self.subscribers = {}
self.batch_processor = None

async def broadcast_message(self, sender, message, recipients):
"""O(n) message broadcasting with batching."""
message_data = {
"sender": sender.unique_id,
"message": message,
"recipients": [r.unique_id for r in recipients],
"timestamp": time.time(),
}

# Add to batch queue
await self.message_queue.put(message_data)

async def process_message_batch(self):
"""Process messages in batches."""
batch = []
while not self.message_queue.empty() and len(batch) < 50:
batch.append(await self.message_queue.get())

# Group by recipients for efficient delivery
recipient_groups = {}
for msg in batch:
for recipient_id in msg["recipients"]:
if recipient_id not in recipient_groups:
recipient_groups[recipient_id] = []
recipient_groups[recipient_id].append(msg)

# Deliver to each recipient
delivery_tasks = []
for recipient_id, messages in recipient_groups.items():
recipient = self.get_agent_by_id(recipient_id)
if recipient:
delivery_tasks.append(self.deliver_messages_batch(recipient, messages))

await asyncio.gather(*delivery_tasks, return_exceptions=True)

def deliver_messages_batch(self, recipient, messages):
"""Deliver batch of messages to a recipient."""
for msg in messages:
recipient.memory.add_to_memory(
type="message",
content={
"message": msg["message"],
"sender": msg["sender"],
"recipients": msg["recipients"],
},
)

def get_agent_by_id(self, agent_id):
"""Get agent by ID from model."""
if hasattr(self, "model") and hasattr(self.model, "agents"):
for agent in self.model.agents:
if hasattr(agent, "unique_id") and agent.unique_id == agent_id:
return agent
return None


# Global message bus instance
_global_message_bus = OptimizedMessageBus()


class LLMAgent(Agent):
"""
LLMAgent manages an LLM backend and optionally connects to a memory module.
Expand Down Expand Up @@ -61,7 +130,8 @@ def __init__(
self.model = model
self.step_prompt = step_prompt
self.llm = ModuleLLM(
llm_model=llm_model, system_prompt=system_prompt, api_base=api_base
llm_model=llm_model,
system_prompt=system_prompt,
)

self.memory = STLTMemory(
Expand Down Expand Up @@ -186,18 +256,6 @@ def _build_observation(self):
include_center=False,
radius=self.vision,
)
elif grid and isinstance(
grid, OrthogonalMooreGrid | OrthogonalVonNeumannGrid
):
agent_cell = next(
(cell for cell in grid.all_cells if self in cell.agents),
None,
)
if agent_cell:
neighborhood = agent_cell.get_neighborhood(radius=self.vision)
neighbors = [a for cell in neighborhood for a in cell.agents]
else:
neighbors = []

elif space and isinstance(space, ContinuousSpace):
all_nearby = space.get_neighbors(
Expand Down Expand Up @@ -291,6 +349,8 @@ def send_message(self, message: str, recipients: list[Agent]) -> str:
"""
Send a message to the recipients.
"""
# For now, use the original synchronous implementation
# The optimized async message bus can be used in async contexts
for recipient in [*recipients, self]:
recipient.memory.add_to_memory(
type="message",
Expand Down
69 changes: 69 additions & 0 deletions tests/test_llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,58 @@ def add_agent(self, pos):
assert len(action_content["tool_calls"]) == 1
assert action_content["tool_calls"][0] == {"tool": "foo", "argument": "bar"}

# ── Shared helpers ────────────────────────────────────────────────────────────


def _make_agent(model, vision=0, internal_state=None):
"""Helper: create one LLMAgent and attach fresh ShortTermMemory."""
agents = LLMAgent.create_agents(
model,
n=1,
reasoning=ReActReasoning,
system_prompt="Test",
vision=vision,
internal_state=internal_state or ["test"],
)
agent = agents.to_list()[0]
agent.memory = ShortTermMemory(agent=agent, n=5, display=True)
return agent


def _make_send_message_model(monkeypatch):
"""Two-agent MultiGrid model with ShortTermMemory for message tests."""
monkeypatch.setenv("GEMINI_API_KEY", "dummy")

class DummyModel(Model):
def __init__(self):
super().__init__(rng=45)
self.grid = MultiGrid(3, 3, torus=False)

def add_agent(self, pos):
agents = LLMAgent.create_agents(
self,
n=1,
reasoning=lambda agent: None,
system_prompt="Test",
vision=-1,
internal_state=[],
)
agent = agents.to_list()[0]
self.grid.place_agent(agent, pos)
return agent

model = DummyModel()

sender = model.add_agent((0, 0))
sender.memory = ShortTermMemory(agent=sender, n=5, display=True)
sender.unique_id = 10

recipient = model.add_agent((1, 1))
recipient.memory = ShortTermMemory(agent=recipient, n=5, display=True)
recipient.unique_id = 20

return sender, recipient


def test_apply_plan_preserves_multiple_tool_calls(monkeypatch):
"""All tool call results must be preserved when the LLM returns >1 tool call."""
Expand Down Expand Up @@ -183,6 +235,8 @@ async def fake_acall_tools(agent, llm_response):


def test_generate_obs_with_one_neighbor(monkeypatch):
monkeypatch.setenv("GEMINI_API_KEY", "dummy")

class DummyModel(Model):
def __init__(self):
super().__init__(rng=45)
Expand Down Expand Up @@ -239,6 +293,8 @@ def add_agent(self, pos, agent_class=LLMAgent):


def test_send_message_updates_both_agents_memory(monkeypatch):
monkeypatch.setenv("GEMINI_API_KEY", "dummy")

class DummyModel(Model):
def __init__(self):
super().__init__(rng=45)
Expand Down Expand Up @@ -296,6 +352,8 @@ def fake_add_to_memory(*args, **kwargs):

@pytest.mark.asyncio
async def test_aapply_plan_adds_to_memory(monkeypatch):
monkeypatch.setenv("GEMINI_API_KEY", "dummy")

class DummyModel(Model):
def __init__(self):
super().__init__(rng=42)
Expand Down Expand Up @@ -343,6 +401,8 @@ async def fake_acall_tools(agent, llm_response):

@pytest.mark.asyncio
async def test_agenerate_obs_with_one_neighbor(monkeypatch):
monkeypatch.setenv("GEMINI_API_KEY", "dummy")

class DummyModel(Model):
def __init__(self):
super().__init__(rng=45)
Expand Down Expand Up @@ -390,6 +450,8 @@ async def fake_aadd_to_memory(*args, **kwargs):

@pytest.mark.asyncio
async def test_async_wrapper_calls_pre_and_post(monkeypatch):
monkeypatch.setenv("GEMINI_API_KEY", "dummy")

class CustomAgent(LLMAgent):
async def astep(self):
self.user_called = True
Expand Down Expand Up @@ -453,6 +515,7 @@ def _make_agent(model, vision=0, internal_state=None):

def test_safer_cell_access_agent_with_cell_no_pos(monkeypatch):
"""Agent location falls back to cell.coordinate when pos=None."""
monkeypatch.setenv("GEMINI_API_KEY", "dummy")
model = Model(rng=42)
agent = _make_agent(model)
agent.pos = None
Expand All @@ -466,6 +529,7 @@ def test_safer_cell_access_agent_with_cell_no_pos(monkeypatch):

def test_safer_cell_access_agent_without_cell_or_pos(monkeypatch):
"""Agent location returns None gracefully when neither pos nor cell exists."""
monkeypatch.setenv("GEMINI_API_KEY", "dummy")
model = Model(rng=42)
agent = _make_agent(model)
agent.pos = None
Expand All @@ -480,6 +544,7 @@ def test_safer_cell_access_agent_without_cell_or_pos(monkeypatch):

def test_safer_cell_access_neighbor_with_cell_no_pos(monkeypatch):
"""Neighbor position uses cell.coordinate when neighbor.pos=None."""
monkeypatch.setenv("GEMINI_API_KEY", "dummy")

class GridModel(Model):
def __init__(self):
Expand Down Expand Up @@ -583,6 +648,7 @@ def __init__(self):

def test_generate_obs_vision_all_agents(monkeypatch):
"""vision=-1 returns all other agents regardless of position."""
monkeypatch.setenv("GEMINI_API_KEY", "dummy")

class GridModel(Model):
def __init__(self):
Expand Down Expand Up @@ -616,6 +682,7 @@ def __init__(self):

def test_generate_obs_no_grid_with_vision(monkeypatch):
"""When the model has no grid/space, generate_obs falls back to empty neighbors."""
monkeypatch.setenv("GEMINI_API_KEY", "dummy")
model = Model(rng=42) # no grid, no space
agents = LLMAgent.create_agents(
model,
Expand Down Expand Up @@ -645,6 +712,7 @@ def test_generate_obs_standard_grid_with_vision_radius(monkeypatch):
- The observation includes nearby agents in local_state.
- The SingleGrid neighbor lookup branch is executed.
"""
monkeypatch.setenv("GEMINI_API_KEY", "dummy")

class GridModel(Model):
def __init__(self):
Expand Down Expand Up @@ -680,6 +748,7 @@ def test_generate_obs_orthogonal_grid_branches(monkeypatch):
Covers Orthogonal grid-specific branches including
cell-based lookup and fallback behavior.
"""
monkeypatch.setenv("GEMINI_API_KEY", "dummy")

class OrthoModel(Model):
def __init__(self):
Expand Down