diff --git a/sdk/python/src/lib.rs b/sdk/python/src/lib.rs index b76522b..0c218ab 100644 --- a/sdk/python/src/lib.rs +++ b/sdk/python/src/lib.rs @@ -49,6 +49,11 @@ use a3s_code_core::hooks::{ HookMatcher as RustHookMatcher, HookResponse as RustHookResponse, }; use a3s_code_core::llm::Message as RustMessage; +use a3s_code_core::orchestration::{ + execute_pipeline, execute_steps_parallel, execute_steps_parallel_resumable, + AgentStepSpec as RustAgentStepSpec, PipelineStage as RustPipelineStage, + StepOutcome as RustStepOutcome, +}; use a3s_code_core::permissions::{ PermissionDecision as RustPermissionDecision, PermissionPolicy as RustPermissionPolicy, PermissionRule as RustPermissionRule, @@ -1396,6 +1401,110 @@ impl PySession { Ok(PyAgentResult::from(result)) } + /// Run `specs` as a fan-out of agent steps and return each step's outcome + /// (a dict) in input order. Each spec is a dict with snake_case keys: + /// `task_id`, `agent`, `description`, `prompt`, optional `max_steps`, + /// `parent_session_id`, `output_schema`. A failed step surfaces as + /// `success: False` without failing the batch. + fn parallel(&self, py: Python<'_>, specs: Vec>) -> PyResult> { + let rust_specs = specs + .iter() + .map(|s| py_to_step_spec(py, s)) + .collect::>>()?; + let session = self.inner.clone(); + let outcomes = py.allow_threads(move || { + get_runtime().block_on(async move { + let executor = session.agent_executor(); + execute_steps_parallel(executor, rust_specs, None).await + }) + }); + outcomes.iter().map(|o| step_outcome_to_py(py, o)).collect() + } + + /// Like `parallel`, but resumable: progress is journaled under + /// `workflow_id` via the session's store, so an interrupted run skips + /// already-completed steps. Raises if no `session_store` is configured. + fn parallel_resumable( + &self, + py: Python<'_>, + specs: Vec>, + workflow_id: String, + ) -> PyResult> { + let rust_specs = specs + .iter() + .map(|s| py_to_step_spec(py, s)) + .collect::>>()?; + let session = self.inner.clone(); + let outcomes = py + .allow_threads(move || { + get_runtime().block_on(async move { + let Some(store) = session.session_store() else { + return Err("parallel_resumable requires a session_store on the session"); + }; + let executor = session.agent_executor(); + Ok(execute_steps_parallel_resumable( + executor, + rust_specs, + &workflow_id, + store, + None, + ) + .await) + }) + }) + .map_err(PyRuntimeError::new_err)?; + outcomes.iter().map(|o| step_outcome_to_py(py, o)).collect() + } + + /// Run each item through a chain of `stages`, with no barrier between + /// stages. Each stage is a callable `stage(ctx) -> spec_dict | None`, where + /// `ctx = {"previous": , "item": }`. Return a + /// spec dict (snake_case keys) to run that step, or `None` to stop the + /// item's chain. A chain also stops when a step fails. Returns one entry + /// per item (the last outcome dict, or `None`), in input order. + /// + /// A stage callable that raises is caught and treated as `None` (stops that + /// chain). Per-stage `output_schema` is not supported here — use `parallel` + /// for schema-validated steps. + fn pipeline( + &self, + py: Python<'_>, + items: Vec>, + stages: Vec>, + ) -> PyResult>> { + let rust_items = items + .iter() + .map(|i| py_to_json_value(py, i)) + .collect::>>()?; + let rust_stages: Vec> = stages + .into_iter() + .map(|s| { + let stage = std::sync::Arc::new(PythonPipelineStage { + callback: s.unbind(), + }); + let ps: RustPipelineStage = + std::sync::Arc::new(move |prev, item| stage.invoke(prev, item)); + ps + }) + .collect(); + + let session = self.inner.clone(); + let outcomes = py.allow_threads(move || { + get_runtime().block_on(async move { + let executor = session.agent_executor(); + execute_pipeline(executor, rust_items, rust_stages, None).await + }) + }); + + outcomes + .iter() + .map(|o| match o { + Some(outcome) => step_outcome_to_py(py, outcome).map(Some), + None => Ok(None), + }) + .collect() + } + /// Send a prompt or request and get a streaming iterator of events. /// /// When ``history`` is omitted, session history and verification evidence are @@ -3205,6 +3314,84 @@ fn parse_py_hook_response( Ok(RustHookResponse::continue_()) } +// ============================================================================ +// Orchestration: Python <-> Rust conversion + pipeline-stage bridge +// ============================================================================ + +fn py_dumps(py: Python<'_>, obj: &Bound<'_, PyAny>) -> PyResult { + let json_mod = py.import("json")?; + json_mod.call_method1("dumps", (obj,))?.extract() +} + +/// Convert a Python spec dict into an `AgentStepSpec` (snake_case keys) via a +/// JSON round-trip. +fn py_to_step_spec(py: Python<'_>, obj: &Bound<'_, PyAny>) -> PyResult { + serde_json::from_str(&py_dumps(py, obj)?) + .map_err(|e| PyValueError::new_err(format!("invalid AgentStepSpec: {e}"))) +} + +/// Convert an arbitrary Python value into a `serde_json::Value`. +fn py_to_json_value(py: Python<'_>, obj: &Bound<'_, PyAny>) -> PyResult { + serde_json::from_str(&py_dumps(py, obj)?) + .map_err(|e| PyValueError::new_err(format!("invalid JSON: {e}"))) +} + +/// Convert a `StepOutcome` into a Python dict. +fn step_outcome_to_py(py: Python<'_>, outcome: &RustStepOutcome) -> PyResult { + let json = serde_json::to_string(outcome) + .map_err(|e| PyRuntimeError::new_err(format!("serialize outcome: {e}")))?; + json_string_to_py(py, &json) +} + +/// Bridges a Python pipeline-stage callable into a synchronous `PipelineStage`. +/// +/// GIL safety: `pipeline()` releases the GIL via `py.allow_threads`, so +/// re-acquiring it here from a tokio worker thread does not deadlock (same as +/// the hook/budget bridges). A raised exception is caught and treated as +/// `None` (stop the chain). +struct PythonPipelineStage { + callback: pyo3::Py, +} + +impl PythonPipelineStage { + fn invoke( + &self, + prev: Option<&RustStepOutcome>, + item: &serde_json::Value, + ) -> Option { + pyo3::Python::with_gil(|py| { + let result = (|| -> PyResult> { + let json_mod = py.import("json")?; + let previous = match prev { + Some(o) => { + let s = serde_json::to_string(o) + .map_err(|e| PyValueError::new_err(e.to_string()))?; + json_mod.call_method1("loads", (s,))? + } + None => py.None().into_bound(py), + }; + let item_str = serde_json::to_string(item) + .map_err(|e| PyValueError::new_err(e.to_string()))?; + let item_py = json_mod.call_method1("loads", (item_str,))?; + let ctx = PyDict::new(py); + ctx.set_item("previous", previous)?; + ctx.set_item("item", item_py)?; + let ret = self.callback.call1(py, (ctx,))?; + let bound = ret.bind(py); + if bound.is_none() { + return Ok(None); + } + let spec_json: String = json_mod.call_method1("dumps", (bound,))?.extract()?; + serde_json::from_str::(&spec_json) + .map(Some) + .map_err(|e| PyValueError::new_err(format!("invalid step spec: {e}"))) + })(); + // Fail-closed: any exception → stop this chain. + result.unwrap_or(None) + }) + } +} + // ============================================================================ // Python BudgetGuard bridge // ============================================================================