Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 187 additions & 0 deletions sdk/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ use a3s_code_core::hooks::{
HookMatcher as RustHookMatcher, HookResponse as RustHookResponse,
};
use a3s_code_core::llm::Message as RustMessage;
use a3s_code_core::orchestration::{
execute_pipeline, execute_steps_parallel, execute_steps_parallel_resumable,
AgentStepSpec as RustAgentStepSpec, PipelineStage as RustPipelineStage,
StepOutcome as RustStepOutcome,
};
use a3s_code_core::permissions::{
PermissionDecision as RustPermissionDecision, PermissionPolicy as RustPermissionPolicy,
PermissionRule as RustPermissionRule,
Expand Down Expand Up @@ -1396,6 +1401,110 @@ impl PySession {
Ok(PyAgentResult::from(result))
}

/// Run `specs` as a fan-out of agent steps and return each step's outcome
/// (a dict) in input order. Each spec is a dict with snake_case keys:
/// `task_id`, `agent`, `description`, `prompt`, optional `max_steps`,
/// `parent_session_id`, `output_schema`. A failed step surfaces as
/// `success: False` without failing the batch.
fn parallel(&self, py: Python<'_>, specs: Vec<Bound<'_, PyAny>>) -> PyResult<Vec<PyObject>> {
let rust_specs = specs
.iter()
.map(|s| py_to_step_spec(py, s))
.collect::<PyResult<Vec<_>>>()?;
let session = self.inner.clone();
let outcomes = py.allow_threads(move || {
get_runtime().block_on(async move {
let executor = session.agent_executor();
execute_steps_parallel(executor, rust_specs, None).await
})
});
outcomes.iter().map(|o| step_outcome_to_py(py, o)).collect()
}

/// Like `parallel`, but resumable: progress is journaled under
/// `workflow_id` via the session's store, so an interrupted run skips
/// already-completed steps. Raises if no `session_store` is configured.
fn parallel_resumable(
&self,
py: Python<'_>,
specs: Vec<Bound<'_, PyAny>>,
workflow_id: String,
) -> PyResult<Vec<PyObject>> {
let rust_specs = specs
.iter()
.map(|s| py_to_step_spec(py, s))
.collect::<PyResult<Vec<_>>>()?;
let session = self.inner.clone();
let outcomes = py
.allow_threads(move || {
get_runtime().block_on(async move {
let Some(store) = session.session_store() else {
return Err("parallel_resumable requires a session_store on the session");
};
let executor = session.agent_executor();
Ok(execute_steps_parallel_resumable(
executor,
rust_specs,
&workflow_id,
store,
None,
)
.await)
})
})
.map_err(PyRuntimeError::new_err)?;
outcomes.iter().map(|o| step_outcome_to_py(py, o)).collect()
}

/// Run each item through a chain of `stages`, with no barrier between
/// stages. Each stage is a callable `stage(ctx) -> spec_dict | None`, where
/// `ctx = {"previous": <outcome dict or None>, "item": <item>}`. Return a
/// spec dict (snake_case keys) to run that step, or `None` to stop the
/// item's chain. A chain also stops when a step fails. Returns one entry
/// per item (the last outcome dict, or `None`), in input order.
///
/// A stage callable that raises is caught and treated as `None` (stops that
/// chain). Per-stage `output_schema` is not supported here — use `parallel`
/// for schema-validated steps.
fn pipeline(
&self,
py: Python<'_>,
items: Vec<Bound<'_, PyAny>>,
stages: Vec<Bound<'_, PyAny>>,
) -> PyResult<Vec<Option<PyObject>>> {
let rust_items = items
.iter()
.map(|i| py_to_json_value(py, i))
.collect::<PyResult<Vec<_>>>()?;
let rust_stages: Vec<RustPipelineStage<serde_json::Value>> = stages
.into_iter()
.map(|s| {
let stage = std::sync::Arc::new(PythonPipelineStage {
callback: s.unbind(),
});
let ps: RustPipelineStage<serde_json::Value> =
std::sync::Arc::new(move |prev, item| stage.invoke(prev, item));
ps
})
.collect();

let session = self.inner.clone();
let outcomes = py.allow_threads(move || {
get_runtime().block_on(async move {
let executor = session.agent_executor();
execute_pipeline(executor, rust_items, rust_stages, None).await
})
});

outcomes
.iter()
.map(|o| match o {
Some(outcome) => step_outcome_to_py(py, outcome).map(Some),
None => Ok(None),
})
.collect()
}

/// Send a prompt or request and get a streaming iterator of events.
///
/// When ``history`` is omitted, session history and verification evidence are
Expand Down Expand Up @@ -3205,6 +3314,84 @@ fn parse_py_hook_response(
Ok(RustHookResponse::continue_())
}

// ============================================================================
// Orchestration: Python <-> Rust conversion + pipeline-stage bridge
// ============================================================================

fn py_dumps(py: Python<'_>, obj: &Bound<'_, PyAny>) -> PyResult<String> {
let json_mod = py.import("json")?;
json_mod.call_method1("dumps", (obj,))?.extract()
}

/// Convert a Python spec dict into an `AgentStepSpec` (snake_case keys) via a
/// JSON round-trip.
fn py_to_step_spec(py: Python<'_>, obj: &Bound<'_, PyAny>) -> PyResult<RustAgentStepSpec> {
serde_json::from_str(&py_dumps(py, obj)?)
.map_err(|e| PyValueError::new_err(format!("invalid AgentStepSpec: {e}")))
}

/// Convert an arbitrary Python value into a `serde_json::Value`.
fn py_to_json_value(py: Python<'_>, obj: &Bound<'_, PyAny>) -> PyResult<serde_json::Value> {
serde_json::from_str(&py_dumps(py, obj)?)
.map_err(|e| PyValueError::new_err(format!("invalid JSON: {e}")))
}

/// Convert a `StepOutcome` into a Python dict.
fn step_outcome_to_py(py: Python<'_>, outcome: &RustStepOutcome) -> PyResult<PyObject> {
let json = serde_json::to_string(outcome)
.map_err(|e| PyRuntimeError::new_err(format!("serialize outcome: {e}")))?;
json_string_to_py(py, &json)
}

/// Bridges a Python pipeline-stage callable into a synchronous `PipelineStage`.
///
/// GIL safety: `pipeline()` releases the GIL via `py.allow_threads`, so
/// re-acquiring it here from a tokio worker thread does not deadlock (same as
/// the hook/budget bridges). A raised exception is caught and treated as
/// `None` (stop the chain).
struct PythonPipelineStage {
callback: pyo3::Py<pyo3::PyAny>,
}

impl PythonPipelineStage {
fn invoke(
&self,
prev: Option<&RustStepOutcome>,
item: &serde_json::Value,
) -> Option<RustAgentStepSpec> {
pyo3::Python::with_gil(|py| {
let result = (|| -> PyResult<Option<RustAgentStepSpec>> {
let json_mod = py.import("json")?;
let previous = match prev {
Some(o) => {
let s = serde_json::to_string(o)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
json_mod.call_method1("loads", (s,))?
}
None => py.None().into_bound(py),
};
let item_str = serde_json::to_string(item)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
let item_py = json_mod.call_method1("loads", (item_str,))?;
let ctx = PyDict::new(py);
ctx.set_item("previous", previous)?;
ctx.set_item("item", item_py)?;
let ret = self.callback.call1(py, (ctx,))?;
let bound = ret.bind(py);
if bound.is_none() {
return Ok(None);
}
let spec_json: String = json_mod.call_method1("dumps", (bound,))?.extract()?;
serde_json::from_str::<RustAgentStepSpec>(&spec_json)
.map(Some)
.map_err(|e| PyValueError::new_err(format!("invalid step spec: {e}")))
})();
// Fail-closed: any exception → stop this chain.
result.unwrap_or(None)
})
}
}

// ============================================================================
// Python BudgetGuard bridge
// ============================================================================
Expand Down
Loading