Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions crates/factor-outbound-pg/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ anyhow = { workspace = true }
bytes = {workspace = true }
chrono = { workspace = true }
deadpool-postgres = { version = "0.14", features = ["rt_tokio_1"] }
futures = { workspace = true }
moka = { version = "0.12", features = ["sync"] }
native-tls = "0.2"
postgres-native-tls = "0.5"
Expand All @@ -20,6 +21,7 @@ spin-factor-otel = { path = "../factor-otel" }
spin-factor-outbound-networking = { path = "../factor-outbound-networking" }
spin-factors = { path = "../factors" }
spin-resource-table = { path = "../table" }
spin-wasi-async = { path = "../wasi-async" }
spin-world = { path = "../world" }
tokio = { workspace = true, features = ["rt-multi-thread"] }
tokio-postgres = { version = "0.7", features = ["with-chrono-0_4", "with-serde_json-1", "with-uuid-1"] }
Expand Down
70 changes: 70 additions & 0 deletions crates/factor-outbound-pg/src/allowed_hosts.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use std::sync::Arc;

use spin_factor_outbound_networking::config::allowed_hosts::OutboundAllowedHosts;
use spin_world::spin::postgres4_2_0::postgres::{self as v4};

/// Encapsulates checking of a PostgreSQL address/connection string against
/// an allow-list.
///
/// This is broken out as a distinct object to allow it to be synchronously retrieved
/// within a P3 Accessor block and then asynchronously queried outside the block.
#[derive(Clone)]
pub(crate) struct AllowedHostChecker {
allowed_hosts: Arc<OutboundAllowedHosts>,
}

impl AllowedHostChecker {
pub fn new(allowed_hosts: OutboundAllowedHosts) -> Self {
Self {
allowed_hosts: Arc::new(allowed_hosts),
}
}
#[allow(clippy::result_large_err)]
pub async fn ensure_address_allowed(&self, address: &str) -> Result<(), v4::Error> {
fn conn_failed(message: impl Into<String>) -> v4::Error {
v4::Error::ConnectionFailed(message.into())
}
fn err_other(err: anyhow::Error) -> v4::Error {
v4::Error::Other(err.to_string())
}

let config = address
.parse::<tokio_postgres::Config>()
.map_err(|e| conn_failed(e.to_string()))?;

for (i, host) in config.get_hosts().iter().enumerate() {
match host {
tokio_postgres::config::Host::Tcp(address) => {
let ports = config.get_ports();
// The port we use is either:
// * The port at the same index as the host
// * The first port if there is only one port
let port = ports.get(i).or_else(|| {
if ports.len() == 1 {
ports.first()
} else {
None
}
});
let port_str = port.map(|p| format!(":{p}")).unwrap_or_default();
let url = format!("{address}{port_str}");
if !self
.allowed_hosts
.check_url(&url, "postgres")
.await
.map_err(err_other)?
{
return Err(conn_failed(format!(
"address postgres://{url} is not permitted"
)));
}
}
#[cfg(unix)]
tokio_postgres::config::Host::Unix(_) => {
return Err(conn_failed("Unix sockets are not supported on WebAssembly"));
}
}
}
Ok(())
}
}
124 changes: 108 additions & 16 deletions crates/factor-outbound-pg/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
use std::sync::Arc;

use anyhow::{Context, Result};
use native_tls::TlsConnector;
use postgres_native_tls::MakeTlsConnector;
use spin_world::async_trait;
use spin_world::spin::postgres4_0_0::postgres::{
use spin_world::spin::postgres4_2_0::postgres::{
self as v4, Column, DbValue, ParameterValue, RowSet,
};
use tokio_postgres::types::ToSql;
use tokio_postgres::{config::SslMode, NoTls, Row};

use crate::types::{convert_data_type, convert_entry, to_sql_parameter};
use crate::types::{
as_sql_parameter_refs, convert_data_type, convert_entry, to_sql_parameter, to_sql_parameters,
};

/// Max connections in a given address' connection pool
const CONNECTION_POOL_SIZE: usize = 64;
Expand Down Expand Up @@ -40,7 +44,7 @@ impl Default for PooledTokioClientFactory {

#[async_trait]
impl ClientFactory for PooledTokioClientFactory {
type Client = deadpool_postgres::Object;
type Client = Arc<deadpool_postgres::Object>;

async fn get_client(&self, address: &str) -> Result<Self::Client> {
let pool = self
Expand All @@ -49,7 +53,7 @@ impl ClientFactory for PooledTokioClientFactory {
.map_err(ArcError)
.context("establishing PostgreSQL connection pool")?;

Ok(pool.get().await?)
Ok(Arc::new(pool.get().await?))
}
}

Expand Down Expand Up @@ -85,7 +89,7 @@ fn create_connection_pool(address: &str) -> Result<deadpool_postgres::Pool> {
}

#[async_trait]
pub trait Client: Send + Sync + 'static {
pub trait Client: Clone + Send + Sync + 'static {
async fn execute(
&self,
statement: String,
Expand All @@ -97,6 +101,18 @@ pub trait Client: Send + Sync + 'static {
statement: String,
params: Vec<ParameterValue>,
) -> Result<RowSet, v4::Error>;

async fn query_async(
&self,
statement: String,
params: Vec<ParameterValue>,
) -> Result<QueryAsyncResult, v4::Error>;
}

pub struct QueryAsyncResult {
pub columns: Vec<v4::Column>,
pub rows: tokio::sync::mpsc::Receiver<v4::Row>,
pub error: tokio::sync::oneshot::Receiver<Result<(), v4::Error>>,
}

/// Extract weak-typed error data for WIT purposes
Expand Down Expand Up @@ -142,7 +158,7 @@ fn query_failed(e: tokio_postgres::error::Error) -> v4::Error {
}

#[async_trait]
impl Client for deadpool_postgres::Object {
impl Client for Arc<deadpool_postgres::Object> {
async fn execute(
&self,
statement: String,
Expand Down Expand Up @@ -170,16 +186,8 @@ impl Client for deadpool_postgres::Object {
statement: String,
params: Vec<ParameterValue>,
) -> Result<RowSet, v4::Error> {
let params = params
.iter()
.map(to_sql_parameter)
.collect::<Result<Vec<_>>>()
.map_err(|e| v4::Error::BadParameter(format!("{e:?}")))?;

let params_refs: Vec<&(dyn ToSql + Sync)> = params
.iter()
.map(|b| b.as_ref() as &(dyn ToSql + Sync))
.collect();
let params = to_sql_parameters(params)?;
let params_refs = as_sql_parameter_refs(&params);

let results = self
.as_ref()
Expand All @@ -203,6 +211,90 @@ impl Client for deadpool_postgres::Object {

Ok(RowSet { columns, rows })
}

async fn query_async(
&self,
statement: String,
params: Vec<ParameterValue>,
) -> Result<QueryAsyncResult, v4::Error> {
use futures::StreamExt;

let params = to_sql_parameters(params)?;
let params_refs = as_sql_parameter_refs(&params);

let this = self.clone();

let stm = this
.as_ref()
.query_raw(&statement, params_refs)
.await
.map_err(query_failed)?;
let mut stm = Box::pin(stm);

let (rows_tx, rows_rx) = tokio::sync::mpsc::channel(1000);
let (err_tx, err_rx) = tokio::sync::oneshot::channel();

let Some(row) = stm.next().await else {
_ = err_tx.send(Ok(()));
return Ok(QueryAsyncResult {
columns: vec![],
rows: rows_rx,
error: err_rx,
});
};

let row = row.map_err(query_failed)?;

let cols = infer_columns(&row);

// macro rather than closure to avoid taking ownership of err_tx
macro_rules! try_send_row {
($row:ident) => {
match convert_row(&$row) {
Ok(row) => {
let send_res = rows_tx.send(row).await;
if send_res.is_err() {
return;
}
}
Err(e) => {
let err = v4::Error::QueryFailed(v4::QueryError::Text(format!("{e:?}")));
_ = err_tx.send(Err(err));
return;
}
}
};
}

tokio::spawn(async move {
try_send_row!(row);

loop {
let Some(row) = stm.next().await else {
break;
};

let row = match row {
Ok(r) => r,
Err(e) => {
let err = query_failed(e);
_ = err_tx.send(Err(err));
return;
}
};

try_send_row!(row);
}

_ = err_tx.send(Ok(()));
});

Ok(QueryAsyncResult {
columns: cols,
rows: rows_rx,
error: err_rx,
})
}
}

fn infer_columns(row: &Row) -> Vec<Column> {
Expand Down
Loading