From 758f05467077db7603228b194051568735264fb3 Mon Sep 17 00:00:00 2001 From: itowlson Date: Fri, 14 Nov 2025 13:21:53 +1300 Subject: [PATCH 1/4] Async PostgreSQL API Signed-off-by: itowlson --- Cargo.lock | 10 + crates/factor-outbound-pg/Cargo.toml | 1 + .../factor-outbound-pg/src/allowed_hosts.rs | 70 ++++++ crates/factor-outbound-pg/src/client.rs | 91 +++++++- crates/factor-outbound-pg/src/host.rs | 201 ++++++++++++++---- crates/factor-outbound-pg/src/lib.rs | 31 ++- crates/factor-outbound-pg/src/types.rs | 25 ++- .../factor-outbound-pg/src/types/convert.rs | 2 +- .../factor-outbound-pg/src/types/interval.rs | 2 +- .../factor-outbound-pg/tests/factor_test.rs | 23 +- crates/wasi-async/Cargo.toml | 10 + crates/wasi-async/src/future.rs | 30 +++ crates/wasi-async/src/lib.rs | 2 + crates/wasi-async/src/stream.rs | 45 ++++ crates/world/src/conversions.rs | 4 +- crates/world/src/lib.rs | 4 +- wit/deps/spin-postgres@4.2.0/postgres.wit | 184 ++++++++++++++++ wit/world.wit | 2 +- 18 files changed, 656 insertions(+), 81 deletions(-) create mode 100644 crates/factor-outbound-pg/src/allowed_hosts.rs create mode 100644 crates/wasi-async/Cargo.toml create mode 100644 crates/wasi-async/src/future.rs create mode 100644 crates/wasi-async/src/lib.rs create mode 100644 crates/wasi-async/src/stream.rs create mode 100644 wit/deps/spin-postgres@4.2.0/postgres.wit diff --git a/Cargo.lock b/Cargo.lock index 15c0c591f..47499b423 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8754,6 +8754,7 @@ dependencies = [ "spin-factors-test", "spin-locked-app", "spin-resource-table", + "spin-wasi-async", "spin-world", "tokio", "tokio-postgres", @@ -9439,6 +9440,15 @@ dependencies = [ "vaultrs", ] +[[package]] +name = "spin-wasi-async" +version = "3.7.0-pre0" +dependencies = [ + "anyhow", + "spin-core", + "tokio", +] + [[package]] name = "spin-world" version = "3.7.0-pre0" diff --git a/crates/factor-outbound-pg/Cargo.toml b/crates/factor-outbound-pg/Cargo.toml index 5a5654ff4..4ad764814 100644 --- a/crates/factor-outbound-pg/Cargo.toml +++ b/crates/factor-outbound-pg/Cargo.toml @@ -24,6 +24,7 @@ spin-factor-outbound-networking = { path = "../factor-outbound-networking" } spin-factors = { path = "../factors" } spin-locked-app = { path = "../locked-app" } spin-resource-table = { path = "../table" } +spin-wasi-async = { path = "../wasi-async" } spin-world = { path = "../world" } tokio = { workspace = true, features = ["rt-multi-thread"] } tokio-postgres = { version = "0.7", features = ["with-chrono-0_4", "with-serde_json-1", "with-uuid-1"] } diff --git a/crates/factor-outbound-pg/src/allowed_hosts.rs b/crates/factor-outbound-pg/src/allowed_hosts.rs new file mode 100644 index 000000000..fb1c68bb1 --- /dev/null +++ b/crates/factor-outbound-pg/src/allowed_hosts.rs @@ -0,0 +1,70 @@ +use std::sync::Arc; + +use spin_factor_outbound_networking::config::allowed_hosts::OutboundAllowedHosts; +use spin_world::spin::postgres4_2_0::postgres::{self as v4}; + +/// Encapsulates checking of a PostgreSQL address/connection string against +/// an allow-list. +/// +/// This is broken out as a distinct object to allow it to be synchronously retrieved +/// within a P3 Accessor block and then asynchronously queried outside the block. +#[derive(Clone)] +pub(crate) struct AllowedHostChecker { + allowed_hosts: Arc, +} + +impl AllowedHostChecker { + pub fn new(allowed_hosts: OutboundAllowedHosts) -> Self { + Self { + allowed_hosts: Arc::new(allowed_hosts), + } + } + #[allow(clippy::result_large_err)] + pub async fn ensure_address_allowed(&self, address: &str) -> Result<(), v4::Error> { + fn conn_failed(message: impl Into) -> v4::Error { + v4::Error::ConnectionFailed(message.into()) + } + fn err_other(err: anyhow::Error) -> v4::Error { + v4::Error::Other(err.to_string()) + } + + let config = address + .parse::() + .map_err(|e| conn_failed(e.to_string()))?; + + for (i, host) in config.get_hosts().iter().enumerate() { + match host { + tokio_postgres::config::Host::Tcp(address) => { + let ports = config.get_ports(); + // The port we use is either: + // * The port at the same index as the host + // * The first port if there is only one port + let port = ports.get(i).or_else(|| { + if ports.len() == 1 { + ports.first() + } else { + None + } + }); + let port_str = port.map(|p| format!(":{p}")).unwrap_or_default(); + let url = format!("{address}{port_str}"); + if !self + .allowed_hosts + .check_url(&url, "postgres") + .await + .map_err(err_other)? + { + return Err(conn_failed(format!( + "address postgres://{url} is not permitted" + ))); + } + } + #[cfg(unix)] + tokio_postgres::config::Host::Unix(_) => { + return Err(conn_failed("Unix sockets are not supported on WebAssembly")); + } + } + } + Ok(()) + } +} diff --git a/crates/factor-outbound-pg/src/client.rs b/crates/factor-outbound-pg/src/client.rs index 657081f4e..7bf75f268 100644 --- a/crates/factor-outbound-pg/src/client.rs +++ b/crates/factor-outbound-pg/src/client.rs @@ -1,9 +1,11 @@ +use std::sync::Arc; + use anyhow::{Context, Result}; use futures::stream::TryStreamExt as _; use native_tls::TlsConnector; use postgres_native_tls::MakeTlsConnector; use spin_world::async_trait; -use spin_world::spin::postgres4_1_0::postgres::{ +use spin_world::spin::postgres4_2_0::postgres::{ self as v4, Column, DbValue, ParameterValue, RowSet, }; use std::pin::pin; @@ -11,7 +13,9 @@ use tokio_postgres::config::SslMode; use tokio_postgres::types::ToSql; use tokio_postgres::{NoTls, Row}; -use crate::types::{convert_data_type, convert_entry, to_sql_parameter}; +use crate::types::{ + as_sql_parameter_refs, convert_data_type, convert_entry, to_sql_parameter, to_sql_parameters, +}; /// Max connections in a given address' connection pool const CONNECTION_POOL_SIZE: usize = 64; @@ -63,7 +67,7 @@ impl Default for PooledTokioClientFactory { #[async_trait] impl ClientFactory for PooledTokioClientFactory { - type Client = deadpool_postgres::Object; + type Client = Arc; async fn get_client( &self, @@ -81,7 +85,7 @@ impl ClientFactory for PooledTokioClientFactory { .map_err(ArcError) .context("establishing PostgreSQL connection pool")?; - Ok(pool.get().await?) + Ok(Arc::new(pool.get().await?)) } } @@ -123,7 +127,7 @@ fn create_connection_pool( } #[async_trait] -pub trait Client: Send + Sync + 'static { +pub trait Client: Clone + Send + Sync + 'static { async fn execute( &self, statement: String, @@ -136,6 +140,18 @@ pub trait Client: Send + Sync + 'static { params: Vec, max_result_bytes: usize, ) -> Result; + + async fn query_async( + &self, + statement: String, + params: Vec, + ) -> Result< + ( + tokio::sync::oneshot::Receiver>, + tokio::sync::mpsc::Receiver>, + ), + v4::Error, + >; } /// Extract weak-typed error data for WIT purposes @@ -181,7 +197,7 @@ fn query_failed(e: tokio_postgres::error::Error) -> v4::Error { } #[async_trait] -impl Client for deadpool_postgres::Object { +impl Client for Arc { async fn execute( &self, statement: String, @@ -210,11 +226,7 @@ impl Client for deadpool_postgres::Object { params: Vec, max_result_bytes: usize, ) -> Result { - let params = params - .iter() - .map(to_sql_parameter) - .collect::>>() - .map_err(|e| v4::Error::BadParameter(format!("{e:?}")))?; + let params = to_sql_parameters(params)?; let mut results = pin!(self .as_ref() @@ -248,6 +260,63 @@ impl Client for deadpool_postgres::Object { rows, }) } + + async fn query_async( + &self, + statement: String, + params: Vec, + ) -> Result< + ( + tokio::sync::oneshot::Receiver>, + tokio::sync::mpsc::Receiver>, + ), + v4::Error, + > { + let params = to_sql_parameters(params)?; + let params_refs = as_sql_parameter_refs(¶ms); + + let stm = self + .as_ref() + .query_raw(&statement, params_refs) + .await + .map_err(query_failed)?; + + let (rows_tx, rows_rx) = tokio::sync::mpsc::channel(1000); + let (cols_tx, cols_rx) = tokio::sync::oneshot::channel(); + let mut cols_tx_opt = Some(cols_tx); + + let mut stm = Box::pin(stm); + + tokio::spawn(async move { + use futures::StreamExt; + loop { + let Some(row) = stm.next().await else { + break; + }; + // TODO: figure out how to deal with errors here - I think there is like a FutureReader pattern? + let row = match row { + Ok(r) => r, + Err(e) => { + let err = query_failed(e); + rows_tx.send(Err(err)).await.unwrap(); + break; + } + }; + if let Some(cols_tx) = cols_tx_opt.take() { + cols_tx.send(infer_columns(&row)).unwrap(); + } + match convert_row(&row) { + Ok(row) => rows_tx.send(Ok(row)).await.unwrap(), + Err(e) => { + let err = v4::Error::QueryFailed(v4::QueryError::Text(format!("{e:?}"))); + rows_tx.send(Err(err)).await.unwrap(); + } + } + } + }); + + Ok((cols_rx, rows_rx)) + } } fn infer_columns(row: &Row) -> Vec { diff --git a/crates/factor-outbound-pg/src/host.rs b/crates/factor-outbound-pg/src/host.rs index 4c22dd984..8456600dc 100644 --- a/crates/factor-outbound-pg/src/host.rs +++ b/crates/factor-outbound-pg/src/host.rs @@ -1,7 +1,8 @@ use anyhow::Result; -use spin_core::wasmtime::component::Resource; +use spin_core::wasmtime; +use spin_core::wasmtime::component::{Accessor, FutureReader, Resource, StreamReader}; use spin_world::spin::postgres3_0_0::postgres::{self as v3}; -use spin_world::spin::postgres4_1_0::postgres::{self as v4}; +use spin_world::spin::postgres4_2_0::postgres::{self as v4}; use spin_world::v1::postgres as v1; use spin_world::v1::rdbms_types as v1_types; use spin_world::v2::postgres::{self as v2}; @@ -11,9 +12,14 @@ use tracing::field::Empty; use tracing::instrument; use tracing::Level; +use crate::allowed_hosts::AllowedHostChecker; use crate::client::{Client, ClientFactory, HashableCertificate}; use crate::InstanceState; +// Declare some types to make Clippy less mad +pub type RowStream = StreamReader>; +pub type ColumnsFuture = FutureReader>; + impl InstanceState { async fn open_connection( &mut self, @@ -40,53 +46,15 @@ impl InstanceState { .ok_or_else(|| v4::Error::ConnectionFailed("no connection found".into())) } + fn allowed_host_checker(&self) -> AllowedHostChecker { + self.allowed_host_checker.clone() + } + #[allow(clippy::result_large_err)] async fn ensure_address_allowed(&self, address: &str) -> Result<(), v4::Error> { - fn conn_failed(message: impl Into) -> v4::Error { - v4::Error::ConnectionFailed(message.into()) - } - fn err_other(err: anyhow::Error) -> v4::Error { - v4::Error::Other(err.to_string()) - } - - let config = address - .parse::() - .map_err(|e| conn_failed(e.to_string()))?; - - for (i, host) in config.get_hosts().iter().enumerate() { - match host { - tokio_postgres::config::Host::Tcp(address) => { - let ports = config.get_ports(); - // The port we use is either: - // * The port at the same index as the host - // * The first port if there is only one port - let port = ports.get(i).or_else(|| { - if ports.len() == 1 { - ports.first() - } else { - None - } - }); - let port_str = port.map(|p| format!(":{p}")).unwrap_or_default(); - let url = format!("{address}{port_str}"); - if !self - .allowed_hosts - .check_url(&url, "postgres") - .await - .map_err(err_other)? - { - return Err(conn_failed(format!( - "address postgres://{url} is not permitted" - ))); - } - } - #[cfg(unix)] - tokio_postgres::config::Host::Unix(_) => { - return Err(conn_failed("Unix sockets are not supported on WebAssembly")); - } - } - } - Ok(()) + self.allowed_host_checker + .ensure_address_allowed(address) + .await } } @@ -242,6 +210,145 @@ impl v4::HostConnection for InstanceState { } } +impl spin_world::spin::postgres4_2_0::postgres::HostConnectionWithStore + for crate::PgFactorData +{ + #[instrument(name = "spin_outbound_pg.open_async", skip(accessor, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))] + async fn open_async( + accessor: &Accessor, + address: String, + ) -> Result, v4::Error> { + spin_factor_outbound_networking::record_address_fields(&address); + + // A merry dance to avoid doing the async allow check under the accessor + let allowed_host_checker = accessor.with(|mut access| { + let host = access.get(); + host.allowed_host_checker() + }); + + allowed_host_checker + .ensure_address_allowed(&address) + .await?; + + let cf = accessor.with(|mut access| { + let host = access.get(); + host.client_factory.clone() + }); + let client = cf + .get_client(&address, None) + .await + .map_err(|e| v4::Error::ConnectionFailed(format!("{e:?}")))?; + let rsrc = accessor.with(|mut access| { + let host = access.get(); + host.connections + .push(client) + .map_err(|_| v4::Error::ConnectionFailed("too many connections".into())) + .map(wasmtime::component::Resource::new_own) + }); + rsrc + } + + #[instrument(name = "spin_outbound_pg.execute", skip(accessor, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))] + async fn execute_async( + accessor: &Accessor, + connection: Resource, + statement: String, + params: Vec, + ) -> Result { + let client = accessor.with(|mut access| { + let host = access.get(); + host.connections.get(connection.rep()).unwrap().clone() + }); + + client.execute(statement, params).await + } + + #[instrument(name = "spin_outbound_pg.query", skip(accessor, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))] + async fn query_async( + accessor: &Accessor, + connection: Resource, + statement: String, + params: Vec, + ) -> Result<(ColumnsFuture, RowStream), v4::Error> { + use wasmtime::AsContextMut; + + let client = accessor.with(|mut access| { + let host = access.get(); + host.connections.get(connection.rep()).unwrap().clone() + }); + + let (col_rx, row_rx) = client.query_async(statement, params).await?; + + let row_producer = spin_wasi_async::stream::producer(row_rx); + let col_producer = spin_wasi_async::future::producer(col_rx); + + let (fr, sr) = accessor.with(|mut access| { + let fr = FutureReader::new(access.as_context_mut(), col_producer); + let sr = StreamReader::new(access.as_context_mut(), row_producer); + (fr, sr) + }); + + Ok((fr, sr)) + } +} + +impl spin_world::spin::postgres4_2_0::postgres::HostConnectionBuilderWithStore + for crate::PgFactorData +{ + async fn build_async( + accessor: &Accessor, + builder: Resource, + ) -> Result, v4::Error> { + // TODO: so much deduplicating + let rep = builder.rep(); + + let (address, root_ca) = accessor.with(|mut access| { + let host = access.get(); + + let builder = host + .builders + .get_mut(rep) + .ok_or_else(|| v4::Error::ConnectionFailed("no builder found".into()))?; + + let address = builder.address.clone(); + let root_ca = builder.root_ca.clone(); + + Ok((address, root_ca)) + })?; + + // TODO: this is from open_async. TODO: so much deduplication + + spin_factor_outbound_networking::record_address_fields(&address); + + // A merry dance to avoid doing the async allow check under the accessor + let allowed_host_checker = accessor.with(|mut access| { + let host = access.get(); + host.allowed_host_checker() + }); + + allowed_host_checker + .ensure_address_allowed(&address) + .await?; + + let cf = accessor.with(|mut access| { + let host = access.get(); + host.client_factory.clone() + }); + let client = cf + .get_client(&address, root_ca) + .await + .map_err(|e| v4::Error::ConnectionFailed(format!("{e:?}")))?; + let rsrc = accessor.with(|mut access| { + let host = access.get(); + host.connections + .push(client) + .map_err(|_| v4::Error::ConnectionFailed("too many connections".into())) + .map(wasmtime::component::Resource::new_own) + }); + rsrc + } +} + impl v2_types::Host for InstanceState { fn convert_error(&mut self, error: v2::Error) -> Result { Ok(error) diff --git a/crates/factor-outbound-pg/src/lib.rs b/crates/factor-outbound-pg/src/lib.rs index 8a8891e25..aae1c2443 100644 --- a/crates/factor-outbound-pg/src/lib.rs +++ b/crates/factor-outbound-pg/src/lib.rs @@ -1,17 +1,16 @@ +mod allowed_hosts; pub mod client; mod host; mod types; use std::{collections::HashMap, sync::Arc}; +use allowed_hosts::AllowedHostChecker; use client::ClientFactory; use spin_factor_otel::OtelFactorState; -use spin_factor_outbound_networking::{ - config::allowed_hosts::OutboundAllowedHosts, OutboundNetworkingFactor, -}; +use spin_factor_outbound_networking::OutboundNetworkingFactor; use spin_factors::{ - anyhow, ConfigureAppContext, Factor, FactorData, PrepareContext, RuntimeFactors, - SelfInstanceBuilder, + anyhow, ConfigureAppContext, Factor, PrepareContext, RuntimeFactors, SelfInstanceBuilder, }; pub struct OutboundPgFactor { @@ -24,13 +23,13 @@ impl Factor for OutboundPgFactor { type InstanceBuilder = InstanceState; fn init(&mut self, ctx: &mut impl spin_factors::InitContext) -> anyhow::Result<()> { - ctx.link_bindings(spin_world::v1::postgres::add_to_linker::<_, FactorData>)?; - ctx.link_bindings(spin_world::v2::postgres::add_to_linker::<_, FactorData>)?; + ctx.link_bindings(spin_world::v1::postgres::add_to_linker::<_, PgFactorData>)?; + ctx.link_bindings(spin_world::v2::postgres::add_to_linker::<_, PgFactorData>)?; ctx.link_bindings( - spin_world::spin::postgres3_0_0::postgres::add_to_linker::<_, FactorData>, + spin_world::spin::postgres3_0_0::postgres::add_to_linker::<_, PgFactorData>, )?; ctx.link_bindings( - spin_world::spin::postgres4_1_0::postgres::add_to_linker::<_, FactorData>, + spin_world::spin::postgres4_2_0::postgres::add_to_linker::<_, PgFactorData>, )?; Ok(()) } @@ -57,7 +56,7 @@ impl Factor for OutboundPgFactor { let cf = ctx.app_state().get(ctx.app_component().id()).unwrap(); Ok(InstanceState { - allowed_hosts, + allowed_host_checker: AllowedHostChecker::new(allowed_hosts), client_factory: cf.clone(), connections: Default::default(), otel, @@ -81,7 +80,7 @@ impl OutboundPgFactor { } pub struct InstanceState { - allowed_hosts: OutboundAllowedHosts, + allowed_host_checker: AllowedHostChecker, client_factory: Arc, connections: spin_resource_table::Table, otel: OtelFactorState, @@ -89,3 +88,13 @@ pub struct InstanceState { } impl SelfInstanceBuilder for InstanceState {} + +pub struct PgFactorData(OutboundPgFactor); + +impl spin_core::wasmtime::component::HasData for PgFactorData { + type Data<'a> = &'a mut InstanceState; +} + +impl spin_core::wasmtime::component::HasData for InstanceState { + type Data<'a> = &'a mut InstanceState; +} diff --git a/crates/factor-outbound-pg/src/types.rs b/crates/factor-outbound-pg/src/types.rs index 85608058f..b146cb27b 100644 --- a/crates/factor-outbound-pg/src/types.rs +++ b/crates/factor-outbound-pg/src/types.rs @@ -1,4 +1,5 @@ -use spin_world::spin::postgres4_1_0::postgres::{DbDataType, DbValue, ParameterValue}; +use anyhow::Result; +use spin_world::spin::postgres4_2_0::postgres::{self as v4, DbDataType, DbValue, ParameterValue}; use tokio_postgres::types::{FromSql, Type}; use tokio_postgres::{types::ToSql, Row}; @@ -162,3 +163,25 @@ pub fn to_sql_parameter(value: &ParameterValue) -> anyhow::Result Ok(Box::new(PgNull)), } } + +// The logic for "vector of ParameterValue to vector of &dyn ToSql" is +// used in multiple places, but needs to be broken into two functions +// because the return value of the first (the Vec) needs to be kept +// around to provide an owner for the refs. +#[allow(clippy::result_large_err)] +pub fn to_sql_parameters( + params: Vec, +) -> Result>, v4::Error> { + params + .iter() + .map(to_sql_parameter) + .collect::>>() + .map_err(|e| v4::Error::BadParameter(format!("{e:?}"))) +} + +pub fn as_sql_parameter_refs(params: &[Box]) -> Vec<&(dyn ToSql + Sync)> { + params + .iter() + .map(|b| b.as_ref() as &(dyn ToSql + Sync)) + .collect() +} diff --git a/crates/factor-outbound-pg/src/types/convert.rs b/crates/factor-outbound-pg/src/types/convert.rs index 5d325f8fe..303cc53b1 100644 --- a/crates/factor-outbound-pg/src/types/convert.rs +++ b/crates/factor-outbound-pg/src/types/convert.rs @@ -2,7 +2,7 @@ //! the tokio_postgres driver. use anyhow::{anyhow, Context}; -use spin_world::spin::postgres4_1_0::postgres::{self as v4}; +use spin_world::spin::postgres4_2_0::postgres::{self as v4}; use super::decimal::RangeableDecimal; diff --git a/crates/factor-outbound-pg/src/types/interval.rs b/crates/factor-outbound-pg/src/types/interval.rs index cd6632d6e..a87bdbde0 100644 --- a/crates/factor-outbound-pg/src/types/interval.rs +++ b/crates/factor-outbound-pg/src/types/interval.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use spin_world::spin::postgres4_1_0::postgres::{self as v4}; +use spin_world::spin::postgres4_2_0::postgres::{self as v4}; use tokio_postgres::types::{FromSql, ToSql, Type}; #[derive(Debug)] diff --git a/crates/factor-outbound-pg/tests/factor_test.rs b/crates/factor-outbound-pg/tests/factor_test.rs index 0c4b6500e..6a475030c 100644 --- a/crates/factor-outbound-pg/tests/factor_test.rs +++ b/crates/factor-outbound-pg/tests/factor_test.rs @@ -8,10 +8,10 @@ use spin_factor_variables::VariablesFactor; use spin_factors::{anyhow, RuntimeFactors}; use spin_factors_test::{toml, TestEnvironment}; use spin_world::async_trait; -use spin_world::spin::postgres4_1_0::postgres::Error as PgError; -use spin_world::spin::postgres4_1_0::postgres::HostConnection; -use spin_world::spin::postgres4_1_0::postgres::{self as v2}; -use spin_world::spin::postgres4_1_0::postgres::{ParameterValue, RowSet}; +use spin_world::spin::postgres4_2_0::postgres::Error as PgError; +use spin_world::spin::postgres4_2_0::postgres::HostConnection; +use spin_world::spin::postgres4_2_0::postgres::{self as v2}; +use spin_world::spin::postgres4_2_0::postgres::{ParameterValue, RowSet}; #[derive(RuntimeFactors)] struct TestFactors { @@ -108,6 +108,7 @@ async fn exercise_query() -> anyhow::Result<()> { // TODO: We can expand this mock to track calls and simulate return values #[derive(Default)] pub struct MockClientFactory {} +#[derive(Clone)] pub struct MockClient {} #[async_trait] @@ -143,4 +144,18 @@ impl Client for MockClient { rows: vec![], }) } + + async fn query_async( + &self, + _statement: String, + _params: Vec, + ) -> Result< + ( + tokio::sync::oneshot::Receiver>, + tokio::sync::mpsc::Receiver>, + ), + v2::Error, + > { + panic!("not implemented"); + } } diff --git a/crates/wasi-async/Cargo.toml b/crates/wasi-async/Cargo.toml new file mode 100644 index 000000000..d352b7c57 --- /dev/null +++ b/crates/wasi-async/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "spin-wasi-async" +version.workspace = true +authors.workspace = true +edition.workspace = true + +[dependencies] +anyhow = { workspace = true } +spin-core = { path = "../core" } +tokio = { workspace = true } diff --git a/crates/wasi-async/src/future.rs b/crates/wasi-async/src/future.rs new file mode 100644 index 000000000..b761d6a8e --- /dev/null +++ b/crates/wasi-async/src/future.rs @@ -0,0 +1,30 @@ +use spin_core::wasmtime; + +pub fn producer(rx: tokio::sync::oneshot::Receiver) -> FutureProducer { + FutureProducer { rx } +} + +pub struct FutureProducer { + rx: tokio::sync::oneshot::Receiver, +} + +impl wasmtime::component::FutureProducer for FutureProducer { + type Item = T; + + fn poll_produce( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + _store: wasmtime::StoreContextMut, + _finish: bool, + ) -> std::task::Poll>> { + use std::future::Future; + use std::task::Poll; + + let pinned_rx = std::pin::Pin::new(&mut self.get_mut().rx); + match pinned_rx.poll(cx) { + Poll::Ready(Err(e)) => Poll::Ready(Err(anyhow::anyhow!("{e:#}"))), + Poll::Ready(Ok(cols)) => Poll::Ready(Ok(Some(cols))), + Poll::Pending => Poll::Pending, + } + } +} diff --git a/crates/wasi-async/src/lib.rs b/crates/wasi-async/src/lib.rs new file mode 100644 index 000000000..fab6801fa --- /dev/null +++ b/crates/wasi-async/src/lib.rs @@ -0,0 +1,2 @@ +pub mod future; +pub mod stream; diff --git a/crates/wasi-async/src/stream.rs b/crates/wasi-async/src/stream.rs new file mode 100644 index 000000000..a7bb86f01 --- /dev/null +++ b/crates/wasi-async/src/stream.rs @@ -0,0 +1,45 @@ +use spin_core::wasmtime; + +pub fn producer(rx: tokio::sync::mpsc::Receiver) -> StreamProducer { + StreamProducer { rx } +} + +pub struct StreamProducer { + rx: tokio::sync::mpsc::Receiver, +} + +impl wasmtime::component::StreamProducer for StreamProducer { + type Item = T; + + type Buffer = Option; + + fn poll_produce<'a>( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + store: wasmtime::StoreContextMut<'a, D>, + mut destination: wasmtime::component::Destination<'a, Self::Item, Self::Buffer>, + finish: bool, + ) -> std::task::Poll> { + use std::task::Poll; + use wasmtime::component::StreamResult; + + if finish { + return Poll::Ready(Ok(StreamResult::Cancelled)); + } + + let remaining = destination.remaining(store); + if remaining.is_some_and(|r| r == 0) { + return Poll::Ready(Ok(StreamResult::Completed)); + } + + let recv = self.get_mut().rx.poll_recv(cx); + match recv { + Poll::Ready(None) => Poll::Ready(Ok(StreamResult::Dropped)), + Poll::Pending => Poll::Pending, + Poll::Ready(Some(row)) => { + destination.set_buffer(Some(row)); + Poll::Ready(Ok(StreamResult::Completed)) + } + } + } +} diff --git a/crates/world/src/conversions.rs b/crates/world/src/conversions.rs index ec977bdfc..53bc78772 100644 --- a/crates/world/src/conversions.rs +++ b/crates/world/src/conversions.rs @@ -3,7 +3,7 @@ use super::*; mod rdbms_types { use super::*; use spin::postgres3_0_0::postgres as pg3; - use spin::postgres4_1_0::postgres as pg4; + use spin::postgres4_2_0::postgres as pg4; impl From for v1::rdbms_types::Column { fn from(value: v2::rdbms_types::Column) -> Self { @@ -422,7 +422,7 @@ mod rdbms_types { mod postgres { use super::*; use spin::postgres3_0_0::postgres as pg3; - use spin::postgres4_1_0::postgres as pg4; + use spin::postgres4_2_0::postgres as pg4; impl From for v1::postgres::RowSet { fn from(value: pg4::RowSet) -> v1::postgres::RowSet { diff --git a/crates/world/src/lib.rs b/crates/world/src/lib.rs index 2d9dc1fe5..099337707 100644 --- a/crates/world/src/lib.rs +++ b/crates/world/src/lib.rs @@ -35,7 +35,7 @@ wasmtime::component::bindgen!({ "fermyon:spin/sqlite.error" => v1::sqlite::Error, "fermyon:spin/variables@2.0.0.error" => v2::variables::Error, "spin:postgres/postgres@3.0.0.error" => spin::postgres3_0_0::postgres::Error, - "spin:postgres/postgres@4.1.0.error" => spin::postgres4_1_0::postgres::Error, + "spin:postgres/postgres@4.2.0.error" => spin::postgres4_2_0::postgres::Error, "spin:sqlite/sqlite.error" => spin::sqlite::sqlite::Error, "wasi:config/store@0.2.0-draft-2024-09-27.error" => wasi::config::store::Error, "wasi:keyvalue/store.error" => wasi::keyvalue::store::Error, @@ -66,7 +66,7 @@ impl spin::sqlite::sqlite::Value { } } -impl spin::postgres4_1_0::postgres::DbValue { +impl spin::postgres4_2_0::postgres::DbValue { pub fn memory_size(&self) -> usize { match self { Self::DbNull diff --git a/wit/deps/spin-postgres@4.2.0/postgres.wit b/wit/deps/spin-postgres@4.2.0/postgres.wit new file mode 100644 index 000000000..feee97d83 --- /dev/null +++ b/wit/deps/spin-postgres@4.2.0/postgres.wit @@ -0,0 +1,184 @@ +package spin:postgres@4.2.0; + +interface postgres { + /// Errors related to interacting with a database. + variant error { + connection-failed(string), + bad-parameter(string), + query-failed(query-error), + value-conversion-failed(string), + other(string) + } + + variant query-error { + /// An error occurred but we do not have structured info for it + text(string), + /// Postgres returned a structured database error + db-error(db-error), + } + + record db-error { + /// Stringised version of the error. This is primarily to facilitate migration of older code. + as-text: string, + severity: string, + code: string, + message: string, + detail: option, + /// Any error information provided by Postgres and not captured above. + extras: list>, + } + + /// Data types for a database column + variant db-data-type { + boolean, + int8, + int16, + int32, + int64, + floating32, + floating64, + str, + binary, + date, + time, + datetime, + timestamp, + uuid, + jsonb, + decimal, + range-int32, + range-int64, + range-decimal, + array-int32, + array-int64, + array-decimal, + array-str, + interval, + other(string), + } + + /// Database values + variant db-value { + boolean(bool), + int8(s8), + int16(s16), + int32(s32), + int64(s64), + floating32(f32), + floating64(f64), + str(string), + binary(list), + date(tuple), // (year, month, day) + time(tuple), // (hour, minute, second, nanosecond) + /// Date-time types are always treated as UTC (without timezone info). + /// The instant is represented as a (year, month, day, hour, minute, second, nanosecond) tuple. + datetime(tuple), + /// Unix timestamp (seconds since epoch) + timestamp(s64), + uuid(string), + jsonb(list), + decimal(string), // I admit defeat. Base 10 + range-int32(tuple>, option>>), + range-int64(tuple>, option>>), + range-decimal(tuple>, option>>), + array-int32(list>), + array-int64(list>), + array-decimal(list>), + array-str(list>), + interval(interval), + db-null, + unsupported(list), + } + + /// Values used in parameterized queries + variant parameter-value { + boolean(bool), + int8(s8), + int16(s16), + int32(s32), + int64(s64), + floating32(f32), + floating64(f64), + str(string), + binary(list), + date(tuple), // (year, month, day) + time(tuple), // (hour, minute, second, nanosecond) + /// Date-time types are always treated as UTC (without timezone info). + /// The instant is represented as a (year, month, day, hour, minute, second, nanosecond) tuple. + datetime(tuple), + /// Unix timestamp (seconds since epoch) + timestamp(s64), + uuid(string), + jsonb(list), + decimal(string), // base 10 + range-int32(tuple>, option>>), + range-int64(tuple>, option>>), + range-decimal(tuple>, option>>), + array-int32(list>), + array-int64(list>), + array-decimal(list>), + array-str(list>), + interval(interval), + db-null, + } + + record interval { + micros: s64, + days: s32, + months: s32, + } + + /// A database column + record column { + name: string, + data-type: db-data-type, + } + + /// A database row + type row = list; + + /// A set of database rows + record row-set { + columns: list, + rows: list, + } + + /// For range types, indicates if each bound is inclusive or exclusive + enum range-bound-kind { + inclusive, + exclusive, + } + + @since(version = 4.1.0) + resource connection-builder { + constructor(address: string); + set-ca-root: func(certificate: string) -> result<_, error>; + build: func() -> result; + @since(version = 4.2.0) + build-async: async func() -> result; + } + + /// A connection to a postgres database. + resource connection { + /// Open a connection to the Postgres instance at `address`. + open: static func(address: string) -> result; + + /// Open a connection to the Postgres instance at `address`. + @since(version = 4.2.0) + open-async: static async func(address: string) -> result; + + /// Query the database. + query: func(statement: string, params: list) -> result; + + /// Query the database. + @since(version = 4.2.0) + query-async: async func(statement: string, params: list) -> result>, stream>>, error>; + + /// Execute command to the database. + execute: func(statement: string, params: list) -> result; + + /// Execute command to the database. + @since(version = 4.2.0) + execute-async: async func(statement: string, params: list) -> result; + } +} diff --git a/wit/world.wit b/wit/world.wit index e66348cc1..0922fac6b 100644 --- a/wit/world.wit +++ b/wit/world.wit @@ -20,7 +20,7 @@ world platform { include fermyon:spin/platform@2.0.0; include wasi:keyvalue/imports@0.2.0-draft2; import spin:postgres/postgres@3.0.0; - import spin:postgres/postgres@4.1.0; + import spin:postgres/postgres@4.2.0; import spin:sqlite/sqlite@3.0.0; import wasi:config/store@0.2.0-draft-2024-09-27; } From 9ea2eb530c842a5942484559c6628920f68b79ae Mon Sep 17 00:00:00 2001 From: itowlson Date: Tue, 24 Feb 2026 15:28:15 +1300 Subject: [PATCH 2/4] Feedback from review Signed-off-by: itowlson --- crates/factor-outbound-pg/src/client.rs | 18 ++++++++++----- crates/factor-outbound-pg/src/host.rs | 3 +-- crates/wasi-async/src/future.rs | 30 ------------------------- crates/wasi-async/src/lib.rs | 1 - crates/wasi-async/src/stream.rs | 12 +++++----- 5 files changed, 21 insertions(+), 43 deletions(-) delete mode 100644 crates/wasi-async/src/future.rs diff --git a/crates/factor-outbound-pg/src/client.rs b/crates/factor-outbound-pg/src/client.rs index 7bf75f268..de393539b 100644 --- a/crates/factor-outbound-pg/src/client.rs +++ b/crates/factor-outbound-pg/src/client.rs @@ -293,23 +293,31 @@ impl Client for Arc { let Some(row) = stm.next().await else { break; }; - // TODO: figure out how to deal with errors here - I think there is like a FutureReader pattern? + let row = match row { Ok(r) => r, Err(e) => { let err = query_failed(e); - rows_tx.send(Err(err)).await.unwrap(); + _ = rows_tx.send(Err(err)).await; break; } }; + if let Some(cols_tx) = cols_tx_opt.take() { - cols_tx.send(infer_columns(&row)).unwrap(); + _ = cols_tx.send(infer_columns(&row)); } + match convert_row(&row) { - Ok(row) => rows_tx.send(Ok(row)).await.unwrap(), + Ok(row) => { + let send_res = rows_tx.send(Ok(row)).await; + if send_res.is_err() { + break; + } + } Err(e) => { let err = v4::Error::QueryFailed(v4::QueryError::Text(format!("{e:?}"))); - rows_tx.send(Err(err)).await.unwrap(); + _ = rows_tx.send(Err(err)).await; + break; } } } diff --git a/crates/factor-outbound-pg/src/host.rs b/crates/factor-outbound-pg/src/host.rs index 8456600dc..40eaa78a9 100644 --- a/crates/factor-outbound-pg/src/host.rs +++ b/crates/factor-outbound-pg/src/host.rs @@ -280,10 +280,9 @@ impl spin_world::spin::postgres4_2_0::postgres::HostConnectio let (col_rx, row_rx) = client.query_async(statement, params).await?; let row_producer = spin_wasi_async::stream::producer(row_rx); - let col_producer = spin_wasi_async::future::producer(col_rx); let (fr, sr) = accessor.with(|mut access| { - let fr = FutureReader::new(access.as_context_mut(), col_producer); + let fr = FutureReader::new(access.as_context_mut(), col_rx); let sr = StreamReader::new(access.as_context_mut(), row_producer); (fr, sr) }); diff --git a/crates/wasi-async/src/future.rs b/crates/wasi-async/src/future.rs deleted file mode 100644 index b761d6a8e..000000000 --- a/crates/wasi-async/src/future.rs +++ /dev/null @@ -1,30 +0,0 @@ -use spin_core::wasmtime; - -pub fn producer(rx: tokio::sync::oneshot::Receiver) -> FutureProducer { - FutureProducer { rx } -} - -pub struct FutureProducer { - rx: tokio::sync::oneshot::Receiver, -} - -impl wasmtime::component::FutureProducer for FutureProducer { - type Item = T; - - fn poll_produce( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - _store: wasmtime::StoreContextMut, - _finish: bool, - ) -> std::task::Poll>> { - use std::future::Future; - use std::task::Poll; - - let pinned_rx = std::pin::Pin::new(&mut self.get_mut().rx); - match pinned_rx.poll(cx) { - Poll::Ready(Err(e)) => Poll::Ready(Err(anyhow::anyhow!("{e:#}"))), - Poll::Ready(Ok(cols)) => Poll::Ready(Ok(Some(cols))), - Poll::Pending => Poll::Pending, - } - } -} diff --git a/crates/wasi-async/src/lib.rs b/crates/wasi-async/src/lib.rs index fab6801fa..baf29e06a 100644 --- a/crates/wasi-async/src/lib.rs +++ b/crates/wasi-async/src/lib.rs @@ -1,2 +1 @@ -pub mod future; pub mod stream; diff --git a/crates/wasi-async/src/stream.rs b/crates/wasi-async/src/stream.rs index a7bb86f01..1e7860cce 100644 --- a/crates/wasi-async/src/stream.rs +++ b/crates/wasi-async/src/stream.rs @@ -23,10 +23,6 @@ impl wasmtime::component::StreamProducer for Str use std::task::Poll; use wasmtime::component::StreamResult; - if finish { - return Poll::Ready(Ok(StreamResult::Cancelled)); - } - let remaining = destination.remaining(store); if remaining.is_some_and(|r| r == 0) { return Poll::Ready(Ok(StreamResult::Completed)); @@ -35,7 +31,13 @@ impl wasmtime::component::StreamProducer for Str let recv = self.get_mut().rx.poll_recv(cx); match recv { Poll::Ready(None) => Poll::Ready(Ok(StreamResult::Dropped)), - Poll::Pending => Poll::Pending, + Poll::Pending => { + if finish { + Poll::Ready(Ok(StreamResult::Cancelled)) + } else { + Poll::Pending + } + } Poll::Ready(Some(row)) => { destination.set_buffer(Some(row)); Poll::Ready(Ok(StreamResult::Completed)) From 44bdd3b4ad5a3f36cdfaac42693cccaab635207b Mon Sep 17 00:00:00 2001 From: itowlson Date: Thu, 26 Feb 2026 09:14:19 +1300 Subject: [PATCH 3/4] Unlovely Signed-off-by: itowlson --- crates/factor-outbound-pg/src/client.rs | 86 ++++++++++--------- crates/factor-outbound-pg/src/host.rs | 32 ++++--- .../factor-outbound-pg/tests/factor_test.rs | 13 +-- wit/deps/spin-postgres@4.2.0/postgres.wit | 2 +- 4 files changed, 69 insertions(+), 64 deletions(-) diff --git a/crates/factor-outbound-pg/src/client.rs b/crates/factor-outbound-pg/src/client.rs index de393539b..b59def844 100644 --- a/crates/factor-outbound-pg/src/client.rs +++ b/crates/factor-outbound-pg/src/client.rs @@ -141,17 +141,13 @@ pub trait Client: Clone + Send + Sync + 'static { max_result_bytes: usize, ) -> Result; - async fn query_async( - &self, - statement: String, - params: Vec, - ) -> Result< - ( - tokio::sync::oneshot::Receiver>, - tokio::sync::mpsc::Receiver>, - ), - v4::Error, - >; + fn query_async(&self, statement: String, params: Vec) -> QueryAsyncResult; +} + +pub struct QueryAsyncResult { + pub columns: tokio::sync::oneshot::Receiver>, + pub rows: tokio::sync::mpsc::Receiver, + pub error: tokio::sync::oneshot::Receiver>, } /// Extract weak-typed error data for WIT purposes @@ -261,34 +257,40 @@ impl Client for Arc { }) } - async fn query_async( - &self, - statement: String, - params: Vec, - ) -> Result< - ( - tokio::sync::oneshot::Receiver>, - tokio::sync::mpsc::Receiver>, - ), - v4::Error, - > { - let params = to_sql_parameters(params)?; - let params_refs = as_sql_parameter_refs(¶ms); - - let stm = self - .as_ref() - .query_raw(&statement, params_refs) - .await - .map_err(query_failed)?; + fn query_async(&self, statement: String, params: Vec) -> QueryAsyncResult { + let this = self.clone(); let (rows_tx, rows_rx) = tokio::sync::mpsc::channel(1000); let (cols_tx, cols_rx) = tokio::sync::oneshot::channel(); + let (err_tx, err_rx) = tokio::sync::oneshot::channel(); let mut cols_tx_opt = Some(cols_tx); - let mut stm = Box::pin(stm); - tokio::spawn(async move { use futures::StreamExt; + + let params = match to_sql_parameters(params) { + Ok(p) => p, + Err(e) => { + _ = err_tx.send(Err(e)); + return; + } + }; + let params_refs = as_sql_parameter_refs(¶ms); + + let stm = match this + .as_ref() + .query_raw(&statement, params_refs) + .await + .map_err(query_failed) + { + Ok(stm) => stm, + Err(e) => { + _ = err_tx.send(Err(e)); + return; + } + }; + let mut stm = Box::pin(stm); + loop { let Some(row) = stm.next().await else { break; @@ -298,8 +300,8 @@ impl Client for Arc { Ok(r) => r, Err(e) => { let err = query_failed(e); - _ = rows_tx.send(Err(err)).await; - break; + _ = err_tx.send(Err(err)); + return; } }; @@ -309,21 +311,27 @@ impl Client for Arc { match convert_row(&row) { Ok(row) => { - let send_res = rows_tx.send(Ok(row)).await; + let send_res = rows_tx.send(row).await; if send_res.is_err() { - break; + return; } } Err(e) => { let err = v4::Error::QueryFailed(v4::QueryError::Text(format!("{e:?}"))); - _ = rows_tx.send(Err(err)).await; - break; + _ = err_tx.send(Err(err)); + return; } } } + + _ = err_tx.send(Ok(())); }); - Ok((cols_rx, rows_rx)) + QueryAsyncResult { + columns: cols_rx, + rows: rows_rx, + error: err_rx, + } } } diff --git a/crates/factor-outbound-pg/src/host.rs b/crates/factor-outbound-pg/src/host.rs index 40eaa78a9..b7c0cea6f 100644 --- a/crates/factor-outbound-pg/src/host.rs +++ b/crates/factor-outbound-pg/src/host.rs @@ -13,13 +13,9 @@ use tracing::instrument; use tracing::Level; use crate::allowed_hosts::AllowedHostChecker; -use crate::client::{Client, ClientFactory, HashableCertificate}; +use crate::client::{Client, ClientFactory, HashableCertificate, QueryAsyncResult}; use crate::InstanceState; -// Declare some types to make Clippy less mad -pub type RowStream = StreamReader>; -pub type ColumnsFuture = FutureReader>; - impl InstanceState { async fn open_connection( &mut self, @@ -263,13 +259,18 @@ impl spin_world::spin::postgres4_2_0::postgres::HostConnectio client.execute(statement, params).await } - #[instrument(name = "spin_outbound_pg.query", skip(accessor, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))] + #[allow(clippy::type_complexity)] // blame bindgen, clippy, blame bindgen + #[instrument(name = "spin_outbound_pg.query_async", skip(accessor, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))] async fn query_async( accessor: &Accessor, connection: Resource, statement: String, params: Vec, - ) -> Result<(ColumnsFuture, RowStream), v4::Error> { + ) -> anyhow::Result<( + FutureReader>, + StreamReader, + FutureReader>, + )> { use wasmtime::AsContextMut; let client = accessor.with(|mut access| { @@ -277,17 +278,22 @@ impl spin_world::spin::postgres4_2_0::postgres::HostConnectio host.connections.get(connection.rep()).unwrap().clone() }); - let (col_rx, row_rx) = client.query_async(statement, params).await?; + let QueryAsyncResult { + columns, + rows, + error, + } = client.query_async(statement, params); - let row_producer = spin_wasi_async::stream::producer(row_rx); + let row_producer = spin_wasi_async::stream::producer(rows); - let (fr, sr) = accessor.with(|mut access| { - let fr = FutureReader::new(access.as_context_mut(), col_rx); + let (fr, sr, efr) = accessor.with(|mut access| { + let fr = FutureReader::new(access.as_context_mut(), columns); let sr = StreamReader::new(access.as_context_mut(), row_producer); - (fr, sr) + let efr = FutureReader::new(access.as_context_mut(), error); + (fr, sr, efr) }); - Ok((fr, sr)) + Ok((fr, sr, efr)) } } diff --git a/crates/factor-outbound-pg/tests/factor_test.rs b/crates/factor-outbound-pg/tests/factor_test.rs index 6a475030c..cc3b723d3 100644 --- a/crates/factor-outbound-pg/tests/factor_test.rs +++ b/crates/factor-outbound-pg/tests/factor_test.rs @@ -3,6 +3,7 @@ use spin_factor_outbound_networking::OutboundNetworkingFactor; use spin_factor_outbound_pg::client::Client; use spin_factor_outbound_pg::client::ClientFactory; use spin_factor_outbound_pg::client::HashableCertificate; +use spin_factor_outbound_pg::client::QueryAsyncResult; use spin_factor_outbound_pg::OutboundPgFactor; use spin_factor_variables::VariablesFactor; use spin_factors::{anyhow, RuntimeFactors}; @@ -145,17 +146,7 @@ impl Client for MockClient { }) } - async fn query_async( - &self, - _statement: String, - _params: Vec, - ) -> Result< - ( - tokio::sync::oneshot::Receiver>, - tokio::sync::mpsc::Receiver>, - ), - v2::Error, - > { + fn query_async(&self, _statement: String, _params: Vec) -> QueryAsyncResult { panic!("not implemented"); } } diff --git a/wit/deps/spin-postgres@4.2.0/postgres.wit b/wit/deps/spin-postgres@4.2.0/postgres.wit index feee97d83..1c756d366 100644 --- a/wit/deps/spin-postgres@4.2.0/postgres.wit +++ b/wit/deps/spin-postgres@4.2.0/postgres.wit @@ -172,7 +172,7 @@ interface postgres { /// Query the database. @since(version = 4.2.0) - query-async: async func(statement: string, params: list) -> result>, stream>>, error>; + query-async: async func(statement: string, params: list) -> tuple>, stream, future>>; /// Execute command to the database. execute: func(statement: string, params: list) -> result; From 96c06a8c5cf4ec68fce7dc500c7a36728071c9d8 Mon Sep 17 00:00:00 2001 From: itowlson Date: Thu, 26 Feb 2026 10:47:44 +1300 Subject: [PATCH 4/4] This is less unlovely Signed-off-by: itowlson --- crates/factor-outbound-pg/src/client.rs | 105 ++++++++++-------- crates/factor-outbound-pg/src/host.rs | 23 ++-- .../factor-outbound-pg/tests/factor_test.rs | 6 +- wit/deps/spin-postgres@4.2.0/postgres.wit | 2 +- 4 files changed, 77 insertions(+), 59 deletions(-) diff --git a/crates/factor-outbound-pg/src/client.rs b/crates/factor-outbound-pg/src/client.rs index b59def844..fdfcb1c5d 100644 --- a/crates/factor-outbound-pg/src/client.rs +++ b/crates/factor-outbound-pg/src/client.rs @@ -141,11 +141,15 @@ pub trait Client: Clone + Send + Sync + 'static { max_result_bytes: usize, ) -> Result; - fn query_async(&self, statement: String, params: Vec) -> QueryAsyncResult; + async fn query_async( + &self, + statement: String, + params: Vec, + ) -> Result; } pub struct QueryAsyncResult { - pub columns: tokio::sync::oneshot::Receiver>, + pub columns: Vec, pub rows: tokio::sync::mpsc::Receiver, pub error: tokio::sync::oneshot::Receiver>, } @@ -257,39 +261,62 @@ impl Client for Arc { }) } - fn query_async(&self, statement: String, params: Vec) -> QueryAsyncResult { + async fn query_async( + &self, + statement: String, + params: Vec, + ) -> Result { + use futures::StreamExt; + + let params = to_sql_parameters(params)?; + let params_refs = as_sql_parameter_refs(¶ms); + let this = self.clone(); + let stm = this + .as_ref() + .query_raw(&statement, params_refs) + .await + .map_err(query_failed)?; + let mut stm = Box::pin(stm); + let (rows_tx, rows_rx) = tokio::sync::mpsc::channel(1000); - let (cols_tx, cols_rx) = tokio::sync::oneshot::channel(); let (err_tx, err_rx) = tokio::sync::oneshot::channel(); - let mut cols_tx_opt = Some(cols_tx); - tokio::spawn(async move { - use futures::StreamExt; + let Some(row) = stm.next().await else { + _ = err_tx.send(Ok(())); + return Ok(QueryAsyncResult { + columns: vec![], + rows: rows_rx, + error: err_rx, + }); + }; - let params = match to_sql_parameters(params) { - Ok(p) => p, - Err(e) => { - _ = err_tx.send(Err(e)); - return; - } - }; - let params_refs = as_sql_parameter_refs(¶ms); - - let stm = match this - .as_ref() - .query_raw(&statement, params_refs) - .await - .map_err(query_failed) - { - Ok(stm) => stm, - Err(e) => { - _ = err_tx.send(Err(e)); - return; + let row = row.map_err(query_failed)?; + + let cols = infer_columns(&row); + + // macro rather than closure to avoid taking ownership of err_tx + macro_rules! try_send_row { + ($row:ident) => { + match convert_row(&$row) { + Ok(row) => { + let send_res = rows_tx.send(row).await; + if send_res.is_err() { + return; + } + } + Err(e) => { + let err = v4::Error::QueryFailed(v4::QueryError::Text(format!("{e:?}"))); + _ = err_tx.send(Err(err)); + return; + } } }; - let mut stm = Box::pin(stm); + } + + tokio::spawn(async move { + try_send_row!(row); loop { let Some(row) = stm.next().await else { @@ -305,33 +332,17 @@ impl Client for Arc { } }; - if let Some(cols_tx) = cols_tx_opt.take() { - _ = cols_tx.send(infer_columns(&row)); - } - - match convert_row(&row) { - Ok(row) => { - let send_res = rows_tx.send(row).await; - if send_res.is_err() { - return; - } - } - Err(e) => { - let err = v4::Error::QueryFailed(v4::QueryError::Text(format!("{e:?}"))); - _ = err_tx.send(Err(err)); - return; - } - } + try_send_row!(row); } _ = err_tx.send(Ok(())); }); - QueryAsyncResult { - columns: cols_rx, + Ok(QueryAsyncResult { + columns: cols, rows: rows_rx, error: err_rx, - } + }) } } diff --git a/crates/factor-outbound-pg/src/host.rs b/crates/factor-outbound-pg/src/host.rs index b7c0cea6f..d44a93d3f 100644 --- a/crates/factor-outbound-pg/src/host.rs +++ b/crates/factor-outbound-pg/src/host.rs @@ -266,11 +266,14 @@ impl spin_world::spin::postgres4_2_0::postgres::HostConnectio connection: Resource, statement: String, params: Vec, - ) -> anyhow::Result<( - FutureReader>, - StreamReader, - FutureReader>, - )> { + ) -> Result< + ( + Vec, + StreamReader, + FutureReader>, + ), + v4::Error, + > { use wasmtime::AsContextMut; let client = accessor.with(|mut access| { @@ -282,18 +285,18 @@ impl spin_world::spin::postgres4_2_0::postgres::HostConnectio columns, rows, error, - } = client.query_async(statement, params); + } = client.query_async(statement, params).await?; let row_producer = spin_wasi_async::stream::producer(rows); - let (fr, sr, efr) = accessor.with(|mut access| { - let fr = FutureReader::new(access.as_context_mut(), columns); + let (sr, efr) = accessor.with(|mut access| { + //let fr = FutureReader::new(access.as_context_mut(), columns); let sr = StreamReader::new(access.as_context_mut(), row_producer); let efr = FutureReader::new(access.as_context_mut(), error); - (fr, sr, efr) + (sr, efr) }); - Ok((fr, sr, efr)) + Ok((columns, sr, efr)) } } diff --git a/crates/factor-outbound-pg/tests/factor_test.rs b/crates/factor-outbound-pg/tests/factor_test.rs index cc3b723d3..96c67e7a4 100644 --- a/crates/factor-outbound-pg/tests/factor_test.rs +++ b/crates/factor-outbound-pg/tests/factor_test.rs @@ -146,7 +146,11 @@ impl Client for MockClient { }) } - fn query_async(&self, _statement: String, _params: Vec) -> QueryAsyncResult { + async fn query_async( + &self, + _statement: String, + _params: Vec, + ) -> Result { panic!("not implemented"); } } diff --git a/wit/deps/spin-postgres@4.2.0/postgres.wit b/wit/deps/spin-postgres@4.2.0/postgres.wit index 1c756d366..093e8a125 100644 --- a/wit/deps/spin-postgres@4.2.0/postgres.wit +++ b/wit/deps/spin-postgres@4.2.0/postgres.wit @@ -172,7 +172,7 @@ interface postgres { /// Query the database. @since(version = 4.2.0) - query-async: async func(statement: string, params: list) -> tuple>, stream, future>>; + query-async: async func(statement: string, params: list) -> result, stream, future>>, error>; /// Execute command to the database. execute: func(statement: string, params: list) -> result;