MANIFEST_FORMAT =
key("manifest.format")
.stringType()
diff --git a/paimon-common/src/main/java/org/apache/paimon/format/variant/InferVariantShreddingWriter.java b/paimon-common/src/main/java/org/apache/paimon/format/variant/InferVariantShreddingWriter.java
new file mode 100644
index 000000000000..2337a63b2e0b
--- /dev/null
+++ b/paimon-common/src/main/java/org/apache/paimon/format/variant/InferVariantShreddingWriter.java
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.format.variant;
+
+import org.apache.paimon.data.InternalRow;
+import org.apache.paimon.data.variant.InferVariantShreddingSchema;
+import org.apache.paimon.format.BundleFormatWriter;
+import org.apache.paimon.format.FormatWriter;
+import org.apache.paimon.fs.PositionOutputStream;
+import org.apache.paimon.io.BundleRecords;
+import org.apache.paimon.types.RowType;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * A generic writer that infers the shredding schema from buffered rows before writing.
+ *
+ * This writer buffers rows up to a threshold, infers the optimal schema from them, then writes
+ * all data using the inferred schema. It works with any format that implements {@link
+ * SupportsVariantInference}.
+ */
+public class InferVariantShreddingWriter implements BundleFormatWriter {
+
+ private final SupportsVariantInference writerFactory;
+ private final InferVariantShreddingSchema shreddingSchemaInfer;
+ private final int maxBufferRow;
+ private final PositionOutputStream out;
+ private final String compression;
+
+ private final List bufferedRows;
+ private final List bufferedBundles;
+
+ private FormatWriter actualWriter;
+ private boolean schemaFinalized = false;
+ private long totalBufferedRowCount = 0;
+
+ public InferVariantShreddingWriter(
+ SupportsVariantInference writerFactory,
+ InferVariantShreddingSchema shreddingSchemaInfer,
+ int maxBufferRow,
+ PositionOutputStream out,
+ String compression) {
+ this.writerFactory = writerFactory;
+ this.shreddingSchemaInfer = shreddingSchemaInfer;
+ this.maxBufferRow = maxBufferRow;
+ this.out = out;
+ this.compression = compression;
+ this.bufferedRows = new ArrayList<>();
+ this.bufferedBundles = new ArrayList<>();
+ }
+
+ @Override
+ public void addElement(InternalRow row) throws IOException {
+ if (!schemaFinalized) {
+ bufferedRows.add(row);
+ totalBufferedRowCount++;
+ if (totalBufferedRowCount >= maxBufferRow) {
+ finalizeSchemaAndFlush();
+ }
+ } else {
+ actualWriter.addElement(row);
+ }
+ }
+
+ @Override
+ public void writeBundle(BundleRecords bundle) throws IOException {
+ if (!schemaFinalized) {
+ bufferedBundles.add(bundle);
+ totalBufferedRowCount += bundle.rowCount();
+ if (totalBufferedRowCount >= maxBufferRow) {
+ finalizeSchemaAndFlush();
+ }
+ } else {
+ ((BundleFormatWriter) actualWriter).writeBundle(bundle);
+ }
+ }
+
+ @Override
+ public boolean reachTargetSize(boolean suggestedCheck, long targetSize) throws IOException {
+ if (!schemaFinalized) {
+ return false;
+ }
+ return actualWriter.reachTargetSize(suggestedCheck, targetSize);
+ }
+
+ @Override
+ public void close() throws IOException {
+ try {
+ if (!schemaFinalized) {
+ finalizeSchemaAndFlush();
+ }
+ } finally {
+ if (actualWriter != null) {
+ actualWriter.close();
+ }
+ }
+ }
+
+ private void finalizeSchemaAndFlush() throws IOException {
+ RowType inferredShreddingSchema = shreddingSchemaInfer.inferSchema(collectAllRows());
+ actualWriter =
+ writerFactory.createWithShreddingSchema(out, compression, inferredShreddingSchema);
+ schemaFinalized = true;
+
+ if (!bufferedBundles.isEmpty()) {
+ BundleFormatWriter bundleWriter = (BundleFormatWriter) actualWriter;
+ for (BundleRecords bundle : bufferedBundles) {
+ bundleWriter.writeBundle(bundle);
+ }
+ bufferedBundles.clear();
+ } else {
+ for (InternalRow row : bufferedRows) {
+ actualWriter.addElement(row);
+ }
+ bufferedRows.clear();
+ }
+ }
+
+ private List collectAllRows() {
+ if (!bufferedBundles.isEmpty()) {
+ List allRows = new ArrayList<>();
+ for (BundleRecords bundle : bufferedBundles) {
+ for (InternalRow row : bundle) {
+ allRows.add(row);
+ }
+ }
+ return allRows;
+ } else {
+ return bufferedRows;
+ }
+ }
+}
diff --git a/paimon-common/src/main/java/org/apache/paimon/format/variant/SupportsVariantInference.java b/paimon-common/src/main/java/org/apache/paimon/format/variant/SupportsVariantInference.java
new file mode 100644
index 000000000000..1fe404080e63
--- /dev/null
+++ b/paimon-common/src/main/java/org/apache/paimon/format/variant/SupportsVariantInference.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.format.variant;
+
+import org.apache.paimon.format.FormatWriter;
+import org.apache.paimon.fs.PositionOutputStream;
+import org.apache.paimon.types.RowType;
+
+import java.io.IOException;
+
+/**
+ * Interface for FormatWriterFactory implementations that support variant schema inference.
+ *
+ * Writers implementing this interface can dynamically update their schema based on inferred
+ * variant shredding schemas.
+ */
+public interface SupportsVariantInference {
+
+ /**
+ * Create the writer with the inferred shredding schema using the same output stream and
+ * compression settings.
+ *
+ * @param out The output stream to write to
+ * @param compression The compression codec
+ * @param inferredShreddingSchema The inferred shredding schema for variant fields
+ * @return A new FormatWriter configured with the inferred schema
+ * @throws IOException If the writer cannot be created
+ */
+ FormatWriter createWithShreddingSchema(
+ PositionOutputStream out, String compression, RowType inferredShreddingSchema)
+ throws IOException;
+}
diff --git a/paimon-common/src/main/java/org/apache/paimon/format/variant/VariantInferenceConfig.java b/paimon-common/src/main/java/org/apache/paimon/format/variant/VariantInferenceConfig.java
new file mode 100644
index 000000000000..8c289d605629
--- /dev/null
+++ b/paimon-common/src/main/java/org/apache/paimon/format/variant/VariantInferenceConfig.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.format.variant;
+
+import org.apache.paimon.CoreOptions;
+import org.apache.paimon.data.variant.InferVariantShreddingSchema;
+import org.apache.paimon.options.Options;
+import org.apache.paimon.types.DataField;
+import org.apache.paimon.types.RowType;
+import org.apache.paimon.types.VariantType;
+
+/** Variant schema inference configuration. */
+public class VariantInferenceConfig {
+
+ private final RowType rowType;
+ private final Options options;
+
+ public VariantInferenceConfig(RowType rowType, Options options) {
+ this.rowType = rowType;
+ this.options = options;
+ }
+
+ /** Determines whether variant schema inference should be enabled. */
+ public boolean shouldEnableInference() {
+ if (options.contains(CoreOptions.VARIANT_SHREDDING_SCHEMA)) {
+ return false;
+ }
+
+ if (!options.get(CoreOptions.VARIANT_INFER_SHREDDING_SCHEMA)) {
+ return false;
+ }
+
+ return containsVariantFields(rowType);
+ }
+
+ private boolean containsVariantFields(RowType rowType) {
+ for (DataField field : rowType.getFields()) {
+ if (field.type() instanceof VariantType) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ /** Create a schema inferrer. */
+ public InferVariantShreddingSchema createInferrer() {
+ return new InferVariantShreddingSchema(
+ rowType,
+ options.get(CoreOptions.VARIANT_SHREDDING_MAX_SCHEMA_WIDTH),
+ options.get(CoreOptions.VARIANT_SHREDDING_MAX_SCHEMA_DEPTH),
+ options.get(CoreOptions.VARIANT_SHREDDING_MIN_FIELD_CARDINALITY_RATIO));
+ }
+
+ /** Get the maximum number of rows to buffer for inference. */
+ public int getMaxBufferRow() {
+ return options.get(CoreOptions.VARIANT_SHREDDING_MAX_INFER_BUFFER_ROW);
+ }
+}
diff --git a/paimon-common/src/main/java/org/apache/paimon/format/variant/VariantInferenceWriterFactory.java b/paimon-common/src/main/java/org/apache/paimon/format/variant/VariantInferenceWriterFactory.java
new file mode 100644
index 000000000000..8945c8c47433
--- /dev/null
+++ b/paimon-common/src/main/java/org/apache/paimon/format/variant/VariantInferenceWriterFactory.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.format.variant;
+
+import org.apache.paimon.format.FormatWriter;
+import org.apache.paimon.format.FormatWriterFactory;
+import org.apache.paimon.fs.PositionOutputStream;
+
+import java.io.IOException;
+
+/**
+ * A decorator factory that adds variant schema inference capability to any {@link
+ * FormatWriterFactory}.
+ *
+ *
This factory wraps an existing FormatWriterFactory and automatically enables variant schema
+ * inference if the delegate factory supports it (implements {@link SupportsVariantInference}) and
+ * the configuration enables inference.
+ */
+public class VariantInferenceWriterFactory implements FormatWriterFactory {
+
+ private final FormatWriterFactory delegate;
+ private final VariantInferenceConfig config;
+
+ public VariantInferenceWriterFactory(
+ FormatWriterFactory delegate, VariantInferenceConfig config) {
+ this.delegate = delegate;
+ this.config = config;
+ }
+
+ @Override
+ public FormatWriter create(PositionOutputStream out, String compression) throws IOException {
+ if (!config.shouldEnableInference()) {
+ return delegate.create(out, compression);
+ }
+
+ if (!(delegate instanceof SupportsVariantInference)) {
+ return delegate.create(out, compression);
+ }
+
+ return new InferVariantShreddingWriter(
+ (SupportsVariantInference) delegate,
+ config.createInferrer(),
+ config.getMaxBufferRow(),
+ out,
+ compression);
+ }
+}
diff --git a/paimon-format/src/main/java/org/apache/paimon/format/parquet/ParquetFileFormat.java b/paimon-format/src/main/java/org/apache/paimon/format/parquet/ParquetFileFormat.java
index 0350496d84d5..4e6c9de9a6f7 100644
--- a/paimon-format/src/main/java/org/apache/paimon/format/parquet/ParquetFileFormat.java
+++ b/paimon-format/src/main/java/org/apache/paimon/format/parquet/ParquetFileFormat.java
@@ -26,6 +26,8 @@
import org.apache.paimon.format.FormatWriterFactory;
import org.apache.paimon.format.SimpleStatsExtractor;
import org.apache.paimon.format.parquet.writer.RowDataParquetBuilder;
+import org.apache.paimon.format.variant.VariantInferenceConfig;
+import org.apache.paimon.format.variant.VariantInferenceWriterFactory;
import org.apache.paimon.options.MemorySize;
import org.apache.paimon.options.Options;
import org.apache.paimon.predicate.Predicate;
@@ -45,12 +47,14 @@
/** Parquet {@link FileFormat}. */
public class ParquetFileFormat extends FileFormat {
+ private final FormatContext formatContext;
private final Options options;
private final int readBatchSize;
public ParquetFileFormat(FormatContext formatContext) {
super(IDENTIFIER);
+ this.formatContext = formatContext;
this.options = getParquetConfiguration(formatContext);
this.readBatchSize = formatContext.readBatchSize();
}
@@ -85,7 +89,11 @@ public FormatReaderFactory createReaderFactory(
@Override
public FormatWriterFactory createWriterFactory(RowType type) {
- return new ParquetWriterFactory(new RowDataParquetBuilder(type, options));
+ ParquetWriterFactory baseFactory =
+ new ParquetWriterFactory(new RowDataParquetBuilder(type, options));
+ // Wrap with variant inference decorator
+ return new VariantInferenceWriterFactory(
+ baseFactory, new VariantInferenceConfig(type, formatContext.options()));
}
@Override
diff --git a/paimon-format/src/main/java/org/apache/paimon/format/parquet/ParquetOptions.java b/paimon-format/src/main/java/org/apache/paimon/format/parquet/ParquetOptions.java
deleted file mode 100644
index d264226b7e2c..000000000000
--- a/paimon-format/src/main/java/org/apache/paimon/format/parquet/ParquetOptions.java
+++ /dev/null
@@ -1,34 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.paimon.format.parquet;
-
-import org.apache.paimon.options.ConfigOption;
-
-import static org.apache.paimon.options.ConfigOptions.key;
-
-/** Options for parquet format. */
-public class ParquetOptions {
-
- public static final ConfigOption PARQUET_VARIANT_SHREDDING_SCHEMA =
- key("parquet.variant.shreddingSchema")
- .stringType()
- .noDefaultValue()
- .withDescription(
- "Specify the variant shredding schema for writing parquet files.");
-}
diff --git a/paimon-format/src/main/java/org/apache/paimon/format/parquet/ParquetWriterFactory.java b/paimon-format/src/main/java/org/apache/paimon/format/parquet/ParquetWriterFactory.java
index 006c2121aa74..282805897a5f 100644
--- a/paimon-format/src/main/java/org/apache/paimon/format/parquet/ParquetWriterFactory.java
+++ b/paimon-format/src/main/java/org/apache/paimon/format/parquet/ParquetWriterFactory.java
@@ -24,8 +24,11 @@
import org.apache.paimon.format.HadoopCompressionType;
import org.apache.paimon.format.parquet.writer.ParquetBuilder;
import org.apache.paimon.format.parquet.writer.ParquetBulkWriter;
+import org.apache.paimon.format.parquet.writer.RowDataParquetBuilder;
import org.apache.paimon.format.parquet.writer.StreamOutputFile;
+import org.apache.paimon.format.variant.SupportsVariantInference;
import org.apache.paimon.fs.PositionOutputStream;
+import org.apache.paimon.types.RowType;
import org.apache.parquet.hadoop.ParquetWriter;
import org.apache.parquet.io.OutputFile;
@@ -33,7 +36,7 @@
import java.io.IOException;
/** A factory that creates a Parquet {@link FormatWriter}. */
-public class ParquetWriterFactory implements FormatWriterFactory {
+public class ParquetWriterFactory implements FormatWriterFactory, SupportsVariantInference {
/** The builder to construct the ParquetWriter. */
private final ParquetBuilder writerBuilder;
@@ -53,7 +56,24 @@ public FormatWriter create(PositionOutputStream stream, String compression) thro
if (HadoopCompressionType.NONE.value().equals(compression)) {
compression = null;
}
+
final ParquetWriter writer = writerBuilder.createWriter(out, compression);
return new ParquetBulkWriter(writer);
}
+
+ @Override
+ public FormatWriter createWithShreddingSchema(
+ PositionOutputStream stream, String compression, RowType inferredShreddingSchema)
+ throws IOException {
+ final OutputFile out = new StreamOutputFile(stream);
+ if (HadoopCompressionType.NONE.value().equals(compression)) {
+ compression = null;
+ }
+
+ ParquetBuilder newBuilder =
+ ((RowDataParquetBuilder) writerBuilder)
+ .withShreddingSchemas(inferredShreddingSchema);
+ final ParquetWriter writer = newBuilder.createWriter(out, compression);
+ return new ParquetBulkWriter(writer);
+ }
}
diff --git a/paimon-format/src/main/java/org/apache/paimon/format/parquet/VariantUtils.java b/paimon-format/src/main/java/org/apache/paimon/format/parquet/VariantUtils.java
index 86267052e134..38b23edaddc5 100644
--- a/paimon-format/src/main/java/org/apache/paimon/format/parquet/VariantUtils.java
+++ b/paimon-format/src/main/java/org/apache/paimon/format/parquet/VariantUtils.java
@@ -18,8 +18,10 @@
package org.apache.paimon.format.parquet;
+import org.apache.paimon.CoreOptions;
import org.apache.paimon.data.variant.PaimonShreddingUtils;
import org.apache.paimon.data.variant.VariantAccessInfo;
+import org.apache.paimon.options.Options;
import org.apache.paimon.types.DataField;
import org.apache.paimon.types.DataType;
import org.apache.paimon.types.DataTypes;
@@ -27,7 +29,6 @@
import org.apache.paimon.types.VariantType;
import org.apache.paimon.utils.JsonSerdeUtil;
-import org.apache.hadoop.conf.Configuration;
import org.apache.parquet.schema.MessageType;
import javax.annotation.Nullable;
@@ -71,42 +72,34 @@ public static RowType[] extractShreddingSchemasFromParquetSchema(
return shreddingSchemas;
}
+ /** For writer, extract shredding schemas from conf. */
@Nullable
- public static RowType shreddingFields(Configuration conf) {
- String shreddingSchema =
- conf.get(ParquetOptions.PARQUET_VARIANT_SHREDDING_SCHEMA.key(), "");
- if (shreddingSchema.isEmpty()) {
+ public static RowType shreddingSchemasFromOptions(Options options) {
+ if (!options.contains(CoreOptions.VARIANT_SHREDDING_SCHEMA)) {
return null;
- } else {
- return (RowType) JsonSerdeUtil.fromJson(shreddingSchema, DataType.class);
}
- }
- /** For writer, extract shredding schemas from conf. */
- @Nullable
- public static RowType extractShreddingSchemaFromConf(Configuration conf, String fieldName) {
- RowType shreddingFields = shreddingFields(conf);
- if (shreddingFields != null && shreddingFields.containsField(fieldName)) {
- return PaimonShreddingUtils.variantShreddingSchema(
- shreddingFields.getField(fieldName).type());
- } else {
- return null;
+ String shreddingSchema = options.get(CoreOptions.VARIANT_SHREDDING_SCHEMA);
+ RowType rowType = (RowType) JsonSerdeUtil.fromJson(shreddingSchema, DataType.class);
+ ArrayList fields = new ArrayList<>();
+ for (DataField field : rowType.getFields()) {
+ fields.add(field.newType(PaimonShreddingUtils.variantShreddingSchema(field.type())));
}
+ return new RowType(fields);
}
- public static RowType replaceWithShreddingType(Configuration conf, RowType rowType) {
- RowType shreddingFields = shreddingFields(conf);
- if (shreddingFields == null) {
+ public static RowType replaceWithShreddingType(
+ RowType rowType, @Nullable RowType shreddingSchemas) {
+ if (shreddingSchemas == null) {
return rowType;
}
List newFields = new ArrayList<>();
for (DataField field : rowType.getFields()) {
+ // todo: support nested variant.
if (field.type() instanceof VariantType
- && shreddingFields.containsField(field.name())) {
- RowType shreddingSchema =
- PaimonShreddingUtils.variantShreddingSchema(
- shreddingFields.getField(field.name()).type());
+ && shreddingSchemas.containsField(field.name())) {
+ RowType shreddingSchema = (RowType) shreddingSchemas.getField(field.name()).type();
newFields.add(field.newType(shreddingSchema));
} else {
newFields.add(field);
diff --git a/paimon-format/src/main/java/org/apache/paimon/format/parquet/writer/ParquetRowDataBuilder.java b/paimon-format/src/main/java/org/apache/paimon/format/parquet/writer/ParquetRowDataBuilder.java
index b6a5ea361c23..14970e548e9b 100644
--- a/paimon-format/src/main/java/org/apache/paimon/format/parquet/writer/ParquetRowDataBuilder.java
+++ b/paimon-format/src/main/java/org/apache/paimon/format/parquet/writer/ParquetRowDataBuilder.java
@@ -29,6 +29,8 @@
import org.apache.parquet.io.api.RecordConsumer;
import org.apache.parquet.schema.MessageType;
+import javax.annotation.Nullable;
+
import java.util.HashMap;
import static org.apache.paimon.format.parquet.ParquetSchemaConverter.convertToParquetMessageType;
@@ -38,10 +40,13 @@ public class ParquetRowDataBuilder
extends ParquetWriter.Builder {
private final RowType rowType;
+ @Nullable private final RowType shreddingSchemas;
- public ParquetRowDataBuilder(OutputFile path, RowType rowType) {
+ public ParquetRowDataBuilder(
+ OutputFile path, RowType rowType, @Nullable RowType shreddingSchemas) {
super(path);
this.rowType = rowType;
+ this.shreddingSchemas = shreddingSchemas;
}
@Override
@@ -65,7 +70,7 @@ private ParquetWriteSupport(Configuration conf) {
this.conf = conf;
this.schema =
convertToParquetMessageType(
- VariantUtils.replaceWithShreddingType(conf, rowType));
+ VariantUtils.replaceWithShreddingType(rowType, shreddingSchemas));
}
@Override
@@ -75,7 +80,9 @@ public WriteContext init(Configuration configuration) {
@Override
public void prepareForWrite(RecordConsumer recordConsumer) {
- this.writer = new ParquetRowDataWriter(recordConsumer, rowType, schema, conf);
+ this.writer =
+ new ParquetRowDataWriter(
+ recordConsumer, rowType, schema, conf, shreddingSchemas);
}
@Override
diff --git a/paimon-format/src/main/java/org/apache/paimon/format/parquet/writer/ParquetRowDataWriter.java b/paimon-format/src/main/java/org/apache/paimon/format/parquet/writer/ParquetRowDataWriter.java
index c00734b02434..3629bb852b94 100644
--- a/paimon-format/src/main/java/org/apache/paimon/format/parquet/writer/ParquetRowDataWriter.java
+++ b/paimon-format/src/main/java/org/apache/paimon/format/parquet/writer/ParquetRowDataWriter.java
@@ -29,7 +29,6 @@
import org.apache.paimon.data.variant.Variant;
import org.apache.paimon.data.variant.VariantSchema;
import org.apache.paimon.format.parquet.ParquetSchemaConverter;
-import org.apache.paimon.format.parquet.VariantUtils;
import org.apache.paimon.types.ArrayType;
import org.apache.paimon.types.DataType;
import org.apache.paimon.types.DecimalType;
@@ -70,11 +69,17 @@ public class ParquetRowDataWriter {
private final Configuration conf;
private final RowWriter rowWriter;
private final RecordConsumer recordConsumer;
+ @Nullable private final RowType shreddingSchemas;
public ParquetRowDataWriter(
- RecordConsumer recordConsumer, RowType rowType, GroupType schema, Configuration conf) {
+ RecordConsumer recordConsumer,
+ RowType rowType,
+ GroupType schema,
+ Configuration conf,
+ @Nullable RowType shreddingSchemas) {
this.conf = conf;
this.recordConsumer = recordConsumer;
+ this.shreddingSchemas = shreddingSchemas;
this.rowWriter = new RowWriter(rowType, schema);
}
@@ -144,9 +149,11 @@ private FieldWriter createWriter(DataType t, Type type) {
} else if (t instanceof RowType && type instanceof GroupType) {
return new RowWriter((RowType) t, groupType);
} else if (t instanceof VariantType && type instanceof GroupType) {
- return new VariantWriter(
- groupType,
- VariantUtils.extractShreddingSchemaFromConf(conf, type.getName()));
+ RowType shreddingSchema =
+ shreddingSchemas != null && shreddingSchemas.containsField(type.getName())
+ ? (RowType) shreddingSchemas.getField(type.getName()).type()
+ : null;
+ return new VariantWriter(groupType, shreddingSchema);
} else {
throw new UnsupportedOperationException("Unsupported type: " + type);
}
diff --git a/paimon-format/src/main/java/org/apache/paimon/format/parquet/writer/RowDataParquetBuilder.java b/paimon-format/src/main/java/org/apache/paimon/format/parquet/writer/RowDataParquetBuilder.java
index 6ec349c124da..1a66a55130e1 100644
--- a/paimon-format/src/main/java/org/apache/paimon/format/parquet/writer/RowDataParquetBuilder.java
+++ b/paimon-format/src/main/java/org/apache/paimon/format/parquet/writer/RowDataParquetBuilder.java
@@ -20,6 +20,7 @@
import org.apache.paimon.data.InternalRow;
import org.apache.paimon.format.parquet.ColumnConfigParser;
+import org.apache.paimon.format.parquet.VariantUtils;
import org.apache.paimon.options.Options;
import org.apache.paimon.types.RowType;
@@ -30,6 +31,8 @@
import org.apache.parquet.hadoop.metadata.CompressionCodecName;
import org.apache.parquet.io.OutputFile;
+import javax.annotation.Nullable;
+
import java.io.IOException;
/** A {@link ParquetBuilder} for {@link InternalRow}. */
@@ -37,18 +40,25 @@ public class RowDataParquetBuilder implements ParquetBuilder {
private final RowType rowType;
private final Configuration conf;
+ @Nullable private RowType shreddingSchemas;
public RowDataParquetBuilder(RowType rowType, Options options) {
this.rowType = rowType;
this.conf = new Configuration(false);
+ this.shreddingSchemas = VariantUtils.shreddingSchemasFromOptions(options);
options.toMap().forEach(conf::set);
}
+ public RowDataParquetBuilder withShreddingSchemas(RowType shreddingSchemas) {
+ this.shreddingSchemas = shreddingSchemas;
+ return this;
+ }
+
@Override
public ParquetWriter createWriter(OutputFile out, String compression)
throws IOException {
ParquetRowDataBuilder builder =
- new ParquetRowDataBuilder(out, rowType)
+ new ParquetRowDataBuilder(out, rowType, shreddingSchemas)
.withConf(conf)
.withCompressionCodec(
CompressionCodecName.fromConf(getCompression(compression)))
diff --git a/paimon-format/src/test/java/org/apache/paimon/format/parquet/reader/FileTypeNotMatchReadTypeTest.java b/paimon-format/src/test/java/org/apache/paimon/format/parquet/reader/FileTypeNotMatchReadTypeTest.java
index 7afcbacc9d7a..c1d148116c1b 100644
--- a/paimon-format/src/test/java/org/apache/paimon/format/parquet/reader/FileTypeNotMatchReadTypeTest.java
+++ b/paimon-format/src/test/java/org/apache/paimon/format/parquet/reader/FileTypeNotMatchReadTypeTest.java
@@ -90,7 +90,9 @@ public void testTimestamp() throws Exception {
ParquetRowDataBuilder parquetRowDataBuilder =
new ParquetRowDataBuilder(
- new LocalOutputFile(new File(fileWholePath).toPath()), rowTypeWrite);
+ new LocalOutputFile(new File(fileWholePath).toPath()),
+ rowTypeWrite,
+ null);
ParquetWriter parquetWriter = parquetRowDataBuilder.build();
Timestamp timestamp = Timestamp.now();
@@ -129,7 +131,9 @@ public void testDecimal() throws Exception {
ParquetRowDataBuilder parquetRowDataBuilder =
new ParquetRowDataBuilder(
- new LocalOutputFile(new File(fileWholePath).toPath()), rowTypeWrite);
+ new LocalOutputFile(new File(fileWholePath).toPath()),
+ rowTypeWrite,
+ null);
ParquetWriter parquetWriter = parquetRowDataBuilder.build();
Decimal decimal =
diff --git a/paimon-format/src/test/java/org/apache/paimon/format/parquet/reader/ParquetRowDataBuilderForTest.java b/paimon-format/src/test/java/org/apache/paimon/format/parquet/reader/ParquetRowDataBuilderForTest.java
index 1f5535163b4c..cd5133238381 100644
--- a/paimon-format/src/test/java/org/apache/paimon/format/parquet/reader/ParquetRowDataBuilderForTest.java
+++ b/paimon-format/src/test/java/org/apache/paimon/format/parquet/reader/ParquetRowDataBuilderForTest.java
@@ -66,7 +66,8 @@ public WriteContext init(Configuration configuration) {
@Override
public void prepareForWrite(RecordConsumer recordConsumer) {
this.writer =
- new ParquetRowDataWriter(recordConsumer, rowType, schema, new Configuration());
+ new ParquetRowDataWriter(
+ recordConsumer, rowType, schema, new Configuration(), null);
}
@Override
diff --git a/paimon-format/src/test/java/org/apache/paimon/format/parquet/writer/InferVariantShreddingWriteTest.java b/paimon-format/src/test/java/org/apache/paimon/format/parquet/writer/InferVariantShreddingWriteTest.java
new file mode 100644
index 000000000000..9c8f041447c9
--- /dev/null
+++ b/paimon-format/src/test/java/org/apache/paimon/format/parquet/writer/InferVariantShreddingWriteTest.java
@@ -0,0 +1,582 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.format.parquet.writer;
+
+import org.apache.paimon.CoreOptions;
+import org.apache.paimon.data.BinaryString;
+import org.apache.paimon.data.GenericRow;
+import org.apache.paimon.data.InternalRow;
+import org.apache.paimon.data.serializer.InternalRowSerializer;
+import org.apache.paimon.data.variant.GenericVariant;
+import org.apache.paimon.data.variant.VariantAccessInfo;
+import org.apache.paimon.format.FileFormatFactory;
+import org.apache.paimon.format.FormatReaderContext;
+import org.apache.paimon.format.FormatWriter;
+import org.apache.paimon.format.FormatWriterFactory;
+import org.apache.paimon.format.SupportsDirectWrite;
+import org.apache.paimon.format.parquet.ParquetFileFormat;
+import org.apache.paimon.format.parquet.ParquetUtil;
+import org.apache.paimon.format.parquet.VariantUtils;
+import org.apache.paimon.format.variant.InferVariantShreddingWriter;
+import org.apache.paimon.fs.FileIO;
+import org.apache.paimon.fs.Path;
+import org.apache.paimon.fs.PositionOutputStream;
+import org.apache.paimon.fs.local.LocalFileIO;
+import org.apache.paimon.options.Options;
+import org.apache.paimon.reader.RecordReader;
+import org.apache.paimon.types.DataField;
+import org.apache.paimon.types.DataType;
+import org.apache.paimon.types.DataTypes;
+import org.apache.paimon.types.RowType;
+
+import org.apache.parquet.hadoop.ParquetFileReader;
+import org.apache.parquet.schema.MessageType;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.io.TempDir;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.UUID;
+
+import static org.apache.paimon.data.variant.PaimonShreddingUtils.variantShreddingSchema;
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Test for {@link InferVariantShreddingWriter}. */
+public class InferVariantShreddingWriteTest {
+
+ @TempDir java.nio.file.Path tempPath;
+
+ protected FileIO fileIO;
+ protected Path file;
+ protected Path parent;
+
+ @BeforeEach
+ public void beforeEach() {
+ this.fileIO = LocalFileIO.create();
+ this.parent = new Path(tempPath.toUri());
+ this.file = new Path(new Path(tempPath.toUri()), UUID.randomUUID() + ".parquet");
+ }
+
+ public Options defaultOptions() {
+ Options options = new Options();
+ options.set(CoreOptions.VARIANT_INFER_SHREDDING_SCHEMA.key(), "true");
+ return options;
+ }
+
+ @Test
+ public void testInferSchemaWithSimpleObject() throws Exception {
+ ParquetFileFormat format = createFormat();
+ RowType writeType = DataTypes.ROW(DataTypes.FIELD(0, "v", DataTypes.VARIANT()));
+
+ FormatWriterFactory factory = format.createWriterFactory(writeType);
+ writeRows(
+ factory,
+ GenericRow.of(GenericVariant.fromJson("{\"age\":30,\"name\":\"Alice\"}")),
+ GenericRow.of(GenericVariant.fromJson("{\"age\":25,\"name\":\"Bob\"}")),
+ GenericRow.of(GenericVariant.fromJson("{\"age\":35,\"name\":\"Charlie\"}")));
+
+ List result = readRows(format, writeType);
+ assertThat(result.get(0).getVariant(0).toJson())
+ .isEqualTo("{\"age\":30,\"name\":\"Alice\"}");
+ assertThat(result.get(1).getVariant(0).toJson()).isEqualTo("{\"age\":25,\"name\":\"Bob\"}");
+ assertThat(result.get(2).getVariant(0).toJson())
+ .isEqualTo("{\"age\":35,\"name\":\"Charlie\"}");
+
+ RowType expectShreddedType =
+ RowType.of(
+ new DataType[] {DataTypes.BIGINT(), DataTypes.STRING()},
+ new String[] {"age", "name"});
+ verifyShreddingSchema(writeType, expectShreddedType);
+
+ VariantAccessInfo variantAccess =
+ createVariantAccess("v", new DataField(0, "age", DataTypes.INT()), "$.age");
+ List result2 = readRowsWithVariantAccess(format, writeType, variantAccess);
+ assertThat(result2.get(0)).isEqualTo(GenericRow.of(GenericRow.of(30)));
+ assertThat(result2.get(1)).isEqualTo(GenericRow.of(GenericRow.of(25)));
+ assertThat(result2.get(2)).isEqualTo(GenericRow.of(GenericRow.of(35)));
+ }
+
+ @Test
+ public void testInferSchemaWithArray() throws Exception {
+ ParquetFileFormat format = createFormat();
+ RowType writeType = DataTypes.ROW(DataTypes.FIELD(0, "v", DataTypes.VARIANT()));
+
+ FormatWriterFactory factory = format.createWriterFactory(writeType);
+ writeRows(
+ factory,
+ GenericRow.of(GenericVariant.fromJson("{\"numbers\":[1,2,3]}")),
+ GenericRow.of(GenericVariant.fromJson("{\"numbers\":[4,5,6]}")));
+
+ List result = readRows(format, writeType);
+ assertThat(result.get(0).getVariant(0).toJson()).isEqualTo("{\"numbers\":[1,2,3]}");
+ assertThat(result.get(1).getVariant(0).toJson()).isEqualTo("{\"numbers\":[4,5,6]}");
+
+ VariantAccessInfo variantAccess =
+ createVariantAccess(
+ "v",
+ new DataField(0, "numbers", DataTypes.ARRAY(DataTypes.BIGINT())),
+ "$.numbers");
+ List result2 = readRowsWithVariantAccess(format, writeType, variantAccess);
+ // Verify the nested row structure and array content
+ InternalRow row1 = result2.get(0).getRow(0, 1);
+ assertThat(row1.getArray(0).size()).isEqualTo(3);
+ assertThat(row1.getArray(0).getLong(0)).isEqualTo(1L);
+ assertThat(row1.getArray(0).getLong(1)).isEqualTo(2L);
+ assertThat(row1.getArray(0).getLong(2)).isEqualTo(3L);
+
+ InternalRow row2 = result2.get(1).getRow(0, 1);
+ assertThat(row2.getArray(0).size()).isEqualTo(3);
+ assertThat(row2.getArray(0).getLong(0)).isEqualTo(4L);
+ assertThat(row2.getArray(0).getLong(1)).isEqualTo(5L);
+ assertThat(row2.getArray(0).getLong(2)).isEqualTo(6L);
+
+ RowType expectShreddedType =
+ RowType.of(
+ new DataType[] {DataTypes.ARRAY(DataTypes.BIGINT())},
+ new String[] {"numbers"});
+ verifyShreddingSchema(writeType, expectShreddedType);
+ }
+
+ @Test
+ public void testInferSchemaWithMixedTypes() throws Exception {
+ ParquetFileFormat format = createFormat();
+ RowType writeType = DataTypes.ROW(DataTypes.FIELD(0, "v", DataTypes.VARIANT()));
+
+ FormatWriterFactory factory = format.createWriterFactory(writeType);
+ writeRows(
+ factory,
+ GenericRow.of(
+ GenericVariant.fromJson(
+ "{\"str\":\"hello\",\"num\":42,\"bool\":true,\"dec\":3.14}")),
+ GenericRow.of(
+ GenericVariant.fromJson(
+ "{\"str\":\"world\",\"num\":100,\"bool\":false,\"dec\":2.71}")));
+
+ List result = readRows(format, writeType);
+ assertThat(result.get(0).getVariant(0).toJson())
+ .isEqualTo("{\"bool\":true,\"dec\":3.14,\"num\":42,\"str\":\"hello\"}");
+ assertThat(result.get(1).getVariant(0).toJson())
+ .isEqualTo("{\"bool\":false,\"dec\":2.71,\"num\":100,\"str\":\"world\"}");
+
+ RowType expectShreddedType =
+ RowType.of(
+ new DataType[] {
+ DataTypes.BOOLEAN(),
+ DataTypes.DECIMAL(18, 2),
+ DataTypes.BIGINT(),
+ DataTypes.STRING()
+ },
+ new String[] {"bool", "dec", "num", "str"});
+ verifyShreddingSchema(writeType, expectShreddedType);
+ }
+
+ @Test
+ public void testInferSchemaWithNullValues() throws Exception {
+ ParquetFileFormat format = createFormat();
+ RowType writeType = DataTypes.ROW(DataTypes.FIELD(0, "v", DataTypes.VARIANT()));
+
+ FormatWriterFactory factory = format.createWriterFactory(writeType);
+ writeRows(
+ factory,
+ GenericRow.of(GenericVariant.fromJson("{\"a\":1,\"b\":null}")),
+ GenericRow.of(GenericVariant.fromJson("{\"a\":2,\"b\":3}")));
+
+ List result = readRows(format, writeType);
+ assertThat(result.get(0).getVariant(0).toJson()).isEqualTo("{\"a\":1,\"b\":null}");
+ assertThat(result.get(1).getVariant(0).toJson()).isEqualTo("{\"a\":2,\"b\":3}");
+
+ RowType expectShreddedType =
+ RowType.of(
+ new DataType[] {DataTypes.BIGINT(), DataTypes.BIGINT()},
+ new String[] {"a", "b"});
+ verifyShreddingSchema(writeType, expectShreddedType);
+ }
+
+ @Test
+ public void testInferSchemaWithConflictingTypes() throws Exception {
+ ParquetFileFormat format = createFormat();
+ RowType writeType = DataTypes.ROW(DataTypes.FIELD(0, "v", DataTypes.VARIANT()));
+
+ FormatWriterFactory factory = format.createWriterFactory(writeType);
+ writeRows(
+ factory,
+ GenericRow.of(GenericVariant.fromJson("{\"field\":\"text\"}")),
+ GenericRow.of(GenericVariant.fromJson("{\"field\":123}")),
+ GenericRow.of(GenericVariant.fromJson("{\"field\":true}")));
+
+ List result = readRows(format, writeType);
+ assertThat(result.get(0).getVariant(0).toJson()).isEqualTo("{\"field\":\"text\"}");
+ assertThat(result.get(1).getVariant(0).toJson()).isEqualTo("{\"field\":123}");
+ assertThat(result.get(2).getVariant(0).toJson()).isEqualTo("{\"field\":true}");
+
+ // When types conflict, the field should be inferred as VARIANT type
+ RowType expectShreddedType =
+ RowType.of(new DataType[] {DataTypes.VARIANT()}, new String[] {"field"});
+ verifyShreddingSchema(writeType, expectShreddedType);
+
+ VariantAccessInfo variantAccess =
+ createVariantAccess("v", new DataField(0, "field", DataTypes.VARIANT()), "$.field");
+ List result2 = readRowsWithVariantAccess(format, writeType, variantAccess);
+ assertThat(result2.get(0).getRow(0, 1).getVariant(0).toJson()).isEqualTo("\"text\"");
+ assertThat(result2.get(1).getRow(0, 1).getVariant(0).toJson()).isEqualTo("123");
+ assertThat(result2.get(2).getRow(0, 1).getVariant(0).toJson()).isEqualTo("true");
+ }
+
+ @Test
+ public void testInferSchemaWithDeepNesting() throws Exception {
+ ParquetFileFormat format = createFormat();
+ RowType writeType = DataTypes.ROW(DataTypes.FIELD(0, "v", DataTypes.VARIANT()));
+
+ String deepJson = "{\"level1\":{\"level2\":{\"level3\":{\"value\":42}}}}";
+ FormatWriterFactory factory = format.createWriterFactory(writeType);
+ writeRows(factory, GenericRow.of(GenericVariant.fromJson(deepJson)));
+
+ List result = readRows(format, writeType);
+ assertThat(result.get(0).getVariant(0).toJson()).isEqualTo(deepJson);
+
+ // Deep nesting: level1 -> level2 -> level3 -> value
+ RowType level3Type =
+ RowType.of(new DataType[] {DataTypes.BIGINT()}, new String[] {"value"});
+ RowType level2Type = RowType.of(new DataType[] {level3Type}, new String[] {"level3"});
+ RowType level1Type = RowType.of(new DataType[] {level2Type}, new String[] {"level2"});
+ RowType expectShreddedType =
+ RowType.of(new DataType[] {level1Type}, new String[] {"level1"});
+ verifyShreddingSchema(writeType, expectShreddedType);
+ }
+
+ @Test
+ public void testMultipleVariantFields() throws Exception {
+ ParquetFileFormat format = createFormat();
+ RowType writeType =
+ DataTypes.ROW(
+ DataTypes.FIELD(0, "v1", DataTypes.VARIANT()),
+ DataTypes.FIELD(1, "v2", DataTypes.VARIANT()),
+ DataTypes.FIELD(2, "id", DataTypes.INT()));
+
+ GenericVariant variant1 = GenericVariant.fromJson("{\"name\":\"Alice\"}");
+ GenericVariant variant2 = GenericVariant.fromJson("{\"age\":30}");
+
+ FormatWriterFactory factory = format.createWriterFactory(writeType);
+ writeRows(
+ factory,
+ GenericRow.of(variant1, variant2, 1),
+ GenericRow.of(variant1, variant2, 2));
+
+ List result = readRows(format, writeType);
+ assertThat(result.get(0).getVariant(0).toJson()).isEqualTo("{\"name\":\"Alice\"}");
+ assertThat(result.get(0).getVariant(1).toJson()).isEqualTo("{\"age\":30}");
+ assertThat(result.get(0).getInt(2)).isEqualTo(1);
+ assertThat(result.get(1).getVariant(0).toJson()).isEqualTo("{\"name\":\"Alice\"}");
+ assertThat(result.get(1).getVariant(1).toJson()).isEqualTo("{\"age\":30}");
+ assertThat(result.get(1).getInt(2)).isEqualTo(2);
+
+ // v1 has "name" field
+ RowType expectShreddedType1 =
+ RowType.of(new DataType[] {DataTypes.STRING()}, new String[] {"name"});
+ // v2 has "age" field
+ RowType expectShreddedType2 =
+ RowType.of(new DataType[] {DataTypes.BIGINT()}, new String[] {"age"});
+ verifyShreddingSchema(writeType, expectShreddedType1, expectShreddedType2);
+ }
+
+ @Test
+ public void testInferSchemaWithAllPrimitiveTypes() throws Exception {
+ ParquetFileFormat format = createFormat();
+ RowType writeType = DataTypes.ROW(DataTypes.FIELD(0, "v", DataTypes.VARIANT()));
+
+ String json =
+ "{\"string\":\"test\",\"long\":123456789,\"double\":3.14159,\"boolean\":true,\"null\":null}";
+ // Fields are sorted alphabetically in variant
+ String expectedJson =
+ "{\"boolean\":true,\"double\":3.14159,\"long\":123456789,\"null\":null,\"string\":\"test\"}";
+ FormatWriterFactory factory = format.createWriterFactory(writeType);
+ writeRows(factory, GenericRow.of(GenericVariant.fromJson(json)));
+
+ List result = readRows(format, writeType);
+ assertThat(result.get(0).getVariant(0).toJson()).isEqualTo(expectedJson);
+
+ RowType expectShreddedType =
+ RowType.of(
+ new DataType[] {
+ DataTypes.BOOLEAN(),
+ DataTypes.DECIMAL(18, 5),
+ DataTypes.BIGINT(),
+ DataTypes.VARIANT(),
+ DataTypes.STRING()
+ },
+ new String[] {"boolean", "double", "long", "null", "string"});
+ verifyShreddingSchema(writeType, expectShreddedType);
+ }
+
+ @Test
+ public void testAllNullRecords() throws Exception {
+ ParquetFileFormat format = createFormat();
+ RowType writeType = DataTypes.ROW(DataTypes.FIELD(0, "v", DataTypes.VARIANT()));
+
+ FormatWriterFactory factory = format.createWriterFactory(writeType);
+ GenericRow[] rows = new GenericRow[10];
+ for (int i = 0; i < 10; i++) {
+ rows[i] = GenericRow.of((GenericVariant) null);
+ }
+ writeRows(factory, rows);
+
+ List result = readRows(format, writeType);
+ assertThat(result.size()).isEqualTo(10);
+ for (InternalRow row : result) {
+ assertThat(row.isNullAt(0)).isTrue();
+ }
+ }
+
+ @Test
+ public void testMixedNullAndValidRecords() throws Exception {
+ ParquetFileFormat format = createFormat();
+ RowType writeType = DataTypes.ROW(DataTypes.FIELD(0, "v", DataTypes.VARIANT()));
+
+ List rows = new ArrayList<>();
+ for (int i = 0; i < 20; i++) {
+ if (i % 3 == 0) {
+ rows.add(GenericRow.of((GenericVariant) null));
+ } else {
+ rows.add(
+ GenericRow.of(
+ GenericVariant.fromJson(
+ String.format("{\"id\":%d,\"value\":\"data%d\"}", i, i))));
+ }
+ }
+
+ FormatWriterFactory factory = format.createWriterFactory(writeType);
+ writeRows(factory, rows.toArray(new InternalRow[0]));
+
+ List result = readRows(format, writeType);
+ assertThat(result.size()).isEqualTo(20);
+ for (int i = 0; i < 20; i++) {
+ if (i % 3 == 0) {
+ assertThat(result.get(i).isNullAt(0)).isTrue();
+ } else {
+ assertThat(result.get(i).getVariant(0).toJson())
+ .isEqualTo(String.format("{\"id\":%d,\"value\":\"data%d\"}", i, i));
+ }
+ }
+
+ RowType expectShreddedType =
+ RowType.of(
+ new DataType[] {DataTypes.BIGINT(), DataTypes.STRING()},
+ new String[] {"id", "value"});
+ verifyShreddingSchema(writeType, expectShreddedType);
+
+ VariantAccessInfo variantAccess =
+ createVariantAccess("v", new DataField(0, "id", DataTypes.BIGINT()), "$.id");
+ List result2 = readRowsWithVariantAccess(format, writeType, variantAccess);
+ assertThat(result2.size()).isEqualTo(20);
+ for (int i = 0; i < 20; i++) {
+ if (i % 3 == 0) {
+ assertThat(result2.get(i).isNullAt(0)).isTrue();
+ } else {
+ assertThat(result2.get(i)).isEqualTo(GenericRow.of(GenericRow.of((long) i)));
+ }
+ }
+ }
+
+ @Test
+ public void testMaxInferBufferRowBoundary() throws Exception {
+ // Test with buffer size = 2
+ // First 2 rows have integer values, 3rd row has string value
+ // Schema will be inferred from first 2 rows (buffer size), so field becomes BIGINT
+ Options customOptions = defaultOptions();
+ customOptions.set(CoreOptions.VARIANT_SHREDDING_MAX_INFER_BUFFER_ROW, 2);
+ ParquetFileFormat format = createFormat(customOptions);
+
+ RowType writeType = DataTypes.ROW(DataTypes.FIELD(0, "v", DataTypes.VARIANT()));
+
+ FormatWriterFactory factory = format.createWriterFactory(writeType);
+ GenericRow[] rows = new GenericRow[3];
+ rows[0] = GenericRow.of(GenericVariant.fromJson("{\"value\":100}"));
+ rows[1] = GenericRow.of(GenericVariant.fromJson("{\"value\":200}"));
+ rows[2] = GenericRow.of(GenericVariant.fromJson("{\"value\":\"text\"}"));
+ writeRows(factory, rows);
+
+ List result = readRows(format, writeType);
+ assertThat(result.size()).isEqualTo(3);
+ assertThat(result.get(0).getVariant(0).toJson()).isEqualTo("{\"value\":100}");
+ assertThat(result.get(1).getVariant(0).toJson()).isEqualTo("{\"value\":200}");
+ assertThat(result.get(2).getVariant(0).toJson()).isEqualTo("{\"value\":\"text\"}");
+
+ // Schema should be inferred as BIGINT (based on first 2 rows in buffer)
+ RowType expectShreddedType =
+ RowType.of(new DataType[] {DataTypes.BIGINT()}, new String[] {"value"});
+ verifyShreddingSchema(writeType, expectShreddedType);
+
+ VariantAccessInfo variantAccess =
+ createVariantAccess("v", new DataField(0, "value", DataTypes.STRING()), "$.value");
+ List result2 = readRowsWithVariantAccess(format, writeType, variantAccess);
+ assertThat(result2.size()).isEqualTo(3);
+ assertThat(result2.get(0))
+ .isEqualTo(GenericRow.of(GenericRow.of(BinaryString.fromString("100"))));
+ assertThat(result2.get(1))
+ .isEqualTo(GenericRow.of(GenericRow.of(BinaryString.fromString("200"))));
+ assertThat(result2.get(2))
+ .isEqualTo(GenericRow.of(GenericRow.of(BinaryString.fromString("text"))));
+ }
+
+ @Test
+ public void testMaxInferBufferRowExactMatch() throws Exception {
+ // Test with buffer size = 5, write exactly 5 rows
+ Options customOptions = defaultOptions();
+ customOptions.set(CoreOptions.VARIANT_SHREDDING_MAX_INFER_BUFFER_ROW, 5);
+ ParquetFileFormat format = createFormat(customOptions);
+
+ RowType writeType = DataTypes.ROW(DataTypes.FIELD(0, "v", DataTypes.VARIANT()));
+
+ FormatWriterFactory factory = format.createWriterFactory(writeType);
+ GenericRow[] rows = new GenericRow[5];
+ for (int i = 0; i < 5; i++) {
+ rows[i] =
+ GenericRow.of(
+ GenericVariant.fromJson(
+ "{\"id\":" + i + ",\"name\":\"user" + i + "\"}"));
+ }
+ writeRows(factory, rows);
+
+ List result = readRows(format, writeType);
+ assertThat(result.size()).isEqualTo(5);
+
+ RowType expectShreddedType =
+ RowType.of(
+ new DataType[] {DataTypes.BIGINT(), DataTypes.STRING()},
+ new String[] {"id", "name"});
+ verifyShreddingSchema(writeType, expectShreddedType);
+ }
+
+ @Test
+ public void testMaxInferBufferRowBelowThreshold() throws Exception {
+ // Test with buffer size = 10, write only 3 rows
+ // Schema inferred at close time
+ Options customOptions = defaultOptions();
+ customOptions.set(CoreOptions.VARIANT_SHREDDING_MAX_INFER_BUFFER_ROW, 10);
+ ParquetFileFormat format = createFormat(customOptions);
+
+ RowType writeType = DataTypes.ROW(DataTypes.FIELD(0, "v", DataTypes.VARIANT()));
+
+ FormatWriterFactory factory = format.createWriterFactory(writeType);
+ writeRows(
+ factory,
+ GenericRow.of(GenericVariant.fromJson("{\"id\":1}")),
+ GenericRow.of(GenericVariant.fromJson("{\"id\":2}")),
+ GenericRow.of(GenericVariant.fromJson("{\"id\":3}")));
+
+ List result = readRows(format, writeType);
+ assertThat(result.size()).isEqualTo(3);
+
+ RowType expectShreddedType =
+ RowType.of(new DataType[] {DataTypes.BIGINT()}, new String[] {"id"});
+ verifyShreddingSchema(writeType, expectShreddedType);
+ }
+
+ protected ParquetFileFormat createFormat() {
+ return createFormat(defaultOptions());
+ }
+
+ protected ParquetFileFormat createFormat(Options options) {
+ return new ParquetFileFormat(new FileFormatFactory.FormatContext(options, 1024, 1024));
+ }
+
+ protected List readRows(ParquetFileFormat format, RowType rowType)
+ throws IOException {
+ List result = new ArrayList<>();
+ try (RecordReader reader =
+ format.createReaderFactory(rowType, rowType, new ArrayList<>())
+ .createReader(
+ new FormatReaderContext(fileIO, file, fileIO.getFileSize(file)))) {
+ InternalRowSerializer serializer = new InternalRowSerializer(rowType);
+ reader.forEachRemaining(row -> result.add(serializer.copy(row)));
+ }
+ return result;
+ }
+
+ protected VariantAccessInfo createVariantAccess(
+ String variantFieldName, DataField field, String path) {
+ List variantFields = new ArrayList<>();
+ variantFields.add(new VariantAccessInfo.VariantField(field, path));
+ return new VariantAccessInfo(variantFieldName, variantFields);
+ }
+
+ protected List readRowsWithVariantAccess(
+ ParquetFileFormat format, RowType writeType, VariantAccessInfo... variantAccessInfos)
+ throws IOException {
+ RowType readStructType = buildReadStructType(variantAccessInfos);
+ List result = new ArrayList<>();
+ try (RecordReader reader =
+ format.createReaderFactory(
+ writeType, writeType, new ArrayList<>(), variantAccessInfos)
+ .createReader(
+ new FormatReaderContext(fileIO, file, fileIO.getFileSize(file)))) {
+ InternalRowSerializer serializer = new InternalRowSerializer(readStructType);
+ reader.forEachRemaining(row -> result.add(serializer.copy(row)));
+ }
+ return result;
+ }
+
+ protected RowType buildReadStructType(VariantAccessInfo... variantAccessInfos) {
+ List fields = new ArrayList<>();
+ for (int i = 0; i < variantAccessInfos.length; i++) {
+ VariantAccessInfo variantAccessInfo = variantAccessInfos[i];
+ List variantFields = new ArrayList<>();
+ for (VariantAccessInfo.VariantField vf : variantAccessInfo.variantFields()) {
+ variantFields.add(vf.dataField());
+ }
+ RowType variantRowType = new RowType(variantFields);
+ fields.add(new DataField(i, variantAccessInfo.columnName(), variantRowType));
+ }
+ return new RowType(fields);
+ }
+
+ protected void writeRows(FormatWriterFactory factory, InternalRow... rows) throws IOException {
+ FormatWriter writer;
+ PositionOutputStream out = null;
+ if (factory instanceof SupportsDirectWrite) {
+ writer = ((SupportsDirectWrite) factory).create(fileIO, file, "zstd");
+ } else {
+ out = fileIO.newOutputStream(file, false);
+ writer = factory.create(out, "zstd");
+ }
+ for (InternalRow row : rows) {
+ writer.addElement(row);
+ }
+ writer.close();
+ if (out != null) {
+ out.close();
+ }
+ }
+
+ protected void verifyShreddingSchema(RowType writeType, RowType... expectShreddedTypes)
+ throws IOException {
+ try (ParquetFileReader reader =
+ ParquetUtil.getParquetReader(fileIO, file, fileIO.getFileSize(file))) {
+ MessageType schema = reader.getFooter().getFileMetaData().getSchema();
+ RowType[] rowTypes =
+ VariantUtils.extractShreddingSchemasFromParquetSchema(
+ writeType.getFields().toArray(new DataField[0]), schema);
+ for (int i = 0; i < expectShreddedTypes.length; i++) {
+ assertThat(rowTypes[i]).isEqualTo(variantShreddingSchema(expectShreddedTypes[i]));
+ }
+ }
+ }
+}
diff --git a/paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/VariantTest.scala b/paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/VariantTest.scala
index aafd1dc4b967..94e9ac683f02 100644
--- a/paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/VariantTest.scala
+++ b/paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/VariantTest.scala
@@ -18,4 +18,16 @@
package org.apache.paimon.spark.sql
-class VariantTest extends VariantTestBase {}
+import org.apache.spark.SparkConf
+
+class VariantTest extends VariantTestBase {
+ override protected def sparkConf: SparkConf = {
+ super.sparkConf.set("spark.paimon.variant.inferShreddingSchema", "false")
+ }
+}
+
+class VariantInferShreddingTest extends VariantTestBase {
+ override protected def sparkConf: SparkConf = {
+ super.sparkConf.set("spark.paimon.variant.inferShreddingSchema", "true")
+ }
+}