diff --git a/ipa-core/src/app.rs b/ipa-core/src/app.rs index a87c253a0..03d1959da 100644 --- a/ipa-core/src/app.rs +++ b/ipa-core/src/app.rs @@ -6,7 +6,7 @@ use crate::{ cli::LoggingHandle, executor::IpaRuntime, helpers::{ - query::{CompareStatusRequest, PrepareQuery, QueryConfig, QueryInput}, + query::{PrepareQuery, QueryConfig, QueryInput}, routing::{Addr, RouteId}, ApiError, BodyStream, HandlerBox, HandlerRef, HelperIdentity, HelperResponse, MpcTransportImpl, RequestHandler, ShardTransportImpl, Transport, TransportIdentity, @@ -208,8 +208,8 @@ impl RequestHandler for Inner { HelperResponse::from(qp.prepare_shard(&self.shard_transport, req)?) } RouteId::QueryStatus => { - let req = req.into::()?; - HelperResponse::from(qp.shard_status(&self.shard_transport, &req)?) + let query_id = ext_query_id(&req)?; + HelperResponse::from(qp.shard_status(&self.shard_transport, query_id)?) } RouteId::CompleteQuery => { // The processing flow for this API is exactly the same, regardless diff --git a/ipa-core/src/helpers/gateway/transport.rs b/ipa-core/src/helpers/gateway/transport.rs index 09a1053cb..7a870bdca 100644 --- a/ipa-core/src/helpers/gateway/transport.rs +++ b/ipa-core/src/helpers/gateway/transport.rs @@ -34,6 +34,7 @@ pub(super) struct Transports, S: Transport::RecordsStream; + type SendResponse = ::SendResponse; type Error = SendToRoleError; fn identity(&self) -> Role { @@ -60,7 +61,7 @@ impl Transport for RoleResolvingTransport { dest: Role, route: R, data: D, - ) -> Result<(), Self::Error> + ) -> Result, Self::Error> where Option: From, Option: From, diff --git a/ipa-core/src/helpers/transport/handler.rs b/ipa-core/src/helpers/transport/handler.rs index 9a1f1b457..130e752e8 100644 --- a/ipa-core/src/helpers/transport/handler.rs +++ b/ipa-core/src/helpers/transport/handler.rs @@ -1,9 +1,11 @@ use std::{fmt::Debug, future::Future, marker::PhantomData}; use async_trait::async_trait; -use serde::de::DeserializeOwned; +use futures_util::TryStreamExt; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde_json::json; +use super::BytesStream; use crate::{ error::BoxError, helpers::{ @@ -113,6 +115,18 @@ impl HelperResponse { pub fn try_into_owned(self) -> Result { serde_json::from_slice(&self.body) } + + /// Asynchronously collects and returns a newly created `HelperResponse`. + /// + /// # Errors + /// + /// If the `BytesStream` cannot be collected into a `BytesMut`, an error is returned. + pub async fn from_bytesstream(value: B) -> Result { + let bytes: bytes::BytesMut = value.try_collect().await?; + Ok(Self { + body: bytes.to_vec(), + }) + } } impl From for HelperResponse { @@ -128,13 +142,26 @@ impl From<()> for HelperResponse { } } +#[derive(Deserialize, Serialize)] +struct QueryStatusResponse { + status: QueryStatus, +} + impl From for HelperResponse { fn from(value: QueryStatus) -> Self { - let v = serde_json::to_vec(&json!({"status": value})).unwrap(); + let response = QueryStatusResponse { status: value }; + let v = serde_json::to_vec(&response).unwrap(); Self { body: v } } } +impl From for QueryStatus { + fn from(value: HelperResponse) -> Self { + let response: QueryStatusResponse = serde_json::from_slice(value.body.as_ref()).unwrap(); + response.status + } +} + impl From for HelperResponse { fn from(value: QueryKilled) -> Self { let v = serde_json::to_vec(&json!({"query_id": value.0, "status": "killed"})).unwrap(); diff --git a/ipa-core/src/helpers/transport/in_memory/transport.rs b/ipa-core/src/helpers/transport/in_memory/transport.rs index 504921eb5..8f26d07bb 100644 --- a/ipa-core/src/helpers/transport/in_memory/transport.rs +++ b/ipa-core/src/helpers/transport/in_memory/transport.rs @@ -12,7 +12,7 @@ use ::tokio::sync::{ }; use async_trait::async_trait; use bytes::Bytes; -use futures::{Stream, StreamExt}; +use futures::{stream, Stream, StreamExt}; #[cfg(all(feature = "shuttle", test))] use shuttle::future as tokio; use tokio_stream::wrappers::ReceiverStream; @@ -156,6 +156,7 @@ impl InMemoryTransport { impl Transport for Weak> { type Identity = I; type RecordsStream = ReceiveRecords; + type SendResponse = InMemoryStream; type Error = Error; fn identity(&self) -> I { @@ -182,7 +183,7 @@ impl Transport for Weak> { dest: I, route: R, data: D, - ) -> Result<(), Error> + ) -> Result, Error> where Option: From, Option: From, @@ -214,7 +215,7 @@ impl Transport for Weak> { io::Error::new::(io::ErrorKind::ConnectionAborted, "channel closed".into()) })?; - ack_rx + let res = ack_rx .await .map_err(|_recv_error| Error::Rejected { dest, @@ -224,8 +225,11 @@ impl Transport for Weak> { dest, inner: e.into(), })?; - - Ok(()) + let body_bytes = res.into_body(); + if body_bytes.is_empty() { + return Ok(None); + } + Ok(Some(InMemoryStream::wrap_bytes(body_bytes))) } fn receive>( @@ -247,6 +251,10 @@ pub struct InMemoryStream { } impl InMemoryStream { + fn wrap_bytes(bytes: Vec) -> Self { + InMemoryStream::wrap(stream::once(async { Ok(Bytes::from(bytes)) })) + } + fn wrap + Send + 'static>(value: S) -> Self { Self { inner: Box::pin(value), diff --git a/ipa-core/src/helpers/transport/mod.rs b/ipa-core/src/helpers/transport/mod.rs index 00021c62c..7cfe9a39f 100644 --- a/ipa-core/src/helpers/transport/mod.rs +++ b/ipa-core/src/helpers/transport/mod.rs @@ -310,7 +310,7 @@ impl RouteParams for (RouteId, QueryId) { } #[derive(thiserror::Error, Debug)] -#[error("One or more peers rejected the request: {failures:?}")] +#[error("One or more peer shards rejected the broadcast request: {failures:?}")] pub struct BroadcastError { pub failures: Vec<(I, E)>, } @@ -325,7 +325,10 @@ impl From> for BroadcastError #[async_trait] pub trait Transport: Clone + Send + Sync + 'static { type Identity: TransportIdentity; + /// They type used by [`receive`]. type RecordsStream: BytesStream; + /// The type used for responses to [`send`] and [`broadcast`]. + type SendResponse: BytesStream; type Error: Debug + Send; /// Return my identity in the network (MPC or Sharded) @@ -349,7 +352,7 @@ pub trait Transport: Clone + Send + Sync + 'static { dest: Self::Identity, route: R, data: D, - ) -> Result<(), Self::Error> + ) -> Result, Self::Error> where Option: From, Option: From, @@ -371,7 +374,10 @@ pub trait Transport: Clone + Send + Sync + 'static { async fn broadcast( &self, route: R, - ) -> Result<(), BroadcastError> + ) -> Result< + Vec<(Self::Identity, Option)>, + BroadcastError, + > where Option: From, Option: From, @@ -388,14 +394,16 @@ pub trait Transport: Clone + Send + Sync + 'static { } let mut errs = Vec::new(); + let mut responses = Vec::new(); while let Some(r) = futs.next().await { - if let Err(e) = r.1 { - errs.push((r.0, e)); + match r.1 { + Err(e) => errs.push((r.0, e)), + Ok(re) => responses.push((r.0, re)), } } if errs.is_empty() { - Ok(()) + Ok(responses) } else { Err(errs.into()) } diff --git a/ipa-core/src/helpers/transport/query/mod.rs b/ipa-core/src/helpers/transport/query/mod.rs index cd3e389d1..08cc4d4e3 100644 --- a/ipa-core/src/helpers/transport/query/mod.rs +++ b/ipa-core/src/helpers/transport/query/mod.rs @@ -15,7 +15,6 @@ use crate::{ RoleAssignment, RouteParams, }, protocol::QueryId, - query::QueryStatus, }; #[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq, Serialize)] @@ -239,33 +238,6 @@ impl Debug for QueryInput { } } -#[derive(Clone, Debug, Serialize, Deserialize)] -#[cfg_attr(test, derive(PartialEq, Eq))] -pub struct CompareStatusRequest { - pub query_id: QueryId, - pub status: QueryStatus, -} - -impl RouteParams for CompareStatusRequest { - type Params = String; - - fn resource_identifier(&self) -> RouteId { - RouteId::QueryStatus - } - - fn query_id(&self) -> QueryId { - self.query_id - } - - fn gate(&self) -> NoStep { - NoStep - } - - fn extra(&self) -> Self::Params { - serde_json::to_string(self).unwrap() - } -} - #[derive(Copy, Clone, Debug, Serialize, Deserialize)] #[cfg_attr(test, derive(PartialEq, Eq))] pub enum QueryType { diff --git a/ipa-core/src/helpers/transport/stream/axum_body.rs b/ipa-core/src/helpers/transport/stream/axum_body.rs index 5560f326e..234eba88e 100644 --- a/ipa-core/src/helpers/transport/stream/axum_body.rs +++ b/ipa-core/src/helpers/transport/stream/axum_body.rs @@ -16,7 +16,8 @@ use crate::{error::BoxError, helpers::BytesStream}; pub struct WrappedAxumBodyStream(#[pin] BodyDataStream); impl WrappedAxumBodyStream { - pub(super) fn new(b: Body) -> Self { + #[must_use] + pub fn new(b: Body) -> Self { Self(b.into_data_stream()) } diff --git a/ipa-core/src/net/client/mod.rs b/ipa-core/src/net/client/mod.rs index 42d7c1377..36a4c3fe6 100644 --- a/ipa-core/src/net/client/mod.rs +++ b/ipa-core/src/net/client/mod.rs @@ -33,10 +33,10 @@ use crate::{ }, executor::IpaRuntime, helpers::{ - query::{CompareStatusRequest, PrepareQuery, QueryConfig, QueryInput}, - TransportIdentity, + query::{PrepareQuery, QueryConfig, QueryInput}, + BodyStream, TransportIdentity, }, - net::{error::ShardQueryStatusMismatchError, http_serde, Error, CRYPTO_PROVIDER}, + net::{http_serde, Error, CRYPTO_PROVIDER}, protocol::{Gate, QueryId}, }; @@ -385,30 +385,44 @@ impl IpaHttpClient { resp_ok(resp).await } - /// This API is used by leader shards in MPC to request query status information on peers. - /// If a given peer has status that doesn't match the one provided by the leader, it responds - /// with 412 error and encodes its status inside the response body. Otherwise, 200 is returned. + /// Sends a query status request and returns the response bytes. /// /// # Errors - /// If the request has illegal arguments, or fails to be delivered - pub async fn status_match(&self, data: CompareStatusRequest) -> Result<(), Error> { - let req = http_serde::query::status_match::try_into_http_request( - &data, - self.scheme.clone(), - self.authority.clone(), - )?; + /// If the request has illegal arguments, or fails to deliver to helper + async fn query_status_impl(&self, query_id: QueryId) -> Result { + let req = http_serde::query::status::Request::new(query_id); + let req = req.try_into_http_request(self.scheme.clone(), self.authority.clone())?; let resp = self.request(req).await?; - - match resp.status() { - StatusCode::OK => Ok(()), - StatusCode::PRECONDITION_FAILED => { - let bytes = response_to_bytes(resp).await?; - let err = serde_json::from_slice::(&bytes)?; - Err(err.into()) - } - _ => Err(Error::from_failed_resp(resp).await), + if resp.status().is_success() { + Ok(response_to_bytes(resp).await?) + } else { + Err(Error::from_failed_resp(resp).await) } } + /// Retrieves the status of a query as a byte stream. + /// + /// This function calls `query_status_impl` and returns the response bytes as a `BodyStream`. + /// + /// # Errors + /// If the request has illegal arguments, or fails to deliver to helper + pub async fn query_status_bytes(&self, query_id: QueryId) -> Result { + let bytes = self.query_status_impl(query_id).await?; + Ok(BodyStream::from(bytes.to_vec())) + } + /// Retrieves the status of a query. + /// + /// This function calls `query_status_impl` and deserializes the response bytes into a `QueryStatus` struct. + /// + /// # Errors + /// If the request has illegal arguments, or fails to deliver to helper + pub async fn query_status( + &self, + query_id: QueryId, + ) -> Result { + let bytes = self.query_status_impl(query_id).await?; + let http_serde::query::status::ResponseBody { status } = serde_json::from_slice(&bytes)?; + Ok(status) + } } impl IpaHttpClient { @@ -467,29 +481,6 @@ impl IpaHttpClient { resp_ok(resp).await } - /// Retrieve the status of a query. - /// - /// ## Errors - /// If the request has illegal arguments, or fails to deliver to helper - #[cfg(any(all(test, not(feature = "shuttle")), feature = "cli"))] - pub async fn query_status( - &self, - query_id: QueryId, - ) -> Result { - let req = http_serde::query::status::Request::new(query_id); - let req = req.try_into_http_request(self.scheme.clone(), self.authority.clone())?; - - let resp = self.request(req).await?; - if resp.status().is_success() { - let bytes = response_to_bytes(resp).await?; - let http_serde::query::status::ResponseBody { status } = - serde_json::from_slice(&bytes)?; - Ok(status) - } else { - Err(Error::from_failed_resp(resp).await) - } - } - /// Wait for completion of the query and pull the results of this query. This is a blocking /// API so it is not supposed to be used outside of CLI context. /// diff --git a/ipa-core/src/net/error.rs b/ipa-core/src/net/error.rs index f137c3232..6665f9d20 100644 --- a/ipa-core/src/net/error.rs +++ b/ipa-core/src/net/error.rs @@ -4,8 +4,7 @@ use axum::{ }; use crate::{ - error::BoxError, net::client::ResponseFromEndpoint, protocol::QueryId, query::QueryStatus, - sharding::ShardIndex, + error::BoxError, net::client::ResponseFromEndpoint, protocol::QueryId, sharding::ShardIndex, }; #[derive(thiserror::Error, Debug)] @@ -62,11 +61,6 @@ pub enum Error { }, #[error("{code}: {error}")] Application { code: StatusCode, error: BoxError }, - #[error(transparent)] - ShardQueryStatusMismatch { - #[from] - error: ShardQueryStatusMismatchError, - }, } impl Error { @@ -148,12 +142,6 @@ pub struct ShardError { pub source: Error, } -#[derive(Debug, thiserror::Error, serde::Deserialize, serde::Serialize)] -#[error("Query status mismatch. Actual status: {actual}")] -pub struct ShardQueryStatusMismatchError { - pub actual: QueryStatus, -} - impl IntoResponse for Error { fn into_response(self) -> Response { let status_code = match self { @@ -177,13 +165,6 @@ impl IntoResponse for Error { | Self::MissingExtension(_) => StatusCode::INTERNAL_SERVER_ERROR, Self::Application { code, .. } => code, - Self::ShardQueryStatusMismatch { error } => { - return ( - StatusCode::PRECONDITION_FAILED, - serde_json::to_string(&error).unwrap(), - ) - .into_response(); - } }; (status_code, self.to_string()).into_response() } diff --git a/ipa-core/src/net/http_serde.rs b/ipa-core/src/net/http_serde.rs index be2bc6d83..f6e095745 100644 --- a/ipa-core/src/net/http_serde.rs +++ b/ipa-core/src/net/http_serde.rs @@ -498,12 +498,10 @@ pub mod query { } impl Request { - #[cfg(any(all(test, not(feature = "shuttle")), feature = "cli"))] // needed because client is blocking; remove when non-blocking pub fn new(query_id: QueryId) -> Self { Self { query_id } } - #[cfg(any(all(test, not(feature = "shuttle")), feature = "cli"))] // needed because client is blocking; remove when non-blocking pub fn try_into_http_request( self, scheme: axum::http::uri::Scheme, @@ -670,48 +668,4 @@ pub mod query { pub const AXUM_PATH: &str = "/:query_id/kill"; } - - pub mod status_match { - use serde::{Deserialize, Serialize}; - - use crate::{helpers::query::CompareStatusRequest, query::QueryStatus}; - - #[derive(Serialize, Deserialize)] - pub struct StatusQueryString { - pub status: QueryStatus, - } - - impl StatusQueryString { - fn url_encode(&self) -> String { - // todo: serde urlencoded - format!("status={}", self.status) - } - } - - impl From for StatusQueryString { - fn from(value: QueryStatus) -> Self { - Self { status: value } - } - } - - pub fn try_into_http_request( - req: &CompareStatusRequest, - scheme: axum::http::uri::Scheme, - authority: axum::http::uri::Authority, - ) -> crate::net::http_serde::OutgoingRequest { - let uri = axum::http::uri::Uri::builder() - .scheme(scheme) - .authority(authority) - .path_and_query(format!( - "{}/{}/status-match?{}", - crate::net::http_serde::query::BASE_AXUM_PATH, - req.query_id.as_ref(), - StatusQueryString::from(req.status).url_encode(), - )) - .build()?; - Ok(hyper::Request::get(uri).body(axum::body::Body::empty())?) - } - - pub const AXUM_PATH: &str = "/:query_id/status-match"; - } } diff --git a/ipa-core/src/net/server/handlers/query/mod.rs b/ipa-core/src/net/server/handlers/query/mod.rs index 8a9881bb7..70d59db76 100644 --- a/ipa-core/src/net/server/handlers/query/mod.rs +++ b/ipa-core/src/net/server/handlers/query/mod.rs @@ -4,7 +4,6 @@ mod kill; mod prepare; mod results; mod status; -mod status_match; mod step; use std::marker::PhantomData; @@ -37,9 +36,9 @@ pub fn query_router(transport: MpcHttpTransport) -> Router { Router::new() .merge(create::router(transport.clone())) .merge(input::router(transport.clone())) - .merge(status::router(transport.clone())) .merge(kill::router(transport.clone())) - .merge(results::router(transport.inner_transport)) + .merge(results::router(Arc::clone(&transport.inner_transport))) + .merge(status::router(transport.inner_transport)) } /// Construct router for helper-to-helper communications @@ -61,8 +60,8 @@ pub fn s2s_router(transport: Arc>) -> Router { Router::new() .merge(step::router(Arc::clone(&transport))) .merge(prepare::router(Arc::clone(&transport))) - .merge(results::router(Arc::clone(&transport))) - .merge(status_match::router(transport)) + .merge(status::router(Arc::clone(&transport))) + .merge(results::router(transport)) .layer(layer_fn(HelperAuthentication::<_, Shard>::new)) } diff --git a/ipa-core/src/net/server/handlers/query/status.rs b/ipa-core/src/net/server/handlers/query/status.rs index 0056b76d0..7bae0e26e 100644 --- a/ipa-core/src/net/server/handlers/query/status.rs +++ b/ipa-core/src/net/server/handlers/query/status.rs @@ -6,25 +6,29 @@ use crate::{ net::{ http_serde::query::status::{self, Request}, server::Error, - transport::MpcHttpTransport, + ConnectionFlavor, HttpTransport, }, protocol::QueryId, + sync::Arc, }; -async fn handler( - transport: Extension, +async fn handler( + transport: Extension>>, Path(query_id): Path, ) -> Result, Error> { let req = Request { query_id }; - match transport.dispatch(req, BodyStream::empty()).await { + match Arc::clone(&transport) + .dispatch(req, BodyStream::empty()) + .await + { Ok(state) => Ok(Json(status::ResponseBody::from(state))), Err(e) => Err(Error::application(StatusCode::INTERNAL_SERVER_ERROR, e)), } } -pub fn router(transport: MpcHttpTransport) -> Router { +pub fn router(transport: Arc>) -> Router { Router::new() - .route(status::AXUM_PATH, get(handler)) + .route(status::AXUM_PATH, get(handler::)) .layer(Extension(transport)) } diff --git a/ipa-core/src/net/server/handlers/query/status_match.rs b/ipa-core/src/net/server/handlers/query/status_match.rs deleted file mode 100644 index 5b2081c5e..000000000 --- a/ipa-core/src/net/server/handlers/query/status_match.rs +++ /dev/null @@ -1,227 +0,0 @@ -use axum::{ - extract::{Path, Query}, - routing::get, - Extension, Router, -}; -use hyper::StatusCode; - -use crate::{ - helpers::{query::CompareStatusRequest, ApiError, BodyStream}, - net::{ - http_serde::query::status_match::{ - StatusQueryString, {self}, - }, - server::Error, - HttpTransport, Shard, - }, - protocol::QueryId, - query::QueryStatusError, - sync::Arc, -}; - -async fn handler( - transport: Extension>>, - Path(query_id): Path, - Query(StatusQueryString { status }): Query, -) -> Result<(), Error> { - let req = CompareStatusRequest { query_id, status }; - match Arc::clone(&transport) - .dispatch(req, BodyStream::empty()) - .await - { - Ok(_) => Ok(()), - Err(ApiError::QueryStatus(QueryStatusError::DifferentStatus { my_status, .. })) => { - Err(crate::net::error::ShardQueryStatusMismatchError { actual: my_status }.into()) - } - Err(e) => Err(Error::application(StatusCode::INTERNAL_SERVER_ERROR, e)), - } -} - -pub fn router(transport: Arc>) -> Router { - Router::new() - .route(status_match::AXUM_PATH, get(handler)) - .layer(Extension(transport)) -} - -#[cfg(all(test, unit_test))] -mod tests { - use std::{borrow::Borrow, sync::Arc}; - - use axum::{ - body::Body, - http::uri::{Authority, Scheme}, - }; - use hyper::StatusCode; - - use crate::{ - helpers::{ - make_owned_handler, - query::CompareStatusRequest, - routing::{Addr, RouteId}, - ApiError, BodyStream, HelperResponse, RequestHandler, - }, - net::{ - error::ShardQueryStatusMismatchError, - http_serde::query::status_match::try_into_http_request, - server::ClientIdentity, - test::{TestServer, TestServerBuilder}, - Error, Shard, - }, - protocol::QueryId, - query::{QueryStatus, QueryStatusError}, - sharding::ShardIndex, - }; - - fn for_status(status: QueryStatus) -> CompareStatusRequest { - CompareStatusRequest { - query_id: QueryId, - status, - } - } - - fn http_request>(req: B) -> hyper::Request { - try_into_http_request( - req.borrow(), - Scheme::HTTP, - Authority::from_static("localhost"), - ) - .unwrap() - } - - fn authenticated(mut req: hyper::Request) -> hyper::Request { - req.extensions_mut() - .insert(ClientIdentity(ShardIndex::from(2))); - req - } - - fn handler_status_match(expected_status: QueryStatus) -> Arc> { - make_owned_handler( - move |addr: Addr, _data: BodyStream| async move { - let RouteId::QueryStatus = addr.route else { - panic!("unexpected call"); - }; - let req = addr.into::().unwrap(); - assert_eq!(req.query_id, QueryId); - assert_eq!(req.status, expected_status); - Ok(HelperResponse::ok()) - }, - ) - } - - fn handler_status_mismatch( - expected_status: QueryStatus, - ) -> Arc> { - assert_ne!(expected_status, QueryStatus::Running); - - make_owned_handler( - move |addr: Addr, _data: BodyStream| async move { - let RouteId::QueryStatus = addr.route else { - panic!("unexpected call"); - }; - let req = addr.into::().unwrap(); - assert_eq!(req.query_id, QueryId); - Err(ApiError::QueryStatus(QueryStatusError::DifferentStatus { - query_id: QueryId, - my_status: QueryStatus::Running, - other_status: expected_status, - })) - }, - ) - } - - #[tokio::test] - async fn status_success() { - let expected_status = QueryStatus::Running; - let req = authenticated(http_request(for_status(expected_status))); - - TestServer::::oneshot_success(req, handler_status_match(expected_status)).await; - } - - #[tokio::test] - async fn status_client_success() { - let expected_status = QueryStatus::Running; - let test_server = TestServerBuilder::::default() - .with_request_handler(handler_status_match(expected_status)) - .build() - .await; - - test_server - .client - .status_match(for_status(expected_status)) - .await - .unwrap(); - } - - #[tokio::test] - async fn status_client_mismatch() { - let diff_status = QueryStatus::Preparing; - let test_server = TestServerBuilder::::default() - .with_request_handler(handler_status_mismatch(diff_status)) - .build() - .await; - let e = test_server - .client - .status_match(for_status(diff_status)) - .await - .unwrap_err(); - assert!(matches!( - e, - Error::ShardQueryStatusMismatch { - error: ShardQueryStatusMismatchError { - actual: QueryStatus::Running - }, - } - )); - } - - #[tokio::test] - async fn status_mismatch() { - let req_status = QueryStatus::Completed; - let handler = handler_status_mismatch(req_status); - let req = authenticated(http_request(for_status(req_status))); - - let resp = TestServer::::oneshot(req, handler).await; - assert_eq!(StatusCode::PRECONDITION_FAILED, resp.status()); - } - - #[tokio::test] - async fn other_query_error() { - let handler = make_owned_handler( - move |_addr: Addr, _data: BodyStream| async move { - Err(ApiError::QueryStatus(QueryStatusError::NoSuchQuery( - QueryId, - ))) - }, - ); - let req = authenticated(http_request(for_status(QueryStatus::Running))); - - let resp = TestServer::::oneshot(req, handler).await; - assert_eq!(StatusCode::INTERNAL_SERVER_ERROR, resp.status()); - } - - #[tokio::test] - async fn unauthenticated() { - assert_eq!( - StatusCode::UNAUTHORIZED, - TestServer::::oneshot( - http_request(for_status(QueryStatus::Running)), - make_owned_handler(|_, _| async move { unimplemented!() }), - ) - .await - .status() - ); - } - - #[tokio::test] - async fn server_error() { - assert_eq!( - StatusCode::INTERNAL_SERVER_ERROR, - TestServer::::oneshot( - authenticated(http_request(for_status(QueryStatus::Running))), - make_owned_handler(|_, _| async move { Err(ApiError::BadRequest("".into())) }), - ) - .await - .status() - ); - } -} diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index 599024271..7e3108faf 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -79,7 +79,7 @@ impl HttpTransport { dest: F::Identity, route: R, data: D, - ) -> Result<(), Error> + ) -> Result, Error> where Option: From, Option: From, @@ -100,20 +100,24 @@ impl HttpTransport { self.http_runtime .spawn(resp_future.map_err(Into::into).and_then(resp_ok)) .await?; - Ok(()) + Ok(None) } RouteId::PrepareQuery => { let req = serde_json::from_str(route.extra().borrow()).unwrap(); - self.clients[client_ix].prepare_query(req).await + self.clients[client_ix].prepare_query(req).await?; + Ok(None) } RouteId::CompleteQuery => { let query_id = >::from(route.query_id()) .expect("query_id is required to call complete query API"); - self.clients[client_ix].complete_query(query_id).await + self.clients[client_ix].complete_query(query_id).await?; + Ok(None) } RouteId::QueryStatus => { - let req = serde_json::from_str(route.extra().borrow())?; - self.clients[client_ix].status_match(req).await + let query_id = >::from(route.query_id()) + .expect("query_id is required to call complete query API"); + let response = self.clients[client_ix].query_status_bytes(query_id).await?; + Ok(Some(response)) } evt @ (RouteId::QueryInput | RouteId::ReceiveQuery @@ -273,6 +277,7 @@ impl MpcHttpTransport { impl Transport for MpcHttpTransport { type Identity = HelperIdentity; type RecordsStream = ReceiveRecords; + type SendResponse = BodyStream; type Error = Error; fn identity(&self) -> Self::Identity { @@ -300,7 +305,7 @@ impl Transport for MpcHttpTransport { dest: Self::Identity, route: R, data: D, - ) -> Result<(), Error> + ) -> Result, Error> where Option: From, Option: From, @@ -353,6 +358,7 @@ impl ShardHttpTransport { impl Transport for ShardHttpTransport { type Identity = ShardIndex; type RecordsStream = ReceiveRecords; + type SendResponse = BodyStream; type Error = ShardError; fn identity(&self) -> Self::Identity { @@ -373,7 +379,7 @@ impl Transport for ShardHttpTransport { dest: Self::Identity, route: R, data: D, - ) -> Result<(), Self::Error> + ) -> Result, Self::Error> where Option: From, Option: From, @@ -427,6 +433,7 @@ mod tests { client::ClientIdentity, test::{TestConfig, TestConfigBuilder, TestServer}, }, + query::QueryStatus, secret_sharing::{replicated::semi_honest::AdditiveShare, IntoShares}, test_fixture::Reconstruct, HelperApp, @@ -636,6 +643,11 @@ mod tests { .collect::>() }); + assert_eq!( + leader_client.query_status(QueryId).await.unwrap(), + QueryStatus::AwaitingInputs + ); + let _ = try_join_all(helper_shares.into_iter().enumerate().map( |(helper, shard_streams)| async move { @@ -653,6 +665,11 @@ mod tests { .await .unwrap(); + assert_eq!( + leader_client.query_status(QueryId).await.unwrap(), + QueryStatus::Running + ); + let result: [_; 3] = join_all(leader_ring_clients.each_ref().map(|client| async move { let r = client.query_results(query_id).await.unwrap(); AdditiveShare::::from_byte_slice_unchecked(&r).collect::>() diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index 95cfa0f44..1e0ff8ff1 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -11,10 +11,10 @@ use crate::{ error::Error as ProtocolError, executor::IpaRuntime, helpers::{ - query::{CompareStatusRequest, PrepareQuery, QueryConfig}, + query::{PrepareQuery, QueryConfig}, routing::RouteId, - BodyStream, BroadcastError, Gateway, GatewayConfig, MpcTransportError, MpcTransportImpl, - Role, RoleAssignment, ShardTransportError, ShardTransportImpl, Transport, + BodyStream, BroadcastError, Gateway, GatewayConfig, HelperResponse, MpcTransportError, + MpcTransportImpl, Role, RoleAssignment, ShardTransportError, ShardTransportImpl, Transport, }, hpke::{KeyRegistry, PrivateKeyOnly}, protocol::QueryId, @@ -118,11 +118,12 @@ pub enum QueryStatusError { NotLeader(ShardIndex), #[error("This is the leader shard")] Leader, - #[error("My status {my_status:?} for query {query_id:?} differs from {other_status:?}")] - DifferentStatus { - query_id: QueryId, - my_status: QueryStatus, - other_status: QueryStatus, + #[error("No response from shard {0:?}")] + NoResponse(ShardIndex), + #[error(transparent)] + UnexpectedResponse { + #[from] + source: crate::error::BoxError, }, } @@ -354,48 +355,6 @@ impl Processor { Some(status) } - /// This helper function is used to transform a [`BoxError`] into a - /// [`QueryStatusError::DifferentStatus`] and retrieve it's internal state. Returns [`None`] - /// if not possible. - #[cfg(feature = "in-memory-infra")] - fn downcast_state_error(box_error: &crate::error::BoxError) -> Option { - use crate::helpers::ApiError; - let api_error = box_error.downcast_ref::(); - if let Some(ApiError::QueryStatus(QueryStatusError::DifferentStatus { - my_status, .. - })) = api_error - { - return Some(*my_status); - } - None - } - - /// This helper is used by the in-memory stack to obtain the state of other shards via a - /// [`QueryStatusError::DifferentStatus`] error. - /// TODO: Ideally broadcast should return a value, that we could use to parse the state instead - /// of relying on errors. - #[cfg(feature = "in-memory-infra")] - fn get_state_from_error( - error: &crate::helpers::InMemoryTransportError, - ) -> Option { - if let crate::helpers::InMemoryTransportError::Rejected { inner, .. } = error { - return Self::downcast_state_error(inner); - } - None - } - - /// This helper is used by the HTTP stack to obtain the state of other shards via a - /// [`QueryStatusError::DifferentStatus`] error. - /// TODO: Ideally broadcast should return a value, that we could use to parse the state instead - /// of relying on errors. - #[cfg(feature = "real-world-infra")] - fn get_state_from_error(shard_error: &crate::net::ShardError) -> Option { - if let crate::net::Error::ShardQueryStatusMismatch { error, .. } = &shard_error.source { - return Some(error.actual); - } - None - } - /// Returns the query status in this helper, by querying all shards. /// /// ## Errors @@ -412,29 +371,26 @@ impl Processor { if shard_index != ShardIndex::FIRST { return Err(QueryStatusError::NotLeader(shard_index)); } - let mut status = self .get_status(query_id) .ok_or(QueryStatusError::NoSuchQuery(query_id))?; - let shard_query_status_req = CompareStatusRequest { query_id, status }; - - let shard_responses = shard_transport.broadcast(shard_query_status_req).await; - if let Err(e) = shard_responses { - for (shard, failure) in &e.failures { - if let Some(other) = Self::get_state_from_error(failure) { - status = min_status(status, other); - } else { - tracing::error!("failed to get status from shard {shard}: {failure:?}"); - return Err(e.into()); - } + let shard_responses = shard_transport + .broadcast((RouteId::QueryStatus, query_id)) + .await?; + for (i, o) in shard_responses { + if o.is_none() { + return Err(QueryStatusError::NoResponse(i)); } + let r = HelperResponse::from_bytesstream(o.unwrap()).await?; + let other = QueryStatus::from(r); + status = min_status(status, other); } Ok(status) } - /// Compares this shard status against the given type. Returns an error if different. + /// Returns the status of this shard for a query. /// /// ## Errors /// If query is not registered on this helper or @@ -444,22 +400,16 @@ impl Processor { pub fn shard_status( &self, shard_transport: &ShardTransportImpl, - req: &CompareStatusRequest, + query_id: QueryId, ) -> Result { let shard_index = shard_transport.identity(); if shard_index == ShardIndex::FIRST { return Err(QueryStatusError::Leader); } let status = self - .get_status(req.query_id) - .ok_or(QueryStatusError::NoSuchQuery(req.query_id))?; - if req.status != status { - return Err(QueryStatusError::DifferentStatus { - query_id: req.query_id, - my_status: status, - other_status: req.status, - }); - } + .get_status(query_id) + .ok_or(QueryStatusError::NoSuchQuery(query_id))?; + Ok(status) } @@ -593,7 +543,7 @@ mod tests { } fn shard_respond_ok(_si: ShardIndex) -> Arc> { - create_handler(|_| async { Ok(HelperResponse::ok()) }) + create_handler(|_| async { Ok(HelperResponse::from(QueryStatus::Completed)) }) } fn test_multiply_config() -> QueryConfig { @@ -1127,7 +1077,7 @@ mod tests { mod query_status { use super::*; - use crate::{helpers::query::CompareStatusRequest, protocol::QueryId}; + use crate::protocol::QueryId; /// * From the standpoint of leader shard in Helper 1 /// * On query_status @@ -1138,25 +1088,13 @@ mod tests { #[tokio::test] async fn combined_status_response() { fn shard_handle(si: ShardIndex) -> Arc> { - const FOURTH_SHARD: ShardIndex = ShardIndex::from_u32(3); - const THIRD_SHARD: ShardIndex = ShardIndex::from_u32(2); create_handler(move |_| async move { + const FOURTH_SHARD: ShardIndex = ShardIndex::from_u32(3); + const THIRD_SHARD: ShardIndex = ShardIndex::from_u32(2); match si { - FOURTH_SHARD => { - Err(ApiError::QueryStatus(QueryStatusError::DifferentStatus { - query_id: QueryId, - my_status: QueryStatus::Completed, - other_status: QueryStatus::Preparing, - })) - } - THIRD_SHARD => { - Err(ApiError::QueryStatus(QueryStatusError::DifferentStatus { - query_id: QueryId, - my_status: QueryStatus::Running, - other_status: QueryStatus::Preparing, - })) - } - _ => Ok(HelperResponse::ok()), + FOURTH_SHARD => Ok(HelperResponse::from(QueryStatus::Completed)), + THIRD_SHARD => Ok(HelperResponse::from(QueryStatus::Running)), + _ => Ok(HelperResponse::from(QueryStatus::AwaitingInputs)), } }) } @@ -1173,7 +1111,7 @@ mod tests { t.processor .prepare_shard( &t.shard_network - .transport(HelperIdentity::ONE, ShardIndex::from(1)), + .transport(HelperIdentity::ONE, ShardIndex::from_u32(1)), req, ) .unwrap(); @@ -1206,13 +1144,9 @@ mod tests { QueryId, ))) } else if si == ShardIndex::from(2) { - Err(ApiError::QueryStatus(QueryStatusError::DifferentStatus { - query_id: QueryId, - my_status: QueryStatus::Running, - other_status: QueryStatus::Preparing, - })) + Ok(HelperResponse::from(QueryStatus::Running)) } else { - Ok(HelperResponse::ok()) + Ok(HelperResponse::from(QueryStatus::AwaitingInputs)) } }) } @@ -1266,17 +1200,13 @@ mod tests { /// call. Only non-leaders (1,2,3...) should handle those calls. #[tokio::test] async fn shard_not_leader() { - let req = CompareStatusRequest { - query_id: QueryId, - status: QueryStatus::Running, - }; let t = TestComponents::new(TestComponentsArgs::default()); assert!(matches!( t.processor .shard_status( &t.shard_network .transport(HelperIdentity::TWO, ShardIndex::FIRST), - &req + QueryId ) .unwrap_err(), QueryStatusError::Leader