diff --git a/Cargo.lock b/Cargo.lock index 33cbe08dd..25d93f609 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4357,7 +4357,6 @@ version = "0.1.0" dependencies = [ "datafusion", "datasets-common", - "js-runtime", "schemars 1.2.1", "serde", "serde_json", diff --git a/crates/core/common/src/dataset_schema_provider.rs b/crates/core/common/src/dataset_schema_provider.rs index 0e29a2988..99fdbe0a4 100644 --- a/crates/core/common/src/dataset_schema_provider.rs +++ b/crates/core/common/src/dataset_schema_provider.rs @@ -17,11 +17,11 @@ use datafusion::{ TableProvider, }, error::DataFusionError, - logical_expr::ScalarUDF, + logical_expr::{ScalarUDF, async_udf::AsyncScalarUDF}, }; use datasets_common::{dataset::Dataset, table_name::TableName}; use datasets_derived::{dataset::Dataset as DerivedDataset, func_name::ETH_CALL_FUNCTION_NAME}; -use js_runtime::isolate_pool::IsolatePool; +use js_runtime::{isolate_pool::IsolatePool, js_udf::JsUdf}; use parking_lot::RwLock; use crate::{ @@ -177,7 +177,22 @@ impl FuncSchemaProvider for DatasetSchemaProvider { // Try to get UDF from derived dataset let udf = self.dataset.downcast_ref::().and_then(|d| { - d.function_by_name(self.schema_name.clone(), name, self.isolate_pool.clone()) + d.function_by_name(name).map(|function| { + AsyncScalarUDF::new(Arc::new(JsUdf::new( + self.isolate_pool.clone(), + self.schema_name.clone(), + function.source.source.clone(), + function.source.filename.clone(), + Arc::from(name), + function + .input_types + .iter() + .map(|dt| dt.clone().into_arrow()) + .collect(), + function.output_type.clone().into_arrow(), + ))) + .into_scalar_udf() + }) }); if let Some(udf) = udf { diff --git a/crates/core/common/src/self_schema_provider.rs b/crates/core/common/src/self_schema_provider.rs index 62da84591..75ffe7a47 100644 --- a/crates/core/common/src/self_schema_provider.rs +++ b/crates/core/common/src/self_schema_provider.rs @@ -15,7 +15,7 @@ use datafusion::{ logical_expr::{ScalarUDF, async_udf::AsyncScalarUDF}, }; use datasets_common::table_name::TableName; -use datasets_derived::{deps::SELF_REF_KEYWORD, func_name::FuncName, manifest::Function}; +use datasets_derived::{deps::SELF_REF_KEYWORD, func_name::FuncName, function::Function}; use js_runtime::{isolate_pool::IsolatePool, js_udf::JsUdf}; use parking_lot::RwLock; @@ -57,35 +57,34 @@ impl SelfSchemaProvider { &self.udfs } - /// Creates a provider from manifest function definitions (no tables). + /// Creates a provider from manifest functions (no tables). /// - /// Builds UDFs from all manifest functions. + /// Functions are already validated at deserialization time. pub fn from_manifest_udfs( - schema_name: String, isolate_pool: IsolatePool, - manifest_udfs: &BTreeMap, + functions: &BTreeMap, ) -> Self { - let udfs: Vec = manifest_udfs + let scalar_udfs: Vec = functions .iter() - .map(|(func_name, func_def)| { + .map(|(name, function)| { AsyncScalarUDF::new(Arc::new(JsUdf::new( isolate_pool.clone(), Some(SELF_REF_KEYWORD.to_string()), - func_def.source.source.clone(), - func_def.source.filename.clone().into(), - Arc::from(func_name.as_str()), - func_def + function.source.source.clone(), + function.source.filename.clone(), + Arc::from(name.as_str()), + function .input_types .iter() .map(|dt| dt.clone().into_arrow()) .collect(), - func_def.output_type.clone().into_arrow(), + function.output_type.clone().into_arrow(), ))) .into_scalar_udf() }) .collect(); - Self::new(schema_name, vec![], udfs) + Self::new(SELF_REF_KEYWORD.to_string(), vec![], scalar_udfs) } } diff --git a/crates/core/datasets-common/src/manifest.rs b/crates/core/datasets-common/src/manifest.rs index 3982b59c0..a07615580 100644 --- a/crates/core/datasets-common/src/manifest.rs +++ b/crates/core/datasets-common/src/manifest.rs @@ -173,29 +173,3 @@ pub struct Field { /// Whether the field can contain null values pub nullable: bool, } - -/// User-defined function specification. -/// -/// Defines a custom function with input/output types and implementation source. -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] -#[serde(rename_all = "camelCase")] -pub struct Function { - // TODO: Support SQL type names, see https://datafusion.apache.org/user-guide/sql/data_types.html - /// Arrow data types for function input parameters - pub input_types: Vec, - /// Arrow data type for function return value - pub output_type: DataType, - /// Function implementation source code and metadata - pub source: FunctionSource, -} - -/// Source code and metadata for a user-defined function. -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] -pub struct FunctionSource { - /// Function implementation source code - pub source: Arc, - /// Filename where the function is defined - pub filename: String, -} diff --git a/crates/core/datasets-derived/Cargo.toml b/crates/core/datasets-derived/Cargo.toml index e04971523..5af0d47d1 100644 --- a/crates/core/datasets-derived/Cargo.toml +++ b/crates/core/datasets-derived/Cargo.toml @@ -10,9 +10,11 @@ schemars = ["dep:schemars", "datasets-common/schemars", "dep:serde_json"] [dependencies] datafusion.workspace = true datasets-common = { path = "../datasets-common" } -js-runtime = { path = "../js-runtime" } schemars = { workspace = true, optional = true } serde.workspace = true serde_json = {workspace = true, optional = true} thiserror.workspace = true +[dev-dependencies] +serde_json.workspace = true + diff --git a/crates/core/datasets-derived/src/dataset.rs b/crates/core/datasets-derived/src/dataset.rs index c53f68205..361bd25d4 100644 --- a/crates/core/datasets-derived/src/dataset.rs +++ b/crates/core/datasets-derived/src/dataset.rs @@ -1,19 +1,16 @@ -use std::{collections::BTreeMap, sync::Arc}; +use std::collections::BTreeMap; -use datafusion::{ - logical_expr::{ScalarUDF, async_udf::AsyncScalarUDF}, - sql::parser, -}; +use datafusion::sql::parser; use datasets_common::{ block_num::BlockNum, dataset::Table, dataset_kind_str::DatasetKindStr, hash_reference::HashReference, table_name::TableName, }; -use js_runtime::{isolate_pool::IsolatePool, js_udf::JsUdf}; use crate::{ DerivedDatasetKind, Manifest, deps::{DepAlias, DepReference}, - function::{Function, FunctionSource}, + func_name::FuncName, + function::Function, manifest::TableInput, sql::{ResolveTableReferencesError, TableReference, resolve_table_references}, }; @@ -46,28 +43,13 @@ pub fn dataset(reference: HashReference, manifest: Manifest) -> Result, tables: Vec, - functions: Vec, + functions: BTreeMap, finalized_blocks_only: bool, } @@ -92,7 +74,7 @@ impl Dataset { kind: DerivedDatasetKind, finalized_blocks_only: bool, tables: Vec
, - functions: Vec, + functions: BTreeMap, ) -> Self { Self { reference, @@ -112,30 +94,11 @@ impl Dataset { &self.dependencies } - /// Looks up a user-defined function by name. - /// - /// Returns the [`ScalarUDF`] for the function if found. This is used - /// for derived datasets that define custom JavaScript functions. + /// Looks up a function by name. /// - /// Returns `None` if the function name is not found. - pub fn function_by_name( - &self, - schema: String, - name: &str, - isolate_pool: IsolatePool, - ) -> Option { - self.functions.iter().find(|f| f.name == name).map(|f| { - AsyncScalarUDF::new(Arc::new(JsUdf::new( - isolate_pool, - schema, - f.source.source.clone(), - f.source.filename.clone().into(), - f.name.clone().into(), - f.input_types.clone(), - f.output_type.clone(), - ))) - .into_scalar_udf() - }) + /// Returns the [`Function`] definition if found. + pub fn function_by_name(&self, name: &str) -> Option<&Function> { + self.functions.get(name) } } diff --git a/crates/core/datasets-derived/src/func_name.rs b/crates/core/datasets-derived/src/func_name.rs index eb92357e2..7be873a62 100644 --- a/crates/core/datasets-derived/src/func_name.rs +++ b/crates/core/datasets-derived/src/func_name.rs @@ -58,6 +58,12 @@ impl AsRef for FuncName { } } +impl std::borrow::Borrow for FuncName { + fn borrow(&self) -> &str { + &self.0 + } +} + impl std::ops::Deref for FuncName { type Target = str; diff --git a/crates/core/datasets-derived/src/function.rs b/crates/core/datasets-derived/src/function.rs index dd4a8cc2a..bd1b61fc7 100644 --- a/crates/core/datasets-derived/src/function.rs +++ b/crates/core/datasets-derived/src/function.rs @@ -1,21 +1,23 @@ -//! User-defined function types for logical catalog. +//! User-defined function types for derived datasets. //! -//! This module provides function representations for the logical catalog, -//! which include the function name field (unlike manifest types where the name is the map key). +//! This module provides function representations used for derived datasets, +//! with Arrow type validation performed during deserialization so that a +//! successfully deserialized `Function` is always valid for the JS UDF runtime. use std::sync::Arc; -use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::datatypes::DataType as ArrowDataType; +use datasets_common::manifest::DataType; -/// User-defined function specification for logical catalog. +/// User-defined function specification. /// -/// This type includes the function name and is used in the logical catalog -/// representation after manifest deserialization. -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +/// Defines a custom function with input/output types and implementation source. +/// Arrow type validation is performed during deserialization — a successfully +/// deserialized `Function` is guaranteed to have JS UDF-compatible types. +#[derive(Debug, Clone, serde::Serialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[serde(rename_all = "camelCase")] pub struct Function { - /// Function name - pub name: String, - // TODO: Support SQL type names, see https://datafusion.apache.org/user-guide/sql/data_types.html /// Arrow data types for function input parameters pub input_types: Vec, @@ -25,11 +27,681 @@ pub struct Function { pub source: FunctionSource, } -/// Source code and metadata for a user-defined function in logical catalog. +impl<'de> serde::Deserialize<'de> for Function { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + // Internal deserialization struct to perform Arrow type validation + #[derive(serde::Deserialize)] + #[serde(rename_all = "camelCase")] + struct Inner { + input_types: Vec, + output_type: DataType, + source: FunctionSource, + } + + let Inner { + input_types, + output_type, + source, + } = serde::Deserialize::deserialize(deserializer)?; + + // Validate that all input types are supported + for (index, input_type) in input_types.iter().enumerate() { + validate_js_udf_input_type(input_type.as_arrow()).map_err(|err| { + serde::de::Error::custom(format!( + "input parameter at index {index} uses an unsupported type: {err}" + )) + })?; + } + + // Validate that the output type is supported + validate_js_udf_output_type(output_type.as_arrow()).map_err(|err| { + serde::de::Error::custom(format!("output uses an unsupported type: {err}")) + })?; + + Ok(Self { + input_types, + output_type, + source, + }) + } +} + +/// Source code and metadata for a user-defined function. #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] pub struct FunctionSource { /// Function implementation source code pub source: Arc, /// Filename where the function is defined - pub filename: String, + pub filename: Arc, +} + +/// Errors from validating Arrow types for the JS UDF runtime. +#[derive(Debug, thiserror::Error)] +pub enum JsUdfTypeError { + /// The Arrow data type is not supported by the JS UDF runtime. + #[error("Arrow data type '{0:?}' is not supported by the JS UDF runtime")] + UnsupportedType(ArrowDataType), + + /// Decimal type has a fractional scale, which the JS UDF runtime cannot represent. + #[error( + "Decimal type with scale {scale} is not supported; \ + only scale 0 (integer decimals) can be converted to BigInt" + )] + FractionalDecimal { + /// The non-zero scale value + scale: i8, + }, + + /// A field within a Struct has an unsupported type. + #[error("field '{name}' has an unsupported type: {source}")] + UnsupportedFieldType { + /// Field name + name: String, + /// The underlying type error + #[source] + source: Box, + }, + + /// The output type is not supported by the JS UDF runtime's FromV8 conversion. + #[error("type '{0:?}' cannot be converted from JavaScript back to Arrow")] + UnsupportedOutputType(ArrowDataType), + + /// A list element type is unsupported. + #[error("list element has an unsupported type: {0}")] + UnsupportedListElement(#[source] Box), +} + +/// Validates an Arrow data type for use as a JS UDF input parameter. +/// +/// Accepts types that the js-runtime `ToV8` implementation can convert +/// from Arrow `ScalarValue` to V8 JavaScript values. +pub fn validate_js_udf_input_type(dt: &ArrowDataType) -> Result<(), JsUdfTypeError> { + match dt { + // Primitives + ArrowDataType::Null + | ArrowDataType::Boolean + | ArrowDataType::Int8 + | ArrowDataType::Int16 + | ArrowDataType::Int32 + | ArrowDataType::Int64 + | ArrowDataType::UInt8 + | ArrowDataType::UInt16 + | ArrowDataType::UInt32 + | ArrowDataType::UInt64 + | ArrowDataType::Float32 + | ArrowDataType::Float64 => Ok(()), + + // String types + ArrowDataType::Utf8 | ArrowDataType::Utf8View | ArrowDataType::LargeUtf8 => Ok(()), + + // Binary types + ArrowDataType::Binary + | ArrowDataType::BinaryView + | ArrowDataType::LargeBinary + | ArrowDataType::FixedSizeBinary(_) => Ok(()), + + // Decimal — only scale 0 (integer decimals that map to BigInt) + ArrowDataType::Decimal128(_, 0) | ArrowDataType::Decimal256(_, 0) => Ok(()), + ArrowDataType::Decimal128(_, scale) | ArrowDataType::Decimal256(_, scale) => { + Err(JsUdfTypeError::FractionalDecimal { scale: *scale }) + } + + // Struct — validate each field recursively + ArrowDataType::Struct(fields) => { + for field in fields.iter() { + validate_js_udf_input_type(field.data_type()).map_err(|source| { + JsUdfTypeError::UnsupportedFieldType { + name: field.name().clone(), + source: Box::new(source), + } + })?; + } + Ok(()) + } + + // List types — validate element type recursively + ArrowDataType::List(field) + | ArrowDataType::LargeList(field) + | ArrowDataType::FixedSizeList(field, _) => validate_js_udf_input_type(field.data_type()) + .map_err(|e| JsUdfTypeError::UnsupportedListElement(Box::new(e))), + + other => Err(JsUdfTypeError::UnsupportedType(other.clone())), + } +} + +/// Validates an Arrow data type for use as a JS UDF output type. +/// +/// Rejects types that the js-runtime `FromV8` implementation cannot convert +/// from V8 JavaScript values back to Arrow `ScalarValue`. List and binary +/// types are input-only (JS Array/TypedArray cannot be converted back). +pub fn validate_js_udf_output_type(dt: &ArrowDataType) -> Result<(), JsUdfTypeError> { + match dt { + // List types cannot be converted back from JS + ArrowDataType::List(_) + | ArrowDataType::LargeList(_) + | ArrowDataType::FixedSizeList(_, _) => { + Err(JsUdfTypeError::UnsupportedOutputType(dt.clone())) + } + + // Binary types cannot be converted back from JS + ArrowDataType::Binary + | ArrowDataType::BinaryView + | ArrowDataType::LargeBinary + | ArrowDataType::FixedSizeBinary(_) => { + Err(JsUdfTypeError::UnsupportedOutputType(dt.clone())) + } + + // Struct — validate each field recursively with output rules + ArrowDataType::Struct(fields) => { + for field in fields.iter() { + validate_js_udf_output_type(field.data_type()).map_err(|source| { + JsUdfTypeError::UnsupportedFieldType { + name: field.name().clone(), + source: Box::new(source), + } + })?; + } + Ok(()) + } + + // Everything else: delegate to input validator + other => validate_js_udf_input_type(other), + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use datafusion::arrow::datatypes::{DataType, Field, Fields}; + + use super::*; + + #[test] + fn validate_js_udf_input_type_with_primitive_scalars_succeeds() { + // Null + assert!( + validate_js_udf_input_type(&DataType::Null).is_ok(), + "Null should be accepted" + ); + + // Boolean + assert!( + validate_js_udf_input_type(&DataType::Boolean).is_ok(), + "Boolean should be accepted" + ); + + // Integer types + assert!( + validate_js_udf_input_type(&DataType::Int8).is_ok(), + "Int8 should be accepted" + ); + assert!( + validate_js_udf_input_type(&DataType::Int16).is_ok(), + "Int16 should be accepted" + ); + assert!( + validate_js_udf_input_type(&DataType::Int32).is_ok(), + "Int32 should be accepted" + ); + assert!( + validate_js_udf_input_type(&DataType::Int64).is_ok(), + "Int64 should be accepted" + ); + assert!( + validate_js_udf_input_type(&DataType::UInt8).is_ok(), + "UInt8 should be accepted" + ); + assert!( + validate_js_udf_input_type(&DataType::UInt16).is_ok(), + "UInt16 should be accepted" + ); + assert!( + validate_js_udf_input_type(&DataType::UInt32).is_ok(), + "UInt32 should be accepted" + ); + assert!( + validate_js_udf_input_type(&DataType::UInt64).is_ok(), + "UInt64 should be accepted" + ); + + // Float types + assert!( + validate_js_udf_input_type(&DataType::Float32).is_ok(), + "Float32 should be accepted" + ); + assert!( + validate_js_udf_input_type(&DataType::Float64).is_ok(), + "Float64 should be accepted" + ); + + // String types + assert!( + validate_js_udf_input_type(&DataType::Utf8).is_ok(), + "Utf8 should be accepted" + ); + assert!( + validate_js_udf_input_type(&DataType::Utf8View).is_ok(), + "Utf8View should be accepted" + ); + assert!( + validate_js_udf_input_type(&DataType::LargeUtf8).is_ok(), + "LargeUtf8 should be accepted" + ); + + // Binary types + assert!( + validate_js_udf_input_type(&DataType::Binary).is_ok(), + "Binary should be accepted" + ); + assert!( + validate_js_udf_input_type(&DataType::BinaryView).is_ok(), + "BinaryView should be accepted" + ); + assert!( + validate_js_udf_input_type(&DataType::LargeBinary).is_ok(), + "LargeBinary should be accepted" + ); + assert!( + validate_js_udf_input_type(&DataType::FixedSizeBinary(32)).is_ok(), + "FixedSizeBinary should be accepted" + ); + + // Decimal types (scale 0 only) + assert!( + validate_js_udf_input_type(&DataType::Decimal128(38, 0)).is_ok(), + "Decimal128 with scale 0 should be accepted" + ); + assert!( + validate_js_udf_input_type(&DataType::Decimal256(76, 0)).is_ok(), + "Decimal256 with scale 0 should be accepted" + ); + } + + #[test] + fn validate_js_udf_input_type_with_fractional_decimal_fails() { + //* Given + let dt = DataType::Decimal128(38, 2); + + //* When + let result = validate_js_udf_input_type(&dt); + + //* Then + let err = result.expect_err("Decimal128 with scale 2 should be rejected"); + assert!( + matches!(err, JsUdfTypeError::FractionalDecimal { scale: 2 }), + "expected FractionalDecimal, got {err:?}" + ); + } + + #[test] + fn validate_js_udf_input_type_with_valid_struct_succeeds() { + //* Given + let dt = DataType::Struct(Fields::from(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, true), + ])); + + //* When + let result = validate_js_udf_input_type(&dt); + + //* Then + assert!( + result.is_ok(), + "Struct with valid fields should be accepted" + ); + } + + #[test] + fn validate_js_udf_input_type_with_invalid_struct_field_fails() { + //* Given + let dt = DataType::Struct(Fields::from(vec![Field::new( + "bad", + DataType::Date32, + false, + )])); + + //* When + let result = validate_js_udf_input_type(&dt); + + //* Then + let err = result.expect_err("Struct with Date32 field should be rejected"); + assert!( + matches!(err, JsUdfTypeError::UnsupportedFieldType { ref name, .. } if name == "bad"), + "expected UnsupportedFieldType for 'bad', got {err:?}" + ); + } + + #[test] + fn validate_js_udf_input_type_with_list_types_succeeds() { + //* Given + let list = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + let large = DataType::LargeList(Arc::new(Field::new("item", DataType::Utf8, true))); + let fixed = + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, true)), 3); + + //* Then + assert!( + validate_js_udf_input_type(&list).is_ok(), + "List should be accepted" + ); + assert!( + validate_js_udf_input_type(&large).is_ok(), + "LargeList should be accepted" + ); + assert!( + validate_js_udf_input_type(&fixed).is_ok(), + "FixedSizeList should be accepted" + ); + } + + #[test] + fn validate_js_udf_input_type_with_unsupported_list_element_fails() { + //* Given + let dt = DataType::List(Arc::new(Field::new("item", DataType::Date32, true))); + + //* When + let result = validate_js_udf_input_type(&dt); + + //* Then + let err = result.expect_err("List with Date32 element should be rejected"); + assert!( + matches!(err, JsUdfTypeError::UnsupportedListElement(_)), + "expected UnsupportedListElement, got {err:?}" + ); + } + + #[test] + fn validate_js_udf_input_type_with_map_fails() { + //* Given + let dt = DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Int32, true), + ])), + false, + )), + false, + ); + + //* When + let result = validate_js_udf_input_type(&dt); + + //* Then + let err = result.expect_err("Map should be rejected"); + assert!( + matches!(err, JsUdfTypeError::UnsupportedType(_)), + "expected UnsupportedType, got {err:?}" + ); + } + + #[test] + fn validate_js_udf_output_type_with_primitive_scalars_succeeds() { + // Null + assert!( + validate_js_udf_output_type(&DataType::Null).is_ok(), + "Null should be accepted" + ); + + // Boolean + assert!( + validate_js_udf_output_type(&DataType::Boolean).is_ok(), + "Boolean should be accepted" + ); + + // Integer types + assert!( + validate_js_udf_output_type(&DataType::Int8).is_ok(), + "Int8 should be accepted" + ); + assert!( + validate_js_udf_output_type(&DataType::Int16).is_ok(), + "Int16 should be accepted" + ); + assert!( + validate_js_udf_output_type(&DataType::Int32).is_ok(), + "Int32 should be accepted" + ); + assert!( + validate_js_udf_output_type(&DataType::Int64).is_ok(), + "Int64 should be accepted" + ); + assert!( + validate_js_udf_output_type(&DataType::UInt8).is_ok(), + "UInt8 should be accepted" + ); + assert!( + validate_js_udf_output_type(&DataType::UInt16).is_ok(), + "UInt16 should be accepted" + ); + assert!( + validate_js_udf_output_type(&DataType::UInt32).is_ok(), + "UInt32 should be accepted" + ); + assert!( + validate_js_udf_output_type(&DataType::UInt64).is_ok(), + "UInt64 should be accepted" + ); + + // Float types + assert!( + validate_js_udf_output_type(&DataType::Float32).is_ok(), + "Float32 should be accepted" + ); + assert!( + validate_js_udf_output_type(&DataType::Float64).is_ok(), + "Float64 should be accepted" + ); + + // String types + assert!( + validate_js_udf_output_type(&DataType::Utf8).is_ok(), + "Utf8 should be accepted" + ); + assert!( + validate_js_udf_output_type(&DataType::Utf8View).is_ok(), + "Utf8View should be accepted" + ); + assert!( + validate_js_udf_output_type(&DataType::LargeUtf8).is_ok(), + "LargeUtf8 should be accepted" + ); + + // Decimal types (scale 0 only) + assert!( + validate_js_udf_output_type(&DataType::Decimal128(38, 0)).is_ok(), + "Decimal128 with scale 0 should be accepted" + ); + assert!( + validate_js_udf_output_type(&DataType::Decimal256(76, 0)).is_ok(), + "Decimal256 with scale 0 should be accepted" + ); + } + + #[test] + fn validate_js_udf_output_type_with_fractional_decimal_fails() { + //* Given + let dt = DataType::Decimal128(38, 2); + + //* When + let result = validate_js_udf_output_type(&dt); + + //* Then + let err = result.expect_err("Decimal128 with scale 2 should be rejected"); + assert!( + matches!(err, JsUdfTypeError::FractionalDecimal { scale: 2 }), + "expected FractionalDecimal, got {err:?}" + ); + } + + #[test] + fn validate_js_udf_output_type_with_valid_struct_succeeds() { + //* Given + let dt = DataType::Struct(Fields::from(vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Utf8, true), + ])); + + //* When + let result = validate_js_udf_output_type(&dt); + + //* Then + assert!( + result.is_ok(), + "Struct with valid fields should be accepted as output" + ); + } + + #[test] + fn validate_js_udf_output_type_with_invalid_struct_field_fails() { + //* Given + let dt = DataType::Struct(Fields::from(vec![Field::new( + "bad", + DataType::Date32, + false, + )])); + + //* When + let result = validate_js_udf_output_type(&dt); + + //* Then + let err = result.expect_err("Struct with Date32 field should be rejected as output"); + assert!( + matches!(err, JsUdfTypeError::UnsupportedFieldType { ref name, .. } if name == "bad"), + "expected UnsupportedFieldType for 'bad', got {err:?}" + ); + } + + #[test] + fn validate_js_udf_output_type_with_list_types_fails() { + //* Given + let cases = [ + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + DataType::LargeList(Arc::new(Field::new("item", DataType::Int32, true))), + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 3), + ]; + + //* Then + for dt in &cases { + let err = validate_js_udf_output_type(dt) + .expect_err(&format!("{dt:?} should be rejected as output")); + assert!( + matches!(err, JsUdfTypeError::UnsupportedOutputType(_)), + "{dt:?}: expected UnsupportedOutputType, got {err:?}" + ); + } + } + + #[test] + fn validate_js_udf_output_type_with_binary_types_fails() { + //* Given + let cases = [ + DataType::Binary, + DataType::BinaryView, + DataType::LargeBinary, + DataType::FixedSizeBinary(32), + ]; + + //* Then + for dt in &cases { + let err = validate_js_udf_output_type(dt) + .expect_err(&format!("{dt:?} should be rejected as output")); + assert!( + matches!(err, JsUdfTypeError::UnsupportedOutputType(_)), + "{dt:?}: expected UnsupportedOutputType, got {err:?}" + ); + } + } + + #[test] + fn validate_js_udf_output_type_with_map_fails() { + //* Given + let dt = DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Int32, true), + ])), + false, + )), + false, + ); + + //* When + let result = validate_js_udf_output_type(&dt); + + //* Then + let err = result.expect_err("Map should be rejected"); + assert!( + matches!(err, JsUdfTypeError::UnsupportedType(_)), + "expected UnsupportedType, got {err:?}" + ); + } + + #[test] + fn deserialize_with_valid_types_succeeds() { + //* Given + let json = make_json(&["Int32", "Utf8"], "Boolean"); + + //* When + let result = serde_json::from_str::(&json); + + //* Then + let func = result.expect("should deserialize valid function"); + assert_eq!(func.input_types.len(), 2, "should have two input types"); + } + + #[test] + fn deserialize_with_unsupported_input_type_fails() { + //* Given + let json = make_json(&["Date32"], "Int32"); + + //* When + let result = serde_json::from_str::(&json); + + //* Then + let err = result.expect_err("should reject unsupported input type"); + let msg = err.to_string(); + assert!( + msg.contains("input parameter at index 0"), + "error should mention input index: {msg}" + ); + } + + #[test] + fn deserialize_with_unsupported_output_type_fails() { + //* Given + // Binary is valid as input but not as output + let json = make_json(&["Int32"], "Binary"); + + //* When + let result = serde_json::from_str::(&json); + + //* Then + let err = result.expect_err("should reject unsupported output type"); + let msg = err.to_string(); + assert!( + msg.contains("output uses an unsupported type"), + "error should mention output: {msg}" + ); + } + + fn make_json(input_types: &[&str], output_type: &str) -> String { + let inputs: Vec = input_types.iter().map(|t| format!("\"{t}\"")).collect(); + format!( + r#"{{ + "inputTypes": [{inputs}], + "outputType": "{output_type}", + "source": {{ "source": "function f() {{}}", "filename": "test.js" }} + }}"#, + inputs = inputs.join(", ") + ) + } } diff --git a/crates/core/datasets-derived/src/lib.rs b/crates/core/datasets-derived/src/lib.rs index aca6901c7..bf45ed97a 100644 --- a/crates/core/datasets-derived/src/lib.rs +++ b/crates/core/datasets-derived/src/lib.rs @@ -25,5 +25,7 @@ pub mod sql_str; pub use self::{ dataset::Dataset, dataset_kind::{DerivedDatasetKind, DerivedDatasetKindError}, + func_name::FuncName, + function::{Function, FunctionSource}, manifest::Manifest, }; diff --git a/crates/core/datasets-derived/src/manifest.rs b/crates/core/datasets-derived/src/manifest.rs index 57571e44f..329ceb428 100644 --- a/crates/core/datasets-derived/src/manifest.rs +++ b/crates/core/datasets-derived/src/manifest.rs @@ -7,13 +7,14 @@ use std::collections::BTreeMap; // Re-export schema types from datasets-common -pub use datasets_common::manifest::{ArrowSchema, Field, Function, FunctionSource, TableSchema}; +pub use datasets_common::manifest::{ArrowSchema, Field, TableSchema}; use datasets_common::{network_id::NetworkId, table_name::TableName}; use crate::{ dataset_kind::DerivedDatasetKind, deps::{DepAlias, DepReference}, func_name::FuncName, + function::Function, sql_str::SqlStr, }; diff --git a/crates/core/worker-datasets-derived/src/job_impl/table.rs b/crates/core/worker-datasets-derived/src/job_impl/table.rs index f88b7c395..3f5805346 100644 --- a/crates/core/worker-datasets-derived/src/job_impl/table.rs +++ b/crates/core/worker-datasets-derived/src/job_impl/table.rs @@ -28,12 +28,12 @@ use datasets_derived::{ deps::{DepAlias, DepAliasError}, manifest::TableInput, }; -use tracing::{Instrument, instrument}; +use tracing::Instrument as _; use super::query::{MaterializeSqlQueryError, materialize_sql_query}; /// Materializes a derived dataset table -#[instrument(skip_all, fields(table = %table.table_name()), err)] +#[tracing::instrument(skip_all, fields(table = %table.table_name()), err)] #[expect(clippy::too_many_arguments)] pub async fn materialize_table( ctx: Ctx, @@ -94,11 +94,8 @@ pub async fn materialize_table( let mut join_set = tasks::FailFastJoinSet::>::new(); - let self_schema_provider = SelfSchemaProvider::from_manifest_udfs( - datasets_derived::deps::SELF_REF_KEYWORD.to_string(), - env.isolate_pool.clone(), - &manifest.functions, - ); + let self_schema_provider = + SelfSchemaProvider::from_manifest_udfs(env.isolate_pool.clone(), &manifest.functions); let catalog = { let table_refs = resolve_table_references::(&query) diff --git a/crates/services/admin-api/src/handlers/common.rs b/crates/services/admin-api/src/handlers/common.rs index f2c94e3ea..23384df4e 100644 --- a/crates/services/admin-api/src/handlers/common.rs +++ b/crates/services/admin-api/src/handlers/common.rs @@ -5,7 +5,7 @@ use std::{collections::BTreeMap, sync::Arc}; use amp_data_store::{DataStore, PhyTableRevision}; use amp_datasets_registry::error::ResolveRevisionError; use common::{ - amp_catalog_provider::{AMP_CATALOG_NAME, AmpCatalogProvider}, + amp_catalog_provider::{AMP_CATALOG_NAME, AmpCatalogProvider, AsyncSchemaProvider}, context::plan::PlanContextBuilder, datasets_cache::{DatasetsCache, GetDatasetError}, ethcall_udfs_cache::EthCallUdfsCache, @@ -348,12 +348,9 @@ pub async fn validate_derived_manifest( .iter() .map(|(alias, hash_ref)| (alias.to_string(), hash_ref.clone())) .collect(); - let self_schema: Arc = - Arc::new(SelfSchemaProvider::from_manifest_udfs( - datasets_derived::deps::SELF_REF_KEYWORD.to_string(), - IsolatePool::dummy(), - &manifest.functions, - )); + let self_schema: Arc = Arc::new( + SelfSchemaProvider::from_manifest_udfs(IsolatePool::dummy(), &manifest.functions), + ); let amp_catalog = Arc::new( AmpCatalogProvider::new( datasets_cache.clone(), diff --git a/crates/services/admin-api/src/handlers/schema.rs b/crates/services/admin-api/src/handlers/schema.rs index c7ff54b4a..0a70d6f07 100644 --- a/crates/services/admin-api/src/handlers/schema.rs +++ b/crates/services/admin-api/src/handlers/schema.rs @@ -6,7 +6,7 @@ use axum::{ http::StatusCode, }; use common::{ - amp_catalog_provider::{AMP_CATALOG_NAME, AmpCatalogProvider}, + amp_catalog_provider::{AMP_CATALOG_NAME, AmpCatalogProvider, AsyncSchemaProvider}, context::plan::{PlanContextBuilder, is_user_input_error}, exec_env::default_session_config, incrementalizer::NonIncrementalQueryError, @@ -21,7 +21,8 @@ use datasets_common::{ use datasets_derived::{ deps::{DepAlias, DepReference, HashOrVersion}, func_name::FuncName, - manifest::{Function, TableSchema}, + function::Function, + manifest::TableSchema, }; use js_runtime::isolate_pool::IsolatePool; use tracing::instrument; @@ -212,12 +213,9 @@ pub async fn handler( // Create planning context with self-schema provider let session_config = default_session_config().map_err(Error::SessionConfig)?; - let self_schema: Arc = - Arc::new(SelfSchemaProvider::from_manifest_udfs( - datasets_derived::deps::SELF_REF_KEYWORD.to_string(), - IsolatePool::dummy(), - &functions, - )); + let self_schema: Arc = Arc::new( + SelfSchemaProvider::from_manifest_udfs(IsolatePool::dummy(), &functions), + ); let amp_catalog = Arc::new( AmpCatalogProvider::new( ctx.datasets_cache.clone(), diff --git a/docs/manifest-schemas/derived.spec.json b/docs/manifest-schemas/derived.spec.json index be8fa8626..e4bd49daa 100644 --- a/docs/manifest-schemas/derived.spec.json +++ b/docs/manifest-schemas/derived.spec.json @@ -100,7 +100,7 @@ ] }, "Function": { - "description": "User-defined function specification.\n\nDefines a custom function with input/output types and implementation source.", + "description": "User-defined function specification.\n\nDefines a custom function with input/output types and implementation source.\nArrow type validation is performed during deserialization — a successfully\ndeserialized `Function` is guaranteed to have JS UDF-compatible types.", "type": "object", "properties": { "inputTypes": {