diff --git a/Cargo.lock b/Cargo.lock index 4021a21a..9bbfcd84 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -840,6 +840,7 @@ dependencies = [ "csv", "docx-rs", "futures", + "guardrail_ffi", "jemallocator", "lopdf 0.32.0", "mimalloc", @@ -880,6 +881,14 @@ dependencies = [ "tokio", ] +[[package]] +name = "guardrail_ffi" +version = "0.1.0" +dependencies = [ + "reqwest", + "serde_json", +] + [[package]] name = "half" version = "2.7.1" @@ -2085,6 +2094,7 @@ checksum = "ab3f43e3283ab1488b624b44b0e988d0acea0b3214e694730a055cb6b2efa801" dependencies = [ "base64", "bytes", + "futures-channel", "futures-core", "futures-util", "http", diff --git a/Cargo.toml b/Cargo.toml index 263ae092..a754d9bc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,7 +41,7 @@ version.workspace = true [profile.bench] debug = false incremental = false -lto = "fat" +lto = false opt-level = 3 # Workspace profiles (applied to all members) @@ -55,7 +55,8 @@ overflow-checks = true codegen-units = 1 debug = false incremental = false -lto = "fat" +# LTO disabled: prebuilt libguardrail_ffi.a has no bitcode; fat LTO would fail at link. +lto = false opt-level = 3 panic = "abort" strip = "symbols" @@ -64,7 +65,8 @@ strip = "symbols" [profile.release-python] codegen-units = 1 inherits = "release" -lto = "fat" +# LTO follows release (false) so we can link prebuilt libguardrail_ffi.a +lto = false opt-level = "s" # Optimize for size (important for Python extensions) panic = "abort" strip = "symbols" @@ -80,6 +82,7 @@ jemallocator.workspace = true [workspace] members = [ "core", + "guardrail_ffi", "python" ] resolver = "2" diff --git a/README.md b/README.md index ad84db85..20cf88fa 100644 --- a/README.md +++ b/README.md @@ -156,7 +156,7 @@ export ANTHROPIC_API_KEY=your_anthropic_api_key_here ```python import os -from graphbit import LlmConfig, Executor, Workflow, Node, tool +from graphbit import LlmConfig, Executor, Workflow, Node, tool, GuardRailPolicyConfig # Initialize and configure config = LlmConfig.openai(os.getenv("OPENAI_API_KEY"), "gpt-4o-mini") @@ -200,7 +200,9 @@ id1 = workflow.add_node(smart_agent) id2 = workflow.add_node(processor) workflow.connect(id1, id2) +# Run (optionally with a guardrail policy for PII masking/mapping) result = executor.execute(workflow) +# Or with policy: result = executor.execute(workflow, policy=GuardRailPolicyConfig.from_json('{"guardrail_policy": {"pii_rules": [...]}}')) print(f"Workflow completed: {result.is_success()}") print("\nSmart Agent Output: \n", result.get_node_output("Smart Agent")) print("\nData Processor Output: \n", result.get_node_output("Data Processor")) diff --git a/core/Cargo.toml b/core/Cargo.toml index 9ee69d4d..d2e0389a 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -1,5 +1,7 @@ [dependencies] pyo3 = {workspace = true, optional = true} +# GuardRail: prebuilt libguardrail_ffi.a only (see vendor/guardrail/README.md). +guardrail_ffi = { path = "../guardrail_ffi" } anyhow.workspace = true async-trait.workspace = true calamine.workspace = true diff --git a/core/src/lib.rs b/core/src/lib.rs index 4c20d0b0..96e4ba6e 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -61,6 +61,9 @@ pub use types::{ pub use validation::ValidationResult; pub use workflow::{Workflow, WorkflowBuilder, WorkflowExecutor}; +// Re-export guardrail types (from prebuilt libguardrail_ffi.a via guardrail_ffi crate) +pub use guardrail_ffi::{DecodeContext, EncodeContext, EncodeResult, DecodeResult, Enforcer, GuardRail, GuardRailConfig}; + /// Version information pub const VERSION: &str = env!("CARGO_PKG_VERSION"); diff --git a/core/src/workflow.rs b/core/src/workflow.rs index a0498bdb..0649a84b 100644 --- a/core/src/workflow.rs +++ b/core/src/workflow.rs @@ -12,6 +12,7 @@ use crate::types::{ ConcurrencyManager, ConcurrencyStats, MessageContent, NodeExecutionResult, NodeId, RetryConfig, TaskInfo, WorkflowContext, WorkflowExecutionStats, WorkflowId, WorkflowState, }; +use crate::{DecodeContext, EncodeContext, Enforcer}; use futures::future::join_all; use regex::Regex; use serde::{Deserialize, Serialize}; @@ -19,8 +20,8 @@ use std::collections::HashMap; use std::sync::{Arc, LazyLock}; use tokio::sync::{Mutex, RwLock}; - -static NODE_REF_PATTERN: LazyLock = LazyLock::new(|| Regex::new(r"\{\{node\.([a-zA-Z0-9_\-\.]+)\}\}").unwrap()); +static NODE_REF_PATTERN: LazyLock = + LazyLock::new(|| Regex::new(r"\{\{node\.([a-zA-Z0-9_\-\.]+)\}\}").unwrap()); /// A complete workflow definition #[derive(Debug, Clone, Serialize, Deserialize)] @@ -68,6 +69,7 @@ impl Workflow { /// Validate the workflow pub fn validate(&self) -> GraphBitResult<()> { + tracing::debug!("Workflow '{:#?}' validated successfully", self.graph); self.graph.validate() } @@ -243,7 +245,9 @@ impl WorkflowExecutor { } // 3. No default fallback - require explicit configuration as requested by user - tracing::error!("No LLM configuration found - neither node-level nor executor-level config provided. System requires explicit configuration."); + tracing::error!( + "No LLM configuration found - neither node-level nor executor-level config provided. System requires explicit configuration." + ); crate::llm::LlmConfig::Unconfigured { message: "No LLM configuration provided. The system requires explicit configuration from program or user input rather than hardcoded defaults.".to_string() } @@ -280,8 +284,15 @@ impl WorkflowExecutor { self.concurrency_manager.get_available_permits().await } - /// Execute a workflow with enhanced performance monitoring - pub async fn execute(&self, workflow: Workflow) -> GraphBitResult { + /// Execute a workflow with enhanced performance monitoring. + /// + /// When `guardrail_enforcer` is `Some`, PII is encoded before each LLM call and + /// decoded on LLM output; tool-call boundaries are handled by the executor layer. + pub async fn execute( + &self, + workflow: Workflow, + guardrail_enforcer: Option>, + ) -> GraphBitResult { let start_time = std::time::Instant::now(); // Initialize workflow context with simple constructor @@ -472,6 +483,7 @@ impl WorkflowExecutor { let circuit_breaker_config = self.circuit_breaker_config.clone(); let retry_config = self.default_retry_config.clone(); let concurrency_manager = self.concurrency_manager.clone(); + let guardrail_enforcer = guardrail_enforcer.clone(); // Use lightweight task spawning without unnecessary permit acquisition overhead let task = tokio::spawn(async move { @@ -503,6 +515,7 @@ impl WorkflowExecutor { circuit_breakers_clone, circuit_breaker_config, retry_config, + guardrail_enforcer, ) .await }); @@ -633,6 +646,7 @@ impl WorkflowExecutor { circuit_breakers: Arc>>, circuit_breaker_config: CircuitBreakerConfig, retry_config: Option, + guardrail_enforcer: Option>, ) -> GraphBitResult { let start_time = std::time::Instant::now(); let mut attempt = 0; @@ -678,6 +692,7 @@ impl WorkflowExecutor { &node.config, context.clone(), agents.clone(), + guardrail_enforcer.clone(), ) .await } @@ -785,7 +800,8 @@ impl WorkflowExecutor { } } - /// Execute an agent node (static version) + /// Execute an agent node (static version). + /// When `guardrail_enforcer` is `Some`, encodes prompt before LLM and decodes response after. async fn execute_agent_node_static( current_node_id: &NodeId, agent_id: &crate::types::AgentId, @@ -793,6 +809,7 @@ impl WorkflowExecutor { node_config: &std::collections::HashMap, context: Arc>, agents: Arc>>>, + guardrail_enforcer: Option>, ) -> GraphBitResult { // Use read lock for better performance let agents_guard = agents.read().await; @@ -959,6 +976,7 @@ impl WorkflowExecutor { current_node_id, &node_name, context.clone(), + guardrail_enforcer.clone(), ) .await; tracing::info!("Agent with tools execution result: {:?}", result); @@ -967,9 +985,52 @@ impl WorkflowExecutor { // Execute agent without tools (original behavior) tracing::info!("NO TOOLS DETECTED - using standard agent execution"); + // Build the executions array for metadata + let mut executions: Vec = Vec::new(); + + // Guardrail: encode prompt before sending to LLM; combine injection text + payload + let mut encoded_payload_for_meta = String::new(); + let prompt_for_llm = if let Some(ref enforcer) = guardrail_enforcer { + tracing::debug!( + "Guardrail: encoding prompt before LLM call (sensitive data will be masked)" + ); + let encode_result = enforcer.encode( + serde_json::Value::String(resolved_prompt.clone()), + EncodeContext::Llm, + ); + tracing::debug!( + "Guardrail: prompt encoded for LLM (payload only): {}", + encode_result.payload.as_str().unwrap_or("") + ); + tracing::debug!( + "[GuardRail] encoded prompt (sent to LLM, payload only): {}", + encode_result.payload.as_str().unwrap_or("") + ); + + // Record guardrail encode execution entry + executions.push(serde_json::json!({ + "type": "guardrail_policy", + "operation": "encode", + "pii_rules_applied_count": encode_result.rules_applied_count, + "pii_rule_names": encode_result.rule_names, + "policy_name": encode_result.policy_name + })); + + // Capture encoded payload (without signature) for metadata user_input + encoded_payload_for_meta = encode_result.payload.as_str().unwrap_or("").to_string(); + + format!( + "{}{}", + encode_result.signature_injection_text, + encode_result.payload.as_str().unwrap_or("") + ) + } else { + resolved_prompt.clone() + }; + // Call LLM provider directly to capture metadata use crate::llm::LlmRequest; - let mut request = LlmRequest::new(resolved_prompt.clone()); + let mut request = LlmRequest::new(prompt_for_llm.clone()); // Apply node-level configuration overrides (temperature, max_tokens, etc.) if let Some(temp_value) = node_config.get("temperature") { @@ -989,8 +1050,97 @@ impl WorkflowExecutor { let llm_start = std::time::Instant::now(); let llm_response = agent.llm_provider().complete(request).await?; let llm_duration_ms = llm_start.elapsed().as_secs_f64() * 1000.0; + let llm_end_timestamp = chrono::Utc::now(); + + // Get provider name for metadata + let provider_name = agent.llm_provider().config().provider_name().to_string(); + + // Capture raw LLM content before decode for metadata + let raw_llm_content = llm_response.content.clone(); + if guardrail_enforcer.is_some() { + tracing::debug!( + "[GuardRail] raw LLM response (before decode): content={:?}, tool_calls={:?}", + llm_response.content, + llm_response.tool_calls + ); + } + + // Build the llm_call execution entry (before decode, captures raw LLM output) + let llm_call_entry = serde_json::json!({ + "type": "llm_call", + "id": llm_response.id.clone().unwrap_or_default(), + "model": llm_response.model, + "provider": provider_name, + "input": if guardrail_enforcer.is_some() { encoded_payload_for_meta.clone() } else { resolved_prompt.clone() }, + "output": llm_response.content, + "finish_reason": format!("{}", llm_response.finish_reason), + "tool_calls": [], + "start_time": execution_timestamp.to_rfc3339(), + "end_time": llm_end_timestamp.to_rfc3339(), + "duration_ms": llm_duration_ms, + "usage": { + "prompt_tokens": llm_response.usage.prompt_tokens, + "completion_tokens": llm_response.usage.completion_tokens, + "total_tokens": llm_response.usage.total_tokens, + "prompt_tokens_details": { + "cached_tokens": 0, + "audio_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0, + "audio_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0 + } + }, + "retries": [] + }); + executions.push(llm_call_entry); + + // Guardrail: decode LLM output before storing in context + let llm_response = if let Some(ref enforcer) = guardrail_enforcer { + tracing::debug!( + "Guardrail: decoding LLM response (rehydrating for context) llm_response.content: {}", + llm_response.content + ); + tracing::debug!("Guardrail: decoding LLM response (rehydrating for context)"); + let payload = serde_json::json!({ + "content": llm_response.content, + "tool_calls": llm_response.tool_calls + }); + let decoded_result = enforcer.decode(payload, DecodeContext::LlmResponse); + tracing::debug!("Guardrail: LLM response decoded"); + + // Record guardrail decode execution entry + executions.push(serde_json::json!({ + "type": "guardrail_policy", + "operation": "decode", + "pii_rules_applied_count": decoded_result.rules_applied_count, + "pii_rule_names": decoded_result.rule_names, + "policy_name": decoded_result.policy_name + })); + + let content = decoded_result + .payload + .get("content") + .and_then(|v| v.as_str()) + .map(String::from) + .unwrap_or_else(|| llm_response.content.clone()); + let tool_calls = decoded_result + .payload + .get("tool_calls") + .and_then(|v| serde_json::from_value(v.clone()).ok()) + .unwrap_or_else(|| llm_response.tool_calls.clone()); + crate::llm::LlmResponse { + content, + tool_calls, + ..llm_response + } + } else { + llm_response + }; - // Store LLM response metadata AND request prompt in context for observability + // Build the node-level metadata with executions array { // First, get the node name before mutable borrow let node_name = { @@ -1004,36 +1154,60 @@ impl WorkflowExecutor { .unwrap_or_else(|| "unknown".to_string()) }; + let max_iterations = node_config + .get("max_iterations") + .and_then(|v| v.as_u64()) + .unwrap_or(5) as u32; + + let node_metadata = serde_json::json!({ + "node_id": current_node_id.to_string(), + "node_name": node_name, + "node_type": "Agent", + // When GR active: user_input = masked prompt, final_output = raw LLM content + // When GR inactive: user_input = original prompt, final_output = decoded content + "user_input": if guardrail_enforcer.is_some() { encoded_payload_for_meta.clone() } else { resolved_prompt.clone() }, + "tools_available": [], + "total_tools_available": 0, + "start_time": execution_timestamp.to_rfc3339(), + "end_time": llm_end_timestamp.to_rfc3339(), + "duration_ms": llm_duration_ms, + "success": true, + "error": serde_json::Value::Null, + "final_output": if guardrail_enforcer.is_some() { raw_llm_content } else { llm_response.content.clone() }, + "total_iterations": 0, + "max_iterations": max_iterations, + "exit_reason": llm_response.finish_reason, + "total_usage": { + "prompt_tokens": llm_response.usage.prompt_tokens, + "completion_tokens": llm_response.usage.completion_tokens, + "total_tokens": llm_response.usage.total_tokens, + "prompt_tokens_details": { + "cached_tokens": 0, + "audio_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0, + "audio_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0 + } + }, + "total_tool_calls": 0, + "total_retries": 0, + "tools_used": [], + "executions": executions + }); + // Now store the metadata let mut ctx = context.lock().await; - if let Ok(mut response_metadata) = serde_json::to_value(&llm_response) { - // Add the request prompt to the metadata - if let Some(obj) = response_metadata.as_object_mut() { - obj.insert( - "prompt".to_string(), - serde_json::Value::String(resolved_prompt.clone()), - ); - // Add LLM call duration for accurate latency tracking - obj.insert( - "duration_ms".to_string(), - serde_json::json!(llm_duration_ms), - ); - // Add execution timestamp for chronological ordering - obj.insert( - "execution_timestamp".to_string(), - serde_json::json!(execution_timestamp.to_rfc3339()), - ); - } - - // Store by node ID - ctx.metadata.insert( - format!("node_response_{current_node_id}"), - response_metadata.clone(), - ); - // Store by node name - ctx.metadata - .insert(format!("node_response_{node_name}"), response_metadata); - } + // Store by node ID + ctx.metadata.insert( + format!("node_response_{current_node_id}"), + node_metadata.clone(), + ); + // Store by node name + ctx.metadata + .insert(format!("node_response_{node_name}"), node_metadata); } // Return the content as JSON value @@ -1046,7 +1220,8 @@ impl WorkflowExecutor { } } - /// Execute an agent with tool calling orchestration + /// Execute an agent with tool calling orchestration. + /// When `guardrail_enforcer` is `Some`, encodes prompt before LLM and decodes response after. async fn execute_agent_with_tools( _agent_id: &crate::types::AgentId, prompt: &str, @@ -1055,10 +1230,54 @@ impl WorkflowExecutor { node_id: &NodeId, node_name: &str, context: Arc>, + guardrail_enforcer: Option>, ) -> GraphBitResult { tracing::info!("Starting execute_agent_with_tools for agent: {_agent_id}"); use crate::llm::{LlmRequest, LlmTool}; + // Build the executions array for metadata + let mut executions: Vec = Vec::new(); + + // Guardrail: encode prompt before sending to LLM; combine injection text + payload + let mut encoded_payload_for_meta = String::new(); + let prompt_for_llm = if let Some(ref enforcer) = guardrail_enforcer { + tracing::debug!( + "Guardrail: encoding prompt before LLM call (tool path; sensitive data masked)" + ); + let encode_result = enforcer.encode( + serde_json::Value::String(prompt.to_string()), + EncodeContext::Llm, + ); + tracing::debug!( + "Guardrail: prompt encoded for LLM (payload only): {}", + encode_result.payload.as_str().unwrap_or_default() + ); + tracing::debug!( + "[GuardRail] encoded prompt (sent to LLM, payload only): {}", + encode_result.payload.as_str().unwrap_or_default() + ); + + // Record guardrail encode execution entry + executions.push(serde_json::json!({ + "type": "guardrail_policy", + "operation": "encode", + "pii_rules_applied_count": encode_result.rules_applied_count, + "pii_rule_names": encode_result.rule_names, + "policy_name": encode_result.policy_name + })); + + // Capture encoded payload (without signature) for metadata user_input + encoded_payload_for_meta = encode_result.payload.as_str().unwrap_or_default().to_string(); + + format!( + "{}{}", + encode_result.signature_injection_text, + encode_result.payload.as_str().unwrap_or_default() + ) + } else { + prompt.to_string() + }; + // Extract tool schemas from node config let tool_schemas = node_config .get("tool_schemas") @@ -1067,6 +1286,12 @@ impl WorkflowExecutor { tracing::info!("Found {} tool schemas", tool_schemas.len()); + // Collect tool names for metadata + let tool_names: Vec = tool_schemas + .iter() + .filter_map(|s| s.get("name").and_then(|v| v.as_str()).map(String::from)) + .collect(); + // Convert tool schemas to LlmTool objects let mut tools = Vec::new(); for schema in tool_schemas { @@ -1079,8 +1304,8 @@ impl WorkflowExecutor { } } - // Create initial LLM request with tools - let mut request = LlmRequest::new(prompt); + // Create initial LLM request with tools (using encoded prompt when guardrail is active) + let mut request = LlmRequest::new(prompt_for_llm.clone()); for tool in &tools { request = request.with_tool(tool.clone()); } @@ -1125,42 +1350,165 @@ impl WorkflowExecutor { // Measure LLM call duration and capture execution timestamp let execution_timestamp = chrono::Utc::now(); let llm_start = std::time::Instant::now(); - let llm_response = agent.llm_provider().complete(request).await?; + let mut llm_response = agent.llm_provider().complete(request).await?; let llm_duration_ms = llm_start.elapsed().as_secs_f64() * 1000.0; + let llm_end_timestamp = chrono::Utc::now(); - // Store LLM response metadata AND request prompt in context for observability - { - let mut ctx = context.lock().await; - if let Ok(mut response_metadata) = serde_json::to_value(&llm_response) { - // Add the request prompt to the metadata - if let Some(obj) = response_metadata.as_object_mut() { - obj.insert( - "prompt".to_string(), - serde_json::Value::String(prompt.to_string()), - ); - // Add LLM call duration for accurate latency tracking - obj.insert( - "duration_ms".to_string(), - serde_json::json!(llm_duration_ms), - ); - // Add execution timestamp for chronological ordering - obj.insert( - "execution_timestamp".to_string(), - serde_json::json!(execution_timestamp.to_rfc3339()), - ); + // Get provider name for metadata + let provider_name = agent.llm_provider().config().provider_name().to_string(); + + // Capture raw LLM content before decode for metadata + let raw_llm_content = llm_response.content.clone(); + if guardrail_enforcer.is_some() { + tracing::debug!( + "[GuardRail] raw LLM response (before decode): content={:?}, tool_calls={:?}", + llm_response.content, + llm_response.tool_calls + ); + } + + // Build tool_calls array for the llm_call execution entry (raw from LLM response) + let llm_tool_calls_for_metadata: Vec = llm_response + .tool_calls + .iter() + .map(|tc| { + serde_json::json!({ + "id": tc.id, + "type": "function", + "function": { + "name": tc.name, + "arguments": serde_json::to_string(&tc.parameters).unwrap_or_default() + } + }) + }) + .collect(); + + // Build the llm_call execution entry + let llm_call_entry = serde_json::json!({ + "type": "llm_call", + "id": llm_response.id.clone().unwrap_or_default(), + "model": llm_response.model, + "provider": provider_name, + "input": if guardrail_enforcer.is_some() { encoded_payload_for_meta.clone() } else { prompt.to_string() }, + "output": llm_response.content, + "finish_reason": format!("{}", llm_response.finish_reason), + "tool_calls": llm_tool_calls_for_metadata, + "start_time": execution_timestamp.to_rfc3339(), + "end_time": llm_end_timestamp.to_rfc3339(), + "duration_ms": llm_duration_ms, + "usage": { + "prompt_tokens": llm_response.usage.prompt_tokens, + "completion_tokens": llm_response.usage.completion_tokens, + "total_tokens": llm_response.usage.total_tokens, + "prompt_tokens_details": { + "cached_tokens": 0, + "audio_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0, + "audio_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0 } + }, + "retries": [] + }); + executions.push(llm_call_entry); - // Store by node ID - ctx.metadata.insert( - format!("node_response_{node_id}"), - response_metadata.clone(), - ); - // Store by node name - ctx.metadata - .insert(format!("node_response_{node_name}"), response_metadata); + // Guardrail: decode LLM output before storing in context + if let Some(ref enforcer) = guardrail_enforcer { + tracing::debug!( + "Guardrail: decoding LLM response (tool path; rehydrating for context) llm_response.content: {}", + llm_response.content + ); + tracing::debug!( + "Guardrail: decoding LLM response (tool path; rehydrating for context)" + ); + let payload = serde_json::json!({ + "content": llm_response.content, + "tool_calls": llm_response.tool_calls + }); + let decoded_result = enforcer.decode(payload, DecodeContext::LlmResponse); + tracing::debug!("Guardrail: LLM response decoded"); + + // Record guardrail decode execution entry + executions.push(serde_json::json!({ + "type": "guardrail_policy", + "operation": "decode", + "pii_rules_applied_count": decoded_result.rules_applied_count, + "pii_rule_names": decoded_result.rule_names, + "policy_name": decoded_result.policy_name + })); + + if let Some(c) = decoded_result + .payload + .get("content") + .and_then(|v| v.as_str()) + { + llm_response.content = c.to_string(); + } + if let Some(tc) = decoded_result.payload.get("tool_calls") { + if let Ok(parsed) = serde_json::from_value(tc.clone()) { + llm_response.tool_calls = parsed; + } } } + // Build the initial node-level metadata with executions array + // The Python layer (handle_tool_calls_in_context) will extend this with tool_call and subsequent llm_call entries + let max_iterations = node_config + .get("max_iterations") + .and_then(|v| v.as_u64()) + .unwrap_or(5) as u32; + + let node_metadata = serde_json::json!({ + "node_id": node_id.to_string(), + "node_name": node_name, + "node_type": "Agent", + // When GR active: user_input = masked prompt, final_output = raw LLM content + // When GR inactive: user_input = original prompt, final_output = decoded content + "user_input": if guardrail_enforcer.is_some() { encoded_payload_for_meta.clone() } else { prompt.to_string() }, + "tools_available": tool_names, + "total_tools_available": tool_names.len(), + "start_time": execution_timestamp.to_rfc3339(), + "end_time": llm_end_timestamp.to_rfc3339(), + "duration_ms": llm_duration_ms, + "success": true, + "error": serde_json::Value::Null, + "final_output": if guardrail_enforcer.is_some() { raw_llm_content.clone() } else { llm_response.content.clone() }, + "total_iterations": 0, + "max_iterations": max_iterations, + "exit_reason": format!("{}", llm_response.finish_reason), + "total_usage": { + "prompt_tokens": llm_response.usage.prompt_tokens, + "completion_tokens": llm_response.usage.completion_tokens, + "total_tokens": llm_response.usage.total_tokens, + "prompt_tokens_details": { + "cached_tokens": 0, + "audio_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0, + "audio_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0 + } + }, + "total_tool_calls": 0, + "total_retries": 0, + "tools_used": [], + "executions": executions + }); + + // Store the metadata + { + let mut ctx = context.lock().await; + ctx.metadata + .insert(format!("node_response_{node_id}"), node_metadata.clone()); + ctx.metadata + .insert(format!("node_response_{node_name}"), node_metadata); + } + // DEBUG: Log LLM response details tracing::info!("LLM Response - Content: '{}'", llm_response.content); tracing::info!( @@ -1188,13 +1536,20 @@ impl WorkflowExecutor { GraphBitError::workflow_execution(format!("Failed to serialize tool calls: {e}")) })?; - // Return a structured response that the Python layer can interpret - // Include token usage for budget tracking + // Return a structured response that the Python layer can interpret. + // When GuardRail is on, pass only the encoded payload (without the RULE signature + // injection text) so the executor can reconstruct the final prompt cleanly. + // The executor will re-encode the final prompt (adding a fresh RULE prefix) before + // the second LLM call; including the RULE here would cause it to appear in metadata. + let original_prompt_for_response = guardrail_enforcer + .as_ref() + .map(|_| encoded_payload_for_meta.clone()) + .unwrap_or_else(|| prompt.to_string()); Ok(serde_json::json!({ "type": "tool_calls_required", "content": llm_response.content, "tool_calls": tool_calls_json, - "original_prompt": prompt, + "original_prompt": original_prompt_for_response, "initial_tokens_used": llm_response.usage.completion_tokens, "max_tokens_configured": node_config.get("max_tokens").and_then(|v| v.as_u64()), "message": "Tool execution should be handled by Python layer with proper tool registry" diff --git a/examples/guardrail_financial/README.md b/examples/guardrail_financial/README.md new file mode 100644 index 00000000..79794f40 --- /dev/null +++ b/examples/guardrail_financial/README.md @@ -0,0 +1,146 @@ +# GuardRail Examples + +This directory contains examples demonstrating GraphBit's GuardRail feature for protecting Personally Identifiable Information (PII) in LLM workflows. + +## How GuardRail Works + +GuardRail provides intelligent masking of sensitive data: +- **LLM sees**: Masked tokens (e.g., `[CREDIT_CARD_1]`, `[EMAIL_1]`) +- **Tools receive**: Real unmasked values for accurate processing +- **Automatic handling**: Encode/decode at LLM and tool boundaries + +## Examples + +### 1. Phone Number Sum (`guardrail_phone/`) + +**Purpose**: Demonstrate GuardRail protecting phone numbers while allowing tools to process them correctly. + +**Pattern**: +- Policy masks phone numbers matching pattern `\d{3}-\d{4}` (e.g., `123-4567`) +- Tool `sum_digits_in_phone()` receives the real unmasked number +- LLM only sees masked token (e.g., `[PHONE_NUMBER_1]`) + +**Files**: +- `guardrail_phone_policy.json` - Policy defining phone number masking rules +- `run_guardrail_phone.py` - Example workflow with tool calling + +**Run**: +```bash +.venv/bin/python examples/guardrail_phone/run_guardrail_phone.py +``` + +### 2. Financial Payment Processing (`guardrail_financial/`) + +**Purpose**: Demonstrate GuardRail protecting multiple types of sensitive financial data (credit cards, emails, SSNs) in a payment processing workflow. + +**Pattern**: +- Policy masks: + - Credit cards: `\d{4}-\d{4}-\d{4}-\d{4}` (e.g., `4532-1234-5678-9010`) + - Emails: Standard email pattern (e.g., `customer@example.com`) + - SSNs: `\d{3}-\d{2}-\d{4}` (e.g., `123-45-6789`) +- Three tools with different data requirements: + - `validate_credit_card()` - Validates the card and returns last 4 digits + - `calculate_transaction_fee()` - Computes 2% fee on amount + - `send_payment_confirmation()` - Sends confirmation to recipient email +- LLM sees only masked tokens, tools get real values + +**Files**: +- `guardrail_financial_policy.json` - Policy defining credit card, email, and SSN masking +- `run_guardrail_financial.py` - Complete payment processing workflow + +**Run**: +```bash +.venv/bin/python examples/guardrail_financial/run_guardrail_financial.py +``` + +## Key Implementation Details + +### Policy Definition (JSON) + +```json +{ + "policy_name": "example_policy", + "policy_version": "1.0.0", + "active": true, + "guardrail_policy": { + "pii_rules": [ + { + "type": "regex", + "name": "RULE_NAME", + "pattern": "regex_pattern_here" + } + ], + "masking": true, + "mapping": true + } +} +``` + +### Tool Definition + +```python +from graphbit import tool + +@tool(_description="Description shown to LLM") +def my_tool(param: str) -> str: + """ + When GuardRail is enabled, this tool receives DECODED values. + """ + print(f"[Tool received] param = {param!r}") + # Process the real unmasked value + return result +``` + +### Workflow Execution with GuardRail + +```python +from graphbit import Executor, GuardRailPolicyConfig + +# Load policy +policy = GuardRailPolicyConfig.from_file("policy.json") + +# Execute with policy - LLM gets masked, tools get real data +result = executor.execute(workflow, policy=policy) +``` + +## Testing Without OpenAI + +These examples use OpenAI (GPT-4o-mini) if `OPENAI_API_KEY` is set, otherwise fallback to Ollama. + +**Using Ollama**: +```bash +# Ensure Ollama is running +ollama pull llama3.2 +ollama serve + +# In another terminal +.venv/bin/python examples/guardrail_financial/run_guardrail_financial.py +``` + +**Using OpenAI**: +```bash +export OPENAI_API_KEY="sk-..." +.venv/bin/python examples/guardrail_financial/run_guardrail_financial.py +``` + +## Debug Output + +Run with debug logging to see GuardRail in action: +- `Guardrail: encoding prompt before LLM` - Shows data being masked +- `Guardrail: decoding tool call parameters` - Shows unmasking at tool boundary +- `[Tool received]` - Shows the real unmasked values tools receive + +Look for masked tokens in logs like: +- `` +- `` +- `` + +## Creating Your Own Example + +Follow this pattern: + +1. **Define a policy** (JSON with PII patterns) +2. **Create tools** (Python functions marked with `@tool`) +3. **Build workflow** (Node.agent with tools parameter) +4. **Execute with policy** (executor.execute(workflow, policy=policy)) +5. **Verify** (check tool received real values in debug output) diff --git a/examples/guardrail_financial/guardrail_financial_policy.json b/examples/guardrail_financial/guardrail_financial_policy.json new file mode 100644 index 00000000..46030400 --- /dev/null +++ b/examples/guardrail_financial/guardrail_financial_policy.json @@ -0,0 +1,26 @@ +{ + "policy_name": "financial_pii_policy", + "policy_version": "1.0.0", + "active": true, + "guardrail_policy": { + "pii_rules": [ + { + "type": "regex", + "name": "CREDIT_CARD", + "pattern": "\\d{4}-\\d{4}-\\d{4}-\\d{4}" + }, + { + "type": "regex", + "name": "EMAIL", + "pattern": "[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}" + }, + { + "type": "regex", + "name": "SSN", + "pattern": "\\d{3}-\\d{2}-\\d{4}" + } + ], + "masking": true, + "mapping": true + } +} diff --git a/examples/guardrail_financial/run_guardrail_financial.py b/examples/guardrail_financial/run_guardrail_financial.py new file mode 100644 index 00000000..3dc05984 --- /dev/null +++ b/examples/guardrail_financial/run_guardrail_financial.py @@ -0,0 +1,153 @@ +""" +GuardRail Financial Example: Secure Payment Processing with PII Masking + +This example demonstrates how GuardRail masks sensitive financial information +(credit card numbers, emails, SSNs) from the LLM while ensuring tools receive +the actual decoded values for proper processing. + +Pattern: +- LLM sees masked tokens (e.g., [CREDIT_CARD_1], [EMAIL_1]) +- Tools receive the real unmasked values for accurate computation +- GuardRail automatically handles encode/decode boundaries +""" + +import os +import sys + +import graphbit +from graphbit import ( + Executor, + GuardRailPolicyConfig, + LlmConfig, + Node, + Workflow, + tool, + init, +) + +# using our policy +POLICY_DIR = os.path.dirname(os.path.abspath(__file__)) +POLICY_PATH = os.path.join(POLICY_DIR, "guardrail_financial_policy.json") + + +@tool(_description="Validate a credit card number and return the last 4 digits safely. Input is a full CC number in format XXXX-XXXX-XXXX-XXXX.") +def validate_credit_card(card_number: str) -> str: + """ + Validate a credit card by checking digit count and return last 4 digits. + When GuardRail is enabled, this tool receives the DECODED (real) card number. + """ + digits_only = card_number.replace("-", "") + is_valid = len(digits_only) == 16 and digits_only.isdigit() + last_four = digits_only[-4:] if len(digits_only) >= 4 else "INVALID" + + print(f"[Tool received] card_number = {card_number!r}") + print(f"[Tool validates] Valid: {is_valid}, Last 4 digits: {last_four}") + + return f"Card valid: {is_valid}, Last 4: {last_four}" + + +@tool(_description="Send a payment confirmation email. Recipient email must be in format user@domain.com") +def send_payment_confirmation(recipient_email: str, amount: str) -> str: + """ + Send payment confirmation to recipient. + When GuardRail is enabled, this tool receives the DECODED (real) email address. + """ + print(f"[Tool received] recipient_email = {recipient_email!r}, amount = {amount!r}") + + # Simulate email validation + if "@" in recipient_email and "." in recipient_email.split("@")[1]: + result = f"Confirmation email sent to {recipient_email} for amount ${amount}" + print(f"[Tool result] {result}") + return result + else: + return "Error: Invalid email format" + + +@tool(_description="Calculate transaction fee based on amount. Amount should be a number like 100.50") +def calculate_transaction_fee(amount: str) -> str: + """ + Calculate the transaction fee (2% of amount). + """ + try: + amt = float(amount) + fee = amt * 0.02 + print(f"[Tool received] amount = {amount!r} -> fee = ${fee:.2f}") + return f"Transaction fee: ${fee:.2f}" + except ValueError: + return "Error: Amount must be a number" + + +def main(): + init(enable_tracing=True, log_level="debug") + + print("=" * 80) + print("GuardRail Financial Example: Secure Payment Processing") + print("=" * 80) + print("\nScenario: Process a payment with sensitive information") + print("- LLM should NOT see real credit card, email, or SSN") + print("- Tools MUST receive actual values for proper processing\n") + + # Prefer OpenAI for reliable tool calling; fallback to Ollama + llm_config = LlmConfig.openai(os.getenv("OPENAI_API_KEY"), "gpt-4o-mini") + + executor = Executor(llm_config) + + # Load policy that masks credit cards, emails, and SSNs + if not os.path.isfile(POLICY_PATH): + print(f"Policy file not found: {POLICY_PATH}") + sys.exit(1) + + policy = GuardRailPolicyConfig.from_file(POLICY_PATH) + print(f"Loaded policy: {policy.policy_name()}, active={policy.is_active()}\n") + + workflow = Workflow("Secure Payment Processing [GuardRail]") + + payment_agent = Node.agent( + name="Payment Processor", + prompt=( + "I want to send money using my bank with this credit card and sender info:\n" + "Credit Card: 4532-1234-5678-9010\n" + "Recipient Email: customer@example.com\n" + "Amount: 250.00\n" + "Customer SSN (for verification): 123-45-6789\n" + + ), + system_prompt=( + "You are a secure payment processor. Use the available tools to validate " + "payments, calculate fees, and send confirmations. Always use the tools " + "with the exact information provided." + ), + tools=[validate_credit_card, calculate_transaction_fee, send_payment_confirmation], + ) + + workflow.add_node(payment_agent) + workflow.validate() + + # Execute WITH policy: LLM sees masked data; tools receive decoded data + print("Executing workflow with GuardRail policy...\n") + result = executor.execute(workflow, policy=policy) + + print("\n" + "=" * 80) + print("--- Result ---") + print("=" * 80) + + if result.is_success(): + out = result.get_node_output("Payment Processor") + print(f"\nAgent Output:\n{out}") + print("\n" + "-" * 80) + print("Verification Notes:") + print("- Above '[Tool received]' entries should show REAL values (unmasked)") + print(" - Credit card: 4532-1234-5678-9010") + print(" - Email: customer@example.com") + print(" - SSN: 123-45-6789") + print("\n- In debug logs you should see:") + print(" - 'Guardrail: encoding prompt before LLM'") + print(" - 'Guardrail: decoding tool call parameters before execution'") + print(" - Masked tokens like [CREDIT_CARD_1], [EMAIL_1], [SSN_1]") + print("-" * 80) + else: + print(f"Workflow failed: {result.state()}") + + +if __name__ == "__main__": + main() diff --git a/examples/guardrail_phone/guardrail_phone_policy.json b/examples/guardrail_phone/guardrail_phone_policy.json new file mode 100644 index 00000000..0831bb2f --- /dev/null +++ b/examples/guardrail_phone/guardrail_phone_policy.json @@ -0,0 +1,16 @@ +{ + "policy_name": "phone_mask_policy", + "policy_version": "1.0.0", + "active": true, + "guardrail_policy": { + "pii_rules": [ + { + "type": "regex", + "name": "PHONE_NUMBER", + "pattern": "\\d{3}-\\d{4}" + } + ], + "masking": true, + "mapping": true + } +} diff --git a/examples/guardrail_phone/run_guardrail_phone.py b/examples/guardrail_phone/run_guardrail_phone.py new file mode 100644 index 00000000..8b49a1d3 --- /dev/null +++ b/examples/guardrail_phone/run_guardrail_phone.py @@ -0,0 +1,83 @@ +import os +import sys + +import graphbit +from graphbit import ( + Executor, + GuardRailPolicyConfig, + LlmConfig, + Node, + Workflow, + tool, + init, +) + +# using our policy +POLICY_DIR = os.path.dirname(os.path.abspath(__file__)) +POLICY_PATH = os.path.join(POLICY_DIR, "guardrail_phone_policy.json") + + +@tool(_description="Sum all digits in a phone number. Input is a string that may look like 555-5555 or a token.") +def sum_digits_in_phone(phone_number: str) -> str: + """ + Compute the sum of all digits in the given phone number. + When GuardRail is enabled, this tool receives the DECODED (real) number so it can compute correctly. + """ + digits = [int(c) for c in phone_number if c.isdigit()] + total = sum(digits) + print(f"[Tool received] phone_number = {phone_number!r} -> sum of digits = {total}") + return str(total) + + +def main(): + init(enable_tracing=True, log_level="debug") + + print("GuardRail phone example: LLM should never see 123-4567; tool should always receive it.\n") + + # Prefer OpenAI for reliable tool calling; fallback to Ollama + if os.getenv("OPENAI_API_KEY"): + llm_config = LlmConfig.openai(os.getenv("OPENAI_API_KEY"), "gpt-5") + else: + print("No OPENAI_API_KEY set. Using Ollama (ollama run llama3.2).") + llm_config = LlmConfig.ollama("llama3.2") + + executor = Executor(llm_config) + + # Load policy that masks 123-4567-style numbers + if not os.path.isfile(POLICY_PATH): + print(f"Policy file not found: {POLICY_PATH}") + sys.exit(1) + policy = GuardRailPolicyConfig.from_file(POLICY_PATH) + print(f"Loaded policy: {policy.policy_name()}, active={policy.is_active()}\n") + + workflow = Workflow("Phone digits sum [GuardRail]") + agent = Node.agent( + name="Phone Agent", + prompt=( + "The user's phone number is 123-4567. " + "Use the sum_digits_in_phone tool to compute the sum of all digits in that phone number, " + "then reply with the result." + ), + system_prompt="You have a tool to sum digits in a phone number. Use it when asked.", + tools=[sum_digits_in_phone], + max_tokens=1000, + ) + workflow.add_node(agent) + workflow.validate() + + # Execute WITH policy: LLM sees masked data; tool receives decoded data + result = executor.execute(workflow, policy=policy) + + print("\n--- Result ---") + + if result.is_success(): + out = result.get_node_output("Phone Agent") + print(f"Agent output: {out}") + print("\nVerify: above '[Tool received]' should show phone_number = '123-4567' (decoded).") + print("In debug logs you should see 'Guardrail: encoding prompt before LLM' and 'decoding tool call parameters'.") + else: + print(f"Workflow failed: {result.state()}") + print("\n--- Node Response Metadata ---") + print(result.get_all_node_response_metadata()) +if __name__ == "__main__": + main() diff --git a/guardrail_ffi/Cargo.toml b/guardrail_ffi/Cargo.toml new file mode 100644 index 00000000..12048ee4 --- /dev/null +++ b/guardrail_ffi/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "guardrail_ffi" +version = "0.1.0" +edition = "2024" +description = "GraphBit wrapper for prebuilt GuardRail C ABI (libguardrail_ffi.a)" + +[dependencies] +serde_json = "1.0" +reqwest = { version = "0.13", default-features = false, features = ["json", "rustls", "blocking"] } + +[lints.rust] +unsafe_code = "allow" diff --git a/guardrail_ffi/build.rs b/guardrail_ffi/build.rs new file mode 100644 index 00000000..2032e11c --- /dev/null +++ b/guardrail_ffi/build.rs @@ -0,0 +1,26 @@ +//! Link the prebuilt GuardRail library: static on Unix (libguardrail_ffi.a), import lib on Windows (guardrail_ffi.lib for guardrail_ffi.dll). + +use std::env; +use std::path::Path; + +fn main() { + let manifest_dir = env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR"); + let lib_dir = env::var("GUARDRAIL_LIB_DIR") + .unwrap_or_else(|_| Path::new(&manifest_dir).join("../vendor/guardrail").to_string_lossy().into_owned()); + + let lib_path = Path::new(&lib_dir); + if !lib_path.exists() { + eprintln!("cargo:warning=GuardRail lib dir not found: {} (set GUARDRAIL_LIB_DIR or add vendor/guardrail/)", lib_dir); + return; + } + + println!("cargo:rustc-link-search=native={}", lib_path.display()); + + let target_os = env::var("CARGO_CFG_TARGET_OS").unwrap_or_default(); + if target_os == "windows" { + // Link the DLL's import library; guardrail_ffi.dll must be shipped next to the .pyd (see workflow / python-src). + println!("cargo:rustc-link-lib=dylib=guardrail_ffi"); + } else { + println!("cargo:rustc-link-lib=static=guardrail_ffi"); + } +} diff --git a/guardrail_ffi/src/lib.rs b/guardrail_ffi/src/lib.rs new file mode 100644 index 00000000..652a1f3a --- /dev/null +++ b/guardrail_ffi/src/lib.rs @@ -0,0 +1,397 @@ +//! GuardRail FFI wrapper — links the prebuilt `libguardrail_ffi.a` only (no guardrail source). + +use std::ffi::{CStr, CString}; +use std::os::raw::{c_char, c_uint}; +use std::path::Path; +use std::sync::Arc; + +unsafe extern "C" { + fn guardrail_config_from_json(json_ptr: *const c_char, json_len: usize) -> *mut std::ffi::c_void; + fn guardrail_config_default() -> *mut std::ffi::c_void; + fn guardrail_config_clone(handle: *mut std::ffi::c_void) -> *mut std::ffi::c_void; + fn guardrail_config_drop(handle: *mut std::ffi::c_void); + fn guardrail_config_policy_name(handle: *mut std::ffi::c_void) -> *mut c_char; + fn guardrail_config_policy_version(handle: *mut std::ffi::c_void) -> *mut c_char; + fn guardrail_config_active(handle: *mut std::ffi::c_void) -> bool; + + fn guardrail_enforcer_create( + config_handle: *mut std::ffi::c_void, + workflow_id_ptr: *const c_char, + workflow_id_len: usize, + ) -> *mut std::ffi::c_void; + fn guardrail_enforcer_drop(handle: *mut std::ffi::c_void); + + fn guardrail_encode( + enforcer_handle: *mut std::ffi::c_void, + json_ptr: *const c_char, + json_len: usize, + encode_context: c_uint, + ) -> *mut c_char; + fn guardrail_decode( + enforcer_handle: *mut std::ffi::c_void, + json_ptr: *const c_char, + json_len: usize, + context: c_uint, + ) -> *mut c_char; + fn guardrail_free(ptr: *mut c_char); +} + +const CONTEXT_TOOL_BOUNDARY: c_uint = 0; +const CONTEXT_LLM_RESPONSE: c_uint = 1; +const CONTEXT_MANUAL_CALL: c_uint = 2; + +/// Encode context: Llm adds signature and instruction text; Manual does not. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum EncodeContext { + /// No signature or instruction (e.g. manual / logging). + Manual, + /// Add 3-digit signature to tokens and prepend instruction text for LLM. + Llm, +} + +const ENCODE_CONTEXT_MANUAL: c_uint = 0; +const ENCODE_CONTEXT_LLM: c_uint = 1; + +/// Decode context for rehydration. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DecodeContext { + /// Rehydrate at tool boundary so the tool receives real PII. + ToolBoundary, + /// Rehydrate LLM output for context. + LlmResponse, + /// Explicit host decode. + ManualCall, +} + +/// Opaque config handle (refcounted via clone/drop). +pub struct GuardRailConfigInner { + pub(crate) handle: *mut std::ffi::c_void, +} + +impl Drop for GuardRailConfigInner { + fn drop(&mut self) { + if !self.handle.is_null() { + unsafe { guardrail_config_drop(self.handle) }; + self.handle = std::ptr::null_mut(); + } + } +} + +impl Clone for GuardRailConfigInner { + fn clone(&self) -> Self { + let handle = if self.handle.is_null() { + std::ptr::null_mut() + } else { + unsafe { guardrail_config_clone(self.handle) } + }; + Self { handle } + } +} + +// Opaque FFI handle: safe to Send/Sync as the C library manages thread safety. +unsafe impl Send for GuardRailConfigInner {} +unsafe impl Sync for GuardRailConfigInner {} + +/// GuardRail policy configuration. +#[derive(Clone)] +pub struct GuardRailConfig { + pub(crate) inner: Arc, +} + +impl GuardRailConfig { + /// Load from JSON string. + pub fn new(json: &str) -> Result { + let c_str = CString::new(json).map_err(|e| e.to_string())?; + let handle = unsafe { guardrail_config_from_json(c_str.as_ptr(), c_str.as_bytes().len()) }; + if handle.is_null() { + return Err("GuardRail config from_json failed".into()); + } + Ok(Self { + inner: Arc::new(GuardRailConfigInner { handle }), + }) + } + + /// Load from file. + pub fn from_file(path: &Path) -> Result { + let json = std::fs::read_to_string(path).map_err(|e| e.to_string())?; + Self::new(&json) + } + + /// Load from URL (blocking GET). + pub fn from_url(url: &str) -> Result { + let client = reqwest::blocking::Client::new(); + let json = client + .get(url) + .send() + .map_err(|e| e.to_string())? + .text() + .map_err(|e| e.to_string())?; + Self::new(&json) + } + + /// Default inactive config. + pub fn default_config() -> Self { + let handle = unsafe { guardrail_config_default() }; + assert!(!handle.is_null(), "guardrail_config_default failed"); + Self { + inner: Arc::new(GuardRailConfigInner { handle }), + } + } + + pub(crate) fn ptr(&self) -> *mut std::ffi::c_void { + self.inner.handle + } + + /// Policy name. + pub fn policy_name(&self) -> String { + if self.ptr().is_null() { + return String::new(); + } + let p = unsafe { guardrail_config_policy_name(self.ptr()) }; + if p.is_null() { + return String::new(); + } + let s = unsafe { CStr::from_ptr(p).to_string_lossy().into_owned() }; + unsafe { guardrail_free(p) }; + s + } + + /// Policy version. + pub fn policy_version(&self) -> String { + if self.ptr().is_null() { + return String::new(); + } + let p = unsafe { guardrail_config_policy_version(self.ptr()) }; + if p.is_null() { + return String::new(); + } + let s = unsafe { CStr::from_ptr(p).to_string_lossy().into_owned() }; + unsafe { guardrail_free(p) }; + s + } + + /// Whether the policy is active. + pub fn active(&self) -> bool { + if self.ptr().is_null() { + return false; + } + unsafe { guardrail_config_active(self.ptr()) } + } +} + +/// Enforcer for one workflow (encode/decode). +pub struct Enforcer { + pub(crate) handle: *mut std::ffi::c_void, +} + +impl Drop for Enforcer { + fn drop(&mut self) { + if !self.handle.is_null() { + unsafe { guardrail_enforcer_drop(self.handle) }; + self.handle = std::ptr::null_mut(); + } + } +} + +impl std::fmt::Debug for Enforcer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Enforcer").finish_non_exhaustive() + } +} + +unsafe impl Send for Enforcer {} +unsafe impl Sync for Enforcer {} + +/// Result of encode: payload (masked only) plus optional injection text and metadata. +#[derive(Debug, Clone)] +pub struct EncodeResult { + pub payload: serde_json::Value, + /// Rule text to prepend when sending to LLM; empty when not applicable. Caller concatenates with payload. + pub signature_injection_text: String, + pub rules_applied_count: u32, + pub rule_names: Vec, + pub policy_name: String, +} + +/// Result of decode: payload plus metadata. +#[derive(Debug, Clone)] +pub struct DecodeResult { + pub payload: serde_json::Value, + pub rules_applied_count: u32, + pub rule_names: Vec, + pub policy_name: String, +} + +impl Enforcer { + /// Encode payload (mask PII). When context is Llm, tokens get a 3-digit signature and instruction text is prepended. + pub fn encode(&self, payload: serde_json::Value, context: EncodeContext) -> EncodeResult { + let default_result = EncodeResult { + payload: payload.clone(), + signature_injection_text: String::new(), + rules_applied_count: 0, + rule_names: Vec::new(), + policy_name: String::new(), + }; + if self.handle.is_null() { + return default_result; + } + let json = match serde_json::to_string(&payload) { + Ok(s) => s, + Err(_) => return default_result, + }; + let c_str = match CString::new(json.as_bytes()) { + Ok(c) => c, + Err(_) => return default_result, + }; + let enc_ctx = match context { + EncodeContext::Llm => ENCODE_CONTEXT_LLM, + EncodeContext::Manual => ENCODE_CONTEXT_MANUAL, + }; + let out = unsafe { + guardrail_encode( + self.handle, + c_str.as_ptr(), + c_str.as_bytes().len(), + enc_ctx, + ) + }; + if out.is_null() { + return default_result; + } + let out_slice = unsafe { CStr::from_ptr(out).to_bytes() }; + let out_str = String::from_utf8_lossy(out_slice).into_owned(); + unsafe { guardrail_free(out) }; + parse_encode_result(&out_str) + .or_else(|| parse_encode_result_legacy(&out_str)) + .unwrap_or(default_result) + } + + /// Decode payload (rehydrate PII). + pub fn decode(&self, payload: serde_json::Value, context: DecodeContext) -> DecodeResult { + let default_result = DecodeResult { + payload: payload.clone(), + rules_applied_count: 0, + rule_names: Vec::new(), + policy_name: String::new(), + }; + if self.handle.is_null() { + return default_result; + } + let json = match serde_json::to_string(&payload) { + Ok(s) => s, + Err(_) => return default_result, + }; + let c_str = match CString::new(json.as_bytes()) { + Ok(c) => c, + Err(_) => return default_result, + }; + let ctx = match context { + DecodeContext::ToolBoundary => CONTEXT_TOOL_BOUNDARY, + DecodeContext::LlmResponse => CONTEXT_LLM_RESPONSE, + DecodeContext::ManualCall => CONTEXT_MANUAL_CALL, + }; + let out = + unsafe { guardrail_decode(self.handle, c_str.as_ptr(), c_str.as_bytes().len(), ctx) }; + if out.is_null() { + return default_result; + } + let out_slice = unsafe { CStr::from_ptr(out).to_bytes() }; + let out_str = String::from_utf8_lossy(out_slice).into_owned(); + unsafe { guardrail_free(out) }; + parse_decode_result(&out_str) + .or_else(|| parse_decode_result_legacy(&out_str)) + .unwrap_or(default_result) + } +} + +/// Legacy FFI return: raw encoded payload as JSON (old guardrail_encode returned Value serialized). +fn parse_encode_result_legacy(s: &str) -> Option { + let payload: serde_json::Value = serde_json::from_str(s).ok()?; + Some(EncodeResult { + payload, + signature_injection_text: String::new(), + rules_applied_count: 0, + rule_names: Vec::new(), + policy_name: String::new(), + }) +} + +/// Legacy FFI return: raw decoded payload as JSON (old guardrail_decode returned Value serialized). +fn parse_decode_result_legacy(s: &str) -> Option { + let payload: serde_json::Value = serde_json::from_str(s).ok()?; + Some(DecodeResult { + payload, + rules_applied_count: 0, + rule_names: Vec::new(), + policy_name: String::new(), + }) +} + +fn parse_encode_result(s: &str) -> Option { + let v: serde_json::Value = serde_json::from_str(s).ok()?; + let payload = v.get("payload")?.clone(); + let signature_injection_text = v + .get("signature_injection_text") + .and_then(|x| x.as_str()) + .unwrap_or("") + .to_string(); + let rules_applied_count = v.get("rules_applied_count")?.as_u64()? as u32; + let rule_names: Vec = v + .get("rule_names")? + .as_array()? + .iter() + .filter_map(|x| x.as_str().map(String::from)) + .collect(); + let policy_name = v.get("policy_name")?.as_str()?.to_string(); + Some(EncodeResult { + payload, + signature_injection_text, + rules_applied_count, + rule_names, + policy_name, + }) +} + +fn parse_decode_result(s: &str) -> Option { + let v: serde_json::Value = serde_json::from_str(s).ok()?; + let payload = v.get("payload")?.clone(); + let rules_applied_count = v.get("rules_applied_count")?.as_u64()? as u32; + let rule_names: Vec = v + .get("rule_names")? + .as_array()? + .iter() + .filter_map(|x| x.as_str().map(String::from)) + .collect(); + let policy_name = v.get("policy_name")?.as_str()?.to_string(); + Some(DecodeResult { + payload, + rules_applied_count, + rule_names, + policy_name, + }) +} + +/// Singleton entry point to create enforcers. +#[derive(Debug, Clone)] +pub struct GuardRail; + +impl GuardRail { + /// Initialize (no-op for FFI; state is inside the lib). + #[must_use] + pub fn init() -> Self { + Self + } + + /// Create an enforcer for this workflow. + pub fn enforcer_for(config: Arc, workflow_id: impl Into) -> Enforcer { + let workflow_id = workflow_id.into(); + let (ptr, len) = if workflow_id.is_empty() { + (std::ptr::null(), 0) + } else { + (workflow_id.as_ptr() as *const c_char, workflow_id.len()) + }; + let handle = unsafe { guardrail_enforcer_create(config.ptr(), ptr, len) }; + assert!(!handle.is_null(), "guardrail_enforcer_create failed"); + Enforcer { handle } + } +} diff --git a/python/python-src/graphbit/__init__.py b/python/python-src/graphbit/__init__.py index 0cb9cbd8..fa08787a 100644 --- a/python/python-src/graphbit/__init__.py +++ b/python/python-src/graphbit/__init__.py @@ -10,7 +10,6 @@ configure_runtime, shutdown, ) - # Document loader classes from .graphbit import ( DocumentLoaderConfig, @@ -30,6 +29,7 @@ # Workflow classes from .graphbit import ( + GuardRailPolicyConfig, Node, Workflow, WorkflowContext, @@ -99,6 +99,8 @@ "FinishReason", "LlmToolCall", "LlmResponse", + # GuardRail + "GuardRailPolicyConfig", # Workflow "Node", "Workflow", diff --git a/python/src/guardrail.rs b/python/src/guardrail.rs new file mode 100644 index 00000000..2706f933 --- /dev/null +++ b/python/src/guardrail.rs @@ -0,0 +1,90 @@ +//! GuardRail policy config exposed to Python as `GuardRailPolicyConfig`. +//! Used with `Executor.execute(workflow, policy=...)` for PII masking/mapping. + +use graphbit_core::GuardRailConfig; +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; +use std::sync::Arc; + +/// Python-facing guardrail policy configuration. +/// +/// Create via `GuardRailPolicyConfig.from_json(...)`, `from_file(...)`, or `from_url(...)`, +/// then pass to `executor.execute(workflow, policy=config)`. +#[pyclass] +#[derive(Clone)] +pub struct GuardRailPolicyConfig { + pub(crate) inner: Arc, +} + +#[pymethods] +impl GuardRailPolicyConfig { + /// Create a config from a JSON string. + /// + /// # Errors + /// Raises `ValueError` if the JSON is invalid or validation fails. + #[staticmethod] + pub fn from_json(json_str: &str) -> PyResult { + let config = GuardRailConfig::new(json_str) + .map_err(|e| PyValueError::new_err(format!("GuardRail config error: {}", e)))?; + Ok(Self { + inner: Arc::new(config), + }) + } + + /// Create a config from a local file path. + /// + /// # Errors + /// Raises `ValueError` if the file cannot be read or validation fails. + #[staticmethod] + pub fn from_file(path: &str) -> PyResult { + let config = GuardRailConfig::from_file(std::path::Path::new(path)) + .map_err(|e| PyValueError::new_err(format!("GuardRail config error: {}", e)))?; + Ok(Self { + inner: Arc::new(config), + }) + } + + /// Create a config from a remote URL (HTTP GET). + /// + /// # Errors + /// Raises `ValueError` if the URL cannot be fetched or validation fails. + #[staticmethod] + pub fn from_url(url: &str) -> PyResult { + let config = GuardRailConfig::from_url(url) + .map_err(|e| PyValueError::new_err(format!("GuardRail config error: {}", e)))?; + Ok(Self { + inner: Arc::new(config), + }) + } + + /// Default (inactive) policy — no masking or mapping. + #[staticmethod] + pub fn default_config() -> Self { + Self { + inner: Arc::new(GuardRailConfig::default_config()), + } + } + + /// Policy name. + pub fn policy_name(&self) -> String { + self.inner.policy_name() + } + + /// Policy version. + pub fn policy_version(&self) -> String { + self.inner.policy_version() + } + + /// Whether the policy is active. + pub fn is_active(&self) -> bool { + self.inner.active() + } +} + +impl GuardRailPolicyConfig { + /// Return the inner config for use by the executor (same-crate only). Not exposed to Python. + #[inline] + pub(crate) fn get_inner(&self) -> Arc { + Arc::clone(&self.inner) + } +} diff --git a/python/src/lib.rs b/python/src/lib.rs index 458d7e38..23ce3995 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -33,6 +33,7 @@ use tracing::{error, info, warn}; mod document_loader; mod embeddings; mod errors; +mod guardrail; mod llm; mod runtime; mod text_splitter; @@ -49,6 +50,7 @@ pub use text_splitter::{ TokenSplitter, }; pub use tools::{ToolDecorator, ToolExecutor, ToolRegistry, ToolResult}; +pub use guardrail::GuardRailPolicyConfig; pub use workflow::{Executor, Node, Workflow, WorkflowContext, WorkflowResult}; /// Global initialization flag to ensure init is called only once @@ -386,6 +388,9 @@ fn graphbit(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + // GuardRail policy config (optional for executor.execute(workflow, policy=...)) + m.add_class::()?; + // Workflow classes m.add_class::()?; m.add_class::()?; diff --git a/python/src/workflow/executor.rs b/python/src/workflow/executor.rs index 7997531d..6647df32 100644 --- a/python/src/workflow/executor.rs +++ b/python/src/workflow/executor.rs @@ -8,13 +8,16 @@ //! - Graceful error handling and recovery use graphbit_core::workflow::WorkflowExecutor as CoreWorkflowExecutor; +use graphbit_core::{DecodeContext, EncodeContext, Enforcer, GuardRail}; use pyo3::prelude::*; use pyo3::types::{PyAny, PyDict}; +use std::sync::Arc; use std::time::{Duration, Instant}; use tracing::{debug, error, info, instrument, warn}; use super::{result::WorkflowResult, workflow::Workflow}; use crate::errors::{timeout_error, to_py_runtime_error, validation_error}; +use crate::guardrail::GuardRailPolicyConfig; use crate::llm::config::LlmConfig; use crate::runtime::get_runtime; @@ -90,10 +93,11 @@ pub struct Executor { #[pymethods] impl Executor { #[new] - #[pyo3(signature = (config, _lightweight_mode=None, timeout_seconds=None, debug=None))] + #[pyo3(signature = (config, lightweight_mode=None, timeout_seconds=None, debug=None))] + #[allow(unused_variables)] fn new( config: LlmConfig, - _lightweight_mode: Option, + lightweight_mode: Option, timeout_seconds: Option, debug: Option, ) -> PyResult { @@ -132,9 +136,18 @@ impl Executor { }) } - /// Execute a workflow with comprehensive error handling and monitoring - #[instrument(skip(self, py, workflow), fields(workflow_name = %workflow.inner.name))] - fn execute(&mut self, py: Python<'_>, workflow: &Workflow) -> PyResult { + /// Execute a workflow with comprehensive error handling and monitoring. + /// + /// `policy` is optional. When provided: encode before every LLM call, decode after every LLM call; + /// before tool usage decode (so tools see real PII); after tool usage do nothing (no encode). + #[instrument(skip(self, py, workflow, policy), fields(workflow_name = %workflow.inner.name))] + #[pyo3(signature = (workflow, policy=None))] + fn execute( + &mut self, + py: Python<'_>, + workflow: &Workflow, + policy: Option<&Bound<'_, GuardRailPolicyConfig>>, + ) -> PyResult { let start_time = Instant::now(); // Validate workflow @@ -161,6 +174,15 @@ impl Executor { let timeout_duration = config.timeout; let debug = config.enable_tracing; // Capture debug flag + // Build optional guardrail enforcer from policy (for encode/decode at LLM and tool boundaries) + let guardrail_enforcer = policy.map(|p| { + let config = p.borrow().get_inner(); + Arc::new(GuardRail::enforcer_for( + config, + workflow_clone.id.to_string(), + )) + }); + if debug { debug!("Starting workflow execution with mode: {:?}", config.mode); } @@ -171,7 +193,13 @@ impl Executor { get_runtime().block_on(async move { // Apply timeout to the entire execution tokio::time::timeout(timeout_duration, async move { - Self::execute_workflow_internal(llm_config, workflow_clone, config).await + Self::execute_workflow_internal( + llm_config, + workflow_clone, + config, + guardrail_enforcer, + ) + .await }) .await }) @@ -210,8 +238,14 @@ impl Executor { } /// Async execution with enhanced performance optimizations - #[instrument(skip(self, workflow, py), fields(workflow_name = %workflow.inner.name))] - fn run_async<'a>(&mut self, workflow: &Workflow, py: Python<'a>) -> PyResult> { + #[instrument(skip(self, workflow, py, policy), fields(workflow_name = %workflow.inner.name))] + #[pyo3(signature = (workflow, policy=None))] + fn run_async<'a>( + &mut self, + workflow: &Workflow, + py: Python<'a>, + policy: Option<&Bound<'_, GuardRailPolicyConfig>>, + ) -> PyResult> { // Validate workflow if let Err(e) = workflow.inner.validate() { return Err(validation_error( @@ -226,7 +260,14 @@ impl Executor { let config = self.config.clone(); let timeout_duration = config.timeout; let start_time = Instant::now(); - let debug = config.enable_tracing; // Capture debug flag + let debug = config.enable_tracing; + let guardrail_enforcer = policy.map(|p| { + let config = p.borrow().get_inner(); + Arc::new(GuardRail::enforcer_for( + config, + workflow_clone.id.to_string(), + )) + }); if debug { debug!( @@ -236,9 +277,14 @@ impl Executor { } pyo3_async_runtimes::tokio::future_into_py(py, async move { - // Apply timeout to the entire execution let result = tokio::time::timeout(timeout_duration, async move { - Self::execute_workflow_internal(llm_config, workflow_clone, config).await + Self::execute_workflow_internal( + llm_config, + workflow_clone, + config, + guardrail_enforcer, + ) + .await }) .await; @@ -381,19 +427,25 @@ impl Executor { } impl Executor { - /// Internal workflow execution with mode-specific optimizations and tool call handling + /// Internal workflow execution with mode-specific optimizations and tool call handling. + /// When `guardrail_enforcer` is `Some`, the core encodes before LLM and decodes after LLM; + /// we decode before tool usage only (no encode after tool). async fn execute_workflow_internal( llm_config: graphbit_core::llm::LlmConfig, workflow: graphbit_core::workflow::Workflow, config: ExecutionConfig, + guardrail_enforcer: Option>, ) -> Result { let executor = match config.mode { - ExecutionMode::Balanced => CoreWorkflowExecutor::new() - .with_default_llm_config(llm_config.clone()), + ExecutionMode::Balanced => { + CoreWorkflowExecutor::new().with_default_llm_config(llm_config.clone()) + } }; - // Execute the workflow - let mut context = executor.execute(workflow.clone()).await?; + // Execute the workflow (core applies encode before LLM, decode after LLM when enforcer is Some) + let mut context = executor + .execute(workflow.clone(), guardrail_enforcer.clone()) + .await?; // Store LLM config in context metadata for tool call handling if let Ok(llm_config_json) = serde_json::to_value(&llm_config) { @@ -402,18 +454,33 @@ impl Executor { .insert("llm_config".to_string(), llm_config_json); } + // Store workflow name in context metadata for result schema + context.metadata.insert( + "workflow_name".to_string(), + serde_json::Value::String(workflow.name.clone()), + ); + // Check if any node outputs contain tool_calls_required responses and handle them - context = Self::handle_tool_calls_in_context(context, &workflow).await?; + context = Self::handle_tool_calls_in_context( + context, + &workflow, + guardrail_enforcer.as_ref().map(|arc| arc.as_ref()), + ) + .await?; Ok(context) } - /// Handle tool calls in workflow context by executing them and updating the context + /// Handle tool calls in workflow context by executing them and updating the context. + /// When `guardrail_enforcer` is `Some`, decodes tool-call parameters before execution only; + /// after tool execution we do nothing (no encode of tool results). async fn handle_tool_calls_in_context( mut context: graphbit_core::types::WorkflowContext, workflow: &graphbit_core::workflow::Workflow, + guardrail_enforcer: Option<&Enforcer>, ) -> Result { use crate::workflow::node::execute_production_tool_calls; + use graphbit_core::DecodeContext; use graphbit_core::llm::{LlmProvider, LlmRequest}; // Check each node output for tool_calls_required responses @@ -447,31 +514,48 @@ impl Executor { }) .unwrap_or_default(); - // Convert tool calls to the format expected by Python layer - let python_tool_calls: Vec = - if let Some(tool_calls_array) = tool_calls.as_array() { - tool_calls_array + // Convert tool calls to the format expected by Python layer. + // Guardrail: decode parameters before tool execution so tools see real PII. + if guardrail_enforcer.is_some() { + tracing::debug!( + "[GuardRail] tool call parameters from LLM (before decode): {:?}", + tool_calls + ); + } + let python_tool_calls: Vec = if let Some( + tool_calls_array, + ) = + tool_calls.as_array() + { + tool_calls_array .iter() .map(|tc| { - // Extract name and parameters from the tool call object let name = tc .get("name") .and_then(|v| v.as_str()) .unwrap_or("unknown"); - let parameters = tc + let mut parameters = tc .get("parameters") .cloned() .unwrap_or(serde_json::json!({})); - + if let Some(enforcer) = guardrail_enforcer { + tracing::debug!( + "Guardrail: decoding tool call parameters (tool boundary — tool will receive real PII)" + ); + let decoded_result = + enforcer.decode(parameters, DecodeContext::ToolBoundary); + parameters = decoded_result.payload; + } serde_json::json!({ + "id": tc.get("id").and_then(|v| v.as_str()).unwrap_or(""), "tool_name": name, "parameters": parameters }) }) .collect() - } else { - Vec::new() - }; + } else { + Vec::new() + }; let tool_calls_json = serde_json::to_string(&python_tool_calls) .map_err(|e| { @@ -524,12 +608,123 @@ impl Executor { summary_lines.join("\n") }; - // Create final prompt with tool results summary + // Guardrail: before tool we decode; after tool we do nothing (no encode of results). + let summary_for_llm = tool_results_summary.clone(); + + // --- Build tool_call execution entries --- + // Read existing node metadata (seeded by core execute_agent_with_tools) + let existing_node_metadata = context + .metadata + .get(&format!("node_response_{}", node.id)) + .cloned(); + + // Extract existing executions array from the seeded metadata + let mut executions: Vec = existing_node_metadata + .as_ref() + .and_then(|m| m.get("executions")) + .and_then(|e| e.as_array()) + .cloned() + .unwrap_or_default(); + + // Append a tool_call entry for each tool execution result + let mut tools_used: Vec = Vec::new(); + for (i, tc) in python_tool_calls.iter().enumerate() { + let tool_name = tc + .get("tool_name") + .and_then(|v| v.as_str()) + .unwrap_or("unknown") + .to_string(); + + let tool_result = tool_execution_results.get(i); + let success = tool_result + .and_then(|r| r.get("success").and_then(|v| v.as_bool())) + .unwrap_or(false); + let output = tool_result + .and_then(|r| r.get("output").and_then(|v| v.as_str())) + .unwrap_or("") + .to_string(); + let error = tool_result + .and_then(|r| r.get("error").and_then(|v| v.as_str())) + .map(|e| serde_json::Value::String(e.to_string())) + .unwrap_or(serde_json::Value::Null); + let start_time = tool_result + .and_then(|r| r.get("start_time")) + .cloned() + .unwrap_or(serde_json::Value::Null); + let end_time = tool_result + .and_then(|r| r.get("end_time")) + .cloned() + .unwrap_or(serde_json::Value::Null); + let latency_ms = tool_result + .and_then(|r| r.get("latency_ms")) + .cloned() + .unwrap_or(serde_json::json!(0.0)); + + if !tools_used.contains(&tool_name) { + tools_used.push(tool_name.clone()); + } + + executions.push(serde_json::json!({ + "type": "tool_call", + "id": tc.get("id").and_then(|v| v.as_str()).unwrap_or(""), + "tool_name": tool_name, + "parameters": tc.get("parameters").cloned().unwrap_or(serde_json::json!({})), + "output": output, + "success": success, + "error": error, + "start_time": start_time, + "end_time": end_time, + "latency_ms": latency_ms, + "retries": [] + })); + } + + // Build final prompt; when GuardRail is active encode it and debug-print. let final_prompt = format!( "{}\n\nTool execution results:\n{}\n\nPlease provide a comprehensive response based on the tool results.", - original_prompt, tool_results_summary + original_prompt, summary_for_llm ); + // Guardrail: encode final prompt before LLM call + let mut encoded_final_payload_for_meta = String::new(); + let prompt_for_final_llm = if let Some(ref enforcer) = + guardrail_enforcer + { + tracing::info!( + "[GuardRail] final prompt (before encode): {}", + final_prompt + ); + let encode_result = enforcer.encode( + serde_json::Value::String(final_prompt.clone()), + EncodeContext::Llm, + ); + + // Record guardrail encode execution entry + executions.push(serde_json::json!({ + "type": "guardrail_policy", + "operation": "encode", + "pii_rules_applied_count": encode_result.rules_applied_count, + "pii_rule_names": encode_result.rule_names, + "policy_name": encode_result.policy_name + })); + + // Capture encoded payload only (no RULE signature) for metadata + encoded_final_payload_for_meta = encode_result.payload.as_str().unwrap_or_default().to_string(); + + let encoded_str = format!( + "{}{}", + encode_result.signature_injection_text, + encode_result.payload.as_str().unwrap_or_default() + ); + tracing::info!( + "[GuardRail] final prompt (after encode, sent to LLM, payload only): {}", + encode_result.payload.as_str().unwrap_or_default() + ); + encoded_str + } else { + final_prompt.clone() + }; + // Get LLM provider from node configuration and make final call if let graphbit_core::graph::NodeType::Agent { .. } = &node.node_type @@ -551,10 +746,10 @@ impl Executor { ) { Ok(provider_trait) => { let llm_provider = - LlmProvider::new(provider_trait, llm_config); + LlmProvider::new(provider_trait, llm_config.clone()); - // Create final request and apply node configuration parameters - let mut final_request = LlmRequest::new(final_prompt.clone()); + // Create final request (with encoded prompt when GuardRail is on) + let mut final_request = LlmRequest::new(prompt_for_final_llm); // CUMULATIVE TOKEN BUDGET TRACKING // Extract initial tokens used and max_tokens to calculate remaining budget @@ -614,8 +809,20 @@ impl Executor { } } + // Measure final LLM call timing + let final_llm_timestamp = chrono::Utc::now(); + let final_llm_start = std::time::Instant::now(); + match llm_provider.complete(final_request).await { Ok(final_response) => { + let final_llm_duration_ms = final_llm_start.elapsed().as_secs_f64() * 1000.0; + let final_llm_end_timestamp = chrono::Utc::now(); + + tracing::info!( + "[GuardRail] final LLM response (GuardRail active={}); before decode: {}", + guardrail_enforcer.is_some(), + final_response.content + ); tracing::debug!( "Final LLM response received - content: '{}', tokens: {}, finish_reason: {:?}", final_response.content, @@ -623,100 +830,128 @@ impl Executor { final_response.finish_reason ); - // Clone the content to avoid borrow checker issues - let response_content = - final_response.content.clone(); - - // Store full LLM response metadata in context - // This enables observability tools to capture complete LLM metadata - // IMPORTANT: Preserve existing metadata fields (prompt, duration_ms, execution_timestamp, tool_calls) - if let Ok(mut response_metadata) = serde_json::to_value(&final_response) { - // Get existing metadata to preserve prompt, duration_ms, execution_timestamp, and tool_calls - let existing_metadata_by_id = context.metadata.get(&format!("node_response_{}", node.id)).cloned(); - - // Merge existing metadata fields into new metadata - if let (Some(existing), Some(response_obj)) = (existing_metadata_by_id.as_ref(), response_metadata.as_object_mut()) { - if let Some(existing_obj) = existing.as_object() { - // Preserve these critical fields from the initial LLM call - if let Some(prompt) = existing_obj.get("prompt") { - response_obj.insert("prompt".to_string(), prompt.clone()); - } - if let Some(duration_ms) = existing_obj.get("duration_ms") { - response_obj.insert("duration_ms".to_string(), duration_ms.clone()); - } - if let Some(execution_timestamp) = existing_obj.get("execution_timestamp") { - response_obj.insert("execution_timestamp".to_string(), execution_timestamp.clone()); - } + // Build the final llm_call execution entry + let final_llm_call_entry = serde_json::json!({ + "type": "llm_call", + "id": final_response.id.clone().unwrap_or_default(), + "model": final_response.model, + "provider": llm_config.provider_name(), + "input": if guardrail_enforcer.is_some() { encoded_final_payload_for_meta.clone() } else { final_prompt.clone() }, + "output": final_response.content, + "finish_reason": format!("{}", final_response.finish_reason), + "tool_calls": [], + "start_time": final_llm_timestamp.to_rfc3339(), + "end_time": final_llm_end_timestamp.to_rfc3339(), + "duration_ms": final_llm_duration_ms, + "usage": { + "prompt_tokens": final_response.usage.prompt_tokens, + "completion_tokens": final_response.usage.completion_tokens, + "total_tokens": final_response.usage.total_tokens, + "prompt_tokens_details": { + "cached_tokens": 0, + "audio_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0, + "audio_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0 + } + }, + "retries": [] + }); + executions.push(final_llm_call_entry); + + // Guardrail: decode after every LLM call so user sees rehydrated content + let response_content = if let Some(ref enforcer) = guardrail_enforcer { + let payload = serde_json::json!({ + "content": final_response.content + }); + let decoded_result = + enforcer.decode(payload, DecodeContext::LlmResponse); + + // Record guardrail decode execution entry + executions.push(serde_json::json!({ + "type": "guardrail_policy", + "operation": "decode", + "pii_rules_applied_count": decoded_result.rules_applied_count, + "pii_rule_names": decoded_result.rule_names, + "policy_name": decoded_result.policy_name + })); + + let content = decoded_result + .payload + .get("content") + .and_then(|v| v.as_str()) + .map(String::from) + .unwrap_or_else(|| final_response.content.clone()); + tracing::info!("[GuardRail] final LLM response (after decode): {}", content); + content + } else { + final_response.content.clone() + }; + + // --- Update node-level metadata with completed data --- + // Aggregate total_usage from all llm_call executions + let mut total_prompt_tokens: u32 = 0; + let mut total_completion_tokens: u32 = 0; + let mut total_tokens: u32 = 0; + for exec in &executions { + if exec.get("type").and_then(|v| v.as_str()) == Some("llm_call") { + if let Some(usage) = exec.get("usage") { + total_prompt_tokens += usage.get("prompt_tokens").and_then(|v| v.as_u64()).unwrap_or(0) as u32; + total_completion_tokens += usage.get("completion_tokens").and_then(|v| v.as_u64()).unwrap_or(0) as u32; + total_tokens += usage.get("total_tokens").and_then(|v| v.as_u64()).unwrap_or(0) as u32; } } + } - // IMPORTANT: Add the original tool_calls from the initial LLM response - // The final_response.tool_calls will be empty since tools were already executed - // We need to preserve the original tool calls for observability - if let Some(response_obj) = response_metadata.as_object_mut() { - // Enrich tool_calls with their execution results - let mut enriched_tool_calls = tool_calls.clone(); - if let Some(calls_array) = enriched_tool_calls.as_array_mut() { - for (i, call) in calls_array.iter_mut().enumerate() { - if let Some(result) = tool_execution_results.get(i) { - if let Some(call_obj) = call.as_object_mut() { - let mut result_clone = result.clone(); - - // Extract timing details and add to tool call object - if let Some(start_time) = result_clone.get("start_time") { - call_obj.insert("start_time".to_string(), start_time.clone()); - } - if let Some(end_time) = result_clone.get("end_time") { - call_obj.insert("end_time".to_string(), end_time.clone()); - } - if let Some(latency) = result_clone.get("latency_ms") { - call_obj.insert("latency_ms".to_string(), latency.clone()); - } - - // Remove timing fields and redundant tool name from the output object to avoid duplication - if let Some(result_obj) = result_clone.as_object_mut() { - result_obj.remove("start_time"); - result_obj.remove("end_time"); - result_obj.remove("latency_ms"); - result_obj.remove("tool_name"); - } - - // Insert the cleaned result object as "output" - call_obj.insert("output".to_string(), result_clone); - } - } + let total_tool_calls = tool_execution_results.len(); + + // Build complete node metadata by updating the seeded metadata + if let Some(mut node_meta) = existing_node_metadata.clone() { + if let Some(obj) = node_meta.as_object_mut() { + obj.insert("end_time".to_string(), serde_json::json!(final_llm_end_timestamp.to_rfc3339())); + // Calculate total duration from node start to now + if let Some(start_str) = obj.get("start_time").and_then(|v| v.as_str()) { + if let Ok(start_dt) = chrono::DateTime::parse_from_rfc3339(start_str) { + let total_duration = (final_llm_end_timestamp - start_dt.with_timezone(&chrono::Utc)).num_milliseconds() as f64; + obj.insert("duration_ms".to_string(), serde_json::json!(total_duration)); } } - - response_obj.insert("tool_calls".to_string(), enriched_tool_calls); - - // 1. Prepare the value (unwrap and make it mutable) - let mut initial_response_value = existing_metadata_by_id - .clone() - .unwrap_or(serde_json::Value::Null); - - // 2. If the value is a JSON Object, remove the unwanted fields - if let Some(obj) = initial_response_value.as_object_mut() { - obj.remove("content"); - obj.remove("duration_ms"); - obj.remove("execution_timestamp"); - obj.remove("metadata"); - } - - // 3. Insert the cleaned object into your response_obj - response_obj.insert( - "initial_response".to_string(), - initial_response_value - ); - - // Add final input - response_obj.insert("final_input".to_string(), serde_json::Value::String(final_prompt.clone())); + // When GR active: final_output = raw LLM content (before decode) + // When GR inactive: final_output = response content (same as raw) + let final_output_for_meta = if guardrail_enforcer.is_some() { + final_response.content.clone() + } else { + response_content.clone() + }; + obj.insert("final_output".to_string(), serde_json::Value::String(final_output_for_meta)); + obj.insert("exit_reason".to_string(), serde_json::json!(format!("{}", final_response.finish_reason))); + obj.insert("total_tool_calls".to_string(), serde_json::json!(total_tool_calls)); + obj.insert("tools_used".to_string(), serde_json::json!(tools_used)); + obj.insert("total_usage".to_string(), serde_json::json!({ + "prompt_tokens": total_prompt_tokens, + "completion_tokens": total_completion_tokens, + "total_tokens": total_tokens, + "prompt_tokens_details": { + "cached_tokens": 0, + "audio_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0, + "audio_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0 + } + })); + obj.insert("executions".to_string(), serde_json::json!(executions)); } // Store by node ID context.metadata.insert( format!("node_response_{}", node.id), - response_metadata.clone(), + node_meta.clone(), ); // Also store by node name if available @@ -729,7 +964,7 @@ impl Executor { { context.metadata.insert( format!("node_response_{}", node_name), - response_metadata, + node_meta, ); } } @@ -794,7 +1029,9 @@ impl Executor { } } else { // No LLM configuration available, just keep tool results - tracing::warn!("No LLM configuration found in context metadata for final response. Using tool results only."); + tracing::warn!( + "No LLM configuration found in context metadata for final response. Using tool results only." + ); context.set_node_output( &node.id, serde_json::Value::String(tool_results_summary.clone()), diff --git a/python/src/workflow/result.rs b/python/src/workflow/result.rs index 08f914e4..9aa6b7e9 100644 --- a/python/src/workflow/result.rs +++ b/python/src/workflow/result.rs @@ -126,20 +126,22 @@ impl WorkflowResult { .collect() } - /// Get LLM response metadata for a specific node + /// Get node execution metadata for a specific node /// - /// Returns a dictionary containing the full LLM response metadata including: - /// - content: The generated text - /// - usage: Token usage statistics (prompt_tokens, completion_tokens, total_tokens) - /// - finish_reason: Why the LLM stopped generating - /// - model: The model used - /// - metadata: Additional provider-specific metadata + /// Returns the full node-level metadata object containing: + /// - node_id, node_name, node_type, user_input, final_output + /// - tools_available, total_tools_available + /// - start_time, end_time, duration_ms, success, error + /// - total_iterations, max_iterations, exit_reason + /// - total_usage (aggregated token usage) + /// - total_tool_calls, total_retries, tools_used + /// - executions: chronological array of llm_call, tool_call, guardrail_policy entries /// /// # Arguments /// * `node_id` - Node ID or node name /// /// # Returns - /// Dictionary with LLM response metadata, or None if not found + /// Dictionary with node execution metadata, or None if not found fn get_node_response_metadata( &self, py: Python<'_>, @@ -156,32 +158,142 @@ impl WorkflowResult { } } - /// Get all node LLM response metadata + /// Get complete workflow execution metadata /// - /// Returns a dictionary mapping node IDs/names to their LLM response metadata. - /// Each metadata entry contains: - /// - content: The generated text - /// - usage: Token usage statistics (prompt_tokens, completion_tokens, total_tokens) - /// - finish_reason: Why the LLM stopped generating - /// - model: The model used - /// - metadata: Additional provider-specific metadata + /// Returns the full workflow-level schema containing: + /// - workflow_id, workflow_name + /// - start_time, end_time, duration_ms + /// - user_input, final_output (from first/last nodes) + /// - workflow_state: completed/failed/cancelled/paused + /// - nodes: array of per-node metadata objects (each with executions array) + /// - total_usage: aggregated token usage across all nodes + /// - total_tool_calls: sum of tool calls across all nodes /// /// # Returns - /// Dictionary mapping node IDs/names to their LLM response metadata + /// Dictionary with the complete workflow-level metadata fn get_all_node_response_metadata(&self, py: Python<'_>) -> PyResult { - use pyo3::types::PyDict; - - let result_dict = PyDict::new(py); + // Collect node metadata entries (by node ID only, skip name duplicates) + let mut nodes: Vec = Vec::new(); + let mut seen_node_ids: std::collections::HashSet = std::collections::HashSet::new(); for (k, v) in self.inner.metadata.iter() { - // Only include keys that start with "node_response_" - if k.starts_with("node_response_") { - let node_id = k.strip_prefix("node_response_").unwrap(); - let py_value = pythonize::pythonize(py, v)?; - result_dict.set_item(node_id, py_value)?; + if let Some(node_id) = k.strip_prefix("node_response_") { + // Skip if this is a name-based duplicate (node names are typically not UUIDs) + // We include if the node_id is a UUID format or if the value has a node_id field matching + if let Some(stored_node_id) = v.get("node_id").and_then(|v| v.as_str()) { + if seen_node_ids.contains(stored_node_id) { + continue; + } + seen_node_ids.insert(stored_node_id.to_string()); + } else if seen_node_ids.contains(node_id) { + continue; + } else { + seen_node_ids.insert(node_id.to_string()); + } + nodes.push(v.clone()); + } + } + + // Aggregate total_usage and total_tool_calls across all nodes + let mut total_prompt_tokens: u64 = 0; + let mut total_completion_tokens: u64 = 0; + let mut total_tokens: u64 = 0; + let mut total_tool_calls: u64 = 0; + + for node in &nodes { + if let Some(usage) = node.get("total_usage") { + total_prompt_tokens += usage + .get("prompt_tokens") + .and_then(|v| v.as_u64()) + .unwrap_or(0); + total_completion_tokens += usage + .get("completion_tokens") + .and_then(|v| v.as_u64()) + .unwrap_or(0); + total_tokens += usage + .get("total_tokens") + .and_then(|v| v.as_u64()) + .unwrap_or(0); } + total_tool_calls += node + .get("total_tool_calls") + .and_then(|v| v.as_u64()) + .unwrap_or(0); } - Ok(result_dict.into()) + // Determine user_input (from first node) and final_output (from last node) + let user_input = nodes + .first() + .and_then(|n| n.get("user_input")) + .cloned() + .unwrap_or(serde_json::Value::String(String::new())); + let final_output = nodes + .last() + .and_then(|n| n.get("final_output")) + .cloned() + .unwrap_or(serde_json::Value::String(String::new())); + + // Determine workflow_state from context state + let workflow_state = match &self.inner.state { + WorkflowState::Completed => "completed", + WorkflowState::Failed { .. } => "failed", + WorkflowState::Cancelled => "cancelled", + WorkflowState::Paused { .. } => "paused", + WorkflowState::Running { .. } => "running", + WorkflowState::Pending => "pending", + }; + + // Get workflow name from metadata (stored during execution) + let workflow_name = self + .inner + .metadata + .get("workflow_name") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + // Build timing fields + let start_time = self.inner.started_at.to_rfc3339(); + let end_time = self + .inner + .completed_at + .map(|t| t.to_rfc3339()) + .unwrap_or_default(); + let duration_ms = self.inner.execution_duration_ms().unwrap_or(0) as f64; + + // Build the workflow-level metadata object + let workflow_metadata = serde_json::json!({ + "workflow_id": self.inner.workflow_id.to_string(), + "workflow_name": workflow_name, + "start_time": start_time, + "end_time": end_time, + "duration_ms": duration_ms, + "user_input": user_input, + // TODO: Remove these placeholder fields in a future release + "user_input_masked": "", + "final_output": final_output, + "final_output_masked": "", + "workflow_state": workflow_state, + "nodes": nodes, + "total_usage": { + "prompt_tokens": total_prompt_tokens, + "completion_tokens": total_completion_tokens, + "total_tokens": total_tokens, + "prompt_tokens_details": { + "cached_tokens": 0, + "audio_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0, + "audio_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0 + } + }, + "total_tool_calls": total_tool_calls + }); + + let py_obj = pythonize::pythonize(py, &workflow_metadata)?; + Ok(py_obj.into()) } } diff --git a/tests/rust_integration_tests/full_workflow_tests.rs b/tests/rust_integration_tests/full_workflow_tests.rs index 49653de5..4fb5dec6 100644 --- a/tests/rust_integration_tests/full_workflow_tests.rs +++ b/tests/rust_integration_tests/full_workflow_tests.rs @@ -593,7 +593,7 @@ async fn test_real_llm_workflow_execution() { let workflow = builder.build().expect("Failed to build workflow"); // Try to execute the workflow - let result = executor.execute(workflow).await; + let result = executor.execute(workflow, None).await; match result { Ok(context) => { println!("Real LLM workflow executed successfully"); @@ -721,7 +721,7 @@ async fn test_workflow_error_propagation() { // Try to execute - should fail gracefully let executor = WorkflowExecutor::new(); - let result = executor.execute(workflow).await; + let result = executor.execute(workflow, None).await; // Should handle gracefully (may succeed or fail depending on implementation) match result { @@ -802,7 +802,7 @@ async fn test_multi_provider_workflow_execution() { let workflow = builder.build().expect("Failed to build workflow"); - let result = executor.execute(workflow).await; + let result = executor.execute(workflow, None).await; match result { Ok(context) => { println!("{provider_name} workflow executed successfully"); @@ -936,7 +936,7 @@ async fn test_comprehensive_real_api_workflow() { .expect("Failed to build workflow"); // Execute workflow - let result = executor.execute(workflow).await; + let result = executor.execute(workflow, None).await; match result { Ok(context) => { println!("Comprehensive real API workflow executed successfully"); @@ -981,7 +981,7 @@ async fn test_workflow_timeout_handling() { // Execute with timeout let executor = WorkflowExecutor::new().with_max_node_execution_time(2000); // 2 second max let start = std::time::Instant::now(); - let result = executor.execute(workflow).await; + let result = executor.execute(workflow, None).await; let duration = start.elapsed(); // Should complete quickly due to timeout, not wait full 10 seconds diff --git a/tests/rust_unit_tests/python_bindings_tests.rs b/tests/rust_unit_tests/python_bindings_tests.rs index aafef5ab..b91fb61f 100644 --- a/tests/rust_unit_tests/python_bindings_tests.rs +++ b/tests/rust_unit_tests/python_bindings_tests.rs @@ -82,7 +82,7 @@ async fn test_workflow_executor() { assert!(workflow.metadata.contains_key("test_key")); let executor = WorkflowExecutor::new(); - let result = executor.execute(workflow).await; + let result = executor.execute(workflow, None).await; // Empty workflows may fail execution, which is expected behavior // Let's check what the actual result is @@ -109,7 +109,7 @@ async fn test_workflow_integration() { .unwrap(); let executor = WorkflowExecutor::new(); - let result = executor.execute(workflow).await; + let result = executor.execute(workflow, None).await; // Empty workflows may fail execution, which is expected behavior match result { diff --git a/tests/rust_unit_tests/workflow_tests.rs b/tests/rust_unit_tests/workflow_tests.rs index fa92719d..db06808d 100644 --- a/tests/rust_unit_tests/workflow_tests.rs +++ b/tests/rust_unit_tests/workflow_tests.rs @@ -191,7 +191,7 @@ async fn test_workflow_execute_with_dummy_agent_success() { let exec = WorkflowExecutor::new(); exec.register_agent(agent).await; - let ctx = exec.execute(wf).await.expect("workflow should execute"); + let ctx = exec.execute(wf, None).await.expect("workflow should execute"); assert!(matches!(ctx.state, WorkflowState::Completed)); let stats = ctx.stats.expect("stats present"); assert!(stats.total_nodes >= 3); @@ -233,7 +233,7 @@ async fn test_workflow_execute_fail_fast_on_error() { exec.register_agent(agent).await; let ctx = exec - .execute(wf) + .execute(wf, None) .await .expect("execution should return context"); // Current executor records node failure but continues; ensure at least one failed node counted @@ -454,7 +454,7 @@ async fn test_workflow_with_llm() { .expect("Failed to build agent"); executor.register_agent(Arc::new(agent)).await; - let result = executor.execute(workflow).await; + let result = executor.execute(workflow, None).await; assert!(result.is_ok()); } @@ -497,7 +497,7 @@ async fn test_workflow_with_anthropic() { .expect("Failed to build agent"); executor.register_agent(Arc::new(agent)).await; - let result = executor.execute(workflow).await; + let result = executor.execute(workflow, None).await; assert!(result.is_ok()); } @@ -542,7 +542,7 @@ async fn test_workflow_with_ollama() { .expect("Failed to build agent"); executor.register_agent(Arc::new(agent)).await; - let result = executor.execute(workflow).await; + let result = executor.execute(workflow, None).await; assert!(result.is_ok()); } diff --git a/vendor/guardrail/guardrail_ffi.dll b/vendor/guardrail/guardrail_ffi.dll new file mode 100644 index 00000000..00b8778c Binary files /dev/null and b/vendor/guardrail/guardrail_ffi.dll differ diff --git a/vendor/guardrail/guardrail_ffi.lib b/vendor/guardrail/guardrail_ffi.lib new file mode 100644 index 00000000..6f525513 Binary files /dev/null and b/vendor/guardrail/guardrail_ffi.lib differ diff --git a/vendor/guardrail/libguardrail_ffi.a b/vendor/guardrail/libguardrail_ffi.a new file mode 100644 index 00000000..e83f00c7 Binary files /dev/null and b/vendor/guardrail/libguardrail_ffi.a differ