From 7ecaa1914a9a3d7b3e46432f6b6d7604c214b4f6 Mon Sep 17 00:00:00 2001 From: Nasr Date: Mon, 9 Feb 2026 14:35:12 -0600 Subject: [PATCH 1/9] feat(sozo): add controller session management commands --- bin/sozo/src/commands/mod.rs | 11 + .../commands/options/account/controller.rs | 248 ++++++++++++++++- bin/sozo/src/commands/options/account/mod.rs | 19 +- bin/sozo/src/commands/session.rs | 260 ++++++++++++++++++ 4 files changed, 533 insertions(+), 5 deletions(-) create mode 100644 bin/sozo/src/commands/session.rs diff --git a/bin/sozo/src/commands/mod.rs b/bin/sozo/src/commands/mod.rs index f987bc0164..539e2b5d15 100644 --- a/bin/sozo/src/commands/mod.rs +++ b/bin/sozo/src/commands/mod.rs @@ -24,6 +24,8 @@ pub(crate) mod mcp; pub(crate) mod migrate; pub(crate) mod model; pub(crate) mod options; +#[cfg(feature = "controller")] +pub(crate) mod session; pub(crate) mod starknet; pub(crate) mod test; pub(crate) mod version; @@ -45,6 +47,8 @@ use invoke::InvokeArgs; use mcp::McpArgs; use migrate::MigrateArgs; use model::ModelArgs; +#[cfg(feature = "controller")] +use session::SessionArgs; #[cfg(feature = "walnut")] use sozo_walnut::walnut::WalnutArgs; use starknet::StarknetArgs; @@ -86,6 +90,9 @@ pub enum Commands { Migrate(Box), #[command(about = "Inspect a model")] Model(Box), + #[cfg(feature = "controller")] + #[command(about = "Manage Cartridge controller sessions")] + Session(Box), #[command(about = "Runs cairo tests")] Test(Box), #[command(about = "Print version")] @@ -118,6 +125,8 @@ impl fmt::Display for Commands { Commands::Inspect(_) => write!(f, "Inspect"), Commands::Migrate(_) => write!(f, "Migrate"), Commands::Model(_) => write!(f, "Model"), + #[cfg(feature = "controller")] + Commands::Session(_) => write!(f, "Session"), Commands::Test(_) => write!(f, "Test"), Commands::Version(_) => write!(f, "Version"), Commands::Mcp(_) => write!(f, "Mcp"), @@ -150,6 +159,8 @@ pub async fn run(command: Commands, scarb_metadata: &Metadata, ui: &SozoUi) -> R Commands::Mcp(args) => args.run(scarb_metadata).await, Commands::Migrate(args) => args.run(scarb_metadata, ui).await, Commands::Model(args) => args.run(scarb_metadata, ui).await, + #[cfg(feature = "controller")] + Commands::Session(args) => args.run(scarb_metadata, ui).await, Commands::Test(args) => args.run(scarb_metadata), Commands::Version(args) => args.run(scarb_metadata), #[cfg(feature = "walnut")] diff --git a/bin/sozo/src/commands/options/account/controller.rs b/bin/sozo/src/commands/options/account/controller.rs index 2a2f0d72bf..25064dfe6d 100644 --- a/bin/sozo/src/commands/options/account/controller.rs +++ b/bin/sozo/src/commands/options/account/controller.rs @@ -1,7 +1,12 @@ use std::collections::HashMap; +use std::io::{Read, Write}; +use std::net::{TcpListener, TcpStream}; +use std::str::FromStr; +use std::time::Duration; -use anyhow::{bail, Result}; +use anyhow::{Context, Result, anyhow, bail}; use dojo_world::contracts::contract_info::ContractInfo; +use serde::{Deserialize, Serialize}; use slot::account_sdk::account::session::account::SessionAccount; use slot::account_sdk::account::session::merkle::MerkleTree; use slot::account_sdk::account::session::policy::{CallPolicy, MerkleLeaf, Policy, ProvedPolicy}; @@ -17,6 +22,63 @@ use url::Url; #[allow(missing_debug_implementations)] pub type ControllerAccount = SessionAccount; +const CONTROLLER_OAUTH_TIMEOUT_SECS: u64 = 300; +const CONTROLLER_OAUTH_CALLBACK_PATH: &str = "/callback"; +const CONTROLLER_LOGIN_PATH: &str = "/slot"; +const CONTROLLER_ACCOUNT_INFO_QUERY: &str = r#" +query ControllerAccountInfo { + me { + id + username + controllers { + edges { + node { + id + address + } + } + } + } +} +"#; + +#[derive(Debug, Deserialize)] +struct ControllerAccountInfoResponse { + me: Option, +} + +#[derive(Debug, Deserialize)] +struct ControllerAccountInfo { + id: String, + username: String, + controllers: ControllerEdges, +} + +#[derive(Debug, Deserialize)] +struct ControllerEdges { + edges: Option>>, +} + +#[derive(Debug, Deserialize)] +struct ControllerEdge { + node: Option, +} + +#[derive(Debug, Deserialize)] +struct ControllerNode { + id: String, + address: String, +} + +#[derive(Debug, Serialize)] +struct GraphqlRequest<'a, T> +where + T: Serialize, +{ + query: &'a str, + variables: T, +} + /// Create a new Catridge Controller account based on session key. /// /// For now, Controller guarantees that if the provided network is among one of the supported @@ -38,7 +100,7 @@ pub async fn create_controller( let chain_id = rpc_provider.chain_id().await?; trace!(target: "account::controller", "Loading Slot credentials."); - let credentials = slot::credential::Credentials::load()?; + let credentials = load_or_bootstrap_credentials().await?; let username = credentials.account.id; // Right now, the Cartridge Controller API ensures that there's always a Controller associated @@ -85,6 +147,169 @@ pub async fn create_controller( Ok(session_details.into_account(rpc_provider)) } +async fn load_or_bootstrap_credentials() -> Result { + match slot::credential::Credentials::load() { + Ok(credentials) => Ok(credentials), + Err(err) if should_bootstrap_credentials(&err) => { + trace!( + target: "account::controller", + error = %err, + "No valid controller credentials found. Starting inline authorization flow." + ); + bootstrap_credentials().await?; + slot::credential::Credentials::load() + .context("Controller credentials were created but could not be loaded") + .map_err(Into::into) + } + Err(err) => Err(err.into()), + } +} + +fn should_bootstrap_credentials(err: &slot::Error) -> bool { + matches!( + err, + slot::Error::Unauthorized | slot::Error::MalformedCredentials | slot::Error::InvalidOAuth + ) +} + +async fn bootstrap_credentials() -> Result<()> { + let listener = TcpListener::bind("127.0.0.1:0") + .context("Failed to start local callback listener for controller authorization")?; + + let callback_uri = format!( + "http://127.0.0.1:{}{}", + listener.local_addr()?.port(), + CONTROLLER_OAUTH_CALLBACK_PATH + ); + + let mut authorize_url = Url::parse(&slot::vars::get_cartridge_keychain_url()) + .context("Invalid Cartridge keychain URL")?; + authorize_url.set_path(CONTROLLER_LOGIN_PATH); + authorize_url.query_pairs_mut().append_pair("callback_uri", &callback_uri); + + println!("Authorize your controller account in browser:\n\n {}\n", authorize_url); + + slot::browser::open(authorize_url.as_str())?; + + let code = tokio::time::timeout( + Duration::from_secs(CONTROLLER_OAUTH_TIMEOUT_SECS), + tokio::task::spawn_blocking(move || wait_for_oauth_code(listener)), + ) + .await + .map_err(|_| { + anyhow!( + "Timed out waiting for controller authorization callback after {} seconds.", + CONTROLLER_OAUTH_TIMEOUT_SECS + ) + })? + .map_err(|e| anyhow!("Failed to run controller authorization callback listener: {e}"))??; + + let mut api = slot::api::Client::new(); + let token = api.oauth2(&code).await.context("Failed to exchange OAuth code")?; + api.set_token(token.clone()); + + let account_info = fetch_controller_account_info(&api) + .await + .context("Failed to load Controller account details after authorization")?; + + let path = slot::credential::Credentials::new(account_info, token) + .store() + .context("Failed to store controller credentials")?; + + trace!( + target: "account::controller", + path = %path.display(), + "Controller credentials stored." + ); + + Ok(()) +} + +async fn fetch_controller_account_info( + api: &slot::api::Client, +) -> Result { + let request = + GraphqlRequest { query: CONTROLLER_ACCOUNT_INFO_QUERY, variables: serde_json::json!({}) }; + + let response: ControllerAccountInfoResponse = api.query(&request).await?; + let me = response.me.ok_or_else(|| anyhow!("Missing `me` account info in API response"))?; + + let mut controllers = Vec::new(); + for edge in me.controllers.edges.unwrap_or_default().into_iter().flatten() { + let Some(node) = edge.node else { + continue; + }; + + let address = Felt::from_str(&node.address) + .with_context(|| format!("Invalid controller address `{}`", node.address))?; + + controllers.push(slot::account::Controller { id: node.id, address }); + } + + Ok(slot::account::AccountInfo { + id: me.id, + username: me.username, + controllers, + credentials: Vec::new(), + }) +} + +fn wait_for_oauth_code(listener: TcpListener) -> Result { + let (mut stream, _) = + listener.accept().context("Failed to accept controller OAuth callback connection")?; + + let mut buffer = [0_u8; 8192]; + let bytes_read = + stream.read(&mut buffer).context("Failed to read controller OAuth callback request")?; + if bytes_read == 0 { + bail!("Controller OAuth callback request was empty."); + } + + let request = String::from_utf8_lossy(&buffer[..bytes_read]); + let request_line = request.lines().next().unwrap_or_default(); + let target = request_line + .split_whitespace() + .nth(1) + .ok_or_else(|| anyhow!("Invalid callback request line: `{request_line}`"))?; + + let Some(code) = extract_oauth_code(target) else { + write_http_response( + &mut stream, + "400 Bad Request", + "Missing authorization code. You can close this tab and retry.", + )?; + bail!("Controller OAuth callback does not contain `code` query parameter."); + }; + + write_http_response( + &mut stream, + "200 OK", + "Controller authorization received. You can close this tab and return to sozo.", + )?; + + Ok(code) +} + +fn extract_oauth_code(target: &str) -> Option { + let callback_url = Url::parse(&format!("http://localhost{target}")).ok()?; + if callback_url.path() != CONTROLLER_OAUTH_CALLBACK_PATH { + return None; + } + + callback_url.query_pairs().find_map(|(key, value)| (key == "code").then(|| value.into_owned())) +} + +fn write_http_response(stream: &mut TcpStream, status: &str, body: &str) -> Result<()> { + let response = format!( + "HTTP/1.1 {status}\r\nContent-Type: text/plain; charset=utf-8\r\nContent-Length: \ + {}\r\nConnection: close\r\n\r\n{body}", + body.len() + ); + stream.write_all(response.as_bytes())?; + stream.flush()?; + Ok(()) +} + // Check if the new policies are equal to the ones in the existing session // // This function would compute the merkle root of the new policies and compare it with the root in @@ -171,7 +396,7 @@ mod tests { use scarb_metadata_ext::MetadataDojoExt; use starknet::macros::felt; - use super::{collect_policies, PolicyMethod}; + use super::{PolicyMethod, collect_policies, extract_oauth_code}; #[test] fn collect_policies_from_project() { @@ -200,4 +425,21 @@ mod tests { }); } } + + #[test] + fn extract_oauth_code_from_callback_target() { + let code = extract_oauth_code("/callback?code=abc123&state=xyz"); + assert_eq!(code.as_deref(), Some("abc123")); + } + + #[test] + fn extract_oauth_code_decodes_url_encoded_value() { + let code = extract_oauth_code("/callback?code=abc%2F123"); + assert_eq!(code.as_deref(), Some("abc/123")); + } + + #[test] + fn extract_oauth_code_rejects_non_callback_target() { + assert_eq!(extract_oauth_code("/not-callback?code=abc123"), None); + } } diff --git a/bin/sozo/src/commands/options/account/mod.rs b/bin/sozo/src/commands/options/account/mod.rs index 6c18d698bb..20aa30c647 100644 --- a/bin/sozo/src/commands/options/account/mod.rs +++ b/bin/sozo/src/commands/options/account/mod.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use std::str::FromStr; use std::sync::Arc; -use anyhow::{anyhow, Result}; +use anyhow::{Result, anyhow}; use clap::{Args, ValueEnum}; use dojo_utils::env::DOJO_ACCOUNT_ADDRESS_ENV_VAR; use dojo_world::config::Environment; @@ -52,6 +52,13 @@ pub struct AccountOptions { #[cfg(feature = "controller")] pub controller: bool, + #[arg(global = true)] + #[arg(long = "session")] + #[arg(help_heading = "Controller options")] + #[arg(help = "Use Cartridge Controller session account (alias of --slot.controller)")] + #[cfg(feature = "controller")] + pub session: bool, + #[command(flatten)] #[command(next_help_heading = "Signer options")] pub signer: SignerOptions, @@ -85,7 +92,7 @@ impl AccountOptions { P: Send + Sync, { #[cfg(feature = "controller")] - if self.controller { + if self.controller || self.session { let url = starknet.url(env_metadata)?; let cartridge_provider = CartridgeJsonRpcProvider::new(url.clone()); let account = self.controller(url, cartridge_provider.clone(), contracts).await?; @@ -222,6 +229,14 @@ mod tests { ); } + #[cfg(feature = "controller")] + #[test] + fn controller_session_alias_flag_is_parsed() { + let cmd = Command::parse_from(["sozo", "--session"]); + assert!(cmd.account.session); + assert!(!cmd.account.controller); + } + #[test] fn account_address_from_both() { let env_metadata = dojo_world::config::Environment { diff --git a/bin/sozo/src/commands/session.rs b/bin/sozo/src/commands/session.rs new file mode 100644 index 0000000000..9c4ffcfbf4 --- /dev/null +++ b/bin/sozo/src/commands/session.rs @@ -0,0 +1,260 @@ +use std::collections::HashMap; +use std::fs; +use std::path::PathBuf; + +use anyhow::{Result, anyhow}; +use clap::{Args, Subcommand}; +use dojo_world::contracts::ContractInfo; +use scarb_metadata::Metadata; +use scarb_metadata_ext::MetadataDojoExt; +use slot::account_sdk::provider::CartridgeJsonRpcProvider; +use sozo_ui::SozoUi; +use starknet::providers::Provider; +use tracing::trace; + +use super::options::account::controller; +use super::options::starknet::StarknetOptions; +use super::options::world::WorldOptions; +use crate::utils; + +#[derive(Debug, Args)] +pub struct SessionArgs { + #[command(subcommand)] + command: SessionCommand, +} + +#[derive(Debug, Subcommand)] +pub enum SessionCommand { + #[command(about = "Create or refresh a controller session from project contracts.")] + Create { + #[arg(long)] + #[arg(help = "Load contracts from world diff (chain) instead of local manifest.")] + diff: bool, + + #[command(flatten)] + starknet: StarknetOptions, + + #[command(flatten)] + world: WorldOptions, + }, + + #[command(about = "Show current controller session status for the selected network.")] + Status { + #[command(flatten)] + starknet: StarknetOptions, + }, + + #[command(about = "Discard stored controller session(s).")] + Discard { + #[arg(long)] + #[arg(help = "Discard all stored sessions for the authenticated controller account.")] + all: bool, + + #[command(flatten)] + starknet: StarknetOptions, + }, +} + +impl SessionArgs { + pub async fn run(self, scarb_metadata: &Metadata, ui: &SozoUi) -> Result<()> { + trace!(args = ?self); + + match self.command { + SessionCommand::Create { diff, starknet, world } => { + create_session(diff, starknet, world, scarb_metadata, ui).await + } + SessionCommand::Status { starknet } => { + status_session(starknet, scarb_metadata, ui).await + } + SessionCommand::Discard { all, starknet } => { + discard_session(all, starknet, scarb_metadata, ui).await + } + } + } +} + +async fn create_session( + diff: bool, + starknet: StarknetOptions, + world: WorldOptions, + scarb_metadata: &Metadata, + ui: &SozoUi, +) -> Result<()> { + ui.title("Create controller session"); + + let profile_config = scarb_metadata.load_dojo_profile_config()?; + let rpc_url = starknet.url(profile_config.env.as_ref())?; + let contracts = load_contracts(diff, starknet.clone(), world, scarb_metadata, ui).await?; + + ui.step("Authorize and register session"); + let rpc_provider = CartridgeJsonRpcProvider::new(rpc_url.clone()); + let chain_id = rpc_provider.chain_id().await?; + let _ = controller::create_controller(rpc_url, rpc_provider, &contracts).await?; + + let session = slot::session::get(chain_id)? + .ok_or_else(|| anyhow!("Session was not found in local storage after creation."))?; + + let session_path = session_file_path(&session.auth.username, chain_id); + ui.result("Session is ready."); + ui.print(format!("Controller address: {:#066x}", session.auth.address)); + ui.print(format!("Chain id : {chain_id:#x}")); + ui.print(format!("Policies : {}", session.session.proved_policies.len())); + ui.print(format!("Expires at (unix) : {}", session.session.inner.expires_at)); + ui.print(format!("Stored session : {}", session_path.display())); + ui.print( + "Use `sozo execute ... --session` (or `--slot.controller`) to execute with this session.", + ); + + Ok(()) +} + +async fn status_session( + starknet: StarknetOptions, + scarb_metadata: &Metadata, + ui: &SozoUi, +) -> Result<()> { + ui.title("Controller session status"); + + let credentials = match slot::credential::Credentials::load() { + Ok(credentials) => credentials, + Err( + slot::Error::Unauthorized + | slot::Error::MalformedCredentials + | slot::Error::InvalidOAuth, + ) => { + ui.warn("No controller credentials found. Run `sozo session create` first."); + return Ok(()); + } + Err(err) => return Err(err.into()), + }; + + let profile_config = scarb_metadata.load_dojo_profile_config()?; + let rpc_url = starknet.url(profile_config.env.as_ref())?; + let chain_id = CartridgeJsonRpcProvider::new(rpc_url).chain_id().await?; + + ui.print(format!("Account id : {}", credentials.account.id)); + ui.print(format!("Username : {}", credentials.account.username)); + ui.print(format!("Chain id : {chain_id:#x}")); + + if let Some(controller) = credentials.account.controllers.first() { + ui.print(format!("Controller address: {:#066x}", controller.address)); + } else { + ui.warn("No controller is associated with the authenticated account."); + } + + let session_path = session_file_path(&credentials.account.id, chain_id); + let session = slot::session::get(chain_id)?; + + if let Some(session) = session { + ui.result("Session: active"); + ui.print(format!("Policies : {}", session.session.proved_policies.len())); + ui.print(format!("Expires at (unix) : {}", session.session.inner.expires_at)); + ui.print(format!("Stored session : {}", session_path.display())); + } else { + ui.warn("Session: not found for this network."); + ui.print(format!("Expected path : {}", session_path.display())); + } + + Ok(()) +} + +async fn discard_session( + all: bool, + starknet: StarknetOptions, + scarb_metadata: &Metadata, + ui: &SozoUi, +) -> Result<()> { + ui.title("Discard controller session"); + + let credentials = match slot::credential::Credentials::load() { + Ok(credentials) => credentials, + Err( + slot::Error::Unauthorized + | slot::Error::MalformedCredentials + | slot::Error::InvalidOAuth, + ) => { + ui.warn("No controller credentials found."); + return Ok(()); + } + Err(err) => return Err(err.into()), + }; + + let mut removed = 0usize; + if all { + let user_dir = slot::utils::config_dir().join(&credentials.account.id); + if user_dir.exists() { + for entry in fs::read_dir(&user_dir)? { + let path = entry?.path(); + let is_session = path + .file_name() + .and_then(|name| name.to_str()) + .is_some_and(|name| name.ends_with("-session.json")); + + if is_session { + fs::remove_file(&path)?; + removed += 1; + } + } + } + + ui.result(format!("Discarded {removed} session(s).")); + return Ok(()); + } + + let profile_config = scarb_metadata.load_dojo_profile_config()?; + let rpc_url = starknet.url(profile_config.env.as_ref())?; + let chain_id = CartridgeJsonRpcProvider::new(rpc_url).chain_id().await?; + + let session_path = session_file_path(&credentials.account.id, chain_id); + if session_path.exists() { + fs::remove_file(&session_path)?; + ui.result("Session discarded."); + ui.print(format!("Removed: {}", session_path.display())); + } else { + ui.warn("No stored session found for this network."); + ui.print(format!("Expected path: {}", session_path.display())); + } + + Ok(()) +} + +async fn load_contracts( + diff: bool, + starknet: StarknetOptions, + world: WorldOptions, + scarb_metadata: &Metadata, + ui: &SozoUi, +) -> Result> { + if diff { + let (world_diff, _, _) = + utils::get_world_diff_and_provider(starknet, world, scarb_metadata, ui).await?; + return Ok((&world_diff).into()); + } + + let manifest = scarb_metadata.read_dojo_manifest_profile()?.ok_or_else(|| { + anyhow!( + "Project manifest not found. Run `sozo migrate` first or pass `--diff` to derive \ + contracts from chain." + ) + })?; + + Ok((&manifest).into()) +} + +fn session_file_path(username: &str, chain_id: starknet::core::types::Felt) -> PathBuf { + slot::utils::config_dir().join(username).join(format!("{chain_id:#x}-session.json")) +} + +#[cfg(test)] +mod tests { + use starknet::macros::felt; + + use super::session_file_path; + + #[test] + fn session_file_path_contains_expected_suffix() { + let path = session_file_path("my-user", felt!("0x534e5f5345504f4c4941")); + let file = path.file_name().and_then(|name| name.to_str()).unwrap(); + assert!(file.ends_with("-session.json")); + } +} From 27e5d54d303f1f635601d5776b103cd983be0f4b Mon Sep 17 00:00:00 2001 From: Nasr Date: Mon, 9 Feb 2026 15:12:50 -0600 Subject: [PATCH 2/9] feat(sozo): limit controller session policies to user systems --- .../commands/options/account/controller.rs | 105 +++++++++++++++--- 1 file changed, 92 insertions(+), 13 deletions(-) diff --git a/bin/sozo/src/commands/options/account/controller.rs b/bin/sozo/src/commands/options/account/controller.rs index 25064dfe6d..a72f5f05f1 100644 --- a/bin/sozo/src/commands/options/account/controller.rs +++ b/bin/sozo/src/commands/options/account/controller.rs @@ -363,7 +363,18 @@ fn collect_policies_from_contracts( let mut policies: Vec = Vec::new(); for (tag, info) in contracts { + // Exclude core world entrypoints from session policies. + if tag == "world" { + trace!(target: "account::controller", tag, "Skipping world contract policies"); + continue; + } + for e in &info.entrypoints { + if !is_session_entrypoint(e) { + trace!(target: "account::controller", tag, method = %e, "Skipping non-session entrypoint"); + continue; + } + let policy = PolicyMethod { target: info.address, method: e.clone() }; trace!(target: "account::controller", tag, target = format!("{:#x}", policy.target), method = %policy.method, "Adding policy"); policies.push(policy); @@ -386,6 +397,11 @@ fn collect_policies_from_contracts( Ok(policies) } +fn is_session_entrypoint(method: &str) -> bool { + // Exclude internal/core methods from app sessions. + method != "upgrade" && !method.starts_with("__") +} + #[cfg(test)] mod tests { use std::collections::HashMap; @@ -406,24 +422,87 @@ mod tests { let manifest = scarb_metadata.read_dojo_manifest_profile().expect("Failed to read manifest").unwrap(); let contracts: HashMap = (&manifest).into(); + let world_address = contracts.get("world").unwrap().address; + let actions = contracts.get("ns-actions").unwrap(); + let actions_address = actions.address; let user_addr = felt!("0x2af9427c5a277474c079a1283c880ee8a6f0f8fbf73ce969c08d88befec1bba"); let policies = collect_policies(user_addr, &contracts).unwrap(); - if std::env::var("POLICIES_FIX").is_ok() { - let policies_json = serde_json::to_string_pretty(&policies).unwrap(); - println!("{}", policies_json); - } else { - let test_data = include_str!("../../../../tests/test_data/policies.json"); - let expected_policies: Vec = serde_json::from_str(test_data).unwrap(); - - // Compare the collected policies with the test data. - assert_eq!(policies.len(), expected_policies.len()); - expected_policies.iter().for_each(|p| { - assert!(policies.contains(p), "Policy method '{}' is missing", p.method) - }); - } + // Should include user systems. + assert!( + policies.contains(&PolicyMethod { target: actions_address, method: "spawn".into() }) + ); + assert!( + policies.contains(&PolicyMethod { target: actions_address, method: "move".into() }) + ); + + // Should not include world/core policy methods. + assert!( + !policies.iter().any(|p| p.target == world_address), + "world entrypoints should not be included in session policies" + ); + + // Should not include upgrade/internal methods. + assert!( + !policies.iter().any(|p| p.method == "upgrade"), + "upgrade should not be included in session policies" + ); + assert!( + !policies + .iter() + .any(|p| p.method.starts_with("__") && p.method != "__declare_transaction__"), + "internal methods should not be included in session policies" + ); + + // Should keep required meta policies. + assert!( + policies.contains(&PolicyMethod { + target: user_addr, + method: "__declare_transaction__".into(), + }), + "declare policy is missing" + ); + assert!( + policies.contains(&PolicyMethod { + target: felt!("0x041a78e741e5af2fec34b695679bc6891742439f7afb8484ecd7766661ad02bf"), + method: "deployContract".into(), + }), + "UDC deployment policy is missing" + ); + } + + #[test] + fn collect_policies_filters_world_and_upgrade() { + let user_addr = felt!("0x123"); + let world_addr = felt!("0x456"); + let actions_addr = felt!("0x789"); + + let mut contracts = HashMap::new(); + contracts.insert( + "world".to_string(), + ContractInfo { + tag_or_name: "world".to_string(), + address: world_addr, + entrypoints: vec!["register_model".into(), "set_entity".into()], + }, + ); + contracts.insert( + "ns-actions".to_string(), + ContractInfo { + tag_or_name: "ns-actions".to_string(), + address: actions_addr, + entrypoints: vec!["spawn".into(), "move".into(), "upgrade".into()], + }, + ); + + let policies = collect_policies(user_addr, &contracts).unwrap(); + + assert!(policies.contains(&PolicyMethod { target: actions_addr, method: "spawn".into() })); + assert!(policies.contains(&PolicyMethod { target: actions_addr, method: "move".into() })); + assert!(!policies.iter().any(|p| p.target == world_addr)); + assert!(!policies.iter().any(|p| p.method == "upgrade")); } #[test] From 7d3289e5baf1d5f436bbafccf1ac2137c190b801 Mon Sep 17 00:00:00 2001 From: Nasr Date: Mon, 9 Feb 2026 15:33:29 -0600 Subject: [PATCH 3/9] feat(sozo): move session flows under controller command --- bin/sozo/src/commands/controller.rs | 26 +++++ bin/sozo/src/commands/mod.rs | 14 ++- .../commands/options/account/controller.rs | 109 ++++++++++++------ bin/sozo/src/commands/options/account/mod.rs | 14 +-- bin/sozo/src/commands/session.rs | 6 +- 5 files changed, 114 insertions(+), 55 deletions(-) create mode 100644 bin/sozo/src/commands/controller.rs diff --git a/bin/sozo/src/commands/controller.rs b/bin/sozo/src/commands/controller.rs new file mode 100644 index 0000000000..c7ee207dfc --- /dev/null +++ b/bin/sozo/src/commands/controller.rs @@ -0,0 +1,26 @@ +use anyhow::Result; +use clap::{Args, Subcommand}; +use scarb_metadata::Metadata; +use sozo_ui::SozoUi; + +use super::session::SessionArgs; + +#[derive(Debug, Args)] +pub struct ControllerArgs { + #[command(subcommand)] + command: ControllerCommand, +} + +#[derive(Debug, Subcommand)] +pub enum ControllerCommand { + #[command(about = "Manage Cartridge controller sessions")] + Session(Box), +} + +impl ControllerArgs { + pub async fn run(self, scarb_metadata: &Metadata, ui: &SozoUi) -> Result<()> { + match self.command { + ControllerCommand::Session(args) => args.run(scarb_metadata, ui).await, + } + } +} diff --git a/bin/sozo/src/commands/mod.rs b/bin/sozo/src/commands/mod.rs index 539e2b5d15..754d5e501f 100644 --- a/bin/sozo/src/commands/mod.rs +++ b/bin/sozo/src/commands/mod.rs @@ -11,6 +11,8 @@ pub(crate) mod bindgen; pub(crate) mod build; pub(crate) mod call; pub(crate) mod clean; +#[cfg(feature = "controller")] +pub(crate) mod controller; pub(crate) mod declare; pub(crate) mod deploy; pub(crate) mod events; @@ -35,6 +37,8 @@ use bindgen::BindgenArgs; use build::BuildArgs; use call::CallArgs; use clean::CleanArgs; +#[cfg(feature = "controller")] +use controller::ControllerArgs; use declare::DeclareArgs; use deploy::DeployArgs; use events::EventsArgs; @@ -47,8 +51,6 @@ use invoke::InvokeArgs; use mcp::McpArgs; use migrate::MigrateArgs; use model::ModelArgs; -#[cfg(feature = "controller")] -use session::SessionArgs; #[cfg(feature = "walnut")] use sozo_walnut::walnut::WalnutArgs; use starknet::StarknetArgs; @@ -91,8 +93,8 @@ pub enum Commands { #[command(about = "Inspect a model")] Model(Box), #[cfg(feature = "controller")] - #[command(about = "Manage Cartridge controller sessions")] - Session(Box), + #[command(about = "Controller utility commands")] + Controller(Box), #[command(about = "Runs cairo tests")] Test(Box), #[command(about = "Print version")] @@ -126,7 +128,7 @@ impl fmt::Display for Commands { Commands::Migrate(_) => write!(f, "Migrate"), Commands::Model(_) => write!(f, "Model"), #[cfg(feature = "controller")] - Commands::Session(_) => write!(f, "Session"), + Commands::Controller(_) => write!(f, "Controller"), Commands::Test(_) => write!(f, "Test"), Commands::Version(_) => write!(f, "Version"), Commands::Mcp(_) => write!(f, "Mcp"), @@ -160,7 +162,7 @@ pub async fn run(command: Commands, scarb_metadata: &Metadata, ui: &SozoUi) -> R Commands::Migrate(args) => args.run(scarb_metadata, ui).await, Commands::Model(args) => args.run(scarb_metadata, ui).await, #[cfg(feature = "controller")] - Commands::Session(args) => args.run(scarb_metadata, ui).await, + Commands::Controller(args) => args.run(scarb_metadata, ui).await, Commands::Test(args) => args.run(scarb_metadata), Commands::Version(args) => args.run(scarb_metadata), #[cfg(feature = "walnut")] diff --git a/bin/sozo/src/commands/options/account/controller.rs b/bin/sozo/src/commands/options/account/controller.rs index a72f5f05f1..4509790e08 100644 --- a/bin/sozo/src/commands/options/account/controller.rs +++ b/bin/sozo/src/commands/options/account/controller.rs @@ -363,18 +363,7 @@ fn collect_policies_from_contracts( let mut policies: Vec = Vec::new(); for (tag, info) in contracts { - // Exclude core world entrypoints from session policies. - if tag == "world" { - trace!(target: "account::controller", tag, "Skipping world contract policies"); - continue; - } - for e in &info.entrypoints { - if !is_session_entrypoint(e) { - trace!(target: "account::controller", tag, method = %e, "Skipping non-session entrypoint"); - continue; - } - let policy = PolicyMethod { target: info.address, method: e.clone() }; trace!(target: "account::controller", tag, target = format!("{:#x}", policy.target), method = %policy.method, "Adding policy"); policies.push(policy); @@ -394,12 +383,10 @@ fn collect_policies_from_contracts( policies.push(PolicyMethod { target: UDC_ADDRESS, method }); trace!(target: "account::controller", "Adding UDC deployment policy"); - Ok(policies) -} + // Keep a deterministic policy order so session root comparison remains stable across runs. + policies.sort_by(|a, b| a.target.cmp(&b.target).then_with(|| a.method.cmp(&b.method))); -fn is_session_entrypoint(method: &str) -> bool { - // Exclude internal/core methods from app sessions. - method != "upgrade" && !method.starts_with("__") + Ok(policies) } #[cfg(test)] @@ -412,7 +399,9 @@ mod tests { use scarb_metadata_ext::MetadataDojoExt; use starknet::macros::felt; - use super::{PolicyMethod, collect_policies, extract_oauth_code}; + use super::{ + PolicyMethod, collect_policies, collect_policies_from_contracts, extract_oauth_code, + }; #[test] fn collect_policies_from_project() { @@ -425,6 +414,7 @@ mod tests { let world_address = contracts.get("world").unwrap().address; let actions = contracts.get("ns-actions").unwrap(); let actions_address = actions.address; + let world = contracts.get("world").unwrap(); let user_addr = felt!("0x2af9427c5a277474c079a1283c880ee8a6f0f8fbf73ce969c08d88befec1bba"); @@ -438,23 +428,18 @@ mod tests { policies.contains(&PolicyMethod { target: actions_address, method: "move".into() }) ); - // Should not include world/core policy methods. + // Should include world contract policies. assert!( - !policies.iter().any(|p| p.target == world_address), - "world entrypoints should not be included in session policies" + policies.iter().any(|p| p.target == world_address), + "world entrypoints should be included in session policies" ); - // Should not include upgrade/internal methods. - assert!( - !policies.iter().any(|p| p.method == "upgrade"), - "upgrade should not be included in session policies" - ); - assert!( - !policies - .iter() - .any(|p| p.method.starts_with("__") && p.method != "__declare_transaction__"), - "internal methods should not be included in session policies" - ); + // World methods from manifest should be part of policies. + for method in &world.entrypoints { + assert!( + policies.contains(&PolicyMethod { target: world_address, method: method.clone() }) + ); + } // Should keep required meta policies. assert!( @@ -474,7 +459,7 @@ mod tests { } #[test] - fn collect_policies_filters_world_and_upgrade() { + fn collect_policies_includes_world_and_upgrade() { let user_addr = felt!("0x123"); let world_addr = felt!("0x456"); let actions_addr = felt!("0x789"); @@ -499,10 +484,66 @@ mod tests { let policies = collect_policies(user_addr, &contracts).unwrap(); + assert!( + policies + .contains(&PolicyMethod { target: world_addr, method: "register_model".into() }) + ); + assert!( + policies.contains(&PolicyMethod { target: world_addr, method: "set_entity".into() }) + ); assert!(policies.contains(&PolicyMethod { target: actions_addr, method: "spawn".into() })); assert!(policies.contains(&PolicyMethod { target: actions_addr, method: "move".into() })); - assert!(!policies.iter().any(|p| p.target == world_addr)); - assert!(!policies.iter().any(|p| p.method == "upgrade")); + assert!( + policies.contains(&PolicyMethod { target: actions_addr, method: "upgrade".into() }) + ); + } + + #[test] + fn collect_policies_has_stable_order() { + let user_addr = felt!("0x123"); + let a_addr = felt!("0x2"); + let b_addr = felt!("0x1"); + + let mut contracts_a = HashMap::new(); + contracts_a.insert( + "a".to_string(), + ContractInfo { + tag_or_name: "a".to_string(), + address: a_addr, + entrypoints: vec!["z".into(), "a".into()], + }, + ); + contracts_a.insert( + "b".to_string(), + ContractInfo { + tag_or_name: "b".to_string(), + address: b_addr, + entrypoints: vec!["m".into()], + }, + ); + + let mut contracts_b = HashMap::new(); + contracts_b.insert( + "b".to_string(), + ContractInfo { + tag_or_name: "b".to_string(), + address: b_addr, + entrypoints: vec!["m".into()], + }, + ); + contracts_b.insert( + "a".to_string(), + ContractInfo { + tag_or_name: "a".to_string(), + address: a_addr, + entrypoints: vec!["z".into(), "a".into()], + }, + ); + + let policies_a = collect_policies_from_contracts(user_addr, &contracts_a).unwrap(); + let policies_b = collect_policies_from_contracts(user_addr, &contracts_b).unwrap(); + + assert_eq!(policies_a, policies_b); } #[test] diff --git a/bin/sozo/src/commands/options/account/mod.rs b/bin/sozo/src/commands/options/account/mod.rs index 20aa30c647..f12386eec1 100644 --- a/bin/sozo/src/commands/options/account/mod.rs +++ b/bin/sozo/src/commands/options/account/mod.rs @@ -45,17 +45,10 @@ pub struct AccountOptions { #[arg(help = "Use one of Katana's pre-funded dev accounts (katana0..katana9).")] pub katana_account: Option, - #[arg(global = true)] - #[arg(long = "slot.controller")] - #[arg(help_heading = "Controller options")] - #[arg(help = "Use Slot's Controller account")] - #[cfg(feature = "controller")] - pub controller: bool, - #[arg(global = true)] #[arg(long = "session")] #[arg(help_heading = "Controller options")] - #[arg(help = "Use Cartridge Controller session account (alias of --slot.controller)")] + #[arg(help = "Use Cartridge Controller session account")] #[cfg(feature = "controller")] pub session: bool, @@ -92,7 +85,7 @@ impl AccountOptions { P: Send + Sync, { #[cfg(feature = "controller")] - if self.controller || self.session { + if self.session { let url = starknet.url(env_metadata)?; let cartridge_provider = CartridgeJsonRpcProvider::new(url.clone()); let account = self.controller(url, cartridge_provider.clone(), contracts).await?; @@ -231,10 +224,9 @@ mod tests { #[cfg(feature = "controller")] #[test] - fn controller_session_alias_flag_is_parsed() { + fn controller_session_flag_is_parsed() { let cmd = Command::parse_from(["sozo", "--session"]); assert!(cmd.account.session); - assert!(!cmd.account.controller); } #[test] diff --git a/bin/sozo/src/commands/session.rs b/bin/sozo/src/commands/session.rs index 9c4ffcfbf4..3007e523f8 100644 --- a/bin/sozo/src/commands/session.rs +++ b/bin/sozo/src/commands/session.rs @@ -101,9 +101,7 @@ async fn create_session( ui.print(format!("Policies : {}", session.session.proved_policies.len())); ui.print(format!("Expires at (unix) : {}", session.session.inner.expires_at)); ui.print(format!("Stored session : {}", session_path.display())); - ui.print( - "Use `sozo execute ... --session` (or `--slot.controller`) to execute with this session.", - ); + ui.print("Use `sozo execute ... --session` to execute with this session."); Ok(()) } @@ -122,7 +120,7 @@ async fn status_session( | slot::Error::MalformedCredentials | slot::Error::InvalidOAuth, ) => { - ui.warn("No controller credentials found. Run `sozo session create` first."); + ui.warn("No controller credentials found. Run `sozo controller session create` first."); return Ok(()); } Err(err) => return Err(err.into()), From b59c4ecd7d1690dddf0681540d397a4e84db5b68 Mon Sep 17 00:00:00 2001 From: Nasr Date: Mon, 9 Feb 2026 15:42:03 -0600 Subject: [PATCH 4/9] fix(sozo): verify controller session is registered onchain --- .../commands/options/account/controller.rs | 80 ++++++++++++++++++- 1 file changed, 77 insertions(+), 3 deletions(-) diff --git a/bin/sozo/src/commands/options/account/controller.rs b/bin/sozo/src/commands/options/account/controller.rs index 4509790e08..a1c3ba3966 100644 --- a/bin/sozo/src/commands/options/account/controller.rs +++ b/bin/sozo/src/commands/options/account/controller.rs @@ -10,9 +10,10 @@ use serde::{Deserialize, Serialize}; use slot::account_sdk::account::session::account::SessionAccount; use slot::account_sdk::account::session::merkle::MerkleTree; use slot::account_sdk::account::session::policy::{CallPolicy, MerkleLeaf, Policy, ProvedPolicy}; +use slot::account_sdk::hash::MessageHashRev1; use slot::account_sdk::provider::CartridgeJsonRpcProvider; use slot::session::{FullSessionInfo, PolicyMethod}; -use starknet::core::types::Felt; +use starknet::core::types::{BlockId, BlockTag, Felt, FunctionCall}; use starknet::core::utils::get_selector_from_name; use starknet::macros::felt; use starknet::providers::Provider; @@ -25,6 +26,8 @@ pub type ControllerAccount = SessionAccount; const CONTROLLER_OAUTH_TIMEOUT_SECS: u64 = 300; const CONTROLLER_OAUTH_CALLBACK_PATH: &str = "/callback"; const CONTROLLER_LOGIN_PATH: &str = "/slot"; +const CONTROLLER_SESSION_REGISTRATION_TIMEOUT_SECS: u64 = 60; +const CONTROLLER_SESSION_REGISTRATION_POLL_MS: u64 = 1_500; const CONTROLLER_ACCOUNT_INFO_QUERY: &str = r#" query ControllerAccountInfo { me { @@ -119,17 +122,33 @@ pub async fn create_controller( // Check if the policies have changed let is_equal = is_equal_to_existing(&policies, &session); - if is_equal { + let is_registered = is_session_registered_onchain( + &rpc_provider, + session.auth.address, + chain_id, + &session, + ) + .await?; + + if is_equal && is_registered { session } else { trace!( target: "account::controller", new_policies = policies.len(), existing_policies = session.session.requested_policies.len(), - "Policies have changed. Creating new session." + is_registered, + "Session missing onchain or policies changed. Creating new session." ); let session = slot::session::create(rpc_url.clone(), &policies).await?; + ensure_session_registered_onchain( + &rpc_provider, + session.auth.address, + chain_id, + &session, + ) + .await?; slot::session::store(chain_id, &session)?; session } @@ -139,6 +158,13 @@ pub async fn create_controller( None => { trace!(target: "account::controller", %username, chain = format!("{chain_id:#}"), "Creating new session."); let session = slot::session::create(rpc_url.clone(), &policies).await?; + ensure_session_registered_onchain( + &rpc_provider, + session.auth.address, + chain_id, + &session, + ) + .await?; slot::session::store(chain_id, &session)?; session } @@ -342,6 +368,54 @@ fn is_equal_to_existing(new_policies: &[PolicyMethod], session_info: &FullSessio new_policies_root == session_info.session.inner.allowed_policies_root } +async fn is_session_registered_onchain( + provider: &CartridgeJsonRpcProvider, + controller_address: Felt, + chain_id: Felt, + session: &FullSessionInfo, +) -> Result { + let session_hash = session.session.inner.get_message_hash_rev_1(chain_id, controller_address); + + let call = FunctionCall { + contract_address: controller_address, + entry_point_selector: get_selector_from_name("is_session_registered") + .context("Failed to resolve selector for `is_session_registered`")?, + calldata: vec![session_hash], + }; + + let result = provider.call(call, BlockId::Tag(BlockTag::Latest)).await?; + Ok(result.first().is_some_and(|v| *v != Felt::ZERO)) +} + +async fn ensure_session_registered_onchain( + provider: &CartridgeJsonRpcProvider, + controller_address: Felt, + chain_id: Felt, + session: &FullSessionInfo, +) -> Result<()> { + if is_session_registered_onchain(provider, controller_address, chain_id, session).await? { + return Ok(()); + } + + let timeout = Duration::from_secs(CONTROLLER_SESSION_REGISTRATION_TIMEOUT_SECS); + let poll = Duration::from_millis(CONTROLLER_SESSION_REGISTRATION_POLL_MS); + let started = std::time::Instant::now(); + + while started.elapsed() < timeout { + tokio::time::sleep(poll).await; + + if is_session_registered_onchain(provider, controller_address, chain_id, session).await? { + return Ok(()); + } + } + + bail!( + "Controller session was created locally but is not registered onchain yet (timeout: {}s). \ + Please retry `sozo controller session create`.", + CONTROLLER_SESSION_REGISTRATION_TIMEOUT_SECS + ); +} + /// Policies are the building block of a session key. It's what defines what methods are allowed for /// an external signer to execute using the session key. /// From e13353c368d7bdbb2a6ba18c28a83106b7969825 Mon Sep 17 00:00:00 2001 From: Nasr Date: Mon, 9 Feb 2026 15:47:31 -0600 Subject: [PATCH 5/9] fix(sozo): pass owner guid to session registration check --- bin/sozo/src/commands/options/account/controller.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/sozo/src/commands/options/account/controller.rs b/bin/sozo/src/commands/options/account/controller.rs index a1c3ba3966..2767ebd5f8 100644 --- a/bin/sozo/src/commands/options/account/controller.rs +++ b/bin/sozo/src/commands/options/account/controller.rs @@ -380,7 +380,7 @@ async fn is_session_registered_onchain( contract_address: controller_address, entry_point_selector: get_selector_from_name("is_session_registered") .context("Failed to resolve selector for `is_session_registered`")?, - calldata: vec![session_hash], + calldata: vec![session_hash, session.auth.owner_guid], }; let result = provider.call(call, BlockId::Tag(BlockTag::Latest)).await?; From 956b56d4c2f518344ca88f26abcbcfea3d44be29 Mon Sep 17 00:00:00 2001 From: Nasr Date: Mon, 9 Feb 2026 16:24:55 -0600 Subject: [PATCH 6/9] fix(sozo): remove blocking onchain wait after session callback --- .../commands/options/account/controller.rs | 45 ------------------- 1 file changed, 45 deletions(-) diff --git a/bin/sozo/src/commands/options/account/controller.rs b/bin/sozo/src/commands/options/account/controller.rs index 2767ebd5f8..099d6693cd 100644 --- a/bin/sozo/src/commands/options/account/controller.rs +++ b/bin/sozo/src/commands/options/account/controller.rs @@ -26,8 +26,6 @@ pub type ControllerAccount = SessionAccount; const CONTROLLER_OAUTH_TIMEOUT_SECS: u64 = 300; const CONTROLLER_OAUTH_CALLBACK_PATH: &str = "/callback"; const CONTROLLER_LOGIN_PATH: &str = "/slot"; -const CONTROLLER_SESSION_REGISTRATION_TIMEOUT_SECS: u64 = 60; -const CONTROLLER_SESSION_REGISTRATION_POLL_MS: u64 = 1_500; const CONTROLLER_ACCOUNT_INFO_QUERY: &str = r#" query ControllerAccountInfo { me { @@ -142,13 +140,6 @@ pub async fn create_controller( ); let session = slot::session::create(rpc_url.clone(), &policies).await?; - ensure_session_registered_onchain( - &rpc_provider, - session.auth.address, - chain_id, - &session, - ) - .await?; slot::session::store(chain_id, &session)?; session } @@ -158,13 +149,6 @@ pub async fn create_controller( None => { trace!(target: "account::controller", %username, chain = format!("{chain_id:#}"), "Creating new session."); let session = slot::session::create(rpc_url.clone(), &policies).await?; - ensure_session_registered_onchain( - &rpc_provider, - session.auth.address, - chain_id, - &session, - ) - .await?; slot::session::store(chain_id, &session)?; session } @@ -387,35 +371,6 @@ async fn is_session_registered_onchain( Ok(result.first().is_some_and(|v| *v != Felt::ZERO)) } -async fn ensure_session_registered_onchain( - provider: &CartridgeJsonRpcProvider, - controller_address: Felt, - chain_id: Felt, - session: &FullSessionInfo, -) -> Result<()> { - if is_session_registered_onchain(provider, controller_address, chain_id, session).await? { - return Ok(()); - } - - let timeout = Duration::from_secs(CONTROLLER_SESSION_REGISTRATION_TIMEOUT_SECS); - let poll = Duration::from_millis(CONTROLLER_SESSION_REGISTRATION_POLL_MS); - let started = std::time::Instant::now(); - - while started.elapsed() < timeout { - tokio::time::sleep(poll).await; - - if is_session_registered_onchain(provider, controller_address, chain_id, session).await? { - return Ok(()); - } - } - - bail!( - "Controller session was created locally but is not registered onchain yet (timeout: {}s). \ - Please retry `sozo controller session create`.", - CONTROLLER_SESSION_REGISTRATION_TIMEOUT_SECS - ); -} - /// Policies are the building block of a session key. It's what defines what methods are allowed for /// an external signer to execute using the session key. /// From 6b4ca534622da2ff04c56574627849bce491c688 Mon Sep 17 00:00:00 2001 From: Nasr Date: Mon, 9 Feb 2026 16:28:12 -0600 Subject: [PATCH 7/9] fix(sozo): stop re-registering controller session on execute --- .../commands/options/account/controller.rs | 35 ++----------------- 1 file changed, 3 insertions(+), 32 deletions(-) diff --git a/bin/sozo/src/commands/options/account/controller.rs b/bin/sozo/src/commands/options/account/controller.rs index 099d6693cd..4509790e08 100644 --- a/bin/sozo/src/commands/options/account/controller.rs +++ b/bin/sozo/src/commands/options/account/controller.rs @@ -10,10 +10,9 @@ use serde::{Deserialize, Serialize}; use slot::account_sdk::account::session::account::SessionAccount; use slot::account_sdk::account::session::merkle::MerkleTree; use slot::account_sdk::account::session::policy::{CallPolicy, MerkleLeaf, Policy, ProvedPolicy}; -use slot::account_sdk::hash::MessageHashRev1; use slot::account_sdk::provider::CartridgeJsonRpcProvider; use slot::session::{FullSessionInfo, PolicyMethod}; -use starknet::core::types::{BlockId, BlockTag, Felt, FunctionCall}; +use starknet::core::types::Felt; use starknet::core::utils::get_selector_from_name; use starknet::macros::felt; use starknet::providers::Provider; @@ -120,23 +119,14 @@ pub async fn create_controller( // Check if the policies have changed let is_equal = is_equal_to_existing(&policies, &session); - let is_registered = is_session_registered_onchain( - &rpc_provider, - session.auth.address, - chain_id, - &session, - ) - .await?; - - if is_equal && is_registered { + if is_equal { session } else { trace!( target: "account::controller", new_policies = policies.len(), existing_policies = session.session.requested_policies.len(), - is_registered, - "Session missing onchain or policies changed. Creating new session." + "Policies have changed. Creating new session." ); let session = slot::session::create(rpc_url.clone(), &policies).await?; @@ -352,25 +342,6 @@ fn is_equal_to_existing(new_policies: &[PolicyMethod], session_info: &FullSessio new_policies_root == session_info.session.inner.allowed_policies_root } -async fn is_session_registered_onchain( - provider: &CartridgeJsonRpcProvider, - controller_address: Felt, - chain_id: Felt, - session: &FullSessionInfo, -) -> Result { - let session_hash = session.session.inner.get_message_hash_rev_1(chain_id, controller_address); - - let call = FunctionCall { - contract_address: controller_address, - entry_point_selector: get_selector_from_name("is_session_registered") - .context("Failed to resolve selector for `is_session_registered`")?, - calldata: vec![session_hash, session.auth.owner_guid], - }; - - let result = provider.call(call, BlockId::Tag(BlockTag::Latest)).await?; - Ok(result.first().is_some_and(|v| *v != Felt::ZERO)) -} - /// Policies are the building block of a session key. It's what defines what methods are allowed for /// an external signer to execute using the session key. /// From 139789ec85e061dc71522b4337345db07eab3daf Mon Sep 17 00:00:00 2001 From: Nasr Date: Mon, 9 Feb 2026 16:34:40 -0600 Subject: [PATCH 8/9] fix(sozo): align session policy ordering with controller canonical sort --- .../commands/options/account/controller.rs | 46 ++++++++++++++++++- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/bin/sozo/src/commands/options/account/controller.rs b/bin/sozo/src/commands/options/account/controller.rs index 4509790e08..d354aef358 100644 --- a/bin/sozo/src/commands/options/account/controller.rs +++ b/bin/sozo/src/commands/options/account/controller.rs @@ -383,8 +383,14 @@ fn collect_policies_from_contracts( policies.push(PolicyMethod { target: UDC_ADDRESS, method }); trace!(target: "account::controller", "Adding UDC deployment policy"); - // Keep a deterministic policy order so session root comparison remains stable across runs. - policies.sort_by(|a, b| a.target.cmp(&b.target).then_with(|| a.method.cmp(&b.method))); + // Keep canonical ordering aligned with controller/keychain sorting: + // contract address (as lowercase hex string) then method name. + policies.sort_by(|a, b| { + format!("{:#x}", a.target) + .to_lowercase() + .cmp(&format!("{:#x}", b.target).to_lowercase()) + .then_with(|| a.method.cmp(&b.method)) + }); Ok(policies) } @@ -546,6 +552,42 @@ mod tests { assert_eq!(policies_a, policies_b); } + #[test] + fn collect_policies_uses_controller_canonical_address_sort() { + let user_addr = felt!("0x123"); + let addr_2 = felt!("0x2"); + let addr_10 = felt!("0x10"); + + let mut contracts = HashMap::new(); + contracts.insert( + "two".to_string(), + ContractInfo { + tag_or_name: "two".to_string(), + address: addr_2, + entrypoints: vec!["exec".into()], + }, + ); + contracts.insert( + "ten".to_string(), + ContractInfo { + tag_or_name: "ten".to_string(), + address: addr_10, + entrypoints: vec!["exec".into()], + }, + ); + + let policies = collect_policies_from_contracts(user_addr, &contracts).unwrap(); + + // Controller canonical sort is string-based, so 0x10 comes before 0x2. + let first_two = policies + .iter() + .filter(|p| p.method == "exec") + .take(2) + .map(|p| p.target) + .collect::>(); + assert_eq!(first_two, vec![addr_10, addr_2]); + } + #[test] fn extract_oauth_code_from_callback_target() { let code = extract_oauth_code("/callback?code=abc123&state=xyz"); From 99aa87d8a47801b6f45f663f4fbbb11cabd23e67 Mon Sep 17 00:00:00 2001 From: Nasr Date: Tue, 10 Feb 2026 17:14:07 -0600 Subject: [PATCH 9/9] feat(sozo): scope controller sessions by project/profile context --- .../commands/options/account/controller.rs | 1047 ++++++++++++++++- bin/sozo/src/commands/session.rs | 136 ++- bin/sozo/src/main.rs | 10 +- 3 files changed, 1124 insertions(+), 69 deletions(-) diff --git a/bin/sozo/src/commands/options/account/controller.rs b/bin/sozo/src/commands/options/account/controller.rs index d354aef358..68e3c929af 100644 --- a/bin/sozo/src/commands/options/account/controller.rs +++ b/bin/sozo/src/commands/options/account/controller.rs @@ -1,21 +1,30 @@ -use std::collections::HashMap; +use std::cmp::Reverse; +use std::collections::{BTreeSet, HashMap}; +use std::fs; use std::io::{Read, Write}; use std::net::{TcpListener, TcpStream}; +use std::path::PathBuf; use std::str::FromStr; use std::time::Duration; use anyhow::{Context, Result, anyhow, bail}; +use base64::{Engine as _, engine::general_purpose}; +use cainome_cairo_serde::NonZero; use dojo_world::contracts::contract_info::ContractInfo; use serde::{Deserialize, Serialize}; +use slot::account_sdk::abigen::controller::{Signer as ControllerSigner, StarknetSigner}; use slot::account_sdk::account::session::account::SessionAccount; -use slot::account_sdk::account::session::merkle::MerkleTree; -use slot::account_sdk::account::session::policy::{CallPolicy, MerkleLeaf, Policy, ProvedPolicy}; +use slot::account_sdk::account::session::hash::Session; +use slot::account_sdk::account::session::policy::{CallPolicy, Policy}; +use slot::account_sdk::hash::MessageHashRev1; use slot::account_sdk::provider::CartridgeJsonRpcProvider; use slot::session::{FullSessionInfo, PolicyMethod}; use starknet::core::types::Felt; use starknet::core::utils::get_selector_from_name; use starknet::macros::felt; use starknet::providers::Provider; +use starknet::providers::jsonrpc::{HttpTransport, JsonRpcClient}; +use starknet::signers::SigningKey; use tracing::trace; use url::Url; @@ -25,6 +34,11 @@ pub type ControllerAccount = SessionAccount; const CONTROLLER_OAUTH_TIMEOUT_SECS: u64 = 300; const CONTROLLER_OAUTH_CALLBACK_PATH: &str = "/callback"; const CONTROLLER_LOGIN_PATH: &str = "/slot"; +const CONTROLLER_SESSION_CREATION_PATH: &str = "/session"; +const CONTROLLER_SHORTENER_PATH: &str = "/s"; +const CONTROLLER_SESSION_TIMEOUT_SECS: u64 = 300; +const MULTI_SESSION_FILE_INFIX: &str = "-session-"; +const SOZO_PROFILE_ENV_VAR: &str = "SOZO_PROFILE"; const CONTROLLER_ACCOUNT_INFO_QUERY: &str = r#" query ControllerAccountInfo { me { @@ -79,6 +93,193 @@ where variables: T, } +#[derive(Debug, Serialize)] +struct ShortUrlRequest<'a> { + url: &'a str, +} + +#[derive(Debug, Deserialize)] +struct ShortUrlResponse { + url: String, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +struct ControllerSessionResponse { + username: String, + address: Felt, + owner_guid: Felt, + expires_at: String, + transaction_hash: Option, + #[serde(default)] + already_registered: bool, + allowed_policies_root: Option, + metadata_hash: Option, + session_key_guid: Option, + guardian_key_guid: Option, +} + +fn session_user_dir(username: &str) -> PathBuf { + slot::utils::config_dir().join(username) +} + +fn fnv1a64(input: &[u8]) -> u64 { + let mut hash: u64 = 0xcbf29ce484222325; + for byte in input { + hash ^= u64::from(*byte); + hash = hash.wrapping_mul(0x100000001b3); + } + hash +} + +fn discover_project_root() -> PathBuf { + let cwd = std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")); + for ancestor in cwd.ancestors() { + if ancestor.join("Scarb.toml").exists() { + return ancestor.to_path_buf(); + } + } + cwd +} + +pub(crate) fn current_session_context_hash() -> String { + let profile = std::env::var(SOZO_PROFILE_ENV_VAR) + .or_else(|_| std::env::var("SCARB_PROFILE")) + .unwrap_or_else(|_| "dev".to_string()); + let project_root = discover_project_root(); + let context_raw = format!("project={}|profile={}", project_root.display(), profile); + format!("{:016x}", fnv1a64(context_raw.as_bytes())) +} + +fn multi_session_file_path( + username: &str, + chain_id: Felt, + context_hash: &str, + policy_root: Felt, +) -> PathBuf { + session_user_dir(username).join(format!( + "{chain_id:#x}{MULTI_SESSION_FILE_INFIX}{context_hash}-{policy_root:064x}.json" + )) +} + +fn load_multi_sessions_for_chain( + username: &str, + chain_id: Felt, + context_hash: &str, +) -> Result> { + let user_dir = session_user_dir(username); + if !user_dir.exists() { + return Ok(Vec::new()); + } + + let chain_prefix = format!("{chain_id:#x}{MULTI_SESSION_FILE_INFIX}{context_hash}-"); + let mut sessions = Vec::new(); + + for entry in fs::read_dir(&user_dir).context("Failed to read controller session directory")? { + let path = entry?.path(); + let Some(file_name) = path.file_name().and_then(|name| name.to_str()) else { + continue; + }; + + if !file_name.starts_with(&chain_prefix) || !file_name.ends_with(".json") { + continue; + } + + let contents = match fs::read_to_string(&path) { + Ok(contents) => contents, + Err(err) => { + trace!( + target: "account::controller", + path = %path.display(), + error = %err, + "Failed to read stored multi-session file, skipping." + ); + continue; + } + }; + + match serde_json::from_str::(&contents) { + Ok(session) if session.chain_id == chain_id => sessions.push(session), + Ok(_) => { + trace!( + target: "account::controller", + path = %path.display(), + "Skipping multi-session file with mismatched chain id." + ); + } + Err(err) => { + trace!( + target: "account::controller", + path = %path.display(), + error = %err, + "Failed to parse stored multi-session file, skipping." + ); + } + } + } + + Ok(sessions) +} + +fn find_matching_stored_session( + username: &str, + chain_id: Felt, + context_hash: &str, + policies: &[PolicyMethod], +) -> Result> { + let mut candidates = load_multi_sessions_for_chain(username, chain_id, context_hash)?; + if candidates.is_empty() { + // Backward-compatible fallback for users that only have the legacy single-session file. + if let Some(session) = slot::session::get(chain_id)? { + candidates.push(session); + } + } + + let mut dedup = BTreeSet::new(); + candidates.retain(|session| { + let key = format!( + "{:#x}:{:#x}:{:#x}:{}", + session.auth.address, + session.auth.owner_guid, + session.session.inner.session_key_guid, + session.session.inner.expires_at + ); + dedup.insert(key) + }); + + let mut matching = candidates + .into_iter() + .filter(|session| is_equal_to_existing(policies, session)) + .collect::>(); + + matching.sort_by_key(|session| Reverse(session.session.inner.expires_at)); + Ok(matching.into_iter().next()) +} + +fn persist_session_files( + chain_id: Felt, + context_hash: &str, + session: &FullSessionInfo, +) -> Result<()> { + slot::session::store(chain_id, session)?; + + let path = multi_session_file_path( + &session.auth.username, + chain_id, + context_hash, + session.session.inner.allowed_policies_root, + ); + if let Some(parent) = path.parent() { + fs::create_dir_all(parent).context("Failed to create controller session directory")?; + } + + let contents = + serde_json::to_string_pretty(session).context("Failed to serialize controller session")?; + fs::write(&path, contents).context("Failed to persist controller multi-session file")?; + + Ok(()) +} + /// Create a new Catridge Controller account based on session key. /// /// For now, Controller guarantees that if the provided network is among one of the supported @@ -110,41 +311,406 @@ pub async fn create_controller( }; let policies = collect_policies(contract_address, contracts)?; + let context_hash = current_session_context_hash(); - // Check if the session exists, if not create a new one - let session_details = match slot::session::get(chain_id)? { - Some(session) => { - trace!(target: "account::controller", expires_at = %session.session.inner.expires_at, policies = session.session.proved_policies.len(), "Found existing session."); - - // Check if the policies have changed - let is_equal = is_equal_to_existing(&policies, &session); - - if is_equal { + // Resolve the best stored session for this policy set and chain. + // This allows multiple project sessions to coexist on the same account/network. + let session_details = + match find_matching_stored_session(&username, chain_id, &context_hash, &policies)? { + Some(session) if !session.session.is_expired() => { + trace!( + target: "account::controller", + context_hash = %context_hash, + expires_at = %session.session.inner.expires_at, + policies = session.session.proved_policies.len(), + "Reusing matching stored session." + ); session - } else { + } + Some(session) => { + trace!( + target: "account::controller", + context_hash = %context_hash, + expires_at = %session.session.inner.expires_at, + "Matching stored session is expired. Creating a new session." + ); + create_session_with_short_url( + rpc_url.clone(), + chain_id, + contract_address, + None, + &policies, + ) + .await? + } + None => { trace!( target: "account::controller", - new_policies = policies.len(), - existing_policies = session.session.requested_policies.len(), - "Policies have changed. Creating new session." + %username, + context_hash = %context_hash, + chain = format!("{chain_id:#}"), + "No matching stored session found. Creating new session." ); + create_session_with_short_url( + rpc_url.clone(), + chain_id, + contract_address, + None, + &policies, + ) + .await? + } + }; - let session = slot::session::create(rpc_url.clone(), &policies).await?; - slot::session::store(chain_id, &session)?; - session + persist_session_files(chain_id, &context_hash, &session_details)?; + + Ok(session_details.into_account(rpc_provider)) +} + +async fn create_session_with_short_url( + rpc_url: Url, + chain_id: Felt, + expected_controller_address: Felt, + existing_session: Option<&FullSessionInfo>, + policies: &[PolicyMethod], +) -> Result { + let signer = SigningKey::from_random(); + let pubkey = signer.verifying_key().scalar(); + + let credentials = slot::credential::Credentials::load()?; + let username = credentials.account.id; + + let response = + create_user_session_with_short_url(pubkey, &username, rpc_url.clone(), policies).await?; + trace!( + target: "account::controller", + already_registered = response.already_registered, + transaction_hash = ?response.transaction_hash, + "Received controller session callback response." + ); + if response.address != expected_controller_address { + bail!( + "Controller session callback address mismatch. expected={:#x}, callback={:#x}", + expected_controller_address, + response.address + ); + } + + let expires_at = response.expires_at.parse::().map_err(|e| anyhow!(e))?; + let mut session = build_session_from_policies(policies, expires_at, &signer, &response)?; + let mut local_hash = + session.inner.get_message_hash_rev_1(chain_id, expected_controller_address); + + // Trust on-chain registration status instead of GraphQL replication state. + let mut local_hash_registered = is_session_hash_registered_onchain( + &rpc_url, + expected_controller_address, + response.owner_guid, + local_hash, + ) + .await?; + + if !local_hash_registered { + // If controller reports already-registered, prefer reusing the currently stored session + // when it is still registered on-chain. + if response.already_registered { + if let Some(existing) = existing_session { + let existing_hash = existing + .session + .inner + .get_message_hash_rev_1(chain_id, expected_controller_address); + if is_session_hash_registered_onchain( + &rpc_url, + expected_controller_address, + response.owner_guid, + existing_hash, + ) + .await? + { + trace!( + target: "account::controller", + existing_hash = format!("{:#x}", existing_hash), + local_hash = format!("{:#x}", local_hash), + "Reusing previously stored registered session after callback hash mismatch." + ); + return Ok(existing.clone()); + } } } - // Create a new session if not found - None => { - trace!(target: "account::controller", %username, chain = format!("{chain_id:#}"), "Creating new session."); - let session = slot::session::create(rpc_url.clone(), &policies).await?; - slot::session::store(chain_id, &session)?; - session + // Try alternate deterministic policy orderings to match keychain canonicalization. + for candidate in alternate_policy_orders(policies) { + let candidate_session = + build_session_from_policies(&candidate, expires_at, &signer, &response)?; + let candidate_hash = candidate_session + .inner + .get_message_hash_rev_1(chain_id, expected_controller_address); + + if candidate_hash == local_hash { + continue; + } + + if is_session_hash_registered_onchain( + &rpc_url, + expected_controller_address, + response.owner_guid, + candidate_hash, + ) + .await? + { + trace!( + target: "account::controller", + previous_hash = format!("{:#x}", local_hash), + matched_hash = format!("{:#x}", candidate_hash), + "Recovered registered controller session hash using alternate policy ordering." + ); + session = candidate_session; + local_hash = candidate_hash; + local_hash_registered = true; + break; + } } + } + + if !local_hash_registered { + bail!( + "Registered session hash mismatch. local={:#x}, controller={:#x}, owner_guid={:#x}. The session was not found on-chain for this owner/session tuple.", + local_hash, + expected_controller_address, + response.owner_guid + ); + } + + let auth = slot::session::SessionAuth { + address: response.address, + username: response.username, + owner_guid: response.owner_guid, + signer: signer.secret_scalar(), }; - Ok(session_details.into_account(rpc_provider)) + Ok(FullSessionInfo { auth, session, chain_id }) +} + +fn build_session_from_policies( + policies: &[PolicyMethod], + expires_at: u64, + signer: &SigningKey, + response: &ControllerSessionResponse, +) -> Result { + let methods = policies + .iter() + .map(|p| -> Result { + let selector = get_selector_from_name(&p.method)?; + Ok(Policy::Call(CallPolicy { + contract_address: p.target, + selector, + authorized: Some(true), + })) + }) + .collect::>>()?; + + let mut session = Session::new( + methods, + expires_at, + &ControllerSigner::Starknet(StarknetSigner { + pubkey: NonZero::new(signer.verifying_key().scalar()) + .expect("public key scalar should not be zero"), + }), + Felt::ZERO, + )?; + + apply_session_response_overrides(&mut session, response)?; + Ok(session) +} + +fn alternate_policy_orders(policies: &[PolicyMethod]) -> Vec> { + let mut candidates = Vec::new(); + + let mut by_unpadded_address_then_method = policies.to_vec(); + by_unpadded_address_then_method.sort_by(|a, b| { + format!("{:#x}", a.target) + .to_ascii_lowercase() + .cmp(&format!("{:#x}", b.target).to_ascii_lowercase()) + .then_with(|| a.method.cmp(&b.method)) + }); + candidates.push(by_unpadded_address_then_method); + + let mut by_address_then_method_casefold = policies.to_vec(); + by_address_then_method_casefold.sort_by(|a, b| { + format!("0x{:064x}", a.target).cmp(&format!("0x{:064x}", b.target)).then_with(|| { + a.method + .to_ascii_lowercase() + .cmp(&b.method.to_ascii_lowercase()) + .then_with(|| a.method.cmp(&b.method)) + }) + }); + candidates.push(by_address_then_method_casefold); + + let mut by_method_then_address = policies.to_vec(); + by_method_then_address.sort_by(|a, b| { + a.method + .cmp(&b.method) + .then_with(|| format!("0x{:064x}", a.target).cmp(&format!("0x{:064x}", b.target))) + }); + candidates.push(by_method_then_address); + + // Keep first occurrence only while preserving insertion order. + let mut unique = Vec::new(); + for candidate in candidates { + if !unique.contains(&candidate) { + unique.push(candidate); + } + } + + unique +} + +async fn create_user_session_with_short_url( + public_key: Felt, + username: &str, + rpc_url: Url, + policies: &[PolicyMethod], +) -> Result { + let listener = TcpListener::bind("localhost:0") + .context("Failed to start local callback listener for controller session authorization")?; + let callback_uri = format!( + "http://localhost:{}{}", + listener.local_addr()?.port(), + CONTROLLER_OAUTH_CALLBACK_PATH + ); + + let authorize_url = build_session_creation_url( + public_key, + username, + rpc_url.as_str(), + policies, + &callback_uri, + )?; + let open_url = shorten_session_authorize_url(&authorize_url).await.unwrap_or_else(|err| { + trace!( + target: "account::controller", + error = %err, + "Failed to shorten controller session URL, falling back to full URL." + ); + authorize_url.clone() + }); + + println!("Authorize your controller session in browser:\n\n {}\n", open_url); + slot::browser::open(open_url.as_str())?; + + let payload = tokio::time::timeout( + Duration::from_secs(CONTROLLER_SESSION_TIMEOUT_SECS), + tokio::task::spawn_blocking(move || wait_for_session_payload(listener)), + ) + .await + .map_err(|_| { + anyhow!( + "Timed out waiting for controller session callback after {} seconds.", + CONTROLLER_SESSION_TIMEOUT_SECS + ) + })? + .map_err(|e| anyhow!("Failed to run controller session callback listener: {e}"))??; + + parse_session_creation_response(&payload) +} + +fn build_session_creation_url( + public_key: Felt, + username: &str, + rpc_url: &str, + policies: &[PolicyMethod], + callback_uri: &str, +) -> Result { + let encoded_policies = policies + .iter() + .map(serde_json::to_string) + .map(|p| Ok(url_encode_query_component(&p?))) + .collect::, serde_json::Error>>()? + .join(","); + + let params = format!( + "username={username}&public_key={public_key}&rpc_url={rpc_url}&policies=[{encoded_policies}]" + ); + let host = slot::vars::get_cartridge_keychain_url(); + let mut url = Url::parse(&format!("{host}{CONTROLLER_SESSION_CREATION_PATH}?{params}")) + .context("Invalid Cartridge keychain URL")?; + url.query_pairs_mut().append_pair("callback_uri", callback_uri); + Ok(url) +} + +fn url_encode_query_component(value: &str) -> String { + url::form_urlencoded::byte_serialize(value.as_bytes()).collect() +} + +async fn shorten_session_authorize_url(authorize_url: &Url) -> Result { + let base = slot::vars::get_cartridge_api_url(); + let endpoint = format!( + "{}/{}", + base.trim_end_matches('/'), + CONTROLLER_SHORTENER_PATH.trim_start_matches('/') + ); + + let response = reqwest::Client::new() + .post(endpoint) + .json(&ShortUrlRequest { url: authorize_url.as_str() }) + .send() + .await + .context("Failed to call Cartridge short URL endpoint")?; + + if !response.status().is_success() { + bail!("Cartridge short URL endpoint returned HTTP {}", response.status()); + } + + let body: ShortUrlResponse = + response.json().await.context("Failed to decode Cartridge short URL response body")?; + Url::parse(&body.url).context("Invalid short URL returned by Cartridge API") +} + +async fn is_session_hash_registered_onchain( + rpc_url: &Url, + controller_address: Felt, + owner_guid: Felt, + session_hash: Felt, +) -> Result { + let provider = JsonRpcClient::new(HttpTransport::new(rpc_url.clone())); + let reader = + slot::account_sdk::abigen::controller::ControllerReader::new(controller_address, provider); + + // Check both owner GUID and controller address. Different deployments may use one or the other + // as the authorizer key for `is_session_registered`. + let mut authorizers = vec![owner_guid, controller_address]; + authorizers.dedup(); + + let mut successful_checks = 0usize; + let mut last_error = None; + + for authorizer in authorizers { + match reader.is_session_registered(&session_hash, &authorizer).call().await { + Ok(true) => return Ok(true), + Ok(false) => { + successful_checks += 1; + } + Err(err) => { + trace!( + target: "account::controller", + authorizer = format!("{:#x}", authorizer), + error = %err, + "Failed to query session registration for authorizer." + ); + last_error = Some(err); + } + } + } + + if successful_checks == 0 { + if let Some(err) = last_error { + return Err(anyhow!( + "Failed to query session registration status on controller contract: {err}" + )); + } + } + + Ok(false) } async fn load_or_bootstrap_credentials() -> Result { @@ -173,11 +739,11 @@ fn should_bootstrap_credentials(err: &slot::Error) -> bool { } async fn bootstrap_credentials() -> Result<()> { - let listener = TcpListener::bind("127.0.0.1:0") + let listener = TcpListener::bind("localhost:0") .context("Failed to start local callback listener for controller authorization")?; let callback_uri = format!( - "http://127.0.0.1:{}{}", + "http://localhost:{}{}", listener.local_addr()?.port(), CONTROLLER_OAUTH_CALLBACK_PATH ); @@ -290,6 +856,201 @@ fn wait_for_oauth_code(listener: TcpListener) -> Result { Ok(code) } +fn wait_for_session_payload(listener: TcpListener) -> Result { + loop { + let (mut stream, _) = + listener.accept().context("Failed to accept controller session callback connection")?; + let request = read_http_request(&mut stream)?; + + let Some(headers_end) = request.windows(4).position(|window| window == b"\r\n\r\n") else { + write_http_response( + &mut stream, + "400 Bad Request", + "Malformed callback request. You can close this tab and retry.", + )?; + continue; + }; + + let head = String::from_utf8_lossy(&request[..headers_end]); + let request_line = head.lines().next().unwrap_or_default(); + let mut request_line_parts = request_line.split_whitespace(); + let method = request_line_parts.next().unwrap_or_default(); + let Some(target) = request_line_parts.next() else { + write_http_response( + &mut stream, + "400 Bad Request", + "Malformed callback request line. You can close this tab and retry.", + )?; + continue; + }; + + let callback_url = Url::parse(&format!("http://localhost{target}")) + .context("Failed to parse callback target URL")?; + if callback_url.path() != CONTROLLER_OAUTH_CALLBACK_PATH { + write_http_response( + &mut stream, + "400 Bad Request", + "Invalid callback path. You can close this tab and retry.", + )?; + continue; + } + + if method.eq_ignore_ascii_case("OPTIONS") { + write_http_response(&mut stream, "204 No Content", "")?; + continue; + } + + if !method.eq_ignore_ascii_case("POST") { + write_http_response( + &mut stream, + "405 Method Not Allowed", + "Unsupported callback method. You can close this tab and retry.", + )?; + continue; + } + + let content_length = head + .lines() + .find_map(|line| { + let (key, value) = line.split_once(':')?; + key.eq_ignore_ascii_case("content-length").then(|| value.trim().parse::()) + }) + .transpose() + .context("Invalid `content-length` header in controller session callback")? + .unwrap_or_default(); + + let body_start = headers_end + 4; + if request.len() < body_start + content_length { + write_http_response( + &mut stream, + "400 Bad Request", + "Incomplete callback payload. You can close this tab and retry.", + )?; + continue; + } + + let body_bytes = &request[body_start..body_start + content_length]; + let body = String::from_utf8(body_bytes.to_vec()) + .context("Controller session callback body is not valid UTF-8")?; + let body = body.trim(); + let payload = serde_json::from_str::(body).unwrap_or_else(|_| body.to_string()); + + if payload.is_empty() { + write_http_response( + &mut stream, + "400 Bad Request", + "Missing session payload. You can close this tab and retry.", + )?; + continue; + } + + write_http_response( + &mut stream, + "200 OK", + "Controller session received. You can close this tab and return to sozo.", + )?; + + return Ok(payload); + } +} + +fn parse_session_creation_response(payload: &str) -> Result { + if let Ok(response) = parse_session_response_encoded(payload) { + return Ok(response); + } + + serde_json::from_str(payload) + .context("Failed to decode controller session callback payload as session JSON.") +} + +fn parse_session_response_encoded(encoded: &str) -> Result { + let bytes = general_purpose::STANDARD_NO_PAD + .decode(encoded) + .context("Failed to decode base64 session callback payload")?; + let decoded = + String::from_utf8(bytes).context("Session callback payload is not valid UTF-8")?; + serde_json::from_str(&decoded).context("Failed to decode session callback JSON payload") +} + +fn apply_session_response_overrides( + session: &mut Session, + response: &ControllerSessionResponse, +) -> Result<()> { + if let Some(session_key_guid) = response.session_key_guid { + if session_key_guid != session.inner.session_key_guid { + bail!( + "Controller returned a session key guid that does not match the generated session signer." + ); + } + session.inner.session_key_guid = session_key_guid; + } + + if let Some(allowed_policies_root) = response.allowed_policies_root { + if allowed_policies_root != session.inner.allowed_policies_root { + bail!( + "Controller returned a policy root that differs from local policy hashing. Check policy ordering." + ); + } + session.inner.allowed_policies_root = allowed_policies_root; + } + + if let Some(metadata_hash) = response.metadata_hash { + session.inner.metadata_hash = metadata_hash; + } + + if let Some(guardian_key_guid) = response.guardian_key_guid { + session.inner.guardian_key_guid = guardian_key_guid; + } + + Ok(()) +} + +fn read_http_request(stream: &mut TcpStream) -> Result> { + const MAX_REQUEST_SIZE: usize = 1024 * 1024; + + let mut request = Vec::with_capacity(8192); + let mut chunk = [0_u8; 8192]; + + loop { + let bytes_read = stream + .read(&mut chunk) + .context("Failed to read controller session callback request")?; + if bytes_read == 0 { + break; + } + + request.extend_from_slice(&chunk[..bytes_read]); + if request.len() > MAX_REQUEST_SIZE { + bail!("Controller session callback request is too large."); + } + + if let Some(headers_end) = request.windows(4).position(|window| window == b"\r\n\r\n") { + let headers = String::from_utf8_lossy(&request[..headers_end]); + let content_length = headers + .lines() + .find_map(|line| { + let (key, value) = line.split_once(':')?; + key.eq_ignore_ascii_case("content-length") + .then(|| value.trim().parse::()) + }) + .transpose() + .context("Invalid `content-length` header in controller session callback")? + .unwrap_or_default(); + + let expected_len = headers_end + 4 + content_length; + if request.len() >= expected_len { + break; + } + } + } + + if request.is_empty() { + bail!("Controller session callback request was empty."); + } + + Ok(request) +} + fn extract_oauth_code(target: &str) -> Option { let callback_url = Url::parse(&format!("http://localhost{target}")).ok()?; if callback_url.path() != CONTROLLER_OAUTH_CALLBACK_PATH { @@ -302,7 +1063,7 @@ fn extract_oauth_code(target: &str) -> Option { fn write_http_response(stream: &mut TcpStream, status: &str, body: &str) -> Result<()> { let response = format!( "HTTP/1.1 {status}\r\nContent-Type: text/plain; charset=utf-8\r\nContent-Length: \ - {}\r\nConnection: close\r\n\r\n{body}", + {}\r\nConnection: close\r\nAccess-Control-Allow-Origin: *\r\nAccess-Control-Allow-Methods: POST, OPTIONS\r\nAccess-Control-Allow-Headers: Content-Type\r\n\r\n{body}", body.len() ); stream.write_all(response.as_bytes())?; @@ -315,31 +1076,33 @@ fn write_http_response(stream: &mut TcpStream, status: &str, body: &str) -> Resu // This function would compute the merkle root of the new policies and compare it with the root in // the existing SessionMetadata. fn is_equal_to_existing(new_policies: &[PolicyMethod], session_info: &FullSessionInfo) -> bool { - let new_policies = new_policies - .iter() - .map(|p| { - Policy::Call(CallPolicy { - authorized: Some(true), - contract_address: p.target, - selector: get_selector_from_name(&p.method).expect("valid selector"), - }) - }) - .collect::>(); - - // Copied from Session::new - let hashes = new_policies.iter().map(Policy::as_merkle_leaf).collect::>(); + // Compare by canonical call policy content only (contract+selector), ignoring ordering, + // duplicates, and authorized toggles. + let new_calls = { + let mut set = BTreeSet::new(); + for policy in new_policies { + let Ok(selector) = get_selector_from_name(&policy.method) else { + return false; + }; + set.insert((format!("0x{:064x}", policy.target), format!("0x{:064x}", selector))); + } + set + }; - let new_policies = new_policies - .into_iter() - .enumerate() - .map(|(i, policy)| ProvedPolicy { - policy, - proof: MerkleTree::compute_proof(hashes.clone(), i), + let existing_calls = session_info + .session + .requested_policies + .iter() + .filter_map(|policy| match policy { + Policy::Call(call) => Some(( + format!("0x{:064x}", call.contract_address), + format!("0x{:064x}", call.selector), + )), + _ => None, }) - .collect::>(); + .collect::>(); - let new_policies_root = MerkleTree::compute_root(hashes[0], new_policies[0].proof.clone()); - new_policies_root == session_info.session.inner.allowed_policies_root + new_calls == existing_calls } /// Policies are the building block of a session key. It's what defines what methods are allowed for @@ -384,11 +1147,10 @@ fn collect_policies_from_contracts( trace!(target: "account::controller", "Adding UDC deployment policy"); // Keep canonical ordering aligned with controller/keychain sorting: - // contract address (as lowercase hex string) then method name. + // normalized lowercase padded hex address, then method name. policies.sort_by(|a, b| { - format!("{:#x}", a.target) - .to_lowercase() - .cmp(&format!("{:#x}", b.target).to_lowercase()) + format!("0x{:064x}", a.target) + .cmp(&format!("0x{:064x}", b.target)) .then_with(|| a.method.cmp(&b.method)) }); @@ -399,16 +1161,51 @@ fn collect_policies_from_contracts( mod tests { use std::collections::HashMap; + use base64::{Engine as _, engine::general_purpose}; + use cainome_cairo_serde::NonZero; use dojo_test_utils::setup::TestSetup; use dojo_world::contracts::ContractInfo; use scarb_interop::Profile; use scarb_metadata_ext::MetadataDojoExt; + use slot::account_sdk::abigen::controller::{Signer as ControllerSigner, StarknetSigner}; + use slot::account_sdk::account::session::hash::Session; + use slot::account_sdk::account::session::policy::{CallPolicy, Policy, TypedDataPolicy}; + use slot::session::{FullSessionInfo, SessionAuth}; + use starknet::core::types::Felt; + use starknet::core::utils::get_selector_from_name; use starknet::macros::felt; + use starknet::signers::SigningKey; use super::{ - PolicyMethod, collect_policies, collect_policies_from_contracts, extract_oauth_code, + PolicyMethod, alternate_policy_orders, collect_policies, collect_policies_from_contracts, + extract_oauth_code, is_equal_to_existing, parse_session_creation_response, }; + fn session_with_requested_policies(requested_policies: Vec) -> FullSessionInfo { + let signer = SigningKey::from_secret_scalar(felt!("0x12345")); + let session = Session::new( + requested_policies, + 4_102_444_800, + &ControllerSigner::Starknet(StarknetSigner { + pubkey: NonZero::new(signer.verifying_key().scalar()) + .expect("public key scalar should not be zero"), + }), + Felt::ZERO, + ) + .expect("session should build"); + + FullSessionInfo { + chain_id: felt!("0x534e5f5345504f4c4941"), + auth: SessionAuth { + username: "alice".into(), + address: felt!("0xabc"), + owner_guid: felt!("0xdef"), + signer: signer.secret_scalar(), + }, + session, + } + } + #[test] fn collect_policies_from_project() { let setup = TestSetup::from_examples("../../crates/dojo/core", "../../examples/"); @@ -578,14 +1375,15 @@ mod tests { let policies = collect_policies_from_contracts(user_addr, &contracts).unwrap(); - // Controller canonical sort is string-based, so 0x10 comes before 0x2. + // Controller canonical sort is done on normalized/padded address strings. + // So 0x2 comes before 0x10. let first_two = policies .iter() .filter(|p| p.method == "exec") .take(2) .map(|p| p.target) .collect::>(); - assert_eq!(first_two, vec![addr_10, addr_2]); + assert_eq!(first_two, vec![addr_2, addr_10]); } #[test] @@ -604,4 +1402,137 @@ mod tests { fn extract_oauth_code_rejects_non_callback_target() { assert_eq!(extract_oauth_code("/not-callback?code=abc123"), None); } + + #[test] + fn parse_session_creation_response_decodes_full_encoded_payload() { + let payload = serde_json::json!({ + "username": "alice", + "address": "0x123", + "ownerGuid": "0x456", + "expiresAt": "1735689600", + "transactionHash": "0x789", + "alreadyRegistered": true, + "allowedPoliciesRoot": "0x111", + "metadataHash": "0x222", + "sessionKeyGuid": "0x333", + "guardianKeyGuid": "0x444" + }); + let encoded = general_purpose::STANDARD_NO_PAD.encode(payload.to_string()); + + let decoded = parse_session_creation_response(&encoded).unwrap(); + + assert_eq!(decoded.username, "alice"); + assert_eq!(decoded.address, felt!("0x123")); + assert_eq!(decoded.owner_guid, felt!("0x456")); + assert_eq!(decoded.expires_at, "1735689600"); + assert_eq!(decoded.transaction_hash, Some(felt!("0x789"))); + assert!(decoded.already_registered); + assert_eq!(decoded.allowed_policies_root, Some(felt!("0x111"))); + assert_eq!(decoded.metadata_hash, Some(felt!("0x222"))); + assert_eq!(decoded.session_key_guid, Some(felt!("0x333"))); + assert_eq!(decoded.guardian_key_guid, Some(felt!("0x444"))); + } + + #[test] + fn alternate_policy_orders_produces_unique_deterministic_candidates() { + let policies = vec![ + PolicyMethod { target: felt!("0x10"), method: "z".into() }, + PolicyMethod { target: felt!("0x2"), method: "a".into() }, + PolicyMethod { target: felt!("0x2"), method: "m".into() }, + ]; + + let candidates = alternate_policy_orders(&policies); + assert!(!candidates.is_empty()); + + // No duplicates across candidate orderings. + for i in 0..candidates.len() { + for j in (i + 1)..candidates.len() { + assert_ne!(candidates[i], candidates[j]); + } + } + + // At least one candidate should keep both addresses present (sanity). + let has_both_addresses = candidates.iter().any(|candidate| { + let targets = candidate.iter().map(|p| p.target).collect::>(); + targets.contains(&felt!("0x2")) && targets.contains(&felt!("0x10")) + }); + assert!(has_both_addresses); + } + + #[test] + fn alternate_policy_orders_keeps_method_order_variant() { + let policies = vec![ + PolicyMethod { target: felt!("0x1"), method: "spawn".into() }, + PolicyMethod { target: felt!("0x2"), method: "move".into() }, + PolicyMethod { target: felt!("0x3"), method: "attack".into() }, + ]; + + let candidates = alternate_policy_orders(&policies); + let has_method_first = candidates.iter().any(|candidate| { + candidate.iter().map(|p| p.method.as_str()).collect::>() + == vec!["attack", "move", "spawn"] + }); + assert!(has_method_first); + } + + #[test] + fn is_equal_to_existing_ignores_order_and_authorized_toggle() { + let new_policies = vec![ + PolicyMethod { target: felt!("0x1"), method: "spawn".into() }, + PolicyMethod { target: felt!("0x2"), method: "move".into() }, + ]; + + let requested = vec![ + Policy::Call(CallPolicy { + contract_address: felt!("0x2"), + selector: get_selector_from_name("move").unwrap(), + authorized: Some(false), + }), + Policy::Call(CallPolicy { + contract_address: felt!("0x1"), + selector: get_selector_from_name("spawn").unwrap(), + authorized: Some(true), + }), + ]; + + let session = session_with_requested_policies(requested); + assert!(is_equal_to_existing(&new_policies, &session)); + } + + #[test] + fn is_equal_to_existing_ignores_non_call_requested_policies() { + let new_policies = vec![PolicyMethod { target: felt!("0x1"), method: "spawn".into() }]; + + let requested = vec![ + Policy::Call(CallPolicy { + contract_address: felt!("0x1"), + selector: get_selector_from_name("spawn").unwrap(), + authorized: Some(true), + }), + Policy::TypedData(TypedDataPolicy { + scope_hash: felt!("0x123"), + authorized: Some(true), + }), + ]; + + let session = session_with_requested_policies(requested); + assert!(is_equal_to_existing(&new_policies, &session)); + } + + #[test] + fn is_equal_to_existing_detects_call_set_changes() { + let new_policies = vec![ + PolicyMethod { target: felt!("0x1"), method: "spawn".into() }, + PolicyMethod { target: felt!("0x2"), method: "move".into() }, + ]; + + let requested = vec![Policy::Call(CallPolicy { + contract_address: felt!("0x1"), + selector: get_selector_from_name("spawn").unwrap(), + authorized: Some(true), + })]; + + let session = session_with_requested_policies(requested); + assert!(!is_equal_to_existing(&new_policies, &session)); + } } diff --git a/bin/sozo/src/commands/session.rs b/bin/sozo/src/commands/session.rs index 3007e523f8..abae99b9c3 100644 --- a/bin/sozo/src/commands/session.rs +++ b/bin/sozo/src/commands/session.rs @@ -17,6 +17,9 @@ use super::options::starknet::StarknetOptions; use super::options::world::WorldOptions; use crate::utils; +const LEGACY_SESSION_FILE_SUFFIX: &str = "-session.json"; +const MULTI_SESSION_FILE_INFIX: &str = "-session-"; + #[derive(Debug, Args)] pub struct SessionArgs { #[command(subcommand)] @@ -141,15 +144,27 @@ async fn status_session( } let session_path = session_file_path(&credentials.account.id, chain_id); + let context_hash = controller::current_session_context_hash(); + let session_variants = + chain_session_file_paths(&credentials.account.id, chain_id, Some(&context_hash))?; + let chain_variants = chain_session_file_paths(&credentials.account.id, chain_id, None)?; let session = slot::session::get(chain_id)?; if let Some(session) = session { ui.result("Session: active"); ui.print(format!("Policies : {}", session.session.proved_policies.len())); ui.print(format!("Expires at (unix) : {}", session.session.inner.expires_at)); + ui.print(format!("Stored variants : {}", session_variants.len())); + ui.print(format!("Chain variants : {}", chain_variants.len())); ui.print(format!("Stored session : {}", session_path.display())); } else { ui.warn("Session: not found for this network."); + if !session_variants.is_empty() { + ui.print(format!("Stored variants : {}", session_variants.len())); + } + if !chain_variants.is_empty() { + ui.print(format!("Chain variants : {}", chain_variants.len())); + } ui.print(format!("Expected path : {}", session_path.display())); } @@ -186,7 +201,7 @@ async fn discard_session( let is_session = path .file_name() .and_then(|name| name.to_str()) - .is_some_and(|name| name.ends_with("-session.json")); + .is_some_and(is_session_file_name); if is_session { fs::remove_file(&path)?; @@ -203,14 +218,22 @@ async fn discard_session( let rpc_url = starknet.url(profile_config.env.as_ref())?; let chain_id = CartridgeJsonRpcProvider::new(rpc_url).chain_id().await?; - let session_path = session_file_path(&credentials.account.id, chain_id); - if session_path.exists() { - fs::remove_file(&session_path)?; + let context_hash = controller::current_session_context_hash(); + let session_files = + chain_session_file_paths(&credentials.account.id, chain_id, Some(&context_hash))?; + if !session_files.is_empty() { + for path in &session_files { + fs::remove_file(path)?; + removed += 1; + } ui.result("Session discarded."); - ui.print(format!("Removed: {}", session_path.display())); + ui.print(format!("Removed {} file(s) for chain {chain_id:#x}.", removed)); } else { ui.warn("No stored session found for this network."); - ui.print(format!("Expected path: {}", session_path.display())); + ui.print(format!( + "Expected path: {}", + session_file_path(&credentials.account.id, chain_id).display() + )); } Ok(()) @@ -243,11 +266,72 @@ fn session_file_path(username: &str, chain_id: starknet::core::types::Felt) -> P slot::utils::config_dir().join(username).join(format!("{chain_id:#x}-session.json")) } +fn is_session_file_name(file_name: &str) -> bool { + file_name.ends_with(LEGACY_SESSION_FILE_SUFFIX) + || (file_name.contains(MULTI_SESSION_FILE_INFIX) && file_name.ends_with(".json")) +} + +fn is_chain_session_file_name(file_name: &str, chain_id: starknet::core::types::Felt) -> bool { + let chain_prefix = format!("{chain_id:#x}"); + if !file_name.starts_with(&chain_prefix) { + return false; + } + + file_name == format!("{chain_prefix}{LEGACY_SESSION_FILE_SUFFIX}") + || (file_name.starts_with(&format!("{chain_prefix}{MULTI_SESSION_FILE_INFIX}")) + && file_name.ends_with(".json")) +} + +fn is_chain_session_file_name_for_context( + file_name: &str, + chain_id: starknet::core::types::Felt, + context_hash: Option<&str>, +) -> bool { + if !is_chain_session_file_name(file_name, chain_id) { + return false; + } + + match context_hash { + Some(hash) => { + file_name == format!("{chain_id:#x}{LEGACY_SESSION_FILE_SUFFIX}") + || file_name.starts_with(&format!("{chain_id:#x}{MULTI_SESSION_FILE_INFIX}{hash}-")) + } + None => true, + } +} + +fn chain_session_file_paths( + username: &str, + chain_id: starknet::core::types::Felt, + context_hash: Option<&str>, +) -> Result> { + let user_dir = slot::utils::config_dir().join(username); + if !user_dir.exists() { + return Ok(Vec::new()); + } + + let mut paths = Vec::new(); + for entry in fs::read_dir(user_dir)? { + let path = entry?.path(); + let is_chain_session = + path.file_name().and_then(|name| name.to_str()).is_some_and(|name| { + is_chain_session_file_name_for_context(name, chain_id, context_hash) + }); + if is_chain_session { + paths.push(path); + } + } + Ok(paths) +} + #[cfg(test)] mod tests { use starknet::macros::felt; - use super::session_file_path; + use super::{ + is_chain_session_file_name, is_chain_session_file_name_for_context, is_session_file_name, + session_file_path, + }; #[test] fn session_file_path_contains_expected_suffix() { @@ -255,4 +339,42 @@ mod tests { let file = path.file_name().and_then(|name| name.to_str()).unwrap(); assert!(file.ends_with("-session.json")); } + + #[test] + fn is_session_file_name_matches_legacy_and_multi_formats() { + assert!(is_session_file_name("0x1-session.json")); + assert!(is_session_file_name( + "0x1-session-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef.json" + )); + assert!(!is_session_file_name("notes.json")); + } + + #[test] + fn is_chain_session_file_name_filters_by_chain() { + let chain = felt!("0x534e5f5345504f4c4941"); + assert!(is_chain_session_file_name("0x534e5f5345504f4c4941-session.json", chain)); + assert!(is_chain_session_file_name("0x534e5f5345504f4c4941-session-deadbeef.json", chain)); + assert!(!is_chain_session_file_name("0x123-session.json", chain)); + assert!(!is_chain_session_file_name("0x534e5f5345504f4c4941-other.json", chain)); + } + + #[test] + fn is_chain_session_file_name_for_context_filters_context_hash() { + let chain = felt!("0x534e5f5345504f4c4941"); + assert!(is_chain_session_file_name_for_context( + "0x534e5f5345504f4c4941-session-feedbeef-deadbeef.json", + chain, + Some("feedbeef") + )); + assert!(!is_chain_session_file_name_for_context( + "0x534e5f5345504f4c4941-session-cafebabe-deadbeef.json", + chain, + Some("feedbeef") + )); + assert!(is_chain_session_file_name_for_context( + "0x534e5f5345504f4c4941-session.json", + chain, + Some("feedbeef") + )); + } } diff --git a/bin/sozo/src/main.rs b/bin/sozo/src/main.rs index 3035ceff08..4311340039 100644 --- a/bin/sozo/src/main.rs +++ b/bin/sozo/src/main.rs @@ -2,7 +2,7 @@ use std::process::exit; -use anyhow::{bail, Result}; +use anyhow::{Result, bail}; use args::SozoArgs; use camino::Utf8PathBuf; use clap::Parser; @@ -15,7 +15,7 @@ mod args; mod commands; mod features; mod utils; -use terminal_colorsaurus::{theme_mode, QueryOptions, ThemeMode}; +use terminal_colorsaurus::{QueryOptions, ThemeMode, theme_mode}; #[tokio::main] async fn main() { @@ -67,8 +67,10 @@ async fn cli_main(args: SozoArgs, ui: &SozoUi) -> Result<()> { bail!("Unable to find {}", &manifest_path); } - let scarb_metadata = - Metadata::load(manifest_path, args.profile_spec.determine()?.as_str(), args.offline)?; + let profile = args.profile_spec.determine()?; + std::env::set_var("SOZO_PROFILE", profile.as_str()); + + let scarb_metadata = Metadata::load(manifest_path, profile.as_str(), args.offline)?; trace!(%scarb_metadata.runtime_manifest, "Configuration built successfully.");