diff --git a/Cargo.lock b/Cargo.lock index 8eaaad283b3bc..bad318e3f77a5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1890,10 +1890,12 @@ dependencies = [ "object_store", "parquet", "paste", + "prost", "pyo3", "rand 0.8.5", "recursive", "sqlparser", + "substrait", "tokio", "web-time", ] @@ -2003,6 +2005,7 @@ dependencies = [ "datafusion-physical-expr-common", "env_logger", "indexmap 2.7.1", + "log", "paste", "recursive", "serde_json", @@ -2179,7 +2182,9 @@ dependencies = [ "ctor", "datafusion-common", "datafusion-expr", + "datafusion-functions", "datafusion-functions-aggregate", + "datafusion-functions-nested", "datafusion-functions-window", "datafusion-functions-window-common", "datafusion-physical-expr", diff --git a/datafusion/catalog/src/table.rs b/datafusion/catalog/src/table.rs index ecc792f73d308..ef85cbc710035 100644 --- a/datafusion/catalog/src/table.rs +++ b/datafusion/catalog/src/table.rs @@ -18,6 +18,7 @@ use std::any::Any; use std::borrow::Cow; use std::fmt::Debug; +use std::collections::HashMap; use std::sync::Arc; use crate::session::Session; @@ -171,6 +172,28 @@ pub trait TableProvider: Debug + Sync + Send { limit: Option, ) -> Result>; + /// Create an [`ExecutionPlan`] with an extra parameter + /// specifying the deep column projections + /// # Deep column projection + /// + /// If specified, a datasource such as Parquet can do deep projection pushdown. + /// In the case of deeply nested schemas (lists in structs etc), the + /// implementation can return a smaller schema that rewrites the entire file + /// schema to return only the necessary fields, no matter where they are (top-level + /// or deep) + /// + async fn scan_deep( + &self, + state: &dyn Session, + projection: Option<&Vec>, + _projection_deep: Option<&HashMap>>, + filters: &[Expr], + limit: Option, + ) -> Result> { + self.scan(state, projection, filters, limit).await + } + + /// Specify if DataFusion should provide filter expressions to the /// TableProvider to apply *during* the scan. /// diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index 76f07be95c601..ef765a8316c23 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -68,6 +68,9 @@ recursive = { workspace = true, optional = true } sqlparser = { workspace = true } tokio = { workspace = true } +prost = "0.13" +substrait = { version = "0.53", features = ["serde"] } + [target.'cfg(target_family = "wasm")'.dependencies] web-time = "1.1.0" diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 65ecd4032729a..7287ffd544c00 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -686,6 +686,9 @@ config_namespace! { /// then the output will be coerced to a non-view. /// Coerces `Utf8View` to `LargeUtf8`, and `BinaryView` to `LargeBinary`. pub expand_views_at_output: bool, default = false + + /// disable deep column pruning + pub deep_column_pruning_flags:usize, default = 7 } } diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 7fe47e1d29db2..c745b75403128 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -56,6 +56,8 @@ pub mod types; pub mod utils; pub mod deep; +pub mod substrait_tree; + /// Reexport arrow crate pub use arrow; pub use column::Column; diff --git a/datafusion/common/src/substrait_tree.rs b/datafusion/common/src/substrait_tree.rs new file mode 100644 index 0000000000000..4a9acc06c357f --- /dev/null +++ b/datafusion/common/src/substrait_tree.rs @@ -0,0 +1,634 @@ +use crate::tree_node::{Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion}; +use crate::{DataFusionError, Result}; +use substrait::proto::{ + rel::RelType, AggregateRel, ConsistentPartitionWindowRel, CrossRel, DdlRel, + ExchangeRel, ExpandRel, ExtensionMultiRel, ExtensionSingleRel, FetchRel, FilterRel, + HashJoinRel, JoinRel, MergeJoinRel, NestedLoopJoinRel, ProjectRel, Rel, SetRel, + SortRel, WriteRel, +}; + +fn inputs(rel: &Rel) -> Vec<&Rel> { + match &rel.rel_type { + Some(rel_type) => match rel_type { + RelType::Read(_) => vec![], + RelType::Project(project_rel) => { + project_rel.input.as_deref().into_iter().collect() + } + RelType::Filter(filter_rel) => { + filter_rel.input.as_deref().into_iter().collect() + } + RelType::Fetch(fetch_rel) => fetch_rel.input.as_deref().into_iter().collect(), + RelType::Aggregate(aggregate_rel) => { + aggregate_rel.input.as_deref().into_iter().collect() + } + RelType::Sort(sort_rel) => sort_rel.input.as_deref().into_iter().collect(), + // FIXME + RelType::Join(join_rel) => { + let mut output: Vec<&Rel> = vec![]; + if let Some(left) = join_rel.left.as_ref() { + output.push(left.as_ref()); + } + if let Some(right) = join_rel.right.as_ref() { + output.push(right.as_ref()); + } + output + } + RelType::Set(set_rel) => set_rel.inputs.iter().map(|input| input).collect(), + RelType::ExtensionSingle(extension_single_rel) => { + extension_single_rel.input.as_deref().into_iter().collect() + } + RelType::ExtensionMulti(extension_multi_rel) => extension_multi_rel + .inputs + .iter() + .map(|input| input) + .collect(), + RelType::ExtensionLeaf(_) => vec![], + RelType::Cross(cross_rel) => { + let mut output: Vec<&Rel> = vec![]; + if let Some(left) = cross_rel.left.as_ref() { + output.push(left.as_ref()); + } + if let Some(right) = cross_rel.right.as_ref() { + output.push(right.as_ref()); + } + output + } + RelType::Exchange(exchange_rel) => { + exchange_rel.input.as_deref().into_iter().collect() + } + // FIXME - add all the others + RelType::Reference(ref_rel) => vec![], + RelType::Write(write_rel) => write_rel.input.as_deref().into_iter().collect(), + RelType::Ddl(ddl_rel) => { + ddl_rel.view_definition.as_deref().into_iter().collect() + } + RelType::HashJoin(hash_join_rel) => { + let mut output: Vec<&Rel> = vec![]; + if let Some(left) = hash_join_rel.left.as_ref() { + output.push(left.as_ref()); + } + if let Some(right) = hash_join_rel.right.as_ref() { + output.push(right.as_ref()); + } + output + } + RelType::MergeJoin(merge_join_rel) => { + let mut output: Vec<&Rel> = vec![]; + if let Some(left) = merge_join_rel.left.as_ref() { + output.push(left.as_ref()); + } + if let Some(right) = merge_join_rel.right.as_ref() { + output.push(right.as_ref()); + } + output + } + RelType::NestedLoopJoin(nested_loop_join) => { + let mut output: Vec<&Rel> = vec![]; + if let Some(left) = nested_loop_join.left.as_ref() { + output.push(left.as_ref()); + } + if let Some(right) = nested_loop_join.right.as_ref() { + output.push(right.as_ref()); + } + output + } + RelType::Window(window_rel) => { + window_rel.input.as_deref().into_iter().collect() + } + RelType::Expand(expand_rel) => { + expand_rel.input.as_deref().into_iter().collect() + } + RelType::Update(update_rel) => vec![], + }, + None => vec![], + } +} + +fn transform_box Result>>( + br: Box, + f: &mut F, +) -> Result>> { + Ok(f(*br)?.update_data(Box::new)) +} + +fn transform_option_box Result>>( + obr: Option>, + f: &mut F, +) -> Result>>> { + obr.map_or(Ok(Transformed::no(None)), |be| { + Ok(transform_box(be, f)?.update_data(Some)) + }) +} + +impl TreeNode for Rel { + fn apply_children<'n, F: FnMut(&'n Self) -> Result>( + &'n self, + f: F, + ) -> Result { + inputs(self).into_iter().apply_until_stop(f) + } + + fn map_children Result>>( + self, + mut f: F, + ) -> Result> { + if let Some(rel_type) = self.rel_type { + let t = match rel_type { + RelType::Read(_) => Transformed::no(rel_type), + RelType::Project(p) => { + let ProjectRel { + common, + input, + expressions, + advanced_extension, + } = *p; + transform_option_box(input, &mut f)?.update_data(|input| { + RelType::Project(Box::new(ProjectRel { + common, + input, + expressions, + advanced_extension, + })) + }) + } + RelType::Filter(p) => { + let FilterRel { + common, + input, + condition, + advanced_extension, + } = *p; + transform_option_box(input, &mut f)?.update_data(|input| { + RelType::Filter(Box::new(FilterRel { + common, + input, + condition, + advanced_extension, + })) + }) + } + + RelType::Fetch(p) => { + let FetchRel { + common, + input, + advanced_extension, + offset_mode, + count_mode, + } = *p; + transform_option_box(input, &mut f)?.update_data(|input| { + RelType::Fetch(Box::new(FetchRel { + common, + input, + advanced_extension, + offset_mode, + count_mode, + })) + }) + } + RelType::Aggregate(p) => { + let AggregateRel { + common, + input, + groupings, + measures, + grouping_expressions, + advanced_extension, + } = *p; + transform_option_box(input, &mut f)?.update_data(|input| { + RelType::Aggregate(Box::new(AggregateRel { + common, + input, + groupings, + measures, + grouping_expressions, + advanced_extension, + })) + }) + } + RelType::Sort(p) => { + let SortRel { + common, + input, + sorts, + advanced_extension, + } = *p; + transform_option_box(input, &mut f)?.update_data(|input| { + RelType::Sort(Box::new(SortRel { + common, + input, + sorts, + advanced_extension, + })) + }) + } + // FIXME + RelType::Set(p) => { + let SetRel { + common, + inputs, + op, + advanced_extension, + } = p; + let mut transformed_any = false; + let new_inputs: Vec<_> = inputs + .into_iter() + .map(|input| { + let transformed = + transform_box(Box::new(input), &mut f).unwrap(); + if transformed.transformed { + transformed_any = true; + } + *transformed.data + }) + .collect(); + if transformed_any { + Transformed::yes(RelType::Set(SetRel { + common, + inputs: new_inputs, + op, + advanced_extension, + })) + } else { + Transformed::no(RelType::Set(SetRel { + common, + inputs: new_inputs, + op, + advanced_extension, + })) + } + } + RelType::ExtensionSingle(p) => { + let ExtensionSingleRel { + common, + input, + detail, + } = *p; + transform_option_box(input, &mut f)?.update_data(|input| { + RelType::ExtensionSingle(Box::new(ExtensionSingleRel { + common, + input, + detail, + })) + }) + } + RelType::ExtensionMulti(p) => { + let ExtensionMultiRel { + common, + inputs, + detail, + } = p; + let mut transformed_any = false; + let new_inputs: Vec = inputs + .into_iter() + .map(|input| { + let transformed = + transform_box(Box::new(input), &mut f).unwrap(); + if transformed.transformed { + transformed_any = true; + } + *transformed.data + }) + .collect(); + if transformed_any { + Transformed::yes(RelType::ExtensionMulti(ExtensionMultiRel { + common, + inputs: new_inputs, + detail, + })) + } else { + Transformed::no(RelType::ExtensionMulti(ExtensionMultiRel { + common, + inputs: new_inputs, + detail, + })) + } + } + RelType::Join(p) => { + let JoinRel { + common, + left, + right, + expression, + post_join_filter, + r#type, + advanced_extension, + } = *p; + let mut transformed_any = false; + let new_left = transform_option_box(left, &mut f)?; + if new_left.transformed { + transformed_any = true; + } + let new_right = transform_option_box(right, &mut f)?; + if new_right.transformed { + transformed_any = true; + } + + if transformed_any { + Transformed::yes(RelType::Join(Box::new(JoinRel { + common, + left: new_left.data, + right: new_right.data, + expression, + post_join_filter, + r#type, + advanced_extension, + }))) + } else { + Transformed::no(RelType::Join(Box::new(JoinRel { + common, + left: new_left.data, + right: new_right.data, + expression, + post_join_filter, + r#type, + advanced_extension, + }))) + } + } + RelType::ExtensionLeaf(inner) => { + Transformed::no(RelType::ExtensionLeaf(inner)) + } + RelType::Cross(p) => { + let CrossRel { + common, + left, + right, + advanced_extension, + } = *p; + let mut transformed_any = false; + let new_left = transform_option_box(left, &mut f)?; + if new_left.transformed { + transformed_any = true; + } + let new_right = transform_option_box(right, &mut f)?; + if new_right.transformed { + transformed_any = true; + } + + if transformed_any { + Transformed::yes(RelType::Cross(Box::new(CrossRel { + common, + left: new_left.data, + right: new_right.data, + advanced_extension, + }))) + } else { + Transformed::no(RelType::Cross(Box::new(CrossRel { + common, + left: new_left.data, + right: new_right.data, + advanced_extension, + }))) + } + } + RelType::Reference(inner) => Transformed::no(RelType::Reference(inner)), + RelType::Write(p) => { + let WriteRel { + table_schema, + op, + input, + create_mode, + output, + common, + advanced_extension, + write_type, + } = *p; + transform_option_box(input, &mut f)?.update_data(|input| { + RelType::Write(Box::new(WriteRel { + table_schema, + op, + input, + create_mode, + output, + common, + advanced_extension, + write_type, + })) + }) + } + RelType::Ddl(p) => { + let DdlRel { + table_schema, + table_defaults, + object, + op, + view_definition, + common, + advanced_extension, + write_type, + } = *p; + transform_option_box(view_definition, &mut f)?.update_data(|input| { + RelType::Ddl(Box::new(DdlRel { + table_schema, + table_defaults, + object, + op, + view_definition: input, + common, + advanced_extension, + write_type, + })) + }) + } + RelType::HashJoin(p) => { + let HashJoinRel { + common, + left, + right, + left_keys, + right_keys, + keys, + post_join_filter, + r#type, + advanced_extension, + } = *p; + let mut transformed_any = false; + let new_left = transform_option_box(left, &mut f)?; + if new_left.transformed { + transformed_any = true; + } + let new_right = transform_option_box(right, &mut f)?; + if new_right.transformed { + transformed_any = true; + } + + if transformed_any { + Transformed::yes(RelType::HashJoin(Box::new(HashJoinRel { + common, + left: new_left.data, + right: new_right.data, + left_keys, + right_keys, + keys, + post_join_filter, + r#type, + advanced_extension, + }))) + } else { + Transformed::no(RelType::HashJoin(Box::new(HashJoinRel { + common, + left: new_left.data, + right: new_right.data, + left_keys, + right_keys, + keys, + post_join_filter, + r#type, + advanced_extension, + }))) + } + } + RelType::MergeJoin(p) => { + let MergeJoinRel { + common, + left, + right, + left_keys, + right_keys, + keys, + post_join_filter, + r#type, + advanced_extension, + } = *p; + let mut transformed_any = false; + let new_left = transform_option_box(left, &mut f)?; + if new_left.transformed { + transformed_any = true; + } + let new_right = transform_option_box(right, &mut f)?; + if new_right.transformed { + transformed_any = true; + } + + if transformed_any { + Transformed::yes(RelType::MergeJoin(Box::new(MergeJoinRel { + common, + left: new_left.data, + right: new_right.data, + left_keys, + right_keys, + keys, + post_join_filter, + r#type, + advanced_extension, + }))) + } else { + Transformed::no(RelType::MergeJoin(Box::new(MergeJoinRel { + common, + left: new_left.data, + right: new_right.data, + left_keys, + right_keys, + keys, + post_join_filter, + r#type, + advanced_extension, + }))) + } + } + RelType::NestedLoopJoin(p) => { + let NestedLoopJoinRel { + common, + left, + right, + expression, + r#type, + advanced_extension, + } = *p; + let mut transformed_any = false; + let new_left = transform_option_box(left, &mut f)?; + if new_left.transformed { + transformed_any = true; + } + let new_right = transform_option_box(right, &mut f)?; + if new_right.transformed { + transformed_any = true; + } + + if transformed_any { + Transformed::yes(RelType::NestedLoopJoin(Box::new( + NestedLoopJoinRel { + common, + left: new_left.data, + right: new_right.data, + expression, + r#type, + advanced_extension, + }, + ))) + } else { + Transformed::no(RelType::NestedLoopJoin(Box::new( + NestedLoopJoinRel { + common, + left: new_left.data, + right: new_right.data, + expression, + r#type, + advanced_extension, + }, + ))) + } + } + RelType::Window(p) => { + let ConsistentPartitionWindowRel { + common, + input, + window_functions, + partition_expressions, + sorts, + advanced_extension, + } = *p; + transform_option_box(input, &mut f)?.update_data(|input| { + RelType::Window(Box::new(ConsistentPartitionWindowRel { + common, + input, + window_functions, + partition_expressions, + sorts, + advanced_extension, + })) + }) + } + RelType::Exchange(p) => { + let ExchangeRel { + common, + input, + partition_count, + targets, + advanced_extension, + exchange_kind, + } = *p; + transform_option_box(input, &mut f)?.update_data(|input| { + RelType::Exchange(Box::new(ExchangeRel { + common, + input, + partition_count, + targets, + advanced_extension, + exchange_kind, + })) + }) + } + RelType::Expand(p) => { + let ExpandRel { + common, + input, + fields, + } = *p; + transform_option_box(input, &mut f)?.update_data(|input| { + RelType::Expand(Box::new(ExpandRel { + common, + input, + fields, + })) + }) + } + RelType::Update(_) => Transformed::no(rel_type), + }; + Ok(t.update_data(|rt| Rel { rel_type: Some(rt) })) + } else { + Err(DataFusionError::Plan("RelType is None".into())) + } + } +} diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index a983f0696e83b..506a2829c05db 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -958,6 +958,118 @@ impl TableProvider for ListingTable { .await } + async fn scan_deep( + &self, + state: &dyn Session, + projection: Option<&Vec>, + projection_deep: Option<&HashMap>>, + filters: &[Expr], + limit: Option, + ) -> Result> { + // extract types of partition columns + let table_partition_cols = self + .options + .table_partition_cols + .iter() + .map(|col| Ok(self.table_schema.field_with_name(&col.0)?.clone())) + .collect::>>()?; + + let table_partition_col_names = table_partition_cols + .iter() + .map(|field| field.name().as_str()) + .collect::>(); + // If the filters can be resolved using only partition cols, there is no need to + // pushdown it to TableScan, otherwise, `unhandled` pruning predicates will be generated + let (partition_filters, filters): (Vec<_>, Vec<_>) = + filters.iter().cloned().partition(|filter| { + can_be_evaluted_for_partition_pruning(&table_partition_col_names, filter) + }); + // TODO (https://github.com/apache/datafusion/issues/11600) remove downcast_ref from here? + let session_state = state.as_any().downcast_ref::().unwrap(); + + // We should not limit the number of partitioned files to scan if there are filters and limit + // at the same time. This is because the limit should be applied after the filters are applied. + let statistic_file_limit = if filters.is_empty() { limit } else { None }; + + let (mut partitioned_file_lists, statistics) = self + .list_files_for_scan(session_state, &partition_filters, statistic_file_limit) + .await?; + + // if no files need to be read, return an `EmptyExec` + if partitioned_file_lists.is_empty() { + let projected_schema = project_schema(&self.schema(), projection)?; + return Ok(Arc::new(EmptyExec::new(projected_schema))); + } + + let output_ordering = self.try_create_output_ordering()?; + match state + .config_options() + .execution + .split_file_groups_by_statistics + .then(|| { + output_ordering.first().map(|output_ordering| { + FileScanConfig::split_groups_by_statistics( + &self.table_schema, + &partitioned_file_lists, + output_ordering, + ) + }) + }) + .flatten() + { + Some(Err(e)) => log::debug!("failed to split file groups by statistics: {e}"), + Some(Ok(new_groups)) => { + if new_groups.len() <= self.options.target_partitions { + partitioned_file_lists = new_groups; + } else { + log::debug!("attempted to split file groups by statistics, but there were more file groups than target_partitions; falling back to unordered") + } + } + None => {} // no ordering required + }; + + let filters = match conjunction(filters.to_vec()) { + Some(expr) => { + let table_df_schema = self.table_schema.as_ref().clone().to_dfschema()?; + let filters = create_physical_expr( + &expr, + &table_df_schema, + state.execution_props(), + )?; + Some(filters) + } + None => None, + }; + + let Some(object_store_url) = + self.table_paths.first().map(ListingTableUrl::object_store) + else { + return Ok(Arc::new(EmptyExec::new(Arc::new(Schema::empty())))); + }; + + // create the execution plan + self.options + .format + .create_physical_plan( + session_state, + FileScanConfig::new( + object_store_url, + Arc::clone(&self.file_schema), + self.options.format.file_source(), + ) + .with_file_groups(partitioned_file_lists) + .with_constraints(self.constraints.clone()) + .with_statistics(statistics) + .with_projection(projection.cloned()) + .with_projection_deep(projection_deep.cloned()) + .with_limit(limit) + .with_output_ordering(output_ordering) + .with_table_partition_cols(table_partition_cols), + filters.as_ref(), + ) + .await + } + fn supports_filters_pushdown( &self, filters: &[&Expr], diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index f677c73cc8819..1883a38193f8b 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -233,8 +233,8 @@ impl ParquetExecBuilder { } let base_config = file_scan_config.with_source(Arc::new(parquet.clone())); - debug!("Creating ParquetExec, files: {:?}, projection {:?}, predicate: {:?}, limit: {:?}", - base_config.file_groups, base_config.projection, predicate, base_config.limit); + debug!("Creating ParquetExec, files: {:?}, projection {:?}, projection deep {:?}, predicate: {:?}, limit: {:?}", + base_config.file_groups, base_config.projection, base_config.projection_deep, predicate, base_config.limit); ParquetExec { inner: DataSourceExec::new(Arc::new(base_config.clone())), diff --git a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs index 4230a1bdce388..b8cba2ec09ec1 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs @@ -17,8 +17,6 @@ //! [`ParquetOpener`] for opening Parquet files -use std::sync::Arc; - use crate::datasource::file_format::parquet::{ coerce_file_schema_to_string_type, coerce_file_schema_to_view_type, }; @@ -31,6 +29,9 @@ use crate::datasource::physical_plan::{ FileMeta, FileOpenFuture, FileOpener, ParquetFileMetrics, ParquetFileReaderFactory, }; use crate::datasource::schema_adapter::SchemaAdapterFactory; +use std::cmp::min; +use std::collections::HashMap; +use std::sync::Arc; use arrow::datatypes::SchemaRef; use arrow::error::ArrowError; @@ -40,10 +41,13 @@ use datafusion_physical_optimizer::pruning::PruningPredicate; use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; use futures::{StreamExt, TryStreamExt}; -use log::debug; +use log::{debug, info, trace}; use parquet::arrow::arrow_reader::{ArrowReaderMetadata, ArrowReaderOptions}; use parquet::arrow::async_reader::AsyncFileReader; use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask}; +use parquet::schema::types::SchemaDescriptor; +// use datafusion_common::DataFusionError; +use datafusion_common::deep::{has_deep_projection, rewrite_schema, splat_columns}; /// Implements [`FileOpener`] for a parquet file pub(super) struct ParquetOpener { @@ -51,7 +55,7 @@ pub(super) struct ParquetOpener { pub partition_index: usize, /// Column indexes in `table_schema` needed by the query pub projection: Arc<[usize]>, - /// Target number of rows in each output RecordBatch + pub projection_deep: Arc>>, pub batch_size: usize, /// Optional limit on the number of rows to read pub limit: Option, @@ -105,11 +109,28 @@ impl FileOpener for ParquetOpener { let batch_size = self.batch_size; - let projected_schema = - SchemaRef::from(self.table_schema.project(&self.projection)?); + let projection = self.projection.clone(); + let projection_vec = projection + .as_ref() + .iter() + .map(|i| *i) + .collect::>(); + info!("ParquetOpener::open projection={:?}", projection); + // FIXME @HStack: ADR: why do we need to do this ? our function needs another param maybe ? + // In the case when the projections requested are empty, we should return an empty schema + let projected_schema = if projection_vec.len() == 0 { + SchemaRef::from(self.table_schema.project(&projection)?) + } else { + rewrite_schema( + self.table_schema.clone(), + &projection_vec, + self.projection_deep.as_ref(), + ) + }; let schema_adapter = self .schema_adapter_factory .create(projected_schema, Arc::clone(&self.table_schema)); + let projection_deep = self.projection_deep.clone(); let predicate = self.predicate.clone(); let pruning_predicate = self.pruning_predicate.clone(); let page_pruning_predicate = self.page_pruning_predicate.clone(); @@ -159,11 +180,32 @@ impl FileOpener for ParquetOpener { let (schema_mapping, adapted_projections) = schema_adapter.map_schema(&file_schema)?; - let mask = ProjectionMask::roots( - builder.parquet_schema(), - adapted_projections.iter().cloned(), - ); - + // let mask = ProjectionMask::roots( + // builder.parquet_schema(), + // adapted_projections.iter().cloned(), + // ); + let mask = if has_deep_projection(Some(projection_deep.clone().as_ref())) { + let leaves = generate_leaf_paths( + table_schema.clone(), + builder.parquet_schema(), + &projection_vec, + projection_deep.clone().as_ref(), + ); + info!( + "ParquetOpener::open, using deep projection parquet leaves: {:?}", + leaves.clone() + ); + // let tmp = builder.parquet_schema(); + // for (i, col) in tmp.columns().iter().enumerate() { + // info!(" {} {}= {:?}", i, col.path(), col); + // } + ProjectionMask::leaves(builder.parquet_schema(), leaves) + } else { + ProjectionMask::roots( + builder.parquet_schema(), + adapted_projections.iter().cloned(), + ) + }; // Filter pushdown: evaluate predicates during scan if let Some(predicate) = pushdown_filters.then_some(predicate).flatten() { let row_filter = row_filter::build_row_filter( @@ -303,3 +345,103 @@ fn create_initial_plan( // default to scanning all row groups Ok(ParquetAccessPlan::new_all(row_group_count)) } + +// FIXME: @HStack ACTUALLY look at the arrow schema and handle map types correctly +// Right now, we are matching "map-like" parquet leaves like "key_value.key" etc +// But, we neeed to walk through both the arrow schema (which KNOWS about the map type) +// and the parquet leaves to do this correctly. +fn equivalent_projection_paths_from_parquet_schema( + _arrow_schema: SchemaRef, + parquet_schema: &SchemaDescriptor, +) -> Vec<(usize, (String, String))> { + let mut output: Vec<(usize, (String, String))> = vec![]; + for (i, col) in parquet_schema.columns().iter().enumerate() { + let original_path = col.path().string(); + let converted_path = + convert_parquet_path_to_deep_projection_path(&original_path.as_str()); + output.push((i, (original_path.clone(), converted_path))); + } + output +} + +fn convert_parquet_path_to_deep_projection_path(parquet_path: &str) -> String { + if parquet_path.contains(".key_value.key") + || parquet_path.contains(".key_value.value") + || parquet_path.contains(".entries.keys") + || parquet_path.contains(".entries.values") + || parquet_path.contains(".list.element") + { + let tmp = parquet_path + .replace("key_value.key", "*") + .replace("key_value.value", "*") + .replace("entries.keys", "*") + .replace("entries.values", "*") + .replace("list.element", "*"); + tmp + } else { + parquet_path.to_string() + } +} + +fn generate_leaf_paths( + arrow_schema: SchemaRef, + parquet_schema: &SchemaDescriptor, + projection: &Vec, + projection_deep: &HashMap>, +) -> Vec { + let actual_projection = if projection.len() == 0 { + (0..arrow_schema.fields().len()).collect() + } else { + projection.clone() + }; + let splatted = + splat_columns(arrow_schema.clone(), &actual_projection, &projection_deep); + trace!(target: "deep", "generate_leaf_paths: splatted: {:?}", &splatted); + + let mut out: Vec = vec![]; + for (i, (original, converted)) in + equivalent_projection_paths_from_parquet_schema(arrow_schema, parquet_schema) + { + // FIXME: @HStack + // for map fields, the actual parquet paths look like x.y.z.key_value.key, x.y.z.key_value.value + // since we are ignoring these names in the paths, we need to actually collapse this access to a * + // so we can filter for them + // also, we need BOTH the key and the value for maps otherwise we run into an arrow-rs error + // "partial projection of MapArray is not supported" + + trace!(target: "deep", " generate_leaf_paths looking at index {} {} = {}", i, &original, &converted); + + let mut found = false; + for filter in splatted.iter() { + // check if this filter matches this leaf path + let filter_pieces = filter.split(".").collect::>(); + // let col_pieces = col_path.parts(); + let col_pieces = converted.split(".").collect::>(); + // let's check + let mut filter_found = true; + for i in 0..min(filter_pieces.len(), col_pieces.len()) { + if i >= filter_pieces.len() { + // we are at the end of the filter, and we matched until now, so we break, we match ! + break; + } + if i >= col_pieces.len() { + // we have a longer filter, we matched until now, we match ! + break; + } + // we can actually check + if !(col_pieces[i] == filter_pieces[i] || filter_pieces[i] == "*") { + filter_found = false; + break; + } + } + if filter_found { + found = true; + break; + } + } + if found { + out.push(i); + } + } + out +} diff --git a/datafusion/core/src/datasource/physical_plan/parquet/source.rs b/datafusion/core/src/datasource/physical_plan/parquet/source.rs index 142725524f1bb..dad593f9c3458 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/source.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/source.rs @@ -17,6 +17,7 @@ //! ParquetSource implementation for reading parquet files use std::any::Any; +use std::collections::HashMap; use std::fmt::Formatter; use std::sync::Arc; @@ -39,7 +40,7 @@ use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricBuilder}; use datafusion_physical_plan::DisplayFormatType; use itertools::Itertools; -use log::debug; +use log::{debug, trace}; use object_store::ObjectStore; /// Execution plan for reading one or more Parquet files. @@ -469,6 +470,28 @@ impl FileSource for ParquetSource { let projection = base_config .file_column_projection_indices() .unwrap_or_else(|| (0..base_config.file_schema.fields().len()).collect()); + + let projection_deep = match &base_config.projection_deep { + None => HashMap::new(), + Some(pd) => { + let mut out: HashMap> = HashMap::new(); + for npi in &projection { + match pd.get(npi) { + None => {} + Some(v) => { + out.insert(*npi, v.clone()); + } + } + } + out + } + }; + trace!( + "ParquetExec::execute projection={:#?}, projection_deep={:#?}", + &projection, + &projection_deep + ); + let schema_adapter_factory = self .schema_adapter_factory .clone() @@ -482,6 +505,7 @@ impl FileSource for ParquetSource { Arc::new(ParquetOpener { partition_index: partition, projection: Arc::from(projection), + projection_deep: Arc::new(projection_deep), batch_size: self .batch_size .expect("Batch size must set before creating ParquetOpener"), diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index c27d1e4fd46b4..aec9695786340 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -62,7 +62,7 @@ use datafusion_expr::{ expr_rewriter::FunctionRewrite, logical_plan::{DdlStatement, Statement}, planner::ExprPlanner, - Expr, UserDefinedLogicalNode, WindowUDF, + Expr, WindowUDF, }; // backwards compatibility @@ -1701,27 +1701,7 @@ pub enum RegisterFunction { #[derive(Debug)] pub struct EmptySerializerRegistry; -impl SerializerRegistry for EmptySerializerRegistry { - fn serialize_logical_plan( - &self, - node: &dyn UserDefinedLogicalNode, - ) -> Result> { - not_impl_err!( - "Serializing user defined logical plan node `{}` is not supported", - node.name() - ) - } - - fn deserialize_logical_plan( - &self, - name: &str, - _bytes: &[u8], - ) -> Result> { - not_impl_err!( - "Deserializing user defined logical plan node `{name}` is not supported" - ) - } -} +impl SerializerRegistry for EmptySerializerRegistry {} /// Describes which SQL statements can be run. /// diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index a74cdcc5920bf..e64e7f9acb2c9 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -90,6 +90,7 @@ use datafusion_physical_plan::unnest::ListUnnest; use crate::schema_equivalence::schema_satisfied_by; use async_trait::async_trait; +use datafusion_common::deep::can_rewrite_schema; use futures::{StreamExt, TryStreamExt}; use itertools::{multiunzip, Itertools}; use log::{debug, trace}; @@ -442,6 +443,7 @@ impl DefaultPhysicalPlanner { LogicalPlan::TableScan(TableScan { source, projection, + projection_deep, filters, fetch, .. @@ -452,7 +454,13 @@ impl DefaultPhysicalPlanner { // referred to in the query let filters = unnormalize_cols(filters.iter().cloned()); source - .scan(session_state, projection.as_ref(), &filters, *fetch) + .scan_deep( + session_state, + projection.as_ref(), + projection_deep.as_ref(), + &filters, + *fetch, + ) .await? } LogicalPlan::Values(Values { values, schema }) => { @@ -651,11 +659,23 @@ impl DefaultPhysicalPlanner { let logical_input_schema = input.as_ref().schema(); let physical_input_schema_from_logical = logical_input_schema.inner(); + let normal_schema_satisfied_by = schema_satisfied_by( + physical_input_schema_from_logical, + &physical_input_schema, + ); + let deep_schema_is_satisfied = + if options.optimizer.deep_column_pruning_flags > 0 { + can_rewrite_schema( + physical_input_schema_from_logical, + &physical_input_schema, + false, + ) + } else { + false + }; + if !options.execution.skip_physical_aggregate_schema_check - && !schema_satisfied_by( - physical_input_schema_from_logical, - &physical_input_schema, - ) + && !(normal_schema_satisfied_by || deep_schema_is_satisfied) { let mut differences = Vec::new(); if physical_input_schema.fields().len() diff --git a/datafusion/core/src/schema_equivalence.rs b/datafusion/core/src/schema_equivalence.rs index 70bee206655bf..c9adeee4aacb7 100644 --- a/datafusion/core/src/schema_equivalence.rs +++ b/datafusion/core/src/schema_equivalence.rs @@ -16,6 +16,9 @@ // under the License. use arrow::datatypes::{DataType, Field, Fields, Schema}; +use datafusion_common::deep::can_rewrite_field; +use log::info; +use std::sync::Arc; /// Verifies whether the original planned schema can be satisfied with data /// adhering to the candidate schema. In practice, this is equality check on the @@ -37,10 +40,24 @@ fn fields_satisfied_by(original: &Fields, candidate: &Fields) -> bool { /// See [`schema_satisfied_by`] for the contract. fn field_satisfied_by(original: &Field, candidate: &Field) -> bool { - original.name() == candidate.name() + let plain_match = original.name() == candidate.name() && (original.is_nullable() || !candidate.is_nullable()) && original.metadata() == candidate.metadata() - && data_type_satisfied_by(original.data_type(), candidate.data_type()) + && data_type_satisfied_by(original.data_type(), candidate.data_type()); + + let deep_match = can_rewrite_field( + &Arc::new(original.clone()), + &Arc::new(candidate.clone()), + false, + ); + info!( + "field satisfied by {} = {}, plain={}, deep={}", + original.name(), + candidate.name(), + plain_match, + deep_match + ); + plain_match || deep_match } /// See [`schema_satisfied_by`] for the contract. diff --git a/datafusion/core/tests/core_integration.rs b/datafusion/core/tests/core_integration.rs index 66b4103160e7b..6dbaa37c9f769 100644 --- a/datafusion/core/tests/core_integration.rs +++ b/datafusion/core/tests/core_integration.rs @@ -42,6 +42,8 @@ mod custom_sources_cases; /// Run all tests that are found in the `optimizer` directory mod optimizer; +mod optimizer_deep_indices; + mod physical_optimizer; mod catalog; diff --git a/datafusion/core/tests/data/deep_projections/first-duckdb.sql b/datafusion/core/tests/data/deep_projections/first-duckdb.sql new file mode 100644 index 0000000000000..ec699c32d4069 --- /dev/null +++ b/datafusion/core/tests/data/deep_projections/first-duckdb.sql @@ -0,0 +1,66 @@ +CREATE OR REPLACE TABLE cross_industry_demo_data ( + _ACP_DATE DATE, + timestamp TIMESTAMP_S, + endUserIDs STRUCT( + aaid_id VARCHAR, + extra INT4 + ), + _experience STRUCT( + eVar56 VARCHAR, + extra VARCHAR + ) +); + +INSERT INTO cross_industry_demo_data VALUES +( + '2025-01-03', + '2025-01-03 02:00:30', + ROW( + 'd1', + 1, + ), + ROW( + 'u1', + 'extra1', + ), +); +INSERT INTO cross_industry_demo_data VALUES +( + '2025-01-03', + '2025-01-03 02:00:30', + ROW( + 'd1', + 1, + ), + ROW( + 'u2', + 'extra1', + ), +); +INSERT INTO cross_industry_demo_data VALUES +( + '2025-01-03', + '2025-01-03 02:00:30', + ROW( + 'd2', + 1, + ), + ROW( + 'u1', + 'extra1', + ), +); +INSERT INTO cross_industry_demo_data VALUES +( + '2025-01-03', + '2025-01-03 02:00:30', + ROW( + 'd2', + 1, + ), + ROW( + 'u2', + 'extra1', + ), +); +COPY cross_industry_demo_data TO 'output.parquet' (FORMAT PARQUET); diff --git a/datafusion/core/tests/data/deep_projections/first.parquet b/datafusion/core/tests/data/deep_projections/first.parquet new file mode 100644 index 0000000000000..9cb5a18df5875 Binary files /dev/null and b/datafusion/core/tests/data/deep_projections/first.parquet differ diff --git a/datafusion/core/tests/data/deep_projections/genstudio/generate.sql b/datafusion/core/tests/data/deep_projections/genstudio/generate.sql new file mode 100644 index 0000000000000..fe331d77e505c --- /dev/null +++ b/datafusion/core/tests/data/deep_projections/genstudio/generate.sql @@ -0,0 +1,140 @@ +CREATE OR REPLACE TABLE meta_asset_summary_metrics( + timestamp TIMESTAMP, -- 0 + _acp_system_metadata STRUCT( -- 1 + acp_sourceBatchId VARCHAR, -- 2 + commitBatchId VARCHAR, --3 + ingestTime INT8, -- 4 + isDeleted BOOL, -- 5 + rowId VARCHAR, + rowVersion INT8, + trackingId VARCHAR + ), + _aresstagevalidationco STRUCT( + genStudioInsights STRUCT( + accountID VARCHAR, + adGroupID VARCHAR, + adID VARCHAR, + assetID VARCHAR, + campaignID VARCHAR, + metrics STRUCT( + engagement STRUCT( + addsToCart STRUCT(value INT8), + addsToWishList STRUCT(value INT8), + page STRUCT( + engagement STRUCT(value INT8), + likes STRUCT(value INT8) + ), + photoViews STRUCT(value INT8), + post STRUCT( + comments STRUCT(value INT8), + engagement STRUCT(value INT8), + reactions STRUCT(value INT8), + saves STRUCT(value INT8), + shares STRUCT(value INT8) + ) + ), + performance STRUCT( + clicks STRUCT(value INT8), + conversionCount STRUCT(value INT8), + conversionValue STRUCT(value INT8), + impressions STRUCT(value INT8) + ), + spend STRUCT(value INT8) + ) + ), + network VARCHAR + ), + _ACP_DATE DATE, + _ACP_BATCHID VARCHAR +); +COPY meta_asset_summary_metrics TO 'meta_asset_summary_metrics.parquet' (FORMAT PARQUET); + +CREATE OR REPLACE TABLE meta_asset_summary_metrics_by_age_and_gender( + timestamp TIMESTAMP, + _acp_system_metadata STRUCT( + acp_sourceBatchId VARCHAR, + commitBatchId VARCHAR, + ingestTime INT8, + isDeleted BOOL, + rowId VARCHAR, + rowVersion INT8, + trackingId VARCHAR + ), + _aresstagevalidationco STRUCT( + genStudioInsights STRUCT( + accountID VARCHAR, + adGroupID VARCHAR, + adID VARCHAR, + age VARCHAR, + assetID VARCHAR, + campaignID VARCHAR, + gender VARCHAR, + metrics STRUCT( + engagement STRUCT( + addsToCart STRUCT(value INT8), + addsToWishList STRUCT(value INT8), + page STRUCT( + engagement STRUCT(value INT8), + likes STRUCT(value INT8) + ), + photoViews STRUCT(value INT8), + post STRUCT( + comments STRUCT(value INT8), + engagement STRUCT(value INT8), + reactions STRUCT(value INT8), + saves STRUCT(value INT8), + shares STRUCT(value INT8) + ) + ), + performance STRUCT( + clicks STRUCT(value INT8), + conversionCount STRUCT(value INT8), + conversionValue STRUCT(value INT8), + impressions STRUCT(value INT8) + ), + spend STRUCT(value INT8) + ) + ), + network VARCHAR + ), + _ACP_DATE DATE, + _ACP_BATCHID VARCHAR +); +COPY meta_asset_summary_metrics_by_age_and_gender TO 'meta_asset_summary_metrics_by_age_and_gender.parquet' (FORMAT PARQUET); + +CREATE OR REPLACE TABLE meta_asset_featurization ( + _acp_system_metadata STRUCT( + acp_sourceBatchId VARCHAR, + commitBatchId VARCHAR, + ingestTime INT8, + isDeleted BOOL, + rowId VARCHAR, + rowVersion INT8, + trackingId VARCHAR + ), + _aresstagevalidationco STRUCT( + contentAssets STRUCT( + assetID VARCHAR, + assetPerceptionID VARCHAR, + assetThumbnailURL VARCHAR, + assetType VARCHAR, + version TIMESTAMP + ), + contentFeaturization STRUCT( + audioGenre VARCHAR, + audioGenreCategory VARCHAR, + audioMood VARCHAR, + audioTypes VARCHAR[], + categories VARCHAR[], + objects VARCHAR[], + orientation VARCHAR, + peopleCategories VARCHAR[], + version VARCHAR + ) + ), + _id VARCHAR, + _ACP_BATCHID VARCHAR +); + +COPY meta_asset_featurization TO 'meta_asset_featurization.parquet' (FORMAT PARQUET); + diff --git a/datafusion/core/tests/data/deep_projections/genstudio/meta_asset_featurization.parquet b/datafusion/core/tests/data/deep_projections/genstudio/meta_asset_featurization.parquet new file mode 100644 index 0000000000000..3af369c14c95e Binary files /dev/null and b/datafusion/core/tests/data/deep_projections/genstudio/meta_asset_featurization.parquet differ diff --git a/datafusion/core/tests/data/deep_projections/genstudio/meta_asset_summary_metrics.parquet b/datafusion/core/tests/data/deep_projections/genstudio/meta_asset_summary_metrics.parquet new file mode 100644 index 0000000000000..4a6c9de169386 Binary files /dev/null and b/datafusion/core/tests/data/deep_projections/genstudio/meta_asset_summary_metrics.parquet differ diff --git a/datafusion/core/tests/data/deep_projections/genstudio/meta_asset_summary_metrics_by_age_and_gender.parquet b/datafusion/core/tests/data/deep_projections/genstudio/meta_asset_summary_metrics_by_age_and_gender.parquet new file mode 100644 index 0000000000000..07774b4f0c6ff Binary files /dev/null and b/datafusion/core/tests/data/deep_projections/genstudio/meta_asset_summary_metrics_by_age_and_gender.parquet differ diff --git a/datafusion/core/tests/data/deep_projections/triplea/midvalues.parquet b/datafusion/core/tests/data/deep_projections/triplea/midvalues.parquet new file mode 100644 index 0000000000000..ef55eea6d765c Binary files /dev/null and b/datafusion/core/tests/data/deep_projections/triplea/midvalues.parquet differ diff --git a/datafusion/core/tests/data/deep_projections/triplea/midvalues.sql b/datafusion/core/tests/data/deep_projections/triplea/midvalues.sql new file mode 100644 index 0000000000000..3d67d406b54bd --- /dev/null +++ b/datafusion/core/tests/data/deep_projections/triplea/midvalues.sql @@ -0,0 +1,45 @@ +CREATE OR REPLACE TABLE midvalues( + timestamp TIMESTAMP_S, -- 0 + web STRUCT( + webPageDetails STRUCT( + pageViews STRUCT(value INT8) + ) + ), + endUserIDs STRUCT( + _experience STRUCT( + mcid STRUCT( + id VARCHAR, + extra1 VARCHAR + ), + aaid STRUCT( + id VARCHAR, + extra1 VARCHAR + ) + ) + ) +); + +INSERT INTO midvalues VALUES +( + '2025-01-15 00:00:01', + struct_pack( + webPageDetails := struct_pack( + pageViews := struct_pack(value := 100) + ) + ), + struct_pack( + _experience := struct_pack( + mcid := struct_pack( + id := 'mcid1', + extra1 := 'extram1' + ), + aaid := struct_pack( + id := 'mcid1', + extra1 := 'extram1' + ) + ) + ) +); + +COPY midvalues TO 'midvalues.parquet' (FORMAT PARQUET); + diff --git a/datafusion/core/tests/optimizer_deep_indices/mod.rs b/datafusion/core/tests/optimizer_deep_indices/mod.rs new file mode 100644 index 0000000000000..544069bead7d3 --- /dev/null +++ b/datafusion/core/tests/optimizer_deep_indices/mod.rs @@ -0,0 +1,759 @@ +use arrow_schema::{DataType, Field, Fields, Schema}; +use datafusion::datasource::physical_plan::ParquetExec; +use datafusion::logical_expr::Operator; +use datafusion::prelude::{ParquetReadOptions, SessionContext}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; +use datafusion_common::Result; +use datafusion_common::{DFSchema, JoinType}; +use datafusion_datasource::file_scan_config::FileScanConfig; +use datafusion_datasource::source::DataSourceExec; +use datafusion_execution::config::SessionConfig; +use datafusion_expr::{col, lit, BinaryExpr, Expr, Literal, LogicalPlanBuilder}; +use datafusion_functions::expr_fn::get_field; +use datafusion_optimizer::common_subexpr_eliminate::CommonSubexprEliminate; +use datafusion_optimizer::optimize_projections_deep::{ + DeepColumnIndexMap, PlanWithDeepColumnMap, FLAG_ENABLE, + FLAG_ENABLE_PROJECTION_MERGING, FLAG_ENABLE_SUBQUERY_TRANSLATION, +}; +use datafusion_optimizer::push_down_filter::PushDownFilter; +use datafusion_optimizer::push_down_limit::PushDownLimit; +use datafusion_physical_plan::work_table::WorkTableExec; +use log::info; +use parquet::arrow::parquet_to_arrow_schema; +use parquet::file::reader::{FileReader, SerializedFileReader}; +use std::collections::HashMap; +use std::fs::File; +use std::path::Path; +use std::sync::Arc; + +#[cfg(test)] +#[ctor::ctor] +fn init() { + // enable logging so RUST_LOG works + let _ = env_logger::try_init(); +} + +pub fn make_get_field(from: Expr, sub_col_name: &str) -> Expr { + get_field(from, sub_col_name) +} + +pub fn build_deep_schema() -> Schema { + Schema::new(vec![ + Field::new("sc1", DataType::Int64, true), + Field::new( + "st1", + DataType::Struct(Fields::from(vec![ + Field::new("sc1", DataType::Utf8, true), + Field::new( + "st1", + DataType::Struct(Fields::from(vec![ + Field::new("sc1", DataType::Int64, true), + Field::new("sc2", DataType::Utf8, true), + ])), + true, + ), + ])), + true, + ), + Field::new( + "st2", + DataType::Struct(Fields::from(vec![Field::new( + "st2_sc1", + DataType::Utf8, + true, + )])), + true, + ), + ]) +} + +#[test] +pub fn test_make_required_indices() { + let _ = env_logger::try_init(); + let schema = build_deep_schema(); + let df_schema = Arc::new(DFSchema::try_from(schema.clone()).unwrap()); + /// let field_a = Field::new("a", DataType::Int64, false); + /// let field_b = Field::new("b", DataType::Boolean, false); + let get_st1_sc1 = make_get_field(col("st1"), "sc1"); + let get_st1_st1_sc1 = make_get_field(make_get_field(col("st1"), "st1"), "sc1"); + let get_st1_st1_sc2 = make_get_field(make_get_field(col("st1"), "st1"), "sc2"); + let st1_sc1_not_null = Expr::IsNotNull(Box::new(get_st1_sc1.clone())); + let st1_st1_sc1_not_null = Expr::IsNotNull(Box::new(get_st1_st1_sc1.clone())); + let test_expr = Expr::BinaryExpr(BinaryExpr::new( + Box::new(st1_sc1_not_null.clone()), + Operator::And, + Box::new(st1_st1_sc1_not_null.clone()), + )); +} + +fn build_context() -> SessionContext { + let config = SessionConfig::new() + .set_bool("datafusion.sql_parser.enable_ident_normalization", false) + .set_usize("datafusion.optimizer.max_passes", 2) + // 0 - disabled + // 1 - just main merging + // 2 - enable projection merging + // 1 | 2 == 3 - all + .set_usize( + "datafusion.optimizer.deep_column_pruning_flags", + FLAG_ENABLE + | FLAG_ENABLE_PROJECTION_MERGING + | FLAG_ENABLE_SUBQUERY_TRANSLATION, + ); + // .set_usize("datafusion.optimizer.deep_column_pruning_flags", FLAG_ENABLE | FLAG_ENABLE_PROJECTION_MERGING); + // .set_bool("datafusion.execution.skip_physical_aggregate_schema_check", true); + SessionContext::new_with_config(config) +} + +#[tokio::test] +async fn test_deep_projections_1() -> Result<()> { + let parquet_path = format!( + "{}/tests/data/deep_projections/first.parquet", + env!("CARGO_MANIFEST_DIR") + ); + + // { + // let file = File::open(Path::new(parquet_path.as_str()))?; + // let reader = SerializedFileReader::new(file).unwrap(); + // let parquet_schema = reader.metadata().file_metadata().schema_descr(); + // let arrow_schema = parquet_to_arrow_schema(parquet_schema, None).unwrap(); + // let df_schema = DFSchema::try_from(arrow_schema.clone()).unwrap(); + // + // let filters = vec![ + // get_field(col("cross_industry_demo_data.endUserIDs"), "aaid_id") + // .is_not_null(), + // get_field(col("cross_industry_demo_data.endUserIDs"), "aaid_id") + // .not_eq(lit("")), + // get_field(col("cross_industry_demo_data._experience"), "eVar56") + // .is_not_null(), + // get_field(col("cross_industry_demo_data._experience"), "eVar56") + // .not_eq(lit("")), + // ]; + // } + + let ctx = build_context(); + ctx.register_parquet( + "cross_industry_demo_data", + parquet_path, + ParquetReadOptions::default(), + ) + .await?; + let _ = run_deep_projection_optimize_test( + &ctx, + r#" + WITH events AS ( + SELECT + endUserIDs.aaid_id as DeviceId, + _experience.eVar56 as UserId, + timestamp + FROM + cross_industry_demo_data + WHERE _ACP_DATE='2025-01-03' + ) + SELECT + events.*, + LAG(UserId, 1) OVER (PARTITION BY DeviceId ORDER BY events.timestamp) AS PreviousUserColName, + cross_industry_demo_data._experience.eVar56 + FROM events + INNER JOIN cross_industry_demo_data on events.DeviceId = cross_industry_demo_data.endUserIDs.aaid_id + LIMIT 100 + "#, + vec![ + Some(HashMap::from([(0, vec![]), (1, vec![]), (2, vec!["aaid_id".to_string()]), (3, vec!["eVar56".to_string()])])), + Some(HashMap::from([(2, vec!["aaid_id".to_string()]), (3, vec!["eVar56".to_string()])])) + ] + ).await; + let _ = run_deep_projection_optimize_test( + &ctx, + r#" + SELECT + count(*) as count_events + FROM cross_industry_demo_data + WHERE + (_ACP_DATE BETWEEN '2023-01-01' AND '2025-02-02') + AND _experience.eVar56 is not null + LIMIT 100 + "#, + vec![Some(HashMap::from([ + (0, vec![]), + (3, vec!["eVar56".to_string()]), + ]))], + ) + .await; + let _ = run_deep_projection_optimize_test( + &ctx, + r#" + SELECT + endUserIDs + FROM cross_industry_demo_data + WHERE + (_ACP_DATE BETWEEN '2025-01-01' AND '2025-01-02') + AND _experience.eVar56 is not null + LIMIT 10 + "#, + vec![Some(HashMap::from([ + (0, vec![]), + (2, vec![]), + (3, vec!["eVar56".to_string()]), + ]))], + ) + .await; + let _ = run_deep_projection_optimize_test( + &ctx, + r#" + SELECT + * + FROM cross_industry_demo_data + WHERE + (_ACP_DATE BETWEEN '2023-01-01' AND '2025-02-02') + AND _experience.eVar56 is not null + LIMIT 100 + "#, + vec![None], + ) + .await; + + Ok(()) +} + +#[tokio::test] +async fn test_deep_projections_genstudio() -> Result<()> { + let ctx = build_context(); + let _ = ctx.register_parquet( + "meta_asset_summary_metrics", + format!("{}/tests/data/deep_projections/genstudio/meta_asset_summary_metrics.parquet", env!("CARGO_MANIFEST_DIR")), + ParquetReadOptions::default(), + ).await?; + let _ = ctx.register_parquet( + "meta_asset_summary_metrics_by_age_and_gender", + format!("{}/tests/data/deep_projections/genstudio/meta_asset_summary_metrics_by_age_and_gender.parquet", env!("CARGO_MANIFEST_DIR")), + ParquetReadOptions::default(), + ).await?; + let _ = ctx.register_parquet( + "meta_asset_featurization", + format!("{}/tests/data/deep_projections/genstudio/meta_asset_featurization.parquet", env!("CARGO_MANIFEST_DIR")), + ParquetReadOptions::default(), + ).await?; + + // Stats: Asset summary metrics + let _ = run_deep_projection_optimize_test( + &ctx, + r#" + SELECT + count(*) AS cnt + FROM + meta_asset_summary_metrics + WHERE + _ACP_DATE = '2024-12-01' + "#, + vec![Some(HashMap::from([(3, vec![])]))], + ) + .await?; + + // Preview: Asset summary metrics + let _ = run_deep_projection_optimize_test( + &ctx, + r#" + SELECT + * + FROM + meta_asset_summary_metrics + LIMIT 100 + "#, + vec![None], + ) + .await?; + + // Agg: Count assets by age + let _ = run_deep_projection_optimize_test( + &ctx, + r#" + SELECT + count (*) AS cnt, + _aresstagevalidationco['genStudioInsights']['age'] AS age + FROM + meta_asset_summary_metrics_by_age_and_gender + WHERE + _ACP_DATE = '2024-12-01' + GROUP BY + age + ORDER BY + cnt DESC + LIMIT + 10 + "#, + vec![Some(HashMap::from([ + (2, vec!["genStudioInsights.age".to_string()]), + (3, vec![]), + ]))], + ) + .await?; + + // Agg: clicks by url + let _ = run_deep_projection_optimize_test( + &ctx, + r#" + SELECT + AVG( + asset_metrics._aresstagevalidationco.genStudioInsights.metrics.performance.clicks.value + ) AS clicks, + asset_meta._aresstagevalidationco.contentAssets.assetThumbnailURL AS asset_url + FROM + (meta_asset_featurization AS asset_meta + INNER JOIN meta_asset_summary_metrics AS asset_metrics ON ( + asset_meta._aresstagevalidationco['contentAssets']['assetID'] = asset_metrics._aresstagevalidationco['genStudioInsights']['assetID'] + )) + WHERE + _ACP_DATE = '2024-12-01' + GROUP BY + asset_url + ORDER BY + clicks DESC + "#, + vec![ + Some( + HashMap::from([ + (1, vec!["contentAssets.assetThumbnailURL".to_string(), "contentAssets.assetID".to_string()]), + ]) + ), + Some( + HashMap::from([ + (2, vec!["genStudioInsights.metrics.performance.clicks.value".to_string(), "genStudioInsights.assetID".to_string()]), + (3, vec![]) + ]) + ), + ], + ) + .await?; + + // Agg: clicks by url + let _ = run_deep_projection_optimize_test( + &ctx, + r#" + SELECT + talias.tmetrics_aresstagevalidationco.genStudioInsights.metrics.performance.clicks.value as clicks, + talias.tfeatures_aresstagevalidationco.contentAssets.assetThumbnailURL AS asset_url + FROM ( + SELECT + asset_metrics._aresstagevalidationco AS tmetrics_aresstagevalidationco, + asset_meta._aresstagevalidationco AS tfeatures_aresstagevalidationco + FROM + meta_asset_featurization AS asset_meta + INNER JOIN meta_asset_summary_metrics AS asset_metrics ON ( + asset_meta._aresstagevalidationco.contentAssets.assetID = asset_metrics._aresstagevalidationco.genStudioInsights.assetID + ) + WHERE + _ACP_DATE = '2024-12-01' + ) AS talias + ORDER BY + clicks DESC + "#, + vec![ + Some( + HashMap::from([ + (1, vec!["contentAssets.assetThumbnailURL".to_string(), "contentAssets.assetID".to_string()]), + ]) + ), + Some( + HashMap::from([ + (2, vec!["genStudioInsights.metrics.performance.clicks.value".to_string(), "genStudioInsights.assetID".to_string()]), + (3, vec![]) + ]) + ), + ], + ) + .await?; + + // SQL Editor + let _ = run_deep_projection_optimize_test( + &ctx, + r#" + SELECT + _ACP_DATE DAY, + _aresstagevalidationco.genStudioInsights.campaignID campaign_id, + SUM( + _aresstagevalidationco.genStudioInsights.metrics.spend.value + ) total_spend + FROM + meta_asset_summary_metrics + WHERE + _ACP_DATE BETWEEN '2024-12-01' AND '2024-12-15' + GROUP BY + DAY, + campaign_id + ORDER BY + DAY, + total_spend DESC, + campaign_id + "#, + vec![Some(HashMap::from([ + ( + 2, + vec![ + "genStudioInsights.campaignID".to_string(), + "genStudioInsights.metrics.spend.value".to_string(), + ], + ), + (3, vec![]), + ]))], + ) + .await?; + + Ok(()) +} + +async fn run_deep_projection_optimize_test( + ctx: &SessionContext, + query: &str, + tests: Vec>, +) -> Result<()> { + let plan = ctx.state().create_logical_plan(query).await?; + let optimized_plan = ctx.state().optimize(&plan)?; + let state = ctx.state(); + let query_planner = state.query_planner().clone(); + let physical_plan = query_planner + .create_physical_plan(&optimized_plan, &state) + .await?; + let mut deep_projections: Vec> = vec![]; + let _ = physical_plan.apply(|pp| { + if let Some(pe) = pp.as_any().downcast_ref::() { + deep_projections.push(pe.base_config().projection_deep.clone()); + // pe.base_config().projection_deep + } + if let Some(dse) = pp.as_any().downcast_ref::() { + let data_source_dyn = dse.data_source(); + if let Some(data_source_file_scan_config) = + data_source_dyn.as_any().downcast_ref::() + { + deep_projections + .push(data_source_file_scan_config.projection_deep.clone()); + // pe.base_config().projection_deep + } + } + Ok(TreeNodeRecursion::Continue) + }); + info!( + "Checking if plan has these deep projections: {:?}", + &deep_projections + ); + assert_eq!(deep_projections.len(), tests.len()); + for i in 0..deep_projections.len() { + assert_eq!( + deep_projections[i], tests[i], + "Deep projections should be equal at index {}: got={:?} != expected={:?}", + i, deep_projections[i], tests[i] + ) + } + Ok(()) +} + +#[tokio::test] +async fn test_very_complicated_plan() -> Result<()> { + let _ = env_logger::try_init(); + + let config = SessionConfig::new() + .set_bool("datafusion.sql_parser.enable_ident_normalization", false) + .set_usize("datafusion.optimizer.max_passes", 2); + + let ctx = SessionContext::new_with_config(config); + // ctx.register_parquet("cross_industry_demo_data", "/Users/adragomi/output.parquet", ParquetReadOptions::default()).await?; + let _ = ctx + .sql( + r#" + CREATE OR REPLACE TABLE fact_profile_overlap_of_namespace ( + merge_policy_id INT8, + date_key DATE, + overlap_id INT8, + count_of_profiles INT8 + ); + "#, + ) + .await?; + + let _ = ctx + .sql( + r#" + CREATE OR REPLACE TABLE dim_overlap_namespaces ( + overlap_id INT8, + merge_policy_id INT8, + overlap_namespaces VARCHAR + ); + "#, + ) + .await?; + + let _ = ctx + .sql( + r#" + CREATE OR REPLACE TABLE fact_profile_by_namespace_trendlines ( + namespace_id INT8, + merge_policy_id INT8, + date_key DATE, + count_of_profiles INT8 + ); + "#, + ) + .await?; + + let _ = ctx + .sql( + r#" + CREATE OR REPLACE TABLE dim_namespaces ( + namespace_id INT8, + namespace_description VARCHAR, + merge_policy_id INT8 + ); + "#, + ) + .await?; + info!("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"); + let query = r#" +SELECT + sum(overlap_col1) overlap_col1, + sum(overlap_col2) overlap_col2, + coalesce(Sum(overlap_count), 0) overlap_count +FROM + ( + SELECT + 0 overlap_col1, + 0 overlap_col2, + Sum(count_of_profiles) overlap_count + FROM + fact_profile_overlap_of_namespace + WHERE + fact_profile_overlap_of_namespace.merge_policy_id = -115008144 + AND fact_profile_overlap_of_namespace.date_key = '2024-11-06' + AND fact_profile_overlap_of_namespace.overlap_id IN ( + SELECT + a.overlap_id + FROM + ( + SELECT + dim_overlap_namespaces.overlap_id overlap_id, + count(*) cnt_num + FROM + dim_overlap_namespaces + WHERE + dim_overlap_namespaces.merge_policy_id = -115008144 + AND dim_overlap_namespaces.overlap_namespaces IN ( + 'aaid', + 'ecid' + ) + GROUP BY + dim_overlap_namespaces.overlap_id + ) a + WHERE + a.cnt_num > 1 + ) + UNION + ALL + SELECT + count_of_profiles overlap_col1, + 0 overlap_col2, + 0 overlap_count + FROM + fact_profile_by_namespace_trendlines + JOIN dim_namespaces ON fact_profile_by_namespace_trendlines.namespace_id = dim_namespaces.namespace_id + AND fact_profile_by_namespace_trendlines.merge_policy_id = dim_namespaces.merge_policy_id + WHERE + fact_profile_by_namespace_trendlines.merge_policy_id = -115008144 + AND fact_profile_by_namespace_trendlines.date_key = '2024-11-06' + AND dim_namespaces.namespace_description = 'aaid' + UNION + ALL + SELECT + 0 overlap_col1, + count_of_profiles overlap_col2, + 0 overlap_count + FROM + fact_profile_by_namespace_trendlines + JOIN dim_namespaces ON fact_profile_by_namespace_trendlines.namespace_id = dim_namespaces.namespace_id + AND fact_profile_by_namespace_trendlines.merge_policy_id = dim_namespaces.merge_policy_id + WHERE + fact_profile_by_namespace_trendlines.merge_policy_id = -115008144 + AND fact_profile_by_namespace_trendlines.date_key = '2024-11-06' + AND dim_namespaces.namespace_description = 'ecid' + ) a; + "#; + let plan = ctx.state().create_logical_plan(query).await?; + info!("plan: {}", &plan); + let optimized_plan = ctx.state().optimize(&plan)?; + info!("optimized: {}", &optimized_plan.display_indent()); + let result = ctx.execute_logical_plan(optimized_plan).await?; + // let result = ctx.sql(query).await?; + result.show().await?; + + let push_down_limit = Arc::new(PushDownLimit::new()); + let push_down_filter = Arc::new(PushDownFilter::new()); + let subexpr_eliminator = Arc::new(CommonSubexprEliminate::new()); + let state = ctx.state(); + + Ok(()) +} + +#[tokio::test] +async fn test_mid_values_window() -> Result<()> { + let _ = env_logger::try_init(); + + let config = SessionConfig::new() + .set_bool("datafusion.sql_parser.enable_ident_normalization", false) + .set_usize("datafusion.optimizer.max_passes", 1) + .set_usize("datafusion.optimizer.deep_column_pruning_flags", 7); + + let ctx = SessionContext::new_with_config(config); + let _ = ctx + .register_parquet( + "midvalues", + format!( + "{}/tests/data/deep_projections/triplea/midvalues.parquet", + env!("CARGO_MANIFEST_DIR") + ), + ParquetReadOptions::default(), + ) + .await?; + let query = r#" + SELECT + timestamp, + web.webPageDetails.pageViews.value AS pageview, + endUserIDs._experience.mcid.id AS mcid, + endUserIDs._experience.aaid.id AS aaid, + COALESCE( + endUserIDs._experience.mcid.id, + endUserIDs._experience.aaid.id + ) AS partitionCol, + LAG(timestamp) OVER( + PARTITION BY COALESCE( + endUserIDs._experience.mcid.id, + endUserIDs._experience.aaid.id + ) + ORDER BY timestamp + ) AS last_event + FROM + midvalues + WHERE + timestamp >= TO_TIMESTAMP('2025-01-15') + AND timestamp < TO_TIMESTAMP('2025-01-16') + + "#; + + let _ = run_deep_projection_optimize_test( + &ctx, + query, + vec![Some(HashMap::from([ + (0, vec![]), + (1, vec!["webPageDetails.pageViews.value".to_string()]), + ( + 2, + vec![ + "_experience.mcid.id".to_string(), + "_experience.aaid.id".to_string(), + ], + ), + ]))], + ) + .await; + // let plan = ctx.state().create_logical_plan(query).await?; + // info!("plan: {}", &plan); + // let optimized_plan = ctx.state().optimize(&plan)?; + // info!("optimized: {}", &optimized_plan.display_indent()); + // let result = ctx.execute_logical_plan(optimized_plan).await?; + // // let result = ctx.sql(query).await?; + // result.show().await?; + + Ok(()) +} + +#[tokio::test] +async fn test_mid_values_window_execution() -> Result<()> { + let _ = env_logger::try_init(); + + let config = SessionConfig::new() + .set_bool("datafusion.sql_parser.enable_ident_normalization", false) + .set_usize("datafusion.optimizer.max_passes", 2) + .set_usize("datafusion.optimizer.deep_column_pruning_flags", 7); + + let ctx = SessionContext::new_with_config(config); + let _ = ctx + .register_parquet( + "midvalues", + format!( + "{}/tests/data/deep_projections/triplea/midvalues.parquet", + env!("CARGO_MANIFEST_DIR") + ), + ParquetReadOptions::default(), + ) + .await?; + let query = r#" + SELECT + timestamp, + web.webPageDetails.pageViews.value AS pageview, + endUserIDs._experience.mcid.id AS mcid, + endUserIDs._experience.aaid.id AS aaid, + COALESCE( + endUserIDs._experience.mcid.id, + endUserIDs._experience.aaid.id + ) AS partitionCol, + LAG(timestamp) OVER( + PARTITION BY COALESCE( + endUserIDs._experience.mcid.id, + endUserIDs._experience.aaid.id + ) + ORDER BY timestamp + ) AS last_event + FROM + midvalues + WHERE + timestamp >= TO_TIMESTAMP('2025-01-15') + AND timestamp < TO_TIMESTAMP('2025-01-16') + + "#; + let result = ctx.sql(query).await?.collect().await?; + + Ok(()) +} + +// #[test] +// fn test_adr() -> datafusion_common::Result<()> { +// let tmp = datafusion_functions::expr_fn::get_field( +// datafusion_functions::expr_fn::get_field( +// col("aa"), +// "bb" +// ), +// "cc" +// ); +// let kk = expr_to_deep_columns(&tmp); +// info!("kk: {:#?}", kk); +// +// let tmp = +// datafusion_functions::expr_fn::get_field( +// datafusion_functions_nested::expr_fn::array_element( +// col("list_struct"), +// 0_i32.lit() +// ), +// "cc" +// ) +// ; +// let kk = expr_to_deep_columns(&tmp); +// info!("kk: {:#?}", kk); +// +// let tmp = datafusion_functions::expr_fn::nullif( +// datafusion_functions::expr_fn::get_field( +// datafusion_functions_nested::expr_fn::array_element( +// col("list_struct"), +// 0_i32.lit() +// ), +// "cc" +// ), +// datafusion_functions::expr_fn::get_field( +// datafusion_functions::expr_fn::get_field( +// col("othercol"), +// "bb" +// ), +// "cc" +// ) +// ); +// let kk = expr_to_deep_columns(&tmp); +// info!("kk: {:#?}", kk); +// +// Ok(()) +// } diff --git a/datafusion/datasource/src/file_groups.rs b/datafusion/datasource/src/file_groups.rs index 2c2b791f23657..9cc96f6de034c 100644 --- a/datafusion/datasource/src/file_groups.rs +++ b/datafusion/datasource/src/file_groups.rs @@ -213,7 +213,13 @@ impl FileGroupPartitioner { .iter() .map(|f| f.object_meta.size as i64) .sum::(); - if total_size < (repartition_file_min_size as i64) || total_size == 0 { + + // bail if we are asked to *split* a set of files that are already too small + // if we are being asked to consolidate, we proceed + if (total_size < (repartition_file_min_size as i64) + && target_partitions >= file_groups.len()) + || total_size == 0 + { return None; } @@ -228,30 +234,46 @@ impl FileGroupPartitioner { .scan( (current_partition_index, current_partition_size), |state, source_file| { - let mut produced_files = vec![]; - let mut range_start = 0; - while range_start < source_file.object_meta.size { - let range_end = min( - range_start + (target_partition_size - state.1), - source_file.object_meta.size, - ); - - let mut produced_file = source_file.clone(); - produced_file.range = Some(FileRange { - start: range_start as i64, - end: range_end as i64, - }); - produced_files.push((state.0, produced_file)); - - if state.1 + (range_end - range_start) >= target_partition_size { + // Skip splitting files smaller than repartition_file_min_size + // This may result in a few more partitions than requested (maybe 1 more) + if source_file.object_meta.size > 0 + && source_file.object_meta.size < repartition_file_min_size + { + state.1 += source_file.object_meta.size; + if state.1 > target_partition_size { state.0 += 1; state.1 = 0; - } else { - state.1 += range_end - range_start; } - range_start = range_end; + let small_file = (state.0, source_file.clone()); + Some(vec![small_file]) + } else { + let mut produced_files = vec![]; + let mut range_start = 0; + while range_start < source_file.object_meta.size { + let range_end = min( + range_start + (target_partition_size - state.1), + source_file.object_meta.size, + ); + + let mut produced_file = source_file.clone(); + produced_file.range = Some(FileRange { + start: range_start as i64, + end: range_end as i64, + }); + produced_files.push((state.0, produced_file)); + + if state.1 + (range_end - range_start) + >= target_partition_size + { + state.0 += 1; + state.1 = 0; + } else { + state.1 += range_end - range_start; + } + range_start = range_end; + } + Some(produced_files) } - Some(produced_files) }, ) .flatten() diff --git a/datafusion/datasource/src/file_scan_config.rs b/datafusion/datasource/src/file_scan_config.rs index bee74e042f220..c9ef0e3164198 100644 --- a/datafusion/datasource/src/file_scan_config.rs +++ b/datafusion/datasource/src/file_scan_config.rs @@ -31,6 +31,7 @@ use arrow::{ buffer::Buffer, datatypes::{ArrowNativeType, DataType, Field, Schema, SchemaRef, UInt16Type}, }; +use datafusion_common::deep::rewrite_field_projection; use datafusion_common::{exec_err, ColumnStatistics, Constraints, Result, Statistics}; use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_execution::{ @@ -46,7 +47,7 @@ use datafusion_physical_plan::{ projection::{all_alias_free_columns, new_projections_for_columns, ProjectionExec}, DisplayAs, DisplayFormatType, ExecutionPlan, }; -use log::{debug, warn}; +use log::{debug, info, trace, warn}; use crate::{ display::FileGroupsDisplay, @@ -154,6 +155,9 @@ pub struct FileScanConfig { /// Columns on which to project the data. Indexes that are higher than the /// number of columns of `file_schema` refer to `table_partition_cols`. pub projection: Option>, + /// Columns on which to project the data. Indexes that are higher than the + /// number of columns of `file_schema` refer to `table_partition_cols`. + pub projection_deep: Option>>, /// The maximum number of records to read from this plan. If `None`, /// all records after filtering are returned. pub limit: Option, @@ -321,6 +325,7 @@ impl FileScanConfig { constraints: Constraints::empty(), statistics, projection: None, + projection_deep: None, limit: None, table_partition_cols: vec![], output_ordering: vec![], @@ -394,7 +399,23 @@ impl FileScanConfig { .into_iter() .map(|idx| { if idx < self.file_schema.fields().len() { - self.file_schema.field(idx).clone() + let output_field = match &self.projection_deep { + None => self.file_schema.field(idx).clone(), + Some(projection_deep) => { + trace!("FileScanConfig::project DEEP PROJECT"); + let rewritten_field_arc = rewrite_field_projection( + self.file_schema.clone(), + idx, + &projection_deep, + ); + trace!( + "FileScanConfig::project DEEP PROJECT {:#?}", + rewritten_field_arc + ); + rewritten_field_arc.as_ref().clone() + } + }; + output_field } else { let partition_idx = idx - self.file_schema.fields().len(); self.table_partition_cols[partition_idx].clone() @@ -422,6 +443,15 @@ impl FileScanConfig { self } + /// Set the projection of the files + pub fn with_projection_deep( + mut self, + projection_deep: Option>>, + ) -> Self { + self.projection_deep = projection_deep; + self + } + /// Set the limit of the files pub fn with_limit(mut self, limit: Option) -> Self { self.limit = limit; diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 37e1ed1936fb4..e2ef17b57f895 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -54,6 +54,7 @@ paste = "^1.0" recursive = { workspace = true, optional = true } serde_json = { workspace = true } sqlparser = { workspace = true } +log = "0.4.21" [dev-dependencies] ctor = { workspace = true } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 72b82fc219eb6..745de663d015e 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -60,9 +60,11 @@ use datafusion_common::{ UnnestOptions, }; use indexmap::IndexSet; +use log::trace; // backwards compatibility use crate::display::PgJsonVisitor; +use datafusion_common::deep::rewrite_schema; pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; pub use datafusion_common::{JoinConstraint, JoinType}; @@ -1743,6 +1745,7 @@ impl LogicalPlan { ref source, ref table_name, ref projection, + ref projection_deep, ref filters, ref fetch, .. @@ -1758,8 +1761,16 @@ impl LogicalPlan { } _ => "".to_string(), }; + let projected_fields_deep = match projection_deep { + Some(deep_indices) => { + // breaks tests + // format!(" projection_deep=[{:?}]", deep_indices) + "".to_string() + } + _ => "".to_string(), + }; - write!(f, "TableScan: {table_name}{projected_fields}")?; + write!(f, "TableScan: {table_name}{projected_fields}{projected_fields_deep}")?; if !filters.is_empty() { let mut full_filter = vec![]; @@ -2506,6 +2517,8 @@ pub struct TableScan { pub source: Arc, /// Optional column indices to use as a projection pub projection: Option>, + /// Optional column indices to use as a projection + pub projection_deep: Option>>, /// The schema description of the output pub projected_schema: DFSchemaRef, /// Optional expressions to be used as filters by the table provider @@ -2520,6 +2533,7 @@ impl Debug for TableScan { .field("table_name", &self.table_name) .field("source", &"...") .field("projection", &self.projection) + // .field("projection_deep", &self.projection_deep) .field("projected_schema", &self.projected_schema) .field("filters", &self.filters) .field("fetch", &self.fetch) @@ -2627,11 +2641,109 @@ impl TableScan { table_name, source: table_source, projection, + projection_deep: None, projected_schema, filters, fetch, }) } + + /// Initialize TableScan with appropriate schema from the given + /// arguments. + pub fn try_new_with_deep_projection( + table_name: impl Into, + table_source: Arc, + projection: Option>, + projection_deep: Option>>, + filters: Vec, + fetch: Option, + ) -> Result { + trace!(target: "deep", "TableScan::try_new_with_deep_projection: {:#?}, {:#?}", projection, projection_deep); + let table_name = table_name.into(); + + if table_name.table().is_empty() { + return plan_err!("table_name cannot be empty"); + } + let schema = table_source.schema(); + let func_dependencies = FunctionalDependencies::new_from_constraints( + table_source.constraints(), + schema.fields.len(), + ); + let projected_schema = projection + .as_ref() + .map(|p| { + let projected_func_dependencies = + func_dependencies.project_functional_dependencies(p, p.len()); + + let df_schema = DFSchema::new_with_metadata( + p.iter() + .map(|i| { + (Some(table_name.clone()), Arc::new(schema.field(*i).clone())) + }) + .collect(), + schema.metadata.clone(), + )?; + df_schema.with_functional_dependencies(projected_func_dependencies) + }) + .unwrap_or_else(|| { + let df_schema = + DFSchema::try_from_qualified_schema(table_name.clone(), &schema)?; + df_schema.with_functional_dependencies(func_dependencies) + })?; + let mut projected_schema = Arc::new(projected_schema); + + // reproject for deep schema + if projection.is_some() && projection_deep.is_some() { + let projection_clone = projection.unwrap().clone(); + let projection_deep_clone = projection_deep.unwrap().clone(); + let mut new_projection_deep: HashMap> = HashMap::new(); + projection_clone.iter().enumerate().for_each(|(ip, elp)| { + let empty_vec: Vec = vec![]; + let deep = projection_deep_clone.get(elp).or(Some(&empty_vec)).unwrap(); + new_projection_deep.insert(ip, deep.clone()); + }); + let new_projection = (0..projection_clone.len()).collect::>(); + let inner_projected_schema = projected_schema.inner().clone(); + let new_inner_projected_schema = rewrite_schema( + inner_projected_schema, + &new_projection, + &new_projection_deep, + ); + let mut new_projected_schema_df = DFSchema::new_with_metadata( + new_inner_projected_schema + .fields() + .iter() + .map(|fi| (Some(table_name.clone()), fi.clone())) + .collect(), + schema.metadata.clone(), + )?; + new_projected_schema_df = new_projected_schema_df + .with_functional_dependencies( + projected_schema.functional_dependencies().clone(), + )?; + projected_schema = Arc::new(new_projected_schema_df); + + Ok(Self { + table_name, + source: table_source, + projection: Some(projection_clone), + projection_deep: Some(projection_deep_clone), + projected_schema, + filters, + fetch, + }) + } else { + Ok(Self { + table_name, + source: table_source, + projection, + projection_deep, + projected_schema, + filters, + fetch, + }) + } + } } // Repartition the plan based on a partitioning scheme. @@ -4309,6 +4421,7 @@ digraph { table_name: TableReference::bare("tab"), source: Arc::clone(&source) as Arc, projection: None, + projection_deep: None, projected_schema: Arc::clone(&schema), filters: vec![], fetch: None, @@ -4339,6 +4452,7 @@ digraph { table_name: TableReference::bare("tab"), source, projection: None, + projection_deep: None, projected_schema: Arc::clone(&unique_schema), filters: vec![], fetch: None, diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index dfc18c74c70aa..1509e14952922 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -605,14 +605,17 @@ impl LogicalPlan { table_name, source, projection, + projection_deep, projected_schema, filters, fetch, + .. }) => filters.map_elements(f)?.update_data(|filters| { LogicalPlan::TableScan(TableScan { table_name, source, projection, + projection_deep, projected_schema, filters, fetch, diff --git a/datafusion/expr/src/registry.rs b/datafusion/expr/src/registry.rs index 4eb49710bcf85..a2f5a45e7b9b7 100644 --- a/datafusion/expr/src/registry.rs +++ b/datafusion/expr/src/registry.rs @@ -19,7 +19,7 @@ use crate::expr_rewriter::FunctionRewrite; use crate::planner::ExprPlanner; -use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF}; +use crate::{AggregateUDF, ScalarUDF, TableSource, UserDefinedLogicalNode, WindowUDF}; use datafusion_common::{not_impl_err, plan_datafusion_err, HashMap, Result}; use std::collections::HashSet; use std::fmt::Debug; @@ -123,24 +123,58 @@ pub trait FunctionRegistry { } } -/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode]. +/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode] +/// and custom table providers for which the name alone is meaningless in the target +/// execution context, e.g. UDTFs, manually registered tables etc. pub trait SerializerRegistry: Debug + Send + Sync { /// Serialize this node to a byte array. This serialization should not include /// input plans. fn serialize_logical_plan( &self, node: &dyn UserDefinedLogicalNode, - ) -> Result>; + ) -> Result { + not_impl_err!( + "Serializing user defined logical plan node `{}` is not supported", + node.name() + ) + } /// Deserialize user defined logical plan node ([UserDefinedLogicalNode]) from /// bytes. fn deserialize_logical_plan( &self, name: &str, - bytes: &[u8], - ) -> Result>; + _bytes: &[u8], + ) -> Result> { + not_impl_err!( + "Deserializing user defined logical plan node `{name}` is not supported" + ) + } + + /// Serialized table definition for UDTFs or some other table provider implementation that + /// can't be marshaled by reference. + fn serialize_custom_table( + &self, + _table: &dyn TableSource, + ) -> Result> { + Ok(None) + } + + /// Deserialize a custom table. + fn deserialize_custom_table( + &self, + name: &str, + _bytes: &[u8], + ) -> Result> { + not_impl_err!("Deserializing custom table `{name}` is not supported") + } } +/// A sequence of bytes with a string qualifier. Meant to encapsulate serialized extensions +/// that need to carry their type, e.g. the `type_url` for protobuf messages. +#[derive(Debug, Clone)] +pub struct NamedBytes(pub String, pub Vec); + /// A [`FunctionRegistry`] that uses in memory [`HashMap`]s #[derive(Default, Debug)] pub struct MemoryFunctionRegistry { diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 3ac26b98359bb..05c6b53d394cf 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -15,14 +15,11 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{ - make_array, make_comparator, Array, BooleanArray, Capacities, MutableArrayData, - Scalar, -}; +use arrow::array::{make_array, make_comparator, Array, BooleanArray, Capacities, ListArray, MutableArrayData, Scalar, StructArray}; use arrow::compute::SortOptions; use arrow::datatypes::DataType; use arrow_buffer::NullBuffer; -use datafusion_common::cast::{as_map_array, as_struct_array}; +use datafusion_common::cast::{as_list_array, as_map_array, as_struct_array}; use datafusion_common::{ exec_err, internal_err, plan_datafusion_err, utils::take_function_args, Result, ScalarValue, @@ -138,6 +135,15 @@ impl ScalarUDFImpl for GetFieldFunc { debug_assert_eq!(args.scalar_arguments.len(), 2); match (&args.arg_types[0], args.scalar_arguments[1].as_ref()) { + (DataType::List(field), Some(ScalarValue::Utf8(Some(field_name)))) => { + if let DataType::Struct(fields) = field.data_type() { + fields.iter().find(|f| f.name() == field_name) + .ok_or(plan_datafusion_err!("Field {field_name} not found in struct")) + .map(|f| ReturnInfo::new_nullable(DataType::List(f.clone()))) + } else { + exec_err!("Expected a List of Structs") + } + } (DataType::Map(fields, _), _) => { match fields.data_type() { DataType::Struct(fields) if fields.len() == 2 => { @@ -185,6 +191,39 @@ impl ScalarUDFImpl for GetFieldFunc { } }; + pub fn get_field_from_list( + array: Arc, + field_name: &str, + ) -> Result { + let list_array = as_list_array(array.as_ref())?; + match list_array.value_type() { + DataType::Struct(fields) => { + let struct_array = as_struct_array(list_array.values()).or_else(|_| { + exec_err!("Expected a StructArray inside the ListArray") + })?; + let Some(field_index) = fields + .iter() + .position(|f| f.name() == field_name) + else { + return exec_err!("Field {field_name} not found in struct") + }; + let projection_array = struct_array.column(field_index); + + let (_, offsets, _, nulls) = list_array.clone().into_parts(); + + let new_list = ListArray::new( + fields[field_index].clone(), + offsets, + projection_array.to_owned(), + nulls, + ); + + Ok(ColumnarValue::Array(Arc::new(new_list))) + } + _ => exec_err!("Expected a ListArray of Structs"), + } + } + fn process_map_array( array: Arc, key_array: Arc, @@ -235,6 +274,13 @@ impl ScalarUDFImpl for GetFieldFunc { } match (array.data_type(), name) { + (DataType::List(field), ScalarValue::Utf8(Some(k))) => { + if let DataType::Struct(_) = field.data_type() { + get_field_from_list(array, &k) + } else { + exec_err!("Expected a List of Structs") + } + } (DataType::Map(_, _), ScalarValue::List(arr)) => { let key_array: Arc = arr; process_map_array(array, key_array) diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 3413b365f67de..5d8be72dd9ce5 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -46,6 +46,9 @@ chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } datafusion-physical-expr = { workspace = true } +datafusion-functions = { workspace = true } +datafusion-functions-nested = { workspace = true } +datafusion-functions-aggregate = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true } log = { workspace = true } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 5dc1a7e5ac5b3..6ad5f411389db 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -689,11 +689,24 @@ impl CSEController for ExprCSEController<'_> { | Expr::Wildcard { .. } ); + // ADR: Option 1 Fix get field screwing up the plan - this makes it so that get_field no longer breaks the plan + let mut is_get_field = false; + if true { + match node { + Expr::ScalarFunction(ScalarFunction { func, args: _args }) => { + if func.name() == "get_field" { + is_get_field = true; + } + } + _ => {} + } + } + let is_aggr = matches!(node, Expr::AggregateFunction(..)); match self.mask { - ExprMask::Normal => is_normal_minus_aggregates || is_aggr, - ExprMask::NormalAndAggregates => is_normal_minus_aggregates, + ExprMask::Normal => is_get_field || is_normal_minus_aggregates || is_aggr, + ExprMask::NormalAndAggregates => is_get_field || is_normal_minus_aggregates, } } diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 61ca9b31cd29b..f314b626183dd 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -52,6 +52,7 @@ pub mod eliminate_outer_join; pub mod extract_equijoin_predicate; pub mod filter_null_join_keys; pub mod optimize_projections; +pub mod optimize_projections_deep; pub mod optimizer; pub mod propagate_empty_relation; pub mod push_down_filter; diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index b7dd391586a18..1ec0767c404be 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -249,7 +249,7 @@ fn optimize_projections( projection, filters, fetch, - projected_schema: _, + projected_schema: _, .. } = table_scan; // Get indices referred to in the original (schema with all fields) diff --git a/datafusion/optimizer/src/optimize_projections_deep.rs b/datafusion/optimizer/src/optimize_projections_deep.rs new file mode 100644 index 0000000000000..970cfa86e9def --- /dev/null +++ b/datafusion/optimizer/src/optimize_projections_deep.rs @@ -0,0 +1,1138 @@ +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; +use arrow::datatypes::{DataType, Field, FieldRef, SchemaRef}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::deep::{rewrite_schema, try_rewrite_schema_opt}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion_common::{ + Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference, +}; +use datafusion_expr::expr::{Alias, ScalarFunction}; +use datafusion_expr::{ + build_join_schema, Expr, Join, LogicalPlan, Projection, Subquery, SubqueryAlias, + TableScan, Union, +}; +use log::{error, info, trace, warn}; +use std::collections::{HashMap, VecDeque}; +use std::fmt::{Debug, Formatter}; +use std::hash::Hash; +use std::sync::Arc; + +#[derive(Default, Debug)] +pub struct OptimizeProjectionsDeep {} + +impl OptimizeProjectionsDeep { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +pub const FLAG_ENABLE: usize = 1; +pub const FLAG_ENABLE_PROJECTION_MERGING: usize = 2; +pub const FLAG_ENABLE_SUBQUERY_TRANSLATION: usize = 4; + +impl OptimizerRule for OptimizeProjectionsDeep { + fn name(&self) -> &str { + "optimize_projections_deep" + } + + fn apply_order(&self) -> Option { + None + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + // All output fields are necessary: + info!("optimize_projections_deep: plan={}", plan.display_indent()); + let options = config.options(); + if options.optimizer.deep_column_pruning_flags & FLAG_ENABLE != 0 { + let new_plan_transformed = deep_plan_transformer(plan, vec![], options)?; + if new_plan_transformed.transformed { + let new_plan = new_plan_transformed.data; + let new_plan = new_plan.transform_up(|p| { + // ADR: just doing recompute schema for all breaks push_down_filter::tests::multi_combined_filter_exact + match p { + LogicalPlan::Window(_) => { + Ok(Transformed::yes(p.recompute_schema()?)) + } + LogicalPlan::Aggregate(_) => { + Ok(Transformed::yes(p.recompute_schema()?)) + } + _ => Ok(Transformed::no(p.clone())), + } + // Ok(Transformed::yes(p.recompute_schema()?)) + })?; + Ok(new_plan) + } else { + Ok(Transformed::no(new_plan_transformed.data)) + } + } else { + Ok(Transformed::no(plan)) + } + } +} + +pub type DeepColumnColumnMap = HashMap>; +pub type DeepColumnIndexMap = HashMap>; + +#[derive(Clone)] +pub struct PlanWithDeepColumnMap { + pub plan: Arc, + pub columns: HashMap>, +} + +impl PlanWithDeepColumnMap { + pub fn from_plan(input: &LogicalPlan) -> Self { + Self { + plan: Arc::new(input.clone()), + columns: get_columns_referenced_in_plan(input), + } + } +} + +impl Debug for PlanWithDeepColumnMap { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str(format!("{} {:?}", get_plan_name(&self.plan), self.columns).as_str()) + } +} + +pub fn is_empty_deep_projection(p: &Vec) -> bool { + p.len() == 1 && p[0] == "*" +} + +pub fn deep_plan_transformer( + plan: LogicalPlan, + cols_from_above: Vec, + options: &ConfigOptions, +) -> Result> { + let mut cols_level: Vec = cols_from_above.clone(); + plan.transform_down(|plan| { + info!( + ">>>>>>>>>>>>>>\ndeep_plan_transformer working on plan={}, cols_from_above ={:?}", + &plan.display(), &cols_from_above + ); + if options.optimizer.deep_column_pruning_flags & FLAG_ENABLE_SUBQUERY_TRANSLATION != 0 { + // try to find aliases in current plan + // for each alias + // try to look at above projections where the table is empty. + // if the column name is the same as the name of the alias, we need to change that pdc + // the key is the alias, the column is the new column + let mut aliases_in_current_plan: HashMap = HashMap::new(); + if matches!(plan, LogicalPlan::TableScan(_)) { + // all the null ones are for me ! + // a table scan has no aliases, we are at the end + let LogicalPlan::TableScan(TableScan { + table_name, .. + }) = &plan else {unreachable!()}; + let this_table_relation = table_name.clone(); + for PlanWithDeepColumnMap { plan: _plan, columns } in cols_level.iter() { + for (col, _projections) in columns.iter() { + if col.relation.is_none() { + // the key for this column is in the map ! + let new_col = Column { + relation: Some(this_table_relation.clone()), + name: col.name.clone(), + spans: Default::default(), + }; + aliases_in_current_plan.insert(col.name.clone(), new_col); + } + } + } + } else if matches!(plan, LogicalPlan::Projection(_)) { + for expr in plan.expressions().iter() { + match expr { + Expr::Alias(Alias { expr: alias_source, relation: _relation, name }) => { + // FIXME: what is the correct relation ? the one in the alias or the one in the expr ? + // We'll use the one in the expr + let alias_source_column_map = expr_to_deep_columns(alias_source.as_ref()); + if alias_source_column_map.len() != 1 { + warn!("subquery translation: fix aliased cols from above - strange result for map: {:?}", &alias_source_column_map); + } else { + let tmp_map_vec = alias_source_column_map.into_iter().collect::>(); + // we have a single entry + let (new_col, _new_deep_projection) = tmp_map_vec[0].clone(); + aliases_in_current_plan.insert(name.clone(), new_col); + } + } + _ => { + let column_map = expr_to_deep_columns(expr); + if column_map.len() != 1 { + warn!("subquery translation: fix aliased cols from above - strange result for map: {:?}", &column_map); + } else { + let tmp_map_vec = column_map.into_iter().collect::>(); + // we have a single entry + let (new_col, _new_deep_projection) = tmp_map_vec[0].clone(); + aliases_in_current_plan.insert(new_col.name.clone(), new_col); + } + } + } + } + } + // info!("subquery translation: found in current plan: {:?}", aliases_in_current_plan); + // info!("subquery translation: found in current plan: {:?}", &cols_from_above); + // now, we go through all the pdcs that we got, and we see whether their relation is null and their column name is in the aliases_in_current_plan map + // if they are, then we rewrite them so that the column + let mut new_cols_level: Vec = vec![]; + for PlanWithDeepColumnMap { plan, columns } in cols_level.iter() { + let mut new_columns: DeepColumnColumnMap = HashMap::new(); + for (col, projections) in columns.iter() { + if col.relation.is_none() && aliases_in_current_plan.contains_key(&col.name) { + // the key for this column is in the map ! + let actual_column = aliases_in_current_plan.get(&col.name).unwrap(); + // info!("subquery translation: replacing column {:?} with {:?} = {:?}", col, actual_column, projections); + new_columns.insert(actual_column.clone(), projections.clone()); + } else { + // copy the column + new_columns.insert(col.clone(), projections.clone()); + } + } + new_cols_level.push(PlanWithDeepColumnMap { + plan: plan.clone(), + columns: new_columns, + }); + } + cols_level = new_cols_level; + info!("subquery translation: REPLACED: {:?}", &cols_level); + } + + let current_pdc = PlanWithDeepColumnMap::from_plan(&plan); + cols_level.push(current_pdc); + // info!("BEFORE: {:?}", &cols_level); + + match plan { + LogicalPlan::Projection(proj) => { + let Projection { + expr, input, schema, .. + } = proj; + + let mut new_proj_schema: Option = None; + if options.optimizer.deep_column_pruning_flags & FLAG_ENABLE_PROJECTION_MERGING != 0 { + // info!("projection merging: EXECUTING FLAG_ENABLE_PROJECTION_MERGING code path, cols_level = {:?}", &cols_level); + // info!("projection merging: EXECUTING FLAG_ENABLE_PROJECTION_MERGING expr = {:?}", &expr); + if are_exprs_plain(&expr) && cols_level.len() > 0 { + // info!("projection merging: can proceed, projection is plain"); + // this projection is a candidate, it "looks" like an inserted projection + // for each of the columns in this projection, try to find previous references to this column + // get the largest of these references and set it to this projection + let my_pdc = cols_level.last().unwrap(); + let mut my_new_modified_columns: DeepColumnColumnMap = HashMap::new(); + 'col_iterate: for (my_col, my_deep_projections) in my_pdc.columns.iter() { + // info!("projection merging searching previous projections for column {:?}, current projections={:?}", my_col, my_deep_projections); + for my_deep_projection in my_deep_projections.iter() { + // col + projection + if my_deep_projection == "*" { + // we only do this if it's a top-level projection + let previous_deep_projections = find_all_projections_for_column(my_col, &cols_level, 0, cols_level.len() - 1); + // info!("projection merging found previous deep projections for column {:?}: {:?}", my_col, &previous_deep_projections); + if previous_deep_projections.len() == 0 { + // we don't specify this column in the previous projection, so we need to leave it + continue 'col_iterate; + } + // we have previous deep projections + // info!("projection merging rewrite projection for col {:?}: {:?}", my_col, previous_deep_projections); + // change the hashmap entry at this level + // ADR: WTF, simplify this, don't know how to write in the friggin vec + my_new_modified_columns.insert(my_col.clone(), previous_deep_projections); + let new_pdc = PlanWithDeepColumnMap { + plan: my_pdc.plan.clone(), + columns: HashMap::new(), + }; + } + } + } + // rewrite it ! + if my_new_modified_columns.len() > 0 { + let mut my_new_pdc = PlanWithDeepColumnMap { + plan: my_pdc.plan.clone(), + columns: HashMap::new(), + }; + for (my_pdc_col, my_pdc_deep_projection) in my_pdc.columns.iter() { + // we just change the projections for this col + // it doesn't matter that it doesn't reflect what's in the actual projection, because we don't use that anymore below + if let Some(new_deep_projection) = my_new_modified_columns.get(my_pdc_col) { + my_new_pdc.columns.insert(my_pdc_col.clone(), new_deep_projection.clone()); + } else { + my_new_pdc.columns.insert(my_pdc_col.clone(), my_pdc_deep_projection.clone()); + } + } + let last_index = cols_level.len() - 1; + cols_level[last_index] = my_new_pdc; + // rewrite the projection schema accordingly !!!! + // ADR: this shouldn't be needed. But it is, there is an extra aggregation step that verifies that the physical schema is the same as the logical schema + // which means they need to match + // the option is options.execution.skip_physical_aggregate_schema_check - see the code for that + // Also, we could CHANGE the actual verification to check that the schema can be rewritten - deep functions, instead of checking for exact equality + // the code WILL work when we + // - remove the check by disabling the option + // - change the check so that it verifies that the schemas can be casted / rewritten, instead of being equal + { + let proj_inner_schema = schema.inner(); + let proj_metadata = schema.metadata().clone(); + let _proj_functional_dependencies = schema.functional_dependencies(); + // compute qualified fields + let mut proj_qualified_fields: Vec> = vec![]; + let mut all_projection_indices: Vec = vec![]; + for (idx, (c, _f)) in schema.iter().enumerate() { + let newc = c.cloned(); + proj_qualified_fields.push(newc); + all_projection_indices.push(idx); + } + + let deep_projection = transform_column_deep_projection_map_to_usize_index_type_and_clean_up_stars_for_plain_schema( + &cols_level.last().unwrap().columns, + proj_inner_schema, + ); + let new_proj_inner_schema = try_rewrite_schema_opt(proj_inner_schema.clone(), Some(&all_projection_indices), Some(&deep_projection))?; + info!("REWRITTEN SCHEMA: {:?}", new_proj_inner_schema); + // remake a DFSchema for the Projection + // zip the fields + let new_qualified_fields = proj_qualified_fields + .into_iter() + .zip(new_proj_inner_schema.fields.iter().map(|f|f.clone())) + .collect::, FieldRef)>>(); + let new_df_schema = Arc::new(DFSchema::new_with_metadata( + new_qualified_fields, + proj_metadata + )?); + new_proj_schema = Some(new_df_schema); + } + } + } + } + + if let Some(new_proj_schema) = new_proj_schema { + let new_proj = Projection::try_new_with_schema(expr, input, new_proj_schema); + Ok(Transformed::new(LogicalPlan::Projection(new_proj?), false, TreeNodeRecursion::Continue)) + } else { + Ok(Transformed::new(LogicalPlan::Projection(Projection::try_new(expr, input)?), false, TreeNodeRecursion::Continue)) + } + } + LogicalPlan::Filter(_) => Ok(Transformed::new(plan, false, TreeNodeRecursion::Continue)), + LogicalPlan::Window(_) => Ok(Transformed::new(plan, false, TreeNodeRecursion::Continue)), + LogicalPlan::Aggregate(_) => Ok(Transformed::new(plan, false, TreeNodeRecursion::Continue)), + LogicalPlan::Limit(_) => Ok(Transformed::new(plan, false, TreeNodeRecursion::Continue)), + LogicalPlan::Distinct(_) => Ok(Transformed::new(plan, false, TreeNodeRecursion::Continue)), + LogicalPlan::Sort(_) => Ok(Transformed::new(plan, false, TreeNodeRecursion::Continue)), + LogicalPlan::Repartition(_) => Ok(Transformed::new(plan, false, TreeNodeRecursion::Continue)), + + LogicalPlan::Join(join) => { + let Join { + left, + right, + on, + filter, + join_type, + join_constraint, + schema: _, + null_equals_null, + } = join; + let new_left = deep_plan_transformer(left.as_ref().clone(), cols_level.clone(), options)?; + let new_right = deep_plan_transformer(right.as_ref().clone(), cols_level.clone(), options)?; + let is_transformed = new_left.transformed || new_right.transformed; + let new_join_schema = Arc::new(build_join_schema(new_left.data.schema(), new_right.data.schema(), &join_type)?); + return Ok( + Transformed::new( + LogicalPlan::Join(Join { + left: Arc::new(new_left.data), + right: Arc::new(new_right.data), + on, + filter, + join_type, + join_constraint, + schema: new_join_schema, // FIXME ? + null_equals_null + }), + is_transformed, + TreeNodeRecursion::Jump, + ) + ) + } + LogicalPlan::Union(union) => { + let mut new_children: Vec> = vec![]; + let Union { + inputs, schema + } = union; + let mut any_transformed = false; + for input in inputs.iter() { + let tmp = deep_plan_transformer(input.as_ref().clone(), vec![], options)?; + any_transformed = any_transformed || tmp.transformed; + new_children.push(Arc::new(tmp.data)); + } + return Ok( + Transformed::new( + LogicalPlan::Union(Union { + inputs: new_children, + schema, + }), + any_transformed, + TreeNodeRecursion::Jump, + ) + ) + } + LogicalPlan::TableScan(table_scan) => { + let TableScan { + table_name, + source, + projection, + projection_deep: _, + filters, + fetch, + projected_schema, + } = table_scan; + + info!("TABLE SCAN: input cols_level= {:?}", &cols_level); + // can we actually do the deep merging ? + // ADR: this no longer works after passing in stuff through SubqueryAlias + let mut projections_before_indices: Vec = vec![]; + for (idx, x) in cols_level.iter().enumerate() { + if matches!(x.plan.as_ref(), LogicalPlan::Projection(_)) { + projections_before_indices.push(idx); + } + } + if projections_before_indices.len() == 0 && cols_from_above.len() == 0{ + return Ok(Transformed::new(LogicalPlan::TableScan(TableScan { + table_name, + source, + projection, + projection_deep: None, + projected_schema, + filters, + fetch, + }), true, TreeNodeRecursion::Continue)); + } + + let df_schema = Arc::new(DFSchema::try_from(source.schema())?); + let cols_level_as_simple_map = convert_plans_with_deep_column_to_columns_map(&cols_level); + // get only the expressions for me + let columns_for_this_scan = filter_expressions_for_table(&cols_level_as_simple_map, table_name.table().to_string()); + info!("TABLE SCAN COLUMNS FOR ME({}): {:?}", table_name.table(), columns_for_this_scan); + // compact all the separate levels to a single map + let single_columns_for_this_scan = compact_list_of_column_maps_to_single_map(&columns_for_this_scan); + info!("TABLE SCAN COLUMNS FOR ME({}): {:?}", table_name.table(), single_columns_for_this_scan); + // fix deep projections - we might have things that look like struct access, but they are map access + let fixed_single_columns_for_this_scan = fix_deep_projection_according_to_table_schema(&single_columns_for_this_scan, &df_schema); + info!("TABLE SCAN COLUMNS FOR ME({}): {:?}", table_name.table(), fixed_single_columns_for_this_scan); + // transform to actual deep projection - use the column index instead of the column, and replace ["*"] with empty + let projection_deep = transform_column_deep_projection_map_to_usize_index_type_and_clean_up_stars(&fixed_single_columns_for_this_scan, &df_schema); + trace!(target: "deep", "Rewriting deep projections for table {}: {:?}", table_name.table(), projection_deep); + let reprojected_schema = reproject_for_deep_schema( + &table_name, + projection.clone(), + Some(projection_deep.clone()), + &projected_schema, + &source.schema(), + )?; + // info!("reprojected schema: {:#?}", reprojected_schema.inner()); + Ok(Transformed::new(LogicalPlan::TableScan(TableScan { + table_name, + source, + projection, + projection_deep: Some(projection_deep), + projected_schema: reprojected_schema, + filters, + fetch, + }), true, TreeNodeRecursion::Continue)) + } + // we don't know what to do with an empty relation + LogicalPlan::EmptyRelation(_) => return Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump)), + LogicalPlan::Subquery(sq) => { + // ADR: we cannot have Subquery plans after optimization, they're all SubqueryAlias + let Subquery { + subquery, + outer_ref_columns + } = sq; + let subquery_result = deep_plan_transformer(subquery.as_ref().clone(), vec![], options)?; + let transformed = subquery_result.transformed; + Ok(Transformed::new(LogicalPlan::Subquery(Subquery { + subquery: Arc::new(subquery_result.data), + outer_ref_columns, + }), transformed, TreeNodeRecursion::Jump)) + + // unreachable!() + } + LogicalPlan::SubqueryAlias(sqa) => { + let SubqueryAlias { + input, + alias, + schema: _, + .. + } = sqa; + let mut translated_pdcs:Vec = vec![]; + if options.optimizer.deep_column_pruning_flags & FLAG_ENABLE_SUBQUERY_TRANSLATION != 0 { + // info!("subquery translation: activating subquery translation, for alias {}", &alias); + // info!("subquery translation: activating subquery translation CHECKING {:?}", &cols_level); + for pdc in cols_level.iter() { + let mut translated_pdc = PlanWithDeepColumnMap { + plan: pdc.plan.clone(), + columns: HashMap::new(), + }; + for (col, projections) in pdc.columns.iter() { + if let Some(col_table_reference) = &col.relation { + if col_table_reference == &alias { + let mut new_col = col.clone(); + new_col.relation = None; + // don't add the projection if it's empty + if !is_empty_deep_projection(projections) { + translated_pdc.columns.insert(new_col, projections.clone()); + } + } + } + } + if translated_pdc.columns.len() > 0 { + translated_pdcs.push(translated_pdc); + } + } + info!("subquery translation: translated pdcs = {:?}", translated_pdcs); + } + + let sq_input_result = deep_plan_transformer(input.as_ref().clone(), translated_pdcs, options)?; + let input_transformed = sq_input_result.transformed; + let new_sqa = LogicalPlan::SubqueryAlias(SubqueryAlias::try_new(Arc::new(sq_input_result.data), alias)?); + Ok(Transformed::new(new_sqa, input_transformed, TreeNodeRecursion::Jump)) + } + // ignore DDL like statements + LogicalPlan::Statement(_) + | LogicalPlan::Values(_) + | LogicalPlan::Explain(_) + | LogicalPlan::Analyze(_) + | LogicalPlan::Extension(_) + | LogicalPlan::Dml(_) + | LogicalPlan::Ddl(_) + | LogicalPlan::DescribeTable(_) + | LogicalPlan::Copy(_) => return Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump)), + LogicalPlan::Unnest(_) => { // WTF + Ok(Transformed::new(plan, false, TreeNodeRecursion::Continue)) + } + LogicalPlan::RecursiveQuery(rq) => { + // FIXME: not tested + // ADR: this makes some tests break + // let recursive_result = deep_plan_transformer(rq.recursive_term.as_ref().clone(), vec![])?; + // Ok(Transformed::new(recursive_result.data, recursive_result.transformed, TreeNodeRecursion::Jump)) + Ok(Transformed::new(LogicalPlan::RecursiveQuery(rq), false, TreeNodeRecursion::Continue)) + } + } + }) +} + +pub fn reproject_for_deep_schema( + table_name: &TableReference, + projection: Option>, + projection_deep: Option, + projected_schema: &DFSchemaRef, + source_schema: &SchemaRef, +) -> Result { + if projection.is_some() && projection_deep.is_some() { + let projection_clone = projection.unwrap().clone(); + let projection_deep_clone = projection_deep.unwrap().clone(); + let mut new_projection_deep: DeepColumnIndexMap = HashMap::new(); + projection_clone.iter().enumerate().for_each(|(ip, elp)| { + let empty_vec: Vec = vec![]; + let deep = projection_deep_clone.get(elp).or(Some(&empty_vec)).unwrap(); + new_projection_deep.insert(ip, deep.clone()); + }); + let new_projection = (0..projection_clone.len()).collect::>(); + let inner_projected_schema = projected_schema.inner().clone(); + let new_inner_projected_schema = rewrite_schema( + inner_projected_schema, + &new_projection, + &new_projection_deep, + ); + let new_fields = new_inner_projected_schema + .fields() + .iter() + .map(|fi| (Some(table_name.clone()), fi.clone())) + .collect::, Arc)>>(); + let mut new_projected_schema_df = + DFSchema::new_with_metadata(new_fields, source_schema.metadata.clone())?; + new_projected_schema_df = new_projected_schema_df.with_functional_dependencies( + projected_schema.functional_dependencies().clone(), + )?; + let new_projected_schema = Arc::new(new_projected_schema_df); + Ok(new_projected_schema) + } else { + Ok(projected_schema.clone()) + } +} + +pub fn get_columns_referenced_in_plan(plan: &LogicalPlan) -> DeepColumnColumnMap { + let expressions = plan.expressions(); //get_plan_expressions(p); + // info!(" expressions = {:?}", &expressions); + let mut deep_column_map: DeepColumnColumnMap = HashMap::new(); + for expr in expressions.iter() { + let tmp = expr_to_deep_columns(expr); + for (k, vs) in tmp.iter() { + // if we have an aggregation below, we might get here a column that looks like + // Alias(Alias { expr: Column(Column { relation: None, name: "lag(events.UserId,Int64(1)) PARTITION BY [events.DeviceId] ORDER BY [events.timestamp ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW" }), relation: None, name: "PreviousUserColName" }), + // we just filter out the None relations, since we care about tables + // we rely on a lower level field having this. + if k.relation.is_none() { + continue; + } + // info!(" column = {:?}, vs: {:?}", k, vs); + for v in vs { + merge_value_in_column_map(&mut deep_column_map, k, v.clone()); + } + // if deep_column_map.contains_key(k) { + // let mut current_values = deep_column_map.get(k).unwrap(); + // + // current_values.push(v.clone()); + // } else { + // deep_column_map.insert(k.clone(), v.clone()); + // } + } + } + // info!("columns : {:?}", &deep_column_map); + deep_column_map +} + +pub fn convert_plans_with_deep_column_to_columns_map( + input: &Vec, +) -> Vec { + input.iter().map(|pdc| pdc.columns.clone()).collect() +} + +pub fn filter_expressions_for_table( + input: &Vec, + table_name: String, +) -> Vec { + let mut output: Vec = vec![]; + for (_level, column_map) in input.iter().enumerate() { + let mut level_map: DeepColumnColumnMap = HashMap::new(); + for (column, deep) in column_map.iter() { + if let Some(table_reference) = &column.relation { + match table_reference { + TableReference::Bare { table, .. } => { + if table.to_string() == table_name { + level_map.insert(column.clone(), deep.clone()); + } + } + TableReference::Partial { table, .. } => { + if table.to_string() == table_name { + level_map.insert(column.clone(), deep.clone()); + } + } + TableReference::Full { table, .. } => { + if table.to_string() == table_name { + level_map.insert(column.clone(), deep.clone()); + } + } + } + } + } + if level_map.len() > 0 { + output.push(level_map); + } + } + output +} + +pub fn compact_list_of_column_maps_to_single_map( + input: &Vec, +) -> DeepColumnColumnMap { + let mut output: DeepColumnColumnMap = HashMap::new(); + for (_level, column_map) in input.iter().enumerate() { + for (column, deep) in column_map.iter() { + for value in deep.iter() { + merge_value_in_column_map(&mut output, column, value.clone()); + } + } + } + output +} + +pub fn transform_column_deep_projection_map_to_usize_index_type_and_clean_up_stars( + input: &DeepColumnColumnMap, + schema: &DFSchemaRef, +) -> DeepColumnIndexMap { + let mut output: DeepColumnIndexMap = HashMap::new(); + for (col, values) in input.iter() { + // This is needed because of the way DF checks qualified columns + let col_to_check = Column::new_unqualified(col.name.clone()); + if let Some(col_idx_in_schema) = schema.maybe_index_of_column(&col_to_check) { + if is_empty_deep_projection(values) { + output.insert(col_idx_in_schema, vec![]); + } else { + output.insert(col_idx_in_schema, values.clone()); + } + } + } + output +} + +pub fn transform_column_deep_projection_map_to_usize_index_type_and_clean_up_stars_for_plain_schema( + input: &DeepColumnColumnMap, + schema: &SchemaRef, +) -> DeepColumnIndexMap { + let mut output: DeepColumnIndexMap = HashMap::new(); + for (col, values) in input.iter() { + // This is needed because of the way DF checks qualified columns + let name = col.name(); + if let Ok(col_idx_in_schema) = schema.index_of(name) { + if is_empty_deep_projection(values) { + output.insert(col_idx_in_schema, vec![]); + } else { + output.insert(col_idx_in_schema, values.clone()); + } + } + } + output +} + +pub fn fix_deep_projection_according_to_table_schema( + input: &DeepColumnColumnMap, + schema: &DFSchemaRef, +) -> DeepColumnColumnMap { + // info!("fix_deep_projection_according_to_table_schema SCHEMA: {:?}", schema); + + let mut output: DeepColumnColumnMap = HashMap::new(); + for (col, deep_projection_specifiers) in input { + // info!("fix_deep_projection_according_to_table_schema checking {:?}", col); + let col_to_check = Column::new_unqualified(col.name.clone()); + if let Some(col_idx_in_schema) = schema.maybe_index_of_column(&col_to_check) { + // info!("COL IDX {}", col_idx_in_schema); + // get the rest and see whether the column type + // we iterate through the rest specifiers and we fix them + // that is, if we see something that looks like a get field, but we know the field in the schema + // is a map, that means that we need to replace it with * + // map_field['val'], projection_rest = ["val"] => projection_rest=["*"] + let mut fixed_deep_projection_specifiers: Vec = vec![]; + for projection_specifier in deep_projection_specifiers { + // ADR: FIXME + let fixed_projection_specifier_pieces = fix_possible_field_accesses( + &schema, + col_idx_in_schema, + projection_specifier, + ) + .unwrap(); + let fixed_projection_specifier = + fixed_projection_specifier_pieces.join("."); + fixed_deep_projection_specifiers.push(fixed_projection_specifier); + } + output.insert(col.clone(), fixed_deep_projection_specifiers); + } + } + output +} + +pub fn merge_value_in_column_map( + map: &mut DeepColumnColumnMap, + col: &Column, + new_value: String, +) { + assert_ne!(new_value, ""); // the value should be something, we don't use empty strings anymore + if map.contains_key(col) { + let current_values = map.get(col).unwrap(); + if new_value == "*" { + // replace the existing values + map.insert(col.clone(), vec!["*".to_string()]); + return; + } else if is_empty_deep_projection(current_values) { + // we already read all from above, ignore anything that comes later + return; + } else { + // we actually add something + + // do we already have it ? then return, don't modify the map + for value in current_values.iter() { + if *value == new_value { + return; + } + } + // a.b.c, but we already have a, or a.b + for value in current_values.iter() { + if new_value.starts_with(value) { + return; + } + } + let mut values = current_values.clone(); + values.push(new_value); + map.insert(col.clone(), values); + } + } else { + map.insert(col.clone(), vec![new_value]); + } +} + +pub fn find_all_projections_for_column( + needle: &Column, + haystack: &Vec, + start: usize, + end: usize, +) -> Vec { + let mut end = end; + if start == end { + end += 1; + } + let mut output: DeepColumnColumnMap = HashMap::new(); + output.insert(needle.clone(), vec![]); + for idx in start..end { + let pdc = haystack.get(idx).unwrap(); + if let Some(projections_for_col) = pdc.columns.get(needle) { + for projection in projections_for_col { + merge_value_in_column_map(&mut output, needle, projection.clone()); + } + } + } + output.get(needle).unwrap().clone() +} + +pub fn get_plan_expressions(plan: &LogicalPlan) -> Vec { + match plan { + LogicalPlan::Join(join) => { + let mut output: Vec = vec![]; + if let Some(join_filter) = &join.filter { + output.push(join_filter.clone()) + } + for (e1, e2) in join.on.iter() { + output.push(e1.clone()); + output.push(e2.clone()); + } + output + } + _ => plan.expressions(), + } +} + +pub fn expr_to_deep_columns(expr: &Expr) -> DeepColumnColumnMap { + let mut accum: DeepColumnColumnMap = HashMap::new(); + let mut field_accum: VecDeque = VecDeque::new(); + let mut in_make_struct_call: bool = false; + let mut in_other_literal_call: bool = false; + let _ = expr + .apply(|expr| { + match expr { + Expr::Column(qc) => { + // @HStack FIXME: ADR: we should have a test case + // ignore deep columns if we have a in_make_struct_call + // case: struct(a, b, c)['col'] - we were getting 'col' in the accum stack + // FIXME Will this work for struct(get_field(a, 'substruct'))['col'] ????? + if in_make_struct_call || in_other_literal_call { + field_accum.clear() + } + // at the end, unwind the field_accum and push all to accum + let mut tmp: Vec = vec![]; + // if we didn't just save a "*" - which means the entire column + if !(field_accum.len() == 1 && field_accum.get(0).unwrap() == "*") { + for f in field_accum.iter().rev() { + tmp.push(f.to_owned()); + } + } + field_accum.clear(); + if tmp.len() == 0 { + // entire column + append_column::(&mut accum, qc, vec!["*".to_string()]); + } else { + append_column::(&mut accum, qc, tmp); + } + } + Expr::ScalarFunction(sf) => { + // TODO what about maps ? what's the operator + match sf.name() { + "get_field" => { + // get field, append the second argument to the stack and continue + match sf.args[1].clone() { + Expr::Literal(lit_expr) => match lit_expr { + ScalarValue::Utf8(str) => { + let tmp = str.unwrap(); + field_accum.push_back(tmp); + } + _ => { + error!( + "Can't handle expression 1 {:?}", + sf.args[1] + ); + in_other_literal_call = true + } + }, + _ => { + error!("Can't handle expression 2 {:?}", sf.args[1]); + // panic!() + } + }; + // + // let literal_expr: String = match sf.args[1].clone() { + // Expr::Literal(lit_expr) => match lit_expr { + // ScalarValue::Utf8(str) => str.unwrap(), + // _ => { + // error!( + // "Can't handle expression 1 {:?}", + // sf.args[1] + // ); + // in_other_literal_call = true + // // panic!() + // } + // }, + // _ => { + // error!("Can't handle expression 2 {:?}", sf.args[1]); + // panic!() + // } + // }; + // field_accum.push_back(literal_expr); + } + "array_element" => { + // We don't have the schema, but when splatting the column, we need to actually push the list inner field name here + field_accum.push_back("*".to_owned()); + } + "struct" => { + in_make_struct_call = true; + } + _ => {} + } + } + Expr::Unnest(_) + | Expr::ScalarVariable(_, _) + | Expr::Alias(_) + | Expr::Literal(_) + | Expr::BinaryExpr { .. } + | Expr::Like { .. } + | Expr::SimilarTo { .. } + | Expr::Not(_) + | Expr::IsNotNull(_) + | Expr::IsNull(_) + | Expr::IsTrue(_) + | Expr::IsFalse(_) + | Expr::IsUnknown(_) + | Expr::IsNotTrue(_) + | Expr::IsNotFalse(_) + | Expr::IsNotUnknown(_) + | Expr::Negative(_) + | Expr::Between { .. } + | Expr::Case { .. } + | Expr::Cast { .. } + | Expr::TryCast { .. } + | Expr::WindowFunction { .. } + | Expr::AggregateFunction { .. } + | Expr::GroupingSet(_) + | Expr::InList { .. } + | Expr::Exists { .. } + | Expr::InSubquery(_) + | Expr::ScalarSubquery(_) + | Expr::Placeholder(_) + | Expr::OuterReferenceColumn { .. } => {} + Expr::Wildcard { .. } => {} + } + Ok(TreeNodeRecursion::Continue) + }) + .map(|_| ()); + accum +} + +pub fn fix_possible_field_accesses( + schema: &DFSchemaRef, + field_idx: usize, + deep_projection_specifier: &String, +) -> Result> { + // info!("fix_possible_field_accesses {} {}", field_idx, deep_projection_specifier); + if deep_projection_specifier == "*" { + return Ok(vec!["*".to_string()]); + } + let rest = deep_projection_specifier + .split(".") + .map(|x| x.to_string()) + .collect::>(); + let mut field = Arc::new(schema.field(field_idx).clone()); + let mut rest_idx = 0 as usize; + let mut out = rest.clone(); + while rest_idx < out.len() { + let (fix_non_star_access, should_continue, new_field) = match field.data_type() { + DataType::Null + | DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Timestamp(_, _) + | DataType::Date32 + | DataType::Date64 + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Duration(_) + | DataType::Interval(_) + | DataType::Binary + | DataType::FixedSizeBinary(_) + | DataType::LargeBinary + | DataType::BinaryView + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Utf8View + | DataType::Dictionary(_, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + | DataType::RunEndEncoded(_, _) => (false, false, None), + DataType::Union(_, _) => { + // FIXME @HStack + // don't know what to do here + (false, false, None) + } + DataType::List(inner) + | DataType::ListView(inner) + | DataType::FixedSizeList(inner, _) + | DataType::LargeList(inner) + | DataType::LargeListView(inner) => { + let new_field = inner.clone(); + (true, true, Some(new_field)) + } + DataType::Struct(inner_struct) => { + let mut new_field: Option = None; + for f in inner_struct.iter() { + if f.name() == &out[rest_idx] { + new_field = Some(f.clone()); + } + } + (false, true, new_field) + } + DataType::Map(inner_map, _) => { + let new_field: Option; + match inner_map.data_type() { + DataType::Struct(inner_map_struct) => { + new_field = Some(inner_map_struct[1].clone()); + } + _ => { + return Err(DataFusionError::Internal(String::from( + "Invalid inner map type", + ))); + } + } + (true, true, new_field) + } + }; + if fix_non_star_access && rest[rest_idx] != "*" { + out[rest_idx] = "*".to_string(); + } + if !should_continue { + break; + } + field = new_field.unwrap(); + rest_idx += 1; + } + Ok(out) +} + +pub fn append_column(acc: &mut HashMap>, column: &T, rest: Vec) +where + T: Debug + Clone + Eq + Hash, +{ + let final_name = rest.join("."); + match acc.get_mut(column) { + None => { + let column_clone = column.clone(); + if rest.len() > 0 { + acc.insert(column_clone, vec![final_name]); + } else { + acc.insert(column_clone, vec![]); + } + } + Some(cc) => { + if cc.len() == 0 { + // we already had this column in full + } else { + if rest.len() > 0 { + if !cc.contains(&final_name) { + cc.push(final_name); + } + } else { + // we are getting the entire column, and we already had something + // we should delete everything + cc.clear(); + } + } + } + } +} + +pub fn are_exprs_plain(exprs: &Vec) -> bool { + for pexpr in exprs.iter() { + match pexpr { + Expr::Alias(Alias { + expr: alias_inner_expr, + .. + }) => match alias_inner_expr.as_ref() { + Expr::Column(_) => {} + Expr::Literal(_) => {} + Expr::ScalarFunction(ScalarFunction { func, args: _args }) => { + if func.name() != "get_field" { + return false; + } + } + _ => { + return false; + } + }, + Expr::Column(_) => {} + Expr::Literal(_) => {} + Expr::ScalarFunction(ScalarFunction { func, args: _args }) => { + if func.name() != "get_field" { + return false; + } + } + _ => { + return false; + } + } + } + true +} + +// pub fn append_column(acc: &mut DeepColumnColumnMap, column: &Column, rest: Vec) { +// info!("APPEND: {:?} = {:?}", column, rest); +// match acc.get_mut(column) { +// None => { +// let column_clone = column.clone(); +// if rest.len() > 0 { +// acc.insert(column_clone, vec![rest.join(".")]); +// } else { +// acc.insert(column_clone, vec![]); +// } +// } +// Some(cc) => { +// if rest.len() > 0 { +// cc.push(rest.join(".")); +// } +// } +// } +// } + +pub fn get_plan_name(plan: &LogicalPlan) -> String { + match plan { + LogicalPlan::Projection(_) => "Projection".to_string(), + LogicalPlan::Filter(_) => "Filter".to_string(), + LogicalPlan::Window(_) => "Window".to_string(), + LogicalPlan::Aggregate(_) => "Aggregate".to_string(), + LogicalPlan::Sort(_) => "Sort".to_string(), + LogicalPlan::Join(_) => "Join".to_string(), + LogicalPlan::Repartition(_) => "Repartition".to_string(), + LogicalPlan::Union(_) => "Union".to_string(), + LogicalPlan::TableScan(_) => "TableScan".to_string(), + LogicalPlan::EmptyRelation(_) => "EmptyRelation".to_string(), + LogicalPlan::Subquery(_) => "Subquery".to_string(), + LogicalPlan::SubqueryAlias(_) => "SubqueryAlias".to_string(), + LogicalPlan::Limit(_) => "Limit".to_string(), + LogicalPlan::Statement(_) => "Statement".to_string(), + LogicalPlan::Values(_) => "Values".to_string(), + LogicalPlan::Explain(_) => "Explain".to_string(), + LogicalPlan::Analyze(_) => "Analyze".to_string(), + LogicalPlan::Extension(_) => "Extension".to_string(), + LogicalPlan::Distinct(_) => "Distinct".to_string(), + LogicalPlan::Dml(_) => "Dml".to_string(), + LogicalPlan::Ddl(_) => "Ddl".to_string(), + LogicalPlan::Copy(_) => "Copy".to_string(), + LogicalPlan::DescribeTable(_) => "DescribeTable".to_string(), + LogicalPlan::Unnest(_) => "Unnest".to_string(), + LogicalPlan::RecursiveQuery(_) => "RecursiveQuery".to_string(), + } +} diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 49bce3c1ce82c..014a2c75779b8 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -46,6 +46,7 @@ use crate::eliminate_outer_join::EliminateOuterJoin; use crate::extract_equijoin_predicate::ExtractEquijoinPredicate; use crate::filter_null_join_keys::FilterNullJoinKeys; use crate::optimize_projections::OptimizeProjections; +use crate::optimize_projections_deep::OptimizeProjectionsDeep; use crate::plan_signature::LogicalPlanSignature; use crate::propagate_empty_relation::PropagateEmptyRelation; use crate::push_down_filter::PushDownFilter; @@ -460,6 +461,11 @@ impl Optimizer { i += 1; } + // This breaks with optimize projections, we need to execute this last + let optimize_deep_projections_rule = Arc::new(OptimizeProjectionsDeep::new()); + let last_plan = optimize_plan_node(new_plan, optimize_deep_projections_rule.as_ref(), config)?; + new_plan = last_plan.data; + // verify that the optimizer passes only mutated what was permitted. assert_valid_optimization(&new_plan, &starting_schema).map_err(|e| { e.context("Check optimizer-specific invariants after all passes") diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 6b408521c5cf9..41959197bb9ad 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -2861,6 +2861,7 @@ mod tests { (*test_provider.schema()).clone(), )?), projection, + projection_deep: None, source: Arc::new(test_provider), fetch: None, }); @@ -2928,14 +2929,26 @@ mod tests { #[test] fn multi_combined_filter() -> Result<()> { - let plan = table_scan_with_pushdown_provider_builder( - TableProviderFilterPushDown::Inexact, - vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))], - Some(vec![0]), - )? - .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))? - .project(vec![col("a"), col("b")])? - .build()?; + let test_provider = PushDownProvider { + filter_support: TableProviderFilterPushDown::Inexact, + }; + + let table_scan = LogicalPlan::TableScan(TableScan { + table_name: "test".into(), + filters: vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))], + projected_schema: Arc::new(DFSchema::try_from( + (*test_provider.schema()).clone(), + )?), + projection: Some(vec![0]), + projection_deep: Some(HashMap::new()), + source: Arc::new(test_provider), + fetch: None, + }); + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))? + .project(vec![col("a"), col("b")])? + .build()?; let expected = "Projection: a, b\ \n Filter: a = Int64(10) AND b > Int64(11)\ @@ -2946,14 +2959,26 @@ mod tests { #[test] fn multi_combined_filter_exact() -> Result<()> { - let plan = table_scan_with_pushdown_provider_builder( - TableProviderFilterPushDown::Exact, - vec![], - Some(vec![0]), - )? - .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))? - .project(vec![col("a"), col("b")])? - .build()?; + let test_provider = PushDownProvider { + filter_support: TableProviderFilterPushDown::Exact, + }; + + let table_scan = LogicalPlan::TableScan(TableScan { + table_name: "test".into(), + filters: vec![], + projected_schema: Arc::new(DFSchema::try_from( + (*test_provider.schema()).clone(), + )?), + projection: Some(vec![0]), + projection_deep: Some(HashMap::new()), + source: Arc::new(test_provider), + fetch: None, + }); + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))? + .project(vec![col("a"), col("b")])? + .build()?; let expected = r#" Projection: a, b diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 1cdfe6d216e32..0dabd66f8cc2f 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -986,6 +986,8 @@ message FileScanExecConf { repeated FileGroup file_groups = 1; datafusion_common.Schema schema = 2; repeated uint32 projection = 4; + // FIXME somewhat abusively using ProjectionColumns to serialize map> + map projection_deep = 20; ScanLimit limit = 5; datafusion_common.Statistics statistics = 6; repeated string table_partition_cols = 7; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 6e09e9a797ea0..9e8b7253e74c1 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -5743,6 +5743,9 @@ impl serde::Serialize for FileScanExecConf { if !self.projection.is_empty() { len += 1; } + if !self.projection_deep.is_empty() { + len += 1; + } if self.limit.is_some() { len += 1; } @@ -5771,6 +5774,9 @@ impl serde::Serialize for FileScanExecConf { if !self.projection.is_empty() { struct_ser.serialize_field("projection", &self.projection)?; } + if !self.projection_deep.is_empty() { + struct_ser.serialize_field("projectionDeep", &self.projection_deep)?; + } if let Some(v) = self.limit.as_ref() { struct_ser.serialize_field("limit", v)?; } @@ -5803,6 +5809,8 @@ impl<'de> serde::Deserialize<'de> for FileScanExecConf { "fileGroups", "schema", "projection", + "projection_deep", + "projectionDeep", "limit", "statistics", "table_partition_cols", @@ -5819,6 +5827,7 @@ impl<'de> serde::Deserialize<'de> for FileScanExecConf { FileGroups, Schema, Projection, + ProjectionDeep, Limit, Statistics, TablePartitionCols, @@ -5849,6 +5858,7 @@ impl<'de> serde::Deserialize<'de> for FileScanExecConf { "fileGroups" | "file_groups" => Ok(GeneratedField::FileGroups), "schema" => Ok(GeneratedField::Schema), "projection" => Ok(GeneratedField::Projection), + "projectionDeep" | "projection_deep" => Ok(GeneratedField::ProjectionDeep), "limit" => Ok(GeneratedField::Limit), "statistics" => Ok(GeneratedField::Statistics), "tablePartitionCols" | "table_partition_cols" => Ok(GeneratedField::TablePartitionCols), @@ -5877,6 +5887,7 @@ impl<'de> serde::Deserialize<'de> for FileScanExecConf { let mut file_groups__ = None; let mut schema__ = None; let mut projection__ = None; + let mut projection_deep__ = None; let mut limit__ = None; let mut statistics__ = None; let mut table_partition_cols__ = None; @@ -5906,6 +5917,15 @@ impl<'de> serde::Deserialize<'de> for FileScanExecConf { .into_iter().map(|x| x.0).collect()) ; } + GeneratedField::ProjectionDeep => { + if projection_deep__.is_some() { + return Err(serde::de::Error::duplicate_field("projectionDeep")); + } + projection_deep__ = Some( + map_.next_value::, _>>()? + .into_iter().map(|(k,v)| (k.0, v)).collect() + ); + } GeneratedField::Limit => { if limit__.is_some() { return Err(serde::de::Error::duplicate_field("limit")); @@ -5948,6 +5968,7 @@ impl<'de> serde::Deserialize<'de> for FileScanExecConf { file_groups: file_groups__.unwrap_or_default(), schema: schema__, projection: projection__.unwrap_or_default(), + projection_deep: projection_deep__.unwrap_or_default(), limit: limit__, statistics: statistics__, table_partition_cols: table_partition_cols__.unwrap_or_default(), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index f5ec45da48f2a..96a458ded7460 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1498,6 +1498,9 @@ pub struct FileScanExecConf { pub schema: ::core::option::Option, #[prost(uint32, repeated, tag = "4")] pub projection: ::prost::alloc::vec::Vec, + /// FIXME abusively using projection columns to serialize map> + #[prost(map = "uint32, message", tag = "20")] + pub projection_deep: ::std::collections::HashMap, #[prost(message, optional, tag = "5")] pub limit: ::core::option::Option, #[prost(message, optional, tag = "6")] diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 641dfe7b5fb84..af995c2c14b13 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -267,6 +267,7 @@ fn from_table_source( table_name, source: target, projection: None, + projection_deep: None, projected_schema, filters: vec![], fetch: None, diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 6331b7fb31144..bf727bf0c946c 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -17,6 +17,7 @@ //! Serde code to convert from protocol buffers to Rust data structures. +use std::collections::HashMap; use std::sync::Arc; use arrow::compute::SortOptions; @@ -492,7 +493,16 @@ pub fn parse_protobuf_file_scan_config( } else { Some(projection) }; - + let projection_deep = proto + .projection_deep + .iter() + .map(|(i, cols)| (*i as usize, cols.columns.clone())) + .collect::>(); + let projection_deep = if projection_deep.is_empty() { + None + } else { + Some(projection_deep) + }; let constraints = convert_required!(proto.constraints)?; let statistics = convert_required!(proto.statistics)?; @@ -542,6 +552,7 @@ pub fn parse_protobuf_file_scan_config( .with_constraints(constraints) .with_statistics(statistics) .with_projection(projection) + .with_projection_deep(projection_deep) .with_limit(proto.limit.as_ref().map(|sl| sl.limit as usize)) .with_table_partition_cols(table_partition_cols) .with_output_ordering(output_ordering); diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 3f67842fe625c..0f6ff3a3f03ff 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; use std::sync::Arc; #[cfg(feature = "parquet")] @@ -42,7 +43,7 @@ use datafusion_expr::WindowFrame; use crate::protobuf::{ self, physical_aggregate_expr_node, physical_window_expr_node, PhysicalSortExprNode, - PhysicalSortExprNodeCollection, + PhysicalSortExprNodeCollection, ProjectionColumns }; use super::PhysicalExtensionCodec; @@ -516,6 +517,13 @@ pub fn serialize_file_scan_config( .iter() .map(|n| *n as u32) .collect(), + projection_deep: conf + .projection_deep + .as_ref() + .unwrap_or(&HashMap::new()) + .iter() + .map(|(n, v)| (*n as u32, ProjectionColumns { columns: v.clone() })) + .collect(), schema: Some(schema.as_ref().try_into()?), table_partition_cols: conf .table_partition_cols diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index a8ee213653086..ef7102530d047 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -758,6 +758,7 @@ fn roundtrip_parquet_exec_with_pruning_predicate() -> Result<()> { ]))), }, projection: None, + projection_deep: None, limit: None, table_partition_cols: vec![], output_ordering: vec![], @@ -822,6 +823,7 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { ]))), }, projection: None, + projection_deep: None, limit: None, table_partition_cols: vec![], output_ordering: vec![], @@ -1619,6 +1621,7 @@ async fn roundtrip_projection_source() -> Result<()> { statistics, file_schema: schema.clone(), projection: Some(vec![0, 1, 2]), + projection_deep: None, limit: None, table_partition_cols: vec![], output_ordering: vec![], diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index b0538b5e65020..95de841102ca3 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -238,6 +238,7 @@ datafusion.explain.show_schema false datafusion.explain.show_sizes true datafusion.explain.show_statistics false datafusion.optimizer.allow_symmetric_joins_without_pruning true +datafusion.optimizer.deep_column_pruning_flags 7 datafusion.optimizer.default_filter_selectivity 20 datafusion.optimizer.enable_distinct_aggregation_soft_limit true datafusion.optimizer.enable_round_robin_repartition true @@ -335,6 +336,7 @@ datafusion.explain.show_schema false When set to true, the explain statement wil datafusion.explain.show_sizes true When set to true, the explain statement will print the partition sizes datafusion.explain.show_statistics false When set to true, the explain statement will print operator statistics for physical plans datafusion.optimizer.allow_symmetric_joins_without_pruning true Should DataFusion allow symmetric hash joins for unbounded data sources even when its inputs do not have any ordering or filtering If the flag is not enabled, the SymmetricHashJoin operator will be unable to prune its internal buffers, resulting in certain join types - such as Full, Left, LeftAnti, LeftSemi, Right, RightAnti, and RightSemi - being produced only at the end of the execution. This is not typical in stream processing. Additionally, without proper design for long runner execution, all types of joins may encounter out-of-memory errors. +datafusion.optimizer.deep_column_pruning_flags 7 disable deep column pruning datafusion.optimizer.default_filter_selectivity 20 The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). datafusion.optimizer.enable_distinct_aggregation_soft_limit true When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. datafusion.optimizer.enable_round_robin_repartition true When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index ffeff3e9df47f..e4385d74f2bb2 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -31,7 +31,7 @@ use datafusion::logical_expr::expr::{Exists, InSubquery, Sort, WindowFunctionPar use datafusion::logical_expr::{ Aggregate, BinaryExpr, Case, Cast, EmptyRelation, Expr, ExprSchemable, Extension, - LogicalPlan, Operator, Projection, SortExpr, Subquery, TryCast, Values, + LogicalPlan, Operator, Projection, SortExpr, Subquery, TableSource, TryCast, Values, }; use substrait::proto::aggregate_rel::Grouping; use substrait::proto::expression as substrait_expression; @@ -86,6 +86,7 @@ use substrait::proto::expression::{ SingularOrList, SwitchExpression, WindowFunction, }; use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile; +use substrait::proto::read_rel::ExtensionTable; use substrait::proto::rel_common::{Emit, EmitKind}; use substrait::proto::set_rel::SetOp; use substrait::proto::{ @@ -457,6 +458,20 @@ pub trait SubstraitConsumer: Send + Sync + Sized { user_defined_literal.type_reference ) } + + fn consume_extension_table( + &self, + extension_table: &ExtensionTable, + ) -> Result> { + if let Some(ext_detail) = extension_table.detail.as_ref() { + substrait_err!( + "Missing handler for extension table: {}", + &ext_detail.type_url + ) + } else { + substrait_err!("Unexpected empty detail in ExtensionTable") + } + } } /// Convert Substrait Rel to DataFusion DataFrame @@ -578,6 +593,19 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> { let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?; Ok(LogicalPlan::Extension(Extension { node: plan })) } + + fn consume_extension_table( + &self, + extension_table: &ExtensionTable, + ) -> Result> { + if let Some(ext_detail) = &extension_table.detail { + self.state + .serializer_registry() + .deserialize_custom_table(&ext_detail.type_url, &ext_detail.value) + } else { + substrait_err!("Unexpected empty detail in ExtensionTable") + } + } } // Substrait PrecisionTimestampTz indicates that the timestamp is relative to UTC, which @@ -1323,26 +1351,14 @@ pub async fn from_read_rel( read: &ReadRel, ) -> Result { async fn read_with_schema( - consumer: &impl SubstraitConsumer, table_ref: TableReference, + table_source: Arc, schema: DFSchema, projection: &Option, ) -> Result { let schema = schema.replace_qualifier(table_ref.clone()); - let plan = { - let provider = match consumer.resolve_table_ref(&table_ref).await? { - Some(ref provider) => Arc::clone(provider), - _ => return plan_err!("No table named '{table_ref}'"), - }; - - LogicalPlanBuilder::scan( - table_ref, - provider_as_source(Arc::clone(&provider)), - None, - )? - .build()? - }; + let plan = { LogicalPlanBuilder::scan(table_ref, table_source, None)?.build()? }; ensure_schema_compatibility(plan.schema(), schema.clone())?; @@ -1351,6 +1367,17 @@ pub async fn from_read_rel( apply_projection(plan, schema) } + async fn table_source( + consumer: &impl SubstraitConsumer, + table_ref: &TableReference, + ) -> Result> { + if let Some(provider) = consumer.resolve_table_ref(table_ref).await? { + Ok(provider_as_source(provider)) + } else { + plan_err!("No table named '{table_ref}'") + } + } + let named_struct = read.base_schema.as_ref().ok_or_else(|| { substrait_datafusion_err!("No base schema provided for Read Relation") })?; @@ -1376,10 +1403,10 @@ pub async fn from_read_rel( table: nt.names[2].clone().into(), }, }; - + let table_source = table_source(consumer, &table_reference).await?; read_with_schema( - consumer, table_reference, + table_source, substrait_schema, &read.projection, ) @@ -1458,17 +1485,41 @@ pub async fn from_read_rel( let name = filename.unwrap(); // directly use unwrap here since we could determine it is a valid one let table_reference = TableReference::Bare { table: name.into() }; + let table_source = table_source(consumer, &table_reference).await?; read_with_schema( - consumer, table_reference, + table_source, substrait_schema, &read.projection, ) .await } - _ => { - not_impl_err!("Unsupported ReadType: {:?}", read.read_type) + Some(ReadType::ExtensionTable(ext)) => { + // look for the original table name under `rel.common.hint.alias` + // in case the producer was kind enough to put it there. + let name_hint = read + .common + .as_ref() + .and_then(|rel_common| rel_common.hint.as_ref()) + .map(|hint| hint.alias.as_str().trim()) + .filter(|alias| !alias.is_empty()); + // if no name hint was provided, use the name that datafusion + // sets for UDTFs + let table_name = name_hint.unwrap_or("tmp_table"); + read_with_schema( + TableReference::from(table_name), + consumer.consume_extension_table(ext)?, + substrait_schema, + &read.projection, + ) + .await + } + Some(ReadType::IcebergTable(_)) => { + substrait_err!("Don't know how to handle Iceberg table") + } + None => { + substrait_err!("Unexpected empty read_type") } } } @@ -1871,7 +1922,7 @@ pub async fn from_substrait_sorts( }, None => not_impl_err!("Sort without sort kind is invalid"), }; - let (asc, nulls_first) = asc_nullfirst.unwrap(); + let (asc, nulls_first) = asc_nullfirst?; sorts.push(Sort { expr, asc, diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index e4df9703b20ca..6d4049048b3f9 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -25,7 +25,7 @@ use datafusion::arrow::datatypes::{Field, IntervalUnit}; use datafusion::logical_expr::{ Aggregate, Distinct, EmptyRelation, Extension, Filter, Join, Like, Limit, Partitioning, Projection, Repartition, Sort, SortExpr, SubqueryAlias, TableScan, - TryCast, Union, Values, Window, WindowFrameUnits, + TableSource, TryCast, Union, Values, Window, WindowFrameUnits, }; use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, @@ -55,6 +55,7 @@ use datafusion::logical_expr::expr::{ AggregateFunctionParams, Alias, BinaryExpr, Case, Cast, GroupingSet, InList, InSubquery, WindowFunction, WindowFunctionParams, }; +use datafusion::logical_expr::registry::NamedBytes; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; use datafusion::prelude::Expr; use pbjson_types::Any as ProtoAny; @@ -70,9 +71,9 @@ use substrait::proto::expression::literal::{ use substrait::proto::expression::subquery::InPredicate; use substrait::proto::expression::window_function::BoundsType; use substrait::proto::expression::ScalarFunction; -use substrait::proto::read_rel::VirtualTable; -use substrait::proto::rel_common::EmitKind; +use substrait::proto::read_rel::{ExtensionTable, VirtualTable}; use substrait::proto::rel_common::EmitKind::Emit; +use substrait::proto::rel_common::{EmitKind, Hint}; use substrait::proto::{ fetch_rel, rel_common, ExchangeRel, ExpressionReference, ExtendedExpression, RelCommon, @@ -367,6 +368,13 @@ pub trait SubstraitProducer: Send + Sync + Sized { ) -> Result { from_in_subquery(self, in_subquery, schema) } + + fn handle_custom_table( + &mut self, + _table: &dyn TableSource, + ) -> Result> { + not_impl_err!("Not implemented") + } } pub struct DefaultSubstraitProducer<'a> { @@ -393,12 +401,12 @@ impl SubstraitProducer for DefaultSubstraitProducer<'_> { } fn handle_extension(&mut self, plan: &Extension) -> Result> { - let extension_bytes = self + let NamedBytes(type_url, bytes) = self .serializer_registry .serialize_logical_plan(plan.node.as_ref())?; let detail = ProtoAny { - type_url: plan.node.name().to_string(), - value: extension_bytes.into(), + type_url, + value: bytes.into(), }; let mut inputs_rel = plan .node @@ -426,6 +434,24 @@ impl SubstraitProducer for DefaultSubstraitProducer<'_> { rel_type: Some(rel_type), })) } + + fn handle_custom_table( + &mut self, + table: &dyn TableSource, + ) -> Result> { + if let Some(NamedBytes(type_url, bytes)) = + self.serializer_registry.serialize_custom_table(table)? + { + Ok(Some(ExtensionTable { + detail: Some(ProtoAny { + type_url, + value: bytes.into(), + }), + })) + } else { + Ok(None) + } + } } /// Convert DataFusion LogicalPlan to Substrait Plan @@ -540,7 +566,7 @@ pub fn to_substrait_rel( } pub fn from_table_scan( - _producer: &mut impl SubstraitProducer, + producer: &mut impl SubstraitProducer, scan: &TableScan, ) -> Result> { let projection = scan.projection.as_ref().map(|p| { @@ -560,18 +586,38 @@ pub fn from_table_scan( let table_schema = scan.source.schema().to_dfschema_ref()?; let base_schema = to_substrait_named_struct(&table_schema)?; + let (table, common) = + if let Ok(Some(ext_table)) = producer.handle_custom_table(scan.source.as_ref()) { + ( + ReadType::ExtensionTable(ext_table), + Some(RelCommon { + hint: Some(Hint { + // store the original table name as rel.common.hint.alias + alias: scan.table_name.to_string(), + ..Default::default() + }), + ..Default::default() + }), + ) + } else { + ( + ReadType::NamedTable(NamedTable { + names: scan.table_name.to_vec(), + advanced_extension: None, + }), + None, + ) + }; + Ok(Box::new(Rel { rel_type: Some(RelType::Read(Box::new(ReadRel { - common: None, + common, base_schema: Some(base_schema), filter: None, best_effort_filter: None, projection, advanced_extension: None, - read_type: Some(ReadType::NamedTable(NamedTable { - names: scan.table_name.to_vec(), - advanced_extension: None, - })), + read_type: Some(table), }))), })) } @@ -1701,7 +1747,7 @@ pub fn from_in_subquery( subquery_type: Some( substrait::proto::expression::subquery::SubqueryType::InPredicate( Box::new(InPredicate { - needles: (vec![substrait_expr]), + needles: vec![substrait_expr], haystack: Some(subquery_plan), }), ), @@ -2540,8 +2586,8 @@ mod test { use super::*; use crate::logical_plan::consumer::{ from_substrait_extended_expr, from_substrait_literal_without_names, - from_substrait_named_struct, from_substrait_type_without_names, - DefaultSubstraitConsumer, + from_substrait_named_struct, from_substrait_plan, + from_substrait_type_without_names, DefaultSubstraitConsumer, }; use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; use datafusion::arrow; @@ -2550,8 +2596,12 @@ mod test { }; use datafusion::arrow::datatypes::{Field, Fields, Schema}; use datafusion::common::scalar::ScalarStructBuilder; - use datafusion::common::DFSchema; + use datafusion::common::{assert_contains, DFSchema}; + use datafusion::datasource::empty::EmptyTable; + use datafusion::datasource::{DefaultTableSource, TableProvider}; use datafusion::execution::{SessionState, SessionStateBuilder}; + use datafusion::logical_expr::registry::SerializerRegistry; + use datafusion::logical_expr::TableSource; use datafusion::prelude::SessionContext; use std::sync::LazyLock; @@ -2889,4 +2939,114 @@ mod test { assert!(matches!(err, Err(DataFusionError::SchemaError(_, _)))); } + + #[tokio::test] + async fn round_trip_extension_table() { + const TABLE_NAME: &str = "custom_table"; + const TYPE_URL: &str = "/substrait.test.CustomTable"; + const SERIALIZED: &[u8] = "table definition".as_bytes(); + + fn custom_table() -> Arc { + Arc::new(EmptyTable::new(Arc::new(Schema::new([ + Arc::new(Field::new("id", DataType::Int32, false)), + Arc::new(Field::new("name", DataType::Utf8, false)), + ])))) + } + + #[derive(Debug)] + struct Registry; + impl SerializerRegistry for Registry { + fn serialize_custom_table( + &self, + table: &dyn TableSource, + ) -> Result> { + if table.schema() == custom_table().schema() { + Ok(Some(NamedBytes(TYPE_URL.to_string(), SERIALIZED.to_vec()))) + } else { + Err(DataFusionError::Internal("Not our table".into())) + } + } + fn deserialize_custom_table( + &self, + name: &str, + bytes: &[u8], + ) -> Result> { + if name == TYPE_URL && bytes == SERIALIZED { + Ok(Arc::new(DefaultTableSource::new(custom_table()))) + } else { + panic!("Unexpected extension table: {name}"); + } + } + } + + async fn round_trip_logical_plans( + local: &SessionContext, + remote: &SessionContext, + ) -> Result<()> { + local.register_table(TABLE_NAME, custom_table())?; + remote.table_provider(TABLE_NAME).await.expect_err( + "The remote context is not supposed to know about custom_table", + ); + let initial_plan = local + .sql(&format!("select id from {TABLE_NAME}")) + .await? + .logical_plan() + .clone(); + + // write substrait locally + let substrait = to_substrait_plan(&initial_plan, &local.state())?; + + // read substrait remotely + // since we know there's no `custom_table` registered in the remote context, this will only succeed + // if our table got encoded as an ExtensionTable and is now decoded back to a table source. + let restored = from_substrait_plan(&remote.state(), &substrait).await?; + assert_contains!( + // confirm that the Substrait plan contains our custom_table as an ExtensionTable + serde_json::to_string(substrait.as_ref()).unwrap(), + format!(r#""extensionTable":{{"detail":{{"typeUrl":"{TYPE_URL}","#) + ); + remote // make sure the restored plan is fully working in the remote context + .execute_logical_plan(restored.clone()) + .await? + .collect() + .await + .expect("Restored plan cannot be executed remotely"); + assert_eq!( + // check that the restored plan is functionally equivalent (and almost identical) to the initial one + initial_plan.to_string(), + restored.to_string().replace( + // substrait will add an explicit full-schema projection if the original table had none + &format!("TableScan: {TABLE_NAME} projection=[id, name]"), + &format!("TableScan: {TABLE_NAME}"), + ) + ); + Ok(()) + } + + // take 1 + let failed_attempt = + round_trip_logical_plans(&SessionContext::new(), &SessionContext::new()) + .await + .expect_err( + "The round trip should fail in the absence of a SerializerRegistry", + ); + assert_contains!( + failed_attempt.message(), + format!("No table named '{TABLE_NAME}'") + ); + + // take 2 + fn proper_context() -> SessionContext { + SessionContext::new_with_state( + SessionStateBuilder::new() + // This will transport our custom_table as a Substrait ExtensionTable + .with_serializer_registry(Arc::new(Registry)) + .build(), + ) + } + + round_trip_logical_plans(&proper_context(), &proper_context()) + .await + .expect("Local plan could not be restored remotely"); + } } diff --git a/datafusion/substrait/tests/cases/mod.rs b/datafusion/substrait/tests/cases/mod.rs index 777246e4139bf..42a9416c9fffb 100644 --- a/datafusion/substrait/tests/cases/mod.rs +++ b/datafusion/substrait/tests/cases/mod.rs @@ -23,4 +23,6 @@ mod roundtrip_logical_plan; #[cfg(feature = "physical")] mod roundtrip_physical_plan; mod serialize; + mod substrait_validations; +mod tree_node; diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index e6b8bdbc047e3..41c840267fa9e 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -31,6 +31,7 @@ use datafusion::error::Result; use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::session_state::SessionStateBuilder; +use datafusion::logical_expr::registry::NamedBytes; use datafusion::logical_expr::{ Extension, InvariantLevel, LogicalPlan, PartitionEvaluator, Repartition, UserDefinedLogicalNode, Values, Volatility, @@ -50,13 +51,13 @@ impl SerializerRegistry for MockSerializerRegistry { fn serialize_logical_plan( &self, node: &dyn UserDefinedLogicalNode, - ) -> Result> { + ) -> Result { if node.name() == "MockUserDefinedLogicalPlan" { let node = node .as_any() .downcast_ref::() .unwrap(); - node.serialize() + Ok(NamedBytes(node.name().to_string(), node.serialize()?)) } else { unreachable!() } diff --git a/datafusion/substrait/tests/cases/tree_node.rs b/datafusion/substrait/tests/cases/tree_node.rs new file mode 100644 index 0000000000000..399e9a1d20829 --- /dev/null +++ b/datafusion/substrait/tests/cases/tree_node.rs @@ -0,0 +1,87 @@ +//! Tests for TreeNode Compatibility + +#[cfg(test)] +mod tests { + use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; + use datafusion::common::Result; + use datafusion::common::substrait_tree::*; + use std::fs::File; + use std::io::BufReader; + use substrait::proto::plan_rel::RelType; + use substrait::proto::rel::RelType::Project; + use substrait::proto::{Plan, ProjectRel, Rel}; + + #[test] + fn tree_visit() -> Result<()> { + let path = "tests/testdata/contains_plan.substrait.json"; + let proto_plan = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + for r in proto_plan.relations { + let rel = match r.rel_type.unwrap() { + RelType::Rel(rel) => rel, + RelType::Root(root_rel) => root_rel.input.unwrap(), + }; + + rel.apply(|r| { + println!("REL: {:#?}", r); + Ok(TreeNodeRecursion::Continue) + })?; + } + + Ok(()) + } + #[test] + fn tree_map() -> Result<()> { + let path = "tests/testdata/contains_plan.substrait.json"; + let proto_plan = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + for r in proto_plan.relations { + let rel = match r.rel_type.unwrap() { + RelType::Rel(rel) => rel, + RelType::Root(root_rel) => root_rel.input.unwrap(), + }; + + rel.apply(|r| { + if let Some(Project(p)) = &r.rel_type { + println!("PROJECT REL: {:#?}", p); + } + Ok(TreeNodeRecursion::Continue) + })?; + + // rewrite ProjectRel node - remove common field + let t = rel + .transform(|r| { + if let Some(Project(p)) = &r.rel_type { + let updated = Project(Box::new(ProjectRel { + common: None, + input: p.input.clone(), + expressions: p.expressions.clone(), + advanced_extension: p.advanced_extension.clone(), + })); + Ok(Transformed::yes(Rel { + rel_type: Some(updated), + })) + } else { + Ok(Transformed::no(r)) + } + })? + .data; + + println!("AFTER"); + t.apply(|r| { + if let Some(Project(p)) = &r.rel_type { + println!("PROJECT REL: {:#?}", p); + } + Ok(TreeNodeRecursion::Continue) + })?; + } + + Ok(()) + } +}