|
18 | 18 |
|
19 | 19 | package com.google.genkit.ai; |
20 | 20 |
|
| 21 | +import java.util.ArrayList; |
21 | 22 | import java.util.HashMap; |
22 | 23 | import java.util.List; |
23 | 24 | import java.util.Map; |
24 | | -import java.util.UUID; |
25 | 25 | import java.util.function.Consumer; |
26 | 26 |
|
27 | 27 | import org.slf4j.Logger; |
@@ -119,27 +119,178 @@ public ModelResponse run(ActionContext ctx, GenerateActionOptions options, |
119 | 119 |
|
120 | 120 | logger.debug("Generating with model: {}", modelKey); |
121 | 121 |
|
122 | | - // Create span metadata for the model call |
123 | | - SpanMetadata spanMetadata = SpanMetadata.builder().name(modelName).type(ActionType.MODEL.getValue()) |
124 | | - .subtype("model").build(); |
| 122 | + // Determine if we should return tool requests without executing them |
| 123 | + boolean returnToolRequests = Boolean.TRUE.equals(options.getReturnToolRequests()); |
| 124 | + |
| 125 | + // Get max turns for tool loop (default to 5) |
| 126 | + int maxTurns = options.getMaxTurns() != null ? options.getMaxTurns() : 5; |
| 127 | + int turn = 0; |
125 | 128 |
|
126 | 129 | String flowName = ctx.getFlowName(); |
127 | | - if (flowName != null) { |
128 | | - spanMetadata.getAttributes().put("genkit:metadata:flow:name", flowName); |
129 | | - } |
130 | | - |
131 | | - // Run the model wrapped in a span |
132 | | - return Tracer.runInNewSpan(ctx, spanMetadata, request, (spanCtx, req) -> { |
133 | | - ActionContext newCtx = ctx.withSpanContext(spanCtx); |
134 | | - String spanPath = "/generate/" + modelName; |
135 | | - if (streamCallback != null && model.supportsStreaming()) { |
136 | | - return ModelTelemetryHelper.runWithTelemetryStreaming(modelName, flowName, spanPath, req, |
137 | | - r -> model.run(newCtx, r, streamCallback)); |
138 | | - } else { |
139 | | - return ModelTelemetryHelper.runWithTelemetry(modelName, flowName, spanPath, req, |
140 | | - r -> model.run(newCtx, r)); |
| 130 | + |
| 131 | + while (turn < maxTurns) { |
| 132 | + // Create span metadata for the model call |
| 133 | + SpanMetadata spanMetadata = SpanMetadata.builder().name(modelName).type(ActionType.MODEL.getValue()) |
| 134 | + .subtype("model").build(); |
| 135 | + |
| 136 | + if (flowName != null) { |
| 137 | + spanMetadata.getAttributes().put("genkit:metadata:flow:name", flowName); |
| 138 | + } |
| 139 | + |
| 140 | + final ModelRequest currentRequest = request; |
| 141 | + final String spanPath = "/generate/" + modelName; |
| 142 | + |
| 143 | + // Run the model wrapped in a span |
| 144 | + ModelResponse response = Tracer.runInNewSpan(ctx, spanMetadata, request, (spanCtx, req) -> { |
| 145 | + ActionContext newCtx = ctx.withSpanContext(spanCtx); |
| 146 | + if (streamCallback != null && model.supportsStreaming()) { |
| 147 | + return ModelTelemetryHelper.runWithTelemetryStreaming(modelName, flowName, spanPath, currentRequest, |
| 148 | + r -> model.run(newCtx, r, streamCallback)); |
| 149 | + } else { |
| 150 | + return ModelTelemetryHelper.runWithTelemetry(modelName, flowName, spanPath, currentRequest, |
| 151 | + r -> model.run(newCtx, r)); |
| 152 | + } |
| 153 | + }); |
| 154 | + |
| 155 | + // Check if the model requested tool calls |
| 156 | + List<Part> toolRequestParts = extractToolRequestParts(response); |
| 157 | + |
| 158 | + // If no tool requests or we should return them without executing, return |
| 159 | + // response |
| 160 | + if (toolRequestParts.isEmpty() || returnToolRequests) { |
| 161 | + return response; |
| 162 | + } |
| 163 | + |
| 164 | + // Check if we have tools to execute |
| 165 | + if (options.getTools() == null || options.getTools().isEmpty()) { |
| 166 | + // No tools available, return response with tool requests |
| 167 | + return response; |
| 168 | + } |
| 169 | + |
| 170 | + // Execute tools |
| 171 | + List<Part> toolResponseParts = executeTools(ctx, toolRequestParts, options.getTools()); |
| 172 | + |
| 173 | + // Add the assistant message with tool requests |
| 174 | + Message assistantMessage = response.getMessage(); |
| 175 | + List<Message> updatedMessages = new ArrayList<>(request.getMessages()); |
| 176 | + updatedMessages.add(assistantMessage); |
| 177 | + |
| 178 | + // Add tool response message |
| 179 | + Message toolResponseMessage = new Message(); |
| 180 | + toolResponseMessage.setRole(Role.TOOL); |
| 181 | + toolResponseMessage.setContent(toolResponseParts); |
| 182 | + updatedMessages.add(toolResponseMessage); |
| 183 | + |
| 184 | + // Update request with new messages for next turn |
| 185 | + request = ModelRequest.builder().messages(updatedMessages).config(request.getConfig()) |
| 186 | + .tools(request.getTools()).output(request.getOutput()).build(); |
| 187 | + |
| 188 | + turn++; |
| 189 | + } |
| 190 | + |
| 191 | + throw new GenkitException("Max tool execution turns (" + maxTurns + ") exceeded"); |
| 192 | + } |
| 193 | + |
| 194 | + /** |
| 195 | + * Extracts tool request parts from a model response. |
| 196 | + */ |
| 197 | + private List<Part> extractToolRequestParts(ModelResponse response) { |
| 198 | + List<Part> toolRequestParts = new ArrayList<>(); |
| 199 | + |
| 200 | + if (response.getMessage() != null && response.getMessage().getContent() != null) { |
| 201 | + for (Part part : response.getMessage().getContent()) { |
| 202 | + if (part.getToolRequest() != null) { |
| 203 | + toolRequestParts.add(part); |
| 204 | + } |
| 205 | + } |
| 206 | + } |
| 207 | + |
| 208 | + return toolRequestParts; |
| 209 | + } |
| 210 | + |
| 211 | + /** |
| 212 | + * Executes tools and returns the response parts. |
| 213 | + */ |
| 214 | + private List<Part> executeTools(ActionContext ctx, List<Part> toolRequestParts, List<String> toolNames) { |
| 215 | + List<Part> responseParts = new ArrayList<>(); |
| 216 | + |
| 217 | + for (Part toolRequestPart : toolRequestParts) { |
| 218 | + ToolRequest toolRequest = toolRequestPart.getToolRequest(); |
| 219 | + String toolName = toolRequest.getName(); |
| 220 | + Object toolInput = toolRequest.getInput(); |
| 221 | + |
| 222 | + // Find the tool |
| 223 | + Tool<?, ?> tool = findTool(toolName, toolNames); |
| 224 | + if (tool == null) { |
| 225 | + // Tool not found, create an error response |
| 226 | + Part errorPart = new Part(); |
| 227 | + ToolResponse errorResponse = new ToolResponse(toolRequest.getRef(), toolName, |
| 228 | + Map.of("error", "Tool not found: " + toolName)); |
| 229 | + errorPart.setToolResponse(errorResponse); |
| 230 | + responseParts.add(errorPart); |
| 231 | + logger.warn("Tool not found: {}", toolName); |
| 232 | + continue; |
| 233 | + } |
| 234 | + |
| 235 | + try { |
| 236 | + // Execute the tool |
| 237 | + @SuppressWarnings("unchecked") |
| 238 | + Tool<Object, Object> typedTool = (Tool<Object, Object>) tool; |
| 239 | + |
| 240 | + // Convert input if necessary |
| 241 | + Object convertedInput = toolInput; |
| 242 | + if (toolInput instanceof Map && tool.getInputClass() != null |
| 243 | + && !Map.class.isAssignableFrom(tool.getInputClass())) { |
| 244 | + convertedInput = objectMapper.convertValue(toolInput, tool.getInputClass()); |
| 245 | + } |
| 246 | + |
| 247 | + Object result = typedTool.run(ctx, convertedInput); |
| 248 | + |
| 249 | + // Create tool response part |
| 250 | + Part responsePart = new Part(); |
| 251 | + ToolResponse toolResponse = new ToolResponse(toolRequest.getRef(), toolName, result); |
| 252 | + responsePart.setToolResponse(toolResponse); |
| 253 | + responseParts.add(responsePart); |
| 254 | + |
| 255 | + logger.debug("Executed tool '{}' successfully", toolName); |
| 256 | + } catch (Exception e) { |
| 257 | + logger.error("Tool execution failed for '{}': {}", toolName, e.getMessage()); |
| 258 | + Part errorPart = new Part(); |
| 259 | + ToolResponse errorResponse = new ToolResponse(toolRequest.getRef(), toolName, |
| 260 | + Map.of("error", "Tool execution failed: " + e.getMessage())); |
| 261 | + errorPart.setToolResponse(errorResponse); |
| 262 | + responseParts.add(errorPart); |
| 263 | + } |
| 264 | + } |
| 265 | + |
| 266 | + return responseParts; |
| 267 | + } |
| 268 | + |
| 269 | + /** |
| 270 | + * Finds a tool by name from the list of tool names or registry. |
| 271 | + */ |
| 272 | + private Tool<?, ?> findTool(String toolName, List<String> toolNames) { |
| 273 | + // First try to find in registry by name |
| 274 | + String toolKey = toolName.startsWith("/tool/") ? toolName : "/tool/" + toolName; |
| 275 | + Action<?, ?, ?> action = registry.lookupAction(toolKey); |
| 276 | + if (action instanceof Tool) { |
| 277 | + return (Tool<?, ?>) action; |
| 278 | + } |
| 279 | + |
| 280 | + // Also try without prefix if the toolNames list includes it |
| 281 | + if (toolNames != null) { |
| 282 | + for (String name : toolNames) { |
| 283 | + String key = name.startsWith("/tool/") ? name : "/tool/" + name; |
| 284 | + if (key.equals(toolKey) || name.equals(toolName)) { |
| 285 | + action = registry.lookupAction(key); |
| 286 | + if (action instanceof Tool) { |
| 287 | + return (Tool<?, ?>) action; |
| 288 | + } |
| 289 | + } |
141 | 290 | } |
142 | | - }); |
| 291 | + } |
| 292 | + |
| 293 | + return null; |
143 | 294 | } |
144 | 295 |
|
145 | 296 | @Override |
@@ -169,9 +320,27 @@ public JsonNode runJson(ActionContext ctx, JsonNode input, Consumer<JsonNode> st |
169 | 320 | @Override |
170 | 321 | public ActionRunResult<JsonNode> runJsonWithTelemetry(ActionContext ctx, JsonNode input, |
171 | 322 | Consumer<JsonNode> streamCallback) throws GenkitException { |
172 | | - String traceId = UUID.randomUUID().toString(); |
173 | | - JsonNode result = runJson(ctx, input, streamCallback); |
174 | | - return new ActionRunResult<>(result, traceId, null); |
| 323 | + // Capture trace info from within the span |
| 324 | + final String[] capturedTraceInfo = new String[2]; // [traceId, spanId] |
| 325 | + |
| 326 | + SpanMetadata spanMetadata = SpanMetadata.builder().name("generate").type("util").build(); |
| 327 | + |
| 328 | + try { |
| 329 | + JsonNode result = Tracer.runInNewSpan(ctx, spanMetadata, input, (spanCtx, in) -> { |
| 330 | + // Capture the span context |
| 331 | + capturedTraceInfo[0] = spanCtx.getTraceId(); |
| 332 | + capturedTraceInfo[1] = spanCtx.getSpanId(); |
| 333 | + |
| 334 | + return runJson(ctx.withSpanContext(spanCtx), in, streamCallback); |
| 335 | + }); |
| 336 | + |
| 337 | + return new ActionRunResult<>(result, capturedTraceInfo[0], capturedTraceInfo[1]); |
| 338 | + } catch (Exception e) { |
| 339 | + if (e instanceof GenkitException) { |
| 340 | + throw (GenkitException) e; |
| 341 | + } |
| 342 | + throw new GenkitException("Generate action failed: " + e.getMessage(), e); |
| 343 | + } |
175 | 344 | } |
176 | 345 |
|
177 | 346 | @Override |
|
0 commit comments