Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
157 changes: 157 additions & 0 deletions examples/stock_market/agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import examples.stock_market.tools # noqa: F401, to register tools
from mesa_llm.llm_agent import LLMAgent
from mesa_llm.tools.tool_manager import ToolManager

trader_tool_manager = ToolManager()
analyst_tool_manager = ToolManager()


def get_trading_history(agent, max_messages: int = 5) -> str:
history = []
memory_source = None
if hasattr(agent.memory, "short_term_memory"):
memory_source = agent.memory.short_term_memory
elif hasattr(agent.memory, "memory_entries"):
memory_source = agent.memory.memory_entries

if memory_source:
entries_to_check = min(len(memory_source), max_messages * 2)
for entry in reversed(list(memory_source)[-entries_to_check:]):
if len(history) >= max_messages:
break
if isinstance(entry.content, dict) and "message" in entry.content:
sender = entry.content.get("sender", "Unknown")
msg = entry.content.get("message", "")
if hasattr(sender, "unique_id"):
sender_name = f"{type(sender).__name__} {sender.unique_id}"
elif isinstance(sender, int):
try:
agent_obj = next(
a for a in agent.model.agents if a.unique_id == sender
)
sender_name = f"{type(agent_obj).__name__} {sender}"
except StopIteration:
sender_name = f"Agent {sender}"
else:
sender_name = str(sender)
history.append(f"- {sender_name}: {msg}")

history.reverse()
return "\n".join(history) if history else "No recent activity."


class TraderAgent(LLMAgent):
def __init__(
self,
model,
reasoning,
llm_model,
system_prompt,
vision,
internal_state,
budget,
api_base=None,
):
super().__init__(
model=model,
reasoning=reasoning,
llm_model=llm_model,
system_prompt=system_prompt,
api_base=api_base,
vision=vision,
internal_state=internal_state,
)
self.tool_manager = trader_tool_manager
self.budget = budget
self.shares = 0
self.trades = 0

def step(self):
observation = self.generate_obs()
history = get_trading_history(self)
price = self.model.current_price
prompt = (
f"MARKET DATA:\n"
f"- Price: ${price:.2f}\n"
f"- Trend: {self.model.price_trend()}\n"
f"- RSI: {self.model.rsi():.1f} (>70 overbought, <30 oversold)\n"
f"- Budget: ${self.budget:.2f} | Shares: {self.shares}\n\n"
f"RECENT ACTIVITY:\n{history}\n\n"
"Use execute_trade to BUY, SELL, or HOLD. Justify briefly."
)
plan = self.reasoning.plan(
prompt=prompt, obs=observation, selected_tools=["execute_trade", "speak_to"]
)
self.apply_plan(plan)

async def astep(self):
observation = self.generate_obs()
history = get_trading_history(self)
price = self.model.current_price
prompt = (
f"MARKET DATA:\n"
f"- Price: ${price:.2f}\n"
f"- Trend: {self.model.price_trend()}\n"
f"- RSI: {self.model.rsi():.1f} (>70 overbought, <30 oversold)\n"
f"- Budget: ${self.budget:.2f} | Shares: {self.shares}\n\n"
f"RECENT ACTIVITY:\n{history}\n\n"
"Use execute_trade to BUY, SELL, or HOLD. Justify briefly."
)
plan = await self.reasoning.aplan(
prompt=prompt, obs=observation, selected_tools=["execute_trade", "speak_to"]
)
self.apply_plan(plan)


class AnalystAgent(LLMAgent):
def __init__(
self,
model,
reasoning,
llm_model,
system_prompt,
vision,
internal_state,
api_base=None,
):
super().__init__(
model=model,
reasoning=reasoning,
llm_model=llm_model,
system_prompt=system_prompt,
api_base=api_base,
vision=vision,
internal_state=internal_state,
)
self.tool_manager = analyst_tool_manager
self.recommendations_sent = 0

def step(self):
observation = self.generate_obs()
prompt = (
f"MARKET SUMMARY:\n"
f"- Price: ${self.model.current_price:.2f}\n"
f"- Trend: {self.model.price_trend()}\n"
f"- RSI: {self.model.rsi():.1f}\n"
f"- Volatility: {self.model.volatility():.4f}\n\n"
"Broadcast a BUY/HOLD/SELL signal with brief reasoning to nearby traders using speak_to."
)
plan = self.reasoning.plan(
prompt=prompt, obs=observation, selected_tools=["speak_to"]
)
self.apply_plan(plan)

async def astep(self):
observation = self.generate_obs()
prompt = (
f"MARKET SUMMARY:\n"
f"- Price: ${self.model.current_price:.2f}\n"
f"- Trend: {self.model.price_trend()}\n"
f"- RSI: {self.model.rsi():.1f}\n"
f"- Volatility: {self.model.volatility():.4f}\n\n"
"Broadcast a BUY/HOLD/SELL signal with brief reasoning to nearby traders using speak_to."
)
plan = await self.reasoning.aplan(
prompt=prompt, obs=observation, selected_tools=["speak_to"]
)
self.apply_plan(plan)
88 changes: 88 additions & 0 deletions examples/stock_market/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import logging
import warnings

import pandas as pd
import solara
from dotenv import load_dotenv
from mesa.visualization import SolaraViz, make_space_component

import examples.stock_market.tools # noqa: F401, registers tools
from examples.stock_market.agents import AnalystAgent, TraderAgent
from examples.stock_market.model import StockMarketModel
from mesa_llm.parallel_stepping import enable_automatic_parallel_stepping
from mesa_llm.reasoning.react import ReActReasoning

warnings.filterwarnings("ignore", category=UserWarning, module="pydantic.main")
logging.getLogger("pydantic").setLevel(logging.ERROR)

enable_automatic_parallel_stepping(mode="threading")
load_dotenv()

model_params = {
"seed": {"type": "InputText", "value": 42, "label": "Random Seed"},
"initial_traders": 4,
"n_analysts": 2,
"width": 5,
"height": 5,
"reasoning": ReActReasoning,
"llm_model": "openai/gpt-4o",
"vision": 3,
"initial_price": 100.0,
"api_base": None,
}

model = StockMarketModel(
initial_traders=model_params["initial_traders"],
n_analysts=model_params["n_analysts"],
width=model_params["width"],
height=model_params["height"],
reasoning=model_params["reasoning"],
llm_model=model_params["llm_model"],
vision=model_params["vision"],
initial_price=model_params["initial_price"],
api_base=model_params["api_base"],
seed=model_params["seed"]["value"],
)

if __name__ == "__main__":

def model_portrayal(agent):
if agent is None:
return
portrayal = {"size": 25}
if isinstance(agent, AnalystAgent):
portrayal["color"] = "tab:orange"
portrayal["marker"] = "D"
portrayal["zorder"] = 3
elif isinstance(agent, TraderAgent):
portrayal["color"] = "tab:green" if agent.budget > 500.0 else "tab:red"
portrayal["marker"] = "o"
portrayal["zorder"] = 2
return portrayal

@solara.component
def MarketStatsPanel(*args, **kwargs):
show = solara.use_reactive(False)
df = solara.use_memo(
lambda: (
model.datacollector.get_model_vars_dataframe()
if show.value
else pd.DataFrame()
),
[show.value],
)
solara.Button(label="Show Market Data", on_click=lambda: show.set(True))
if show.value and not df.empty:
solara.DataFrame(df)

page = SolaraViz(
model,
components=[make_space_component(model_portrayal), MarketStatsPanel],
model_params=model_params,
name="Stock Market",
)

"""
Run with:
conda activate mesa-llm && solara run examples/stock_market/app.py
"""
Loading