-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtree_search_agent.py
More file actions
274 lines (237 loc) · 10.3 KB
/
tree_search_agent.py
File metadata and controls
274 lines (237 loc) · 10.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
"""Tree Search Agent — MCTS for plan exploration.
This example demonstrates Monte Carlo Tree Search (MCTS) for exploring and optimizing
plan branches. The agent uses MCTS to:
1. Select promising branches using UCB1 exploration strategy
2. Expand nodes by asking the LLM to generate candidate actions (one per line)
3. Simulate branches by asking the LLM to score potential paths (returns float 0.0-1.0)
4. Backpropagate scores to improve future branch selection
This creates an adaptive tree search where the agent discovers and scores different
plan paths, eventually converging on the best_plan (highest cumulative scores).
Usage:
# With FakeProvider (no LLM_API_KEY):
uv run python examples/tree_search_agent.py
# With real LLM (set LLM_API_KEY, LLM_BASE_URL, LLM_MODEL):
LLM_API_KEY=your-key uv run python examples/tree_search_agent.py
MCTS Mode:
- FakeProvider: Uses 9 deterministic responses (3 expansions, 6 simulations).
- Real LLM: Generates actions/scores dynamically. Results will vary based on model.
Response patterns:
- Expansion responses: one action per line, max max_branching lines
- Simulation responses: a float value 0.0-1.0 (or text containing a number)
"""
from __future__ import annotations
import asyncio
import os
from ecs_agent.components import (
ConversationComponent,
LLMComponent,
PlanSearchComponent,
)
from ecs_agent.core import Runner, World
from ecs_agent.providers import FakeProvider, OpenAIProvider
from ecs_agent.providers.config import ApiFormat, ProviderConfig
from ecs_agent.systems.tree_search import TreeSearchSystem
from ecs_agent.types import CompletionResult, Message
async def main() -> None:
"""Run a Tree Search Agent exploring problem-solving strategies."""
# --- Read environment variables ---
api_key = os.environ.get("LLM_API_KEY", "")
base_url = os.environ.get(
"LLM_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1"
)
model = os.environ.get("LLM_MODEL", "qwen3.5-flash")
# =========================================================================
# MCTS Setup: Fake provider with alternating expand/simulate responses
# =========================================================================
# TreeSearchSystem follows this pattern each iteration:
# 1. Select a leaf node (using UCB1)
# 2. If leaf is at depth < max_depth and has no children: Expand
# - Calls provider.complete(), expects response with actions (one per line)
# - Creates up to max_branching children from the actions
# 3. Simulate the selected/newly-expanded node
# - Calls provider.complete(), expects response with a score (float 0.0-1.0)
# - _parse_score extracts the number from the response
# 4. Backpropagate the score up the tree
# 5. Repeat until max_depth reached or no more expandable nodes
# Note: Each MCTS iteration may call expand once and simulate once,
# so we need pairs of responses: (expansion actions, simulation score)
# Response pair 0: Expand root node → 2 candidate strategies
response_expand_0 = CompletionResult(
message=Message(
role="assistant",
content=("Systematic step-by-step approach\nDivide-and-conquer strategy"),
),
)
# Response 1: Simulate first strategy
response_score_0 = CompletionResult(
message=Message(role="assistant", content="0.75"),
)
# Response pair 1: Now select will pick the first strategy (UCB1)
# or the second (unexplored). Let's say it explores the second.
# No expansion needed yet (both children of root already exist after response_expand_0)
# Just simulate the second strategy
response_score_1 = CompletionResult(
message=Message(role="assistant", content="0.85"),
)
# Response pair 2: Select expanded first strategy for refinement
# Expand it to get sub-strategies
response_expand_1 = CompletionResult(
message=Message(
role="assistant",
content=("Break into manageable sub-problems\nVerify each step thoroughly"),
),
)
# Response 3: Simulate first sub-strategy
response_score_2 = CompletionResult(
message=Message(role="assistant", content="0.80"),
)
# Response pair 3: Simulate second sub-strategy of first main strategy
response_score_3 = CompletionResult(
message=Message(role="assistant", content="0.82"),
)
# Response pair 4: Expand the divide-and-conquer main strategy
response_expand_2 = CompletionResult(
message=Message(
role="assistant",
content=("Identify independent subproblems\nParallelize when possible"),
),
)
# Response 5: Simulate first sub-strategy of divide-and-conquer
response_score_4 = CompletionResult(
message=Message(role="assistant", content="0.88"),
)
# Response 6: Simulate second sub-strategy of divide-and-conquer
response_score_5 = CompletionResult(
message=Message(role="assistant", content="0.86"),
)
# =========================================================================
# Create the ECS World
# =========================================================================
world = World()
# Create provider with pre-defined responses (FakeProvider) or real LLM
if api_key:
print(f"Using OpenAIProvider with model: {model}")
print(f"Base URL: {base_url}")
provider = OpenAIProvider(config=ProviderConfig(provider_id="openai", base_url=base_url, api_key=api_key, api_format=ApiFormat.OPENAI_CHAT_COMPLETIONS), model=model)
else:
print("No LLM_API_KEY set. Using FakeProvider for demonstration.")
print("To use a real API, set LLM_API_KEY, LLM_BASE_URL, and LLM_MODEL.")
print()
provider = FakeProvider(
responses=[
response_expand_0,
response_score_0,
response_score_1,
response_expand_1,
response_score_2,
response_score_3,
response_expand_2,
response_score_4,
response_score_5,
]
)
# Create the agent entity
agent = world.create_entity()
# Attach required components for TreeSearchSystem
world.add_component(
agent,
LLMComponent(
provider=provider,
model=model if api_key else "fake-model",
system_prompt=(
"You are a planning expert using MCTS to explore solution strategies. "
"When asked to generate actions, return one per line. "
"When asked to score a path, return a number from 0 to 1."
),
),
)
world.add_component(
agent,
ConversationComponent(
messages=[
Message(
role="user",
content="Find the best strategy to solve a complex algorithmic problem",
),
],
max_messages=100,
),
)
# MCTS configuration: max_depth=1, max_branching=2
# This will explore 1 decision level with at most 2 options per level
# With max_depth=1, one call to process() completes the MCTS search
world.add_component(
agent,
PlanSearchComponent(
max_depth=1,
max_branching=2,
exploration_weight=1.414, # UCB1 parameter: balance explore vs exploit
),
)
# =========================================================================
# Register the TreeSearchSystem
# =========================================================================
world.register_system(TreeSearchSystem(priority=0), priority=0)
# =========================================================================
# Run the MCTS loop
# =========================================================================
print("=" * 70)
print("TREE SEARCH AGENT — MCTS for Plan Exploration")
print("=" * 70)
print()
print("Configuration:")
print(" max_depth: 1 (explore up to 1 decision level)")
print(" max_branching: 2 (at most 2 options per level)")
print(" exploration_weight: 1.414 (UCB1 balance: explore vs. exploit)")
print()
print("Tree Search Process:")
print(" 1. Select: Use UCB1 to pick promising leaf nodes")
print(" 2. Expand: Ask LLM for candidate actions (one per line)")
print(" 3. Simulate: Ask LLM to score each candidate (float 0.0-1.0)")
print(" 4. Backpropagate: Update node statistics with scores")
print(" 5. Repeat until max_depth reached or no more expandable nodes")
print()
# Run for one tick - TreeSearchSystem completes one MCTS iteration
# (select → expand → simulate → backpropagate) then sets search_active=False
runner = Runner()
await runner.run(world, max_ticks=1)
# =========================================================================
# Display Results
# =========================================================================
print()
print("=" * 70)
print("MCTS RESULTS")
print("=" * 70)
print()
search = world.get_component(agent, PlanSearchComponent)
if search:
print(f"Search completed: search_active={search.search_active}")
if search.best_plan:
best_path = " → ".join(search.best_plan)
print(f"Best plan (highest-scoring path): {best_path}")
else:
print("Best plan: (none found - may have exhausted responses)")
print()
print("Summary:")
print(f" - Exploration depth: {search.max_depth} levels")
print(f" - Branching factor: {search.max_branching} children/node")
print()
conv = world.get_component(agent, ConversationComponent)
if conv:
print(f"Conversation history ({len(conv.messages)} messages):")
for i, msg in enumerate(conv.messages):
role_label = {
"user": "User",
"assistant": "LLM",
"system": "System",
}.get(msg.role, msg.role.title())
preview = msg.content[:70] if msg.content else "(empty)"
if len(msg.content or "") > 70:
preview += "..."
print(f" {i + 1}. [{role_label}] {preview}")
print()
print("=" * 70)
print("Tree Search Agent completed successfully!")
print("=" * 70)
if __name__ == "__main__":
asyncio.run(main())