diff --git a/paimon-core/src/main/java/org/apache/paimon/catalog/TableQueryAuthResult.java b/paimon-core/src/main/java/org/apache/paimon/catalog/TableQueryAuthResult.java index dcc94031a8ec..c4c33f241a4b 100644 --- a/paimon-core/src/main/java/org/apache/paimon/catalog/TableQueryAuthResult.java +++ b/paimon-core/src/main/java/org/apache/paimon/catalog/TableQueryAuthResult.java @@ -50,4 +50,8 @@ public Predicate rowFilter() { public Map columnMasking() { return columnMasking; } + + public boolean isEmpty() { + return rowFilter == null && (columnMasking == null || columnMasking.isEmpty()); + } } diff --git a/paimon-core/src/main/java/org/apache/paimon/table/source/AbstractDataTableScan.java b/paimon-core/src/main/java/org/apache/paimon/table/source/AbstractDataTableScan.java index dcadfad1ff3a..c64b75557aa5 100644 --- a/paimon-core/src/main/java/org/apache/paimon/table/source/AbstractDataTableScan.java +++ b/paimon-core/src/main/java/org/apache/paimon/table/source/AbstractDataTableScan.java @@ -21,6 +21,7 @@ import org.apache.paimon.CoreOptions; import org.apache.paimon.CoreOptions.ChangelogProducer; import org.apache.paimon.Snapshot; +import org.apache.paimon.catalog.TableQueryAuthResult; import org.apache.paimon.consumer.Consumer; import org.apache.paimon.consumer.ConsumerManager; import org.apache.paimon.data.BinaryRow; @@ -64,14 +65,13 @@ import javax.annotation.Nullable; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.TimeZone; import static org.apache.paimon.CoreOptions.FULL_COMPACTION_DELTA_COMMITS; -import static org.apache.paimon.CoreOptions.IncrementalBetweenScanMode.CHANGELOG; -import static org.apache.paimon.CoreOptions.IncrementalBetweenScanMode.DELTA; import static org.apache.paimon.CoreOptions.IncrementalBetweenScanMode.DIFF; import static org.apache.paimon.utils.Preconditions.checkArgument; import static org.apache.paimon.utils.Preconditions.checkNotNull; @@ -87,6 +87,7 @@ abstract class AbstractDataTableScan implements DataTableScan { private final TableQueryAuth queryAuth; @Nullable private RowType readType; + @Nullable private TableQueryAuthResult authResult; protected AbstractDataTableScan( TableSchema schema, @@ -165,12 +166,34 @@ public AbstractDataTableScan withMetricRegistry(MetricRegistry metricsRegistry) return this; } - protected void authQuery() { + protected TableQueryAuthResult authQuery() { if (!options.queryAuthEnabled()) { - return; + return null; } - queryAuth.auth(readType == null ? null : readType.getFieldNames()); - // TODO add support for row level access control + if (authResult == null) { + authResult = queryAuth.auth(readType == null ? null : readType.getFieldNames()); + } + return authResult; + } + + protected TableScan.Plan applyAuthToSplits(Plan plan) { + TableQueryAuthResult authResult = authQuery(); + if (authResult == null || authResult.isEmpty()) { + return plan; + } + + List splits = plan.splits(); + List authSplits = new ArrayList<>(splits.size()); + for (Split split : splits) { + if (split instanceof DataSplit) { + DataSplit dataSplit = (DataSplit) split; + authSplits.add(QueryAuthSplit.wrap(dataSplit, authResult)); + } else { + authSplits.add(split); + } + } + + return new DataFilePlan<>(authSplits); } @Override diff --git a/paimon-core/src/main/java/org/apache/paimon/table/source/AuthAwareTableRead.java b/paimon-core/src/main/java/org/apache/paimon/table/source/AuthAwareTableRead.java new file mode 100644 index 000000000000..a7772b3c0adf --- /dev/null +++ b/paimon-core/src/main/java/org/apache/paimon/table/source/AuthAwareTableRead.java @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.table.source; + +import org.apache.paimon.catalog.TableQueryAuthResult; +import org.apache.paimon.data.GenericRow; +import org.apache.paimon.data.InternalRow; +import org.apache.paimon.disk.IOManager; +import org.apache.paimon.metrics.MetricRegistry; +import org.apache.paimon.predicate.CompoundPredicate; +import org.apache.paimon.predicate.FieldRef; +import org.apache.paimon.predicate.LeafPredicate; +import org.apache.paimon.predicate.Predicate; +import org.apache.paimon.predicate.PredicateVisitor; +import org.apache.paimon.predicate.Transform; +import org.apache.paimon.reader.RecordReader; +import org.apache.paimon.types.DataType; +import org.apache.paimon.types.RowType; + +import javax.annotation.Nullable; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.paimon.utils.InternalRowUtils.get; + +/** A {@link TableRead} wrapper that checks splits for authorization information. */ +public class AuthAwareTableRead implements TableRead { + + private final TableRead wrapped; + private final RowType outputRowType; + + public AuthAwareTableRead(TableRead wrapped, RowType outputRowType) { + this.wrapped = wrapped; + this.outputRowType = outputRowType; + } + + @Override + public TableRead withMetricRegistry(MetricRegistry registry) { + return new AuthAwareTableRead(wrapped.withMetricRegistry(registry), outputRowType); + } + + @Override + public TableRead executeFilter() { + return new AuthAwareTableRead(wrapped.executeFilter(), outputRowType); + } + + @Override + public TableRead withIOManager(IOManager ioManager) { + return new AuthAwareTableRead(wrapped.withIOManager(ioManager), outputRowType); + } + + @Override + public RecordReader createReader(Split split) throws IOException { + if (split instanceof QueryAuthSplit) { + TableQueryAuthResult authResult = ((QueryAuthSplit) split).authResult(); + if (authResult != null) { + RecordReader reader = + wrapped.createReader(((QueryAuthSplit) split).dataSplit()); + // Apply row-level filter if present + Predicate rowFilter = authResult.rowFilter(); + if (rowFilter != null) { + Predicate remappedFilter = remapPredicateToOutputRow(outputRowType, rowFilter); + if (remappedFilter != null) { + reader = new FilterRecordReader(reader, remappedFilter); + } + } + + // Apply column masking if present + Map columnMasking = authResult.columnMasking(); + if (columnMasking != null && !columnMasking.isEmpty()) { + MaskingApplier applier = new MaskingApplier(outputRowType, columnMasking); + reader = new MaskingRecordReader(reader, applier); + } + + return reader; + } + } + return wrapped.createReader(split); + } + + private static class FilterRecordReader implements RecordReader { + + private final RecordReader wrapped; + private final Predicate predicate; + + private FilterRecordReader(RecordReader wrapped, Predicate predicate) { + this.wrapped = wrapped; + this.predicate = predicate; + } + + @Nullable + @Override + public RecordIterator readBatch() throws IOException { + RecordIterator batch = wrapped.readBatch(); + if (batch == null) { + return null; + } + return new FilterRecordIterator(batch, predicate); + } + + @Override + public void close() throws IOException { + wrapped.close(); + } + } + + private static class FilterRecordIterator implements RecordReader.RecordIterator { + + private final RecordReader.RecordIterator wrapped; + private final Predicate predicate; + + private FilterRecordIterator( + RecordReader.RecordIterator wrapped, Predicate predicate) { + this.wrapped = wrapped; + this.predicate = predicate; + } + + @Nullable + @Override + public InternalRow next() throws IOException { + while (true) { + InternalRow row = wrapped.next(); + if (row == null) { + return null; + } + if (predicate.test(row)) { + return row; + } + } + } + + @Override + public void releaseBatch() { + wrapped.releaseBatch(); + } + } + + private static class MaskingRecordReader implements RecordReader { + + private final RecordReader wrapped; + private final MaskingApplier applier; + + private MaskingRecordReader(RecordReader wrapped, MaskingApplier applier) { + this.wrapped = wrapped; + this.applier = applier; + } + + @Nullable + @Override + public RecordIterator readBatch() throws IOException { + RecordIterator batch = wrapped.readBatch(); + if (batch == null) { + return null; + } + return batch.transform(applier::apply); + } + + @Override + public void close() throws IOException { + wrapped.close(); + } + } + + private static class MaskingApplier { + + private final RowType outputRowType; + private final Map remapped; + + private MaskingApplier(RowType outputRowType, Map masking) { + this.outputRowType = outputRowType; + this.remapped = remapToOutputRow(outputRowType, masking); + } + + private InternalRow apply(InternalRow row) { + if (remapped.isEmpty()) { + return row; + } + int arity = outputRowType.getFieldCount(); + GenericRow out = new GenericRow(row.getRowKind(), arity); + for (int i = 0; i < arity; i++) { + DataType type = outputRowType.getTypeAt(i); + out.setField(i, get(row, i, type)); + } + for (Map.Entry e : remapped.entrySet()) { + int targetIndex = e.getKey(); + Transform transform = e.getValue(); + Object masked = transform.transform(row); + out.setField(targetIndex, masked); + } + return out; + } + + private static Map remapToOutputRow( + RowType outputRowType, Map masking) { + Map out = new HashMap<>(); + if (masking == null || masking.isEmpty()) { + return out; + } + + for (Map.Entry e : masking.entrySet()) { + String targetColumn = e.getKey(); + Transform transform = e.getValue(); + if (targetColumn == null || transform == null) { + continue; + } + + int targetIndex = outputRowType.getFieldIndex(targetColumn); + if (targetIndex < 0) { + continue; + } + + List newInputs = new ArrayList<>(); + for (Object input : transform.inputs()) { + if (input instanceof FieldRef) { + FieldRef ref = (FieldRef) input; + int newIndex = outputRowType.getFieldIndex(ref.name()); + if (newIndex < 0) { + throw new IllegalArgumentException( + "Column masking refers to field '" + + ref.name() + + "' which is not present in output row type " + + outputRowType); + } + DataType type = outputRowType.getTypeAt(newIndex); + newInputs.add(new FieldRef(newIndex, ref.name(), type)); + } else { + newInputs.add(input); + } + } + out.put(targetIndex, transform.copyWithNewInputs(newInputs)); + } + return out; + } + } + + private static Predicate remapPredicateToOutputRow(RowType outputRowType, Predicate predicate) { + return predicate.visit(new PredicateRemapper(outputRowType)); + } + + private static class PredicateRemapper implements PredicateVisitor { + private final RowType outputRowType; + + private PredicateRemapper(RowType outputRowType) { + this.outputRowType = outputRowType; + } + + @Override + public Predicate visit(LeafPredicate predicate) { + Transform transform = predicate.transform(); + List newInputs = new ArrayList<>(); + boolean hasUnmappedField = false; + for (Object input : transform.inputs()) { + if (input instanceof FieldRef) { + FieldRef ref = (FieldRef) input; + String fieldName = ref.name(); + int newIndex = outputRowType.getFieldIndex(fieldName); + if (newIndex < 0) { + hasUnmappedField = true; + break; + } + DataType type = outputRowType.getTypeAt(newIndex); + newInputs.add(new FieldRef(newIndex, fieldName, type)); + } else { + newInputs.add(input); + } + } + if (hasUnmappedField) { + return null; + } + return predicate.copyWithNewInputs(newInputs); + } + + @Override + public Predicate visit(CompoundPredicate predicate) { + List remappedChildren = new ArrayList<>(); + for (Predicate child : predicate.children()) { + Predicate remapped = child.visit(this); + if (remapped != null) { + remappedChildren.add(remapped); + } + } + if (remappedChildren.isEmpty()) { + return null; + } + if (remappedChildren.size() == 1) { + return remappedChildren.get(0); + } + return new CompoundPredicate(predicate.function(), remappedChildren); + } + } +} diff --git a/paimon-core/src/main/java/org/apache/paimon/table/source/DataTableBatchScan.java b/paimon-core/src/main/java/org/apache/paimon/table/source/DataTableBatchScan.java index 4419b3890ef5..9fe11c44fdbf 100644 --- a/paimon-core/src/main/java/org/apache/paimon/table/source/DataTableBatchScan.java +++ b/paimon-core/src/main/java/org/apache/paimon/table/source/DataTableBatchScan.java @@ -91,8 +91,6 @@ public InnerTableScan withTopN(TopN topN) { @Override public TableScan.Plan plan() { - authQuery(); - if (startingScanner == null) { startingScanner = createStartingScanner(false); } @@ -101,13 +99,13 @@ public TableScan.Plan plan() { hasNext = false; Optional pushed = applyPushDownLimit(); if (pushed.isPresent()) { - return DataFilePlan.fromResult(pushed.get()); + return applyAuthToSplits(DataFilePlan.fromResult(pushed.get())); } pushed = applyPushDownTopN(); if (pushed.isPresent()) { - return DataFilePlan.fromResult(pushed.get()); + return applyAuthToSplits(DataFilePlan.fromResult(pushed.get())); } - return DataFilePlan.fromResult(startingScanner.scan(snapshotReader)); + return applyAuthToSplits(DataFilePlan.fromResult(startingScanner.scan(snapshotReader))); } else { throw new EndOfScanException(); } diff --git a/paimon-core/src/main/java/org/apache/paimon/table/source/DataTableStreamScan.java b/paimon-core/src/main/java/org/apache/paimon/table/source/DataTableStreamScan.java index 16bc8509ebd2..678cf0f4dfd8 100644 --- a/paimon-core/src/main/java/org/apache/paimon/table/source/DataTableStreamScan.java +++ b/paimon-core/src/main/java/org/apache/paimon/table/source/DataTableStreamScan.java @@ -118,8 +118,6 @@ public StartingContext startingContext() { @Override public Plan plan() { - authQuery(); - if (!initialized) { initScanner(); } @@ -182,7 +180,7 @@ private Plan tryFirstPlan() { "Starting snapshot is {}, next snapshot will be {}.", scannedResult.plan().snapshotId(), nextSnapshotId); - return scannedResult.plan(); + return applyAuthToSplits(scannedResult.plan()); } else if (result instanceof StartingScanner.NextSnapshot) { nextSnapshotId = ((StartingScanner.NextSnapshot) result).nextSnapshotId(); isFullPhaseEnd = @@ -223,7 +221,7 @@ private Plan nextPlan() { if (overwritePlan.splits().isEmpty()) { continue; } - return overwritePlan; + return applyAuthToSplits(overwritePlan); } } @@ -235,7 +233,7 @@ private Plan nextPlan() { if (plan.splits().isEmpty()) { continue; } - return plan; + return applyAuthToSplits(plan); } else { nextSnapshotId++; } diff --git a/paimon-core/src/main/java/org/apache/paimon/table/source/QueryAuthSplit.java b/paimon-core/src/main/java/org/apache/paimon/table/source/QueryAuthSplit.java new file mode 100644 index 000000000000..9daa9fe9ea63 --- /dev/null +++ b/paimon-core/src/main/java/org/apache/paimon/table/source/QueryAuthSplit.java @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.table.source; + +import org.apache.paimon.catalog.TableQueryAuthResult; +import org.apache.paimon.data.BinaryRow; +import org.apache.paimon.io.DataFileMeta; +import org.apache.paimon.io.DataInputView; +import org.apache.paimon.io.DataInputViewStreamWrapper; +import org.apache.paimon.io.DataOutputView; +import org.apache.paimon.stats.SimpleStatsEvolutions; + +import javax.annotation.Nullable; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.OptionalLong; + +/** + * A wrapper class for {@link DataSplit} that adds query authorization information. This class + * delegates all Split interface methods to the wrapped DataSplit, while providing additional auth + * result functionality. + */ +public class QueryAuthSplit extends DataSplit { + + private static final long serialVersionUID = 1L; + + private DataSplit dataSplit; + @Nullable private TableQueryAuthResult authResult; + + public QueryAuthSplit(DataSplit dataSplit, @Nullable TableQueryAuthResult authResult) { + this.dataSplit = dataSplit; + this.authResult = authResult; + } + + public DataSplit dataSplit() { + return dataSplit; + } + + @Nullable + public TableQueryAuthResult authResult() { + return authResult; + } + + // Delegate all DataSplit methods to the wrapped instance + + public long snapshotId() { + return dataSplit.snapshotId(); + } + + public BinaryRow partition() { + return dataSplit.partition(); + } + + public int bucket() { + return dataSplit.bucket(); + } + + public String bucketPath() { + return dataSplit.bucketPath(); + } + + @Nullable + public Integer totalBuckets() { + return dataSplit.totalBuckets(); + } + + public List beforeFiles() { + return dataSplit.beforeFiles(); + } + + public Optional> beforeDeletionFiles() { + return dataSplit.beforeDeletionFiles(); + } + + public List dataFiles() { + return dataSplit.dataFiles(); + } + + @Override + public Optional> deletionFiles() { + return dataSplit.deletionFiles(); + } + + public boolean isStreaming() { + return dataSplit.isStreaming(); + } + + public boolean rawConvertible() { + return dataSplit.rawConvertible(); + } + + public OptionalLong latestFileCreationEpochMillis() { + return dataSplit.latestFileCreationEpochMillis(); + } + + public OptionalLong earliestFileCreationEpochMillis() { + return dataSplit.earliestFileCreationEpochMillis(); + } + + public long rowCount() { + return dataSplit.rowCount(); + } + + public boolean mergedRowCountAvailable() { + return dataSplit.mergedRowCountAvailable(); + } + + public long mergedRowCount() { + return dataSplit.mergedRowCount(); + } + + public Object minValue( + int fieldIndex, + org.apache.paimon.types.DataField dataField, + SimpleStatsEvolutions evolutions) { + return dataSplit.minValue(fieldIndex, dataField, evolutions); + } + + public Object maxValue( + int fieldIndex, + org.apache.paimon.types.DataField dataField, + SimpleStatsEvolutions evolutions) { + return dataSplit.maxValue(fieldIndex, dataField, evolutions); + } + + public Long nullCount(int fieldIndex, SimpleStatsEvolutions evolutions) { + return dataSplit.nullCount(fieldIndex, evolutions); + } + + public long partialMergedRowCount() { + return dataSplit.partialMergedRowCount(); + } + + @Override + public Optional> convertToRawFiles() { + return dataSplit.convertToRawFiles(); + } + + @Override + @Nullable + public Optional> indexFiles() { + return dataSplit.indexFiles(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + QueryAuthSplit that = (QueryAuthSplit) o; + return Objects.equals(dataSplit, that.dataSplit) + && Objects.equals(authResult, that.authResult); + } + + @Override + public int hashCode() { + return Objects.hash(dataSplit, authResult); + } + + @Override + public String toString() { + return "QueryAuthSplit{" + "dataSplit=" + dataSplit + ", authResult=" + authResult + '}'; + } + + private void writeObject(ObjectOutputStream out) throws IOException { + serialize(new org.apache.paimon.io.DataOutputViewStreamWrapper(out)); + } + + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + QueryAuthSplit other = deserialize(new DataInputViewStreamWrapper(in)); + this.dataSplit = other.dataSplit; + this.authResult = other.authResult; + } + + public void serialize(DataOutputView out) throws IOException { + // Serialize the wrapped DataSplit + dataSplit.serialize(out); + + // Serialize authResult + if (authResult != null) { + out.writeBoolean(true); + TableQueryAuthResultSerializer.serialize(authResult, out); + } else { + out.writeBoolean(false); + } + } + + public static QueryAuthSplit deserialize(DataInputView in) throws IOException { + // Deserialize the wrapped DataSplit + DataSplit dataSplit = DataSplit.deserialize(in); + + // Deserialize authResult + TableQueryAuthResult authResult = null; + if (in.readBoolean()) { + authResult = TableQueryAuthResultSerializer.deserialize(in); + } + + return new QueryAuthSplit(dataSplit, authResult); + } + + public static QueryAuthSplit wrap( + DataSplit dataSplit, @Nullable TableQueryAuthResult authResult) { + if (authResult == null || authResult.isEmpty()) { + return new QueryAuthSplit(dataSplit, null); + } + return new QueryAuthSplit(dataSplit, authResult); + } +} diff --git a/paimon-core/src/main/java/org/apache/paimon/table/source/ReadBuilderImpl.java b/paimon-core/src/main/java/org/apache/paimon/table/source/ReadBuilderImpl.java index c81dfd8e01dd..8da04edae0d0 100644 --- a/paimon-core/src/main/java/org/apache/paimon/table/source/ReadBuilderImpl.java +++ b/paimon-core/src/main/java/org/apache/paimon/table/source/ReadBuilderImpl.java @@ -26,6 +26,7 @@ import org.apache.paimon.predicate.PredicateBuilder; import org.apache.paimon.predicate.TopN; import org.apache.paimon.predicate.VectorSearch; +import org.apache.paimon.table.FileStoreTable; import org.apache.paimon.table.InnerTable; import org.apache.paimon.types.RowType; import org.apache.paimon.utils.Filter; @@ -256,6 +257,12 @@ public TableRead newRead() { if (variantAccessInfo != null) { read.withVariantAccess(variantAccessInfo); } + if (table instanceof FileStoreTable) { + CoreOptions options = new CoreOptions(table.options()); + if (options.queryAuthEnabled()) { + return new AuthAwareTableRead(read, readType()); + } + } return read; } diff --git a/paimon-core/src/main/java/org/apache/paimon/table/source/TableQueryAuthResultSerializer.java b/paimon-core/src/main/java/org/apache/paimon/table/source/TableQueryAuthResultSerializer.java new file mode 100644 index 000000000000..c557ee08e006 --- /dev/null +++ b/paimon-core/src/main/java/org/apache/paimon/table/source/TableQueryAuthResultSerializer.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.table.source; + +import org.apache.paimon.catalog.TableQueryAuthResult; +import org.apache.paimon.io.DataInputView; +import org.apache.paimon.io.DataOutputView; +import org.apache.paimon.predicate.Predicate; +import org.apache.paimon.predicate.Transform; +import org.apache.paimon.rest.RESTApi; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +/** Serializer for {@link TableQueryAuthResult}. */ +public class TableQueryAuthResultSerializer { + public static void serialize(TableQueryAuthResult authResult, DataOutputView out) + throws IOException { + // Serialize row filter + if (authResult.rowFilter() != null) { + out.writeBoolean(true); + String predicateJson = RESTApi.toJson(authResult.rowFilter()); + out.writeUTF(predicateJson); + } else { + out.writeBoolean(false); + } + + // Serialize column masking + Map columnMasking = authResult.columnMasking(); + out.writeInt(columnMasking.size()); + for (Map.Entry entry : columnMasking.entrySet()) { + out.writeUTF(entry.getKey()); + String transformJson = RESTApi.toJson(entry.getValue()); + out.writeUTF(transformJson); + } + } + + public static TableQueryAuthResult deserialize(DataInputView in) throws IOException { + // Deserialize row filter + Predicate rowFilter = null; + if (in.readBoolean()) { + String predicateJson = in.readUTF(); + rowFilter = RESTApi.fromJson(predicateJson, Predicate.class); + } + + // Deserialize column masking + int maskingSize = in.readInt(); + Map columnMasking = new HashMap<>(maskingSize); + for (int i = 0; i < maskingSize; i++) { + String columnName = in.readUTF(); + String transformJson = in.readUTF(); + Transform transform = RESTApi.fromJson(transformJson, Transform.class); + columnMasking.put(columnName, transform); + } + + return new TableQueryAuthResult(rowFilter, columnMasking); + } +} diff --git a/paimon-core/src/test/java/org/apache/paimon/rest/MockRESTCatalogTest.java b/paimon-core/src/test/java/org/apache/paimon/rest/MockRESTCatalogTest.java index 7a52bf1af9c8..f628a68ea01e 100644 --- a/paimon-core/src/test/java/org/apache/paimon/rest/MockRESTCatalogTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/rest/MockRESTCatalogTest.java @@ -26,13 +26,25 @@ import org.apache.paimon.catalog.CatalogContext; import org.apache.paimon.catalog.Identifier; import org.apache.paimon.catalog.TableQueryAuthResult; +import org.apache.paimon.data.BinaryString; +import org.apache.paimon.data.GenericRow; +import org.apache.paimon.data.InternalRow; import org.apache.paimon.options.CatalogOptions; import org.apache.paimon.options.Options; +import org.apache.paimon.predicate.CastTransform; +import org.apache.paimon.predicate.ConcatTransform; +import org.apache.paimon.predicate.ConcatWsTransform; +import org.apache.paimon.predicate.Equal; import org.apache.paimon.predicate.FieldRef; +import org.apache.paimon.predicate.FieldTransform; +import org.apache.paimon.predicate.GreaterOrEqual; +import org.apache.paimon.predicate.GreaterThan; +import org.apache.paimon.predicate.LeafPredicate; import org.apache.paimon.predicate.Predicate; import org.apache.paimon.predicate.PredicateBuilder; import org.apache.paimon.predicate.Transform; import org.apache.paimon.predicate.UpperTransform; +import org.apache.paimon.reader.RecordReader; import org.apache.paimon.rest.auth.AuthProvider; import org.apache.paimon.rest.auth.AuthProviderEnum; import org.apache.paimon.rest.auth.BearTokenAuthProvider; @@ -41,9 +53,17 @@ import org.apache.paimon.rest.auth.DLFTokenLoaderFactory; import org.apache.paimon.rest.auth.RESTAuthParameter; import org.apache.paimon.rest.exceptions.NotAuthorizedException; -import org.apache.paimon.rest.responses.AuthTableQueryResponse; import org.apache.paimon.rest.responses.ConfigResponse; import org.apache.paimon.schema.Schema; +import org.apache.paimon.table.Table; +import org.apache.paimon.table.sink.BatchTableCommit; +import org.apache.paimon.table.sink.BatchTableWrite; +import org.apache.paimon.table.sink.BatchWriteBuilder; +import org.apache.paimon.table.sink.CommitMessage; +import org.apache.paimon.table.source.ReadBuilder; +import org.apache.paimon.table.source.Split; +import org.apache.paimon.table.source.TableRead; +import org.apache.paimon.types.DataField; import org.apache.paimon.types.DataTypes; import org.apache.paimon.types.RowType; import org.apache.paimon.utils.JsonSerdeUtil; @@ -56,12 +76,15 @@ import java.io.File; import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.UUID; +import static org.apache.paimon.CoreOptions.QUERY_AUTH_ENABLED; import static org.apache.paimon.catalog.Catalog.TABLE_DEFAULT_OPTION_PREFIX; import static org.apache.paimon.rest.RESTApi.HEADER_PREFIX; import static org.assertj.core.api.Assertions.assertThat; @@ -273,14 +296,13 @@ void testAuthTableQueryResponseWithColumnMasking() throws Exception { Transform transform = new UpperTransform( Collections.singletonList(new FieldRef(1, "col2", DataTypes.STRING()))); - String transformJson = JsonSerdeUtil.toFlatJson(transform); // Set up mock response with filter and columnMasking - List filter = Collections.singletonList(predicateJson); - Map columnMasking = new HashMap<>(); - columnMasking.put("col2", transformJson); - AuthTableQueryResponse response = new AuthTableQueryResponse(filter, columnMasking); - restCatalogServer.setTableQueryAuthResponse(identifier, response); + List rowFilters = Collections.singletonList(predicate); + Map columnMasking = new HashMap<>(); + columnMasking.put("col2", transform); + restCatalogServer.setRowFilterAuth(identifier, rowFilters); + restCatalogServer.setColumnMaskingAuth(identifier, columnMasking); TableQueryAuthResult result = catalog.authTableQuery(identifier, null); assertThat(result.rowFilter()).isEqualTo(predicate); @@ -292,6 +314,453 @@ void testAuthTableQueryResponseWithColumnMasking() throws Exception { catalog.dropDatabase(identifier.getDatabaseName(), true, true); } + @Test + void testColumnMaskingApplyOnRead() throws Exception { + Identifier identifier = Identifier.create("test_table_db", "auth_table_masking_apply"); + catalog.createDatabase(identifier.getDatabaseName(), true); + + // Create table with multiple columns of different types + List fields = new ArrayList<>(); + fields.add(new DataField(0, "col1", DataTypes.STRING())); + fields.add(new DataField(1, "col2", DataTypes.STRING())); + fields.add(new DataField(2, "col3", DataTypes.INT())); + fields.add(new DataField(3, "col4", DataTypes.STRING())); + fields.add(new DataField(4, "col5", DataTypes.STRING())); + + catalog.createTable( + identifier, + new Schema( + fields, + Collections.emptyList(), + Collections.emptyList(), + Collections.singletonMap(QUERY_AUTH_ENABLED.key(), "true"), + ""), + true); + + Table table = catalog.getTable(identifier); + + // Write test data + BatchWriteBuilder writeBuilder = table.newBatchWriteBuilder(); + BatchTableWrite write = writeBuilder.newWrite(); + write.write( + GenericRow.of( + BinaryString.fromString("hello"), + BinaryString.fromString("world"), + 100, + BinaryString.fromString("test"), + BinaryString.fromString("data"))); + write.write( + GenericRow.of( + BinaryString.fromString("foo"), + BinaryString.fromString("bar"), + 200, + BinaryString.fromString("example"), + BinaryString.fromString("value"))); + List messages = write.prepareCommit(); + BatchTableCommit commit = writeBuilder.newCommit(); + commit.commit(messages); + write.close(); + commit.close(); + + // Set up column masking with various transform types + Map columnMasking = new HashMap<>(); + + // Test 1: ConcatTransform - mask col1 with "****" + ConcatTransform concatTransform = + new ConcatTransform(Collections.singletonList(BinaryString.fromString("****"))); + columnMasking.put("col1", concatTransform); + + // Test 2: UpperTransform - convert col2 to uppercase + UpperTransform upperTransform = + new UpperTransform( + Collections.singletonList(new FieldRef(1, "col2", DataTypes.STRING()))); + columnMasking.put("col2", upperTransform); + + // Test 3: CastTransform - cast col3 (INT) to STRING + CastTransform castTransform = + new CastTransform(new FieldRef(2, "col3", DataTypes.INT()), DataTypes.STRING()); + columnMasking.put("col3", castTransform); + + // Test 4: ConcatWsTransform - concatenate col4 with separator + ConcatWsTransform concatWsTransform = + new ConcatWsTransform( + java.util.Arrays.asList( + BinaryString.fromString("-"), + BinaryString.fromString("prefix"), + new FieldRef(3, "col4", DataTypes.STRING()))); + columnMasking.put("col4", concatWsTransform); + + // col5 is intentionally not masked to verify unmasked columns work correctly + + restCatalogServer.setColumnMaskingAuth(identifier, columnMasking); + + // Read and verify masked data + ReadBuilder readBuilder = table.newReadBuilder(); + List splits = readBuilder.newScan().plan().splits(); + TableRead read = readBuilder.newRead(); + RecordReader reader = read.createReader(splits); + + List rows = new ArrayList<>(); + reader.forEachRemaining(rows::add); + + assertThat(rows).hasSize(2); + + // Verify first row + InternalRow row1 = rows.get(0); + assertThat(row1.getString(0).toString()) + .isEqualTo("****"); // col1 masked with ConcatTransform + assertThat(row1.getString(1).toString()) + .isEqualTo("WORLD"); // col2 masked with UpperTransform + assertThat(row1.getString(2).toString()) + .isEqualTo("100"); // col3 masked with CastTransform (INT->STRING) + assertThat(row1.getString(3).toString()) + .isEqualTo("prefix-test"); // col4 masked with ConcatWsTransform + assertThat(row1.getString(4).toString()) + .isEqualTo("data"); // col5 NOT masked - original value + + // Verify second row + InternalRow row2 = rows.get(1); + assertThat(row2.getString(0).toString()) + .isEqualTo("****"); // col1 masked with ConcatTransform + assertThat(row2.getString(1).toString()) + .isEqualTo("BAR"); // col2 masked with UpperTransform + assertThat(row2.getString(2).toString()) + .isEqualTo("200"); // col3 masked with CastTransform (INT->STRING) + assertThat(row2.getString(3).toString()) + .isEqualTo("prefix-example"); // col4 masked with ConcatWsTransform + assertThat(row2.getString(4).toString()) + .isEqualTo("value"); // col5 NOT masked - original value + } + + @Test + void testRowFilter() throws Exception { + Identifier identifier = Identifier.create("test_table_db", "auth_table_filter"); + catalog.createDatabase(identifier.getDatabaseName(), true); + + // Create table with multiple data types + List fields = new ArrayList<>(); + fields.add(new DataField(0, "id", DataTypes.INT())); + fields.add(new DataField(1, "name", DataTypes.STRING())); + fields.add(new DataField(2, "age", DataTypes.BIGINT())); + fields.add(new DataField(3, "salary", DataTypes.DOUBLE())); + fields.add(new DataField(4, "is_active", DataTypes.BOOLEAN())); + fields.add(new DataField(5, "score", DataTypes.FLOAT())); + + catalog.createTable( + identifier, + new Schema( + fields, + Collections.emptyList(), + Collections.emptyList(), + Collections.singletonMap(QUERY_AUTH_ENABLED.key(), "true"), + ""), + true); + + Table table = catalog.getTable(identifier); + + // Write test data with various types + BatchWriteBuilder writeBuilder = table.newBatchWriteBuilder(); + BatchTableWrite write = writeBuilder.newWrite(); + write.write(GenericRow.of(1, BinaryString.fromString("Alice"), 25L, 50000.0, true, 85.5f)); + write.write(GenericRow.of(2, BinaryString.fromString("Bob"), 30L, 60000.0, false, 90.0f)); + write.write( + GenericRow.of(3, BinaryString.fromString("Charlie"), 35L, 70000.0, true, 95.5f)); + write.write(GenericRow.of(4, BinaryString.fromString("David"), 28L, 55000.0, true, 88.0f)); + List messages = write.prepareCommit(); + BatchTableCommit commit = writeBuilder.newCommit(); + commit.commit(messages); + write.close(); + commit.close(); + + // Test 1: Filter by INT type (id > 2) + LeafPredicate intFilterPredicate = + LeafPredicate.of( + new FieldTransform(new FieldRef(0, "id", DataTypes.INT())), + GreaterThan.INSTANCE, + Collections.singletonList(2)); + restCatalogServer.setRowFilterAuth( + identifier, Collections.singletonList(intFilterPredicate)); + + List result1 = batchRead(table); + assertThat(result1).hasSize(2); + assertThat(result1) + .contains( + "+I[3, Charlie, 35, 70000.0, true, 95.5]", + "+I[4, David, 28, 55000.0, true, 88.0]"); + + // Test 2: Filter by BIGINT type (age >= 30) + LeafPredicate bigintFilterPredicate = + LeafPredicate.of( + new FieldTransform(new FieldRef(2, "age", DataTypes.BIGINT())), + GreaterOrEqual.INSTANCE, + Collections.singletonList(30L)); + restCatalogServer.setRowFilterAuth( + identifier, Collections.singletonList(bigintFilterPredicate)); + + List result2 = batchRead(table); + assertThat(result2).hasSize(2); + assertThat(result2) + .contains( + "+I[2, Bob, 30, 60000.0, false, 90.0]", + "+I[3, Charlie, 35, 70000.0, true, 95.5]"); + + // Test 3: Filter by DOUBLE type (salary > 55000.0) + LeafPredicate doubleFilterPredicate = + LeafPredicate.of( + new FieldTransform(new FieldRef(3, "salary", DataTypes.DOUBLE())), + GreaterThan.INSTANCE, + Collections.singletonList(55000.0)); + restCatalogServer.setRowFilterAuth( + identifier, Collections.singletonList(doubleFilterPredicate)); + + List result3 = batchRead(table); + assertThat(result3).hasSize(2); + assertThat(result3) + .contains( + "+I[2, Bob, 30, 60000.0, false, 90.0]", + "+I[3, Charlie, 35, 70000.0, true, 95.5]"); + + // Test 4: Filter by BOOLEAN type (is_active = true) + LeafPredicate booleanFilterPredicate = + LeafPredicate.of( + new FieldTransform(new FieldRef(4, "is_active", DataTypes.BOOLEAN())), + Equal.INSTANCE, + Collections.singletonList(true)); + restCatalogServer.setRowFilterAuth( + identifier, Collections.singletonList(booleanFilterPredicate)); + + List result4 = batchRead(table); + assertThat(result4).hasSize(3); + assertThat(result4) + .contains( + "+I[1, Alice, 25, 50000.0, true, 85.5]", + "+I[3, Charlie, 35, 70000.0, true, 95.5]", + "+I[4, David, 28, 55000.0, true, 88.0]"); + + // Test 5: Filter by FLOAT type (score >= 90.0) + LeafPredicate floatFilterPredicate = + LeafPredicate.of( + new FieldTransform(new FieldRef(5, "score", DataTypes.FLOAT())), + GreaterOrEqual.INSTANCE, + Collections.singletonList(90.0f)); + restCatalogServer.setRowFilterAuth( + identifier, Collections.singletonList(floatFilterPredicate)); + + List result5 = batchRead(table); + assertThat(result5).hasSize(2); + assertThat(result5) + .contains( + "+I[2, Bob, 30, 60000.0, false, 90.0]", + "+I[3, Charlie, 35, 70000.0, true, 95.5]"); + + // Test 6: Filter by STRING type (name = "Alice") + LeafPredicate stringFilterPredicate = + LeafPredicate.of( + new FieldTransform(new FieldRef(1, "name", DataTypes.STRING())), + Equal.INSTANCE, + Collections.singletonList(BinaryString.fromString("Alice"))); + restCatalogServer.setRowFilterAuth( + identifier, Collections.singletonList(stringFilterPredicate)); + + List result6 = batchRead(table); + assertThat(result6).hasSize(1); + assertThat(result6).contains("+I[1, Alice, 25, 50000.0, true, 85.5]"); + + // Test 7: Filter with two predicates (age >= 30 AND is_active = true) + LeafPredicate ageGe30Predicate = + LeafPredicate.of( + new FieldTransform(new FieldRef(2, "age", DataTypes.BIGINT())), + GreaterOrEqual.INSTANCE, + Collections.singletonList(30L)); + LeafPredicate isActiveTruePredicate = + LeafPredicate.of( + new FieldTransform(new FieldRef(4, "is_active", DataTypes.BOOLEAN())), + Equal.INSTANCE, + Collections.singletonList(true)); + restCatalogServer.setRowFilterAuth( + identifier, Arrays.asList(ageGe30Predicate, isActiveTruePredicate)); + + List result7 = batchRead(table); + assertThat(result7).hasSize(1); + assertThat(result7).contains("+I[3, Charlie, 35, 70000.0, true, 95.5]"); + + // Test 8: Filter with two predicates (salary > 55000.0 AND score >= 90.0) + LeafPredicate salaryGt55000Predicate = + LeafPredicate.of( + new FieldTransform(new FieldRef(3, "salary", DataTypes.DOUBLE())), + GreaterThan.INSTANCE, + Collections.singletonList(55000.0)); + LeafPredicate scoreGe90Predicate = + LeafPredicate.of( + new FieldTransform(new FieldRef(5, "score", DataTypes.FLOAT())), + GreaterOrEqual.INSTANCE, + Collections.singletonList(90.0f)); + restCatalogServer.setRowFilterAuth( + identifier, Arrays.asList(salaryGt55000Predicate, scoreGe90Predicate)); + + List result8 = batchRead(table); + assertThat(result8).hasSize(2); + assertThat(result8) + .contains( + "+I[2, Bob, 30, 60000.0, false, 90.0]", + "+I[3, Charlie, 35, 70000.0, true, 95.5]"); + } + + @Test + void testColumnMaskingAndRowFilter() throws Exception { + Identifier identifier = Identifier.create("test_table_db", "combined_auth_table"); + catalog.createDatabase(identifier.getDatabaseName(), true); + + // Create table with test data + List fields = new ArrayList<>(); + fields.add(new DataField(0, "id", DataTypes.INT())); + fields.add(new DataField(1, "name", DataTypes.STRING())); + fields.add(new DataField(2, "salary", DataTypes.STRING())); + fields.add(new DataField(3, "age", DataTypes.INT())); + fields.add(new DataField(4, "department", DataTypes.STRING())); + + catalog.createTable( + identifier, + new Schema( + fields, + Collections.emptyList(), + Collections.emptyList(), + Collections.singletonMap(QUERY_AUTH_ENABLED.key(), "true"), + ""), + true); + + Table table = catalog.getTable(identifier); + + // Write test data + BatchWriteBuilder writeBuilder = table.newBatchWriteBuilder(); + BatchTableWrite write = writeBuilder.newWrite(); + write.write( + GenericRow.of( + 1, + BinaryString.fromString("Alice"), + BinaryString.fromString("50000.0"), + 25, + BinaryString.fromString("IT"))); + write.write( + GenericRow.of( + 2, + BinaryString.fromString("Bob"), + BinaryString.fromString("60000.0"), + 30, + BinaryString.fromString("HR"))); + write.write( + GenericRow.of( + 3, + BinaryString.fromString("Charlie"), + BinaryString.fromString("70000.0"), + 35, + BinaryString.fromString("IT"))); + write.write( + GenericRow.of( + 4, + BinaryString.fromString("David"), + BinaryString.fromString("55000.0"), + 28, + BinaryString.fromString("Finance"))); + List messages = write.prepareCommit(); + BatchTableCommit commit = writeBuilder.newCommit(); + commit.commit(messages); + write.close(); + commit.close(); + + // Test column masking only + Transform salaryMaskTransform = + new ConcatTransform(Collections.singletonList(BinaryString.fromString("***"))); + Map columnMasking = new HashMap<>(); + columnMasking.put("salary", salaryMaskTransform); + restCatalogServer.setColumnMaskingAuth(identifier, columnMasking); + + ReadBuilder readBuilder = table.newReadBuilder(); + List splits = readBuilder.newScan().plan().splits(); + TableRead read = readBuilder.newRead(); + RecordReader reader = read.createReader(splits); + + List rows = new ArrayList<>(); + reader.forEachRemaining(rows::add); + assertThat(rows).hasSize(4); + assertThat(rows.get(0).getString(2).toString()).isEqualTo("***"); + + // Test row filter only (clear column masking first) + restCatalogServer.setColumnMaskingAuth(identifier, new HashMap<>()); + Predicate ageGe30Predicate = + LeafPredicate.of( + new FieldTransform(new FieldRef(3, "age", DataTypes.INT())), + GreaterOrEqual.INSTANCE, + Collections.singletonList(30)); + restCatalogServer.setRowFilterAuth(identifier, Collections.singletonList(ageGe30Predicate)); + + readBuilder = table.newReadBuilder(); + splits = readBuilder.newScan().plan().splits(); + read = readBuilder.newRead(); + reader = read.createReader(splits); + + rows = new ArrayList<>(); + reader.forEachRemaining(rows::add); + assertThat(rows).hasSize(2); + + // Test both column masking and row filter together + columnMasking.put("salary", salaryMaskTransform); + Transform nameMaskTransform = + new ConcatTransform(Collections.singletonList(BinaryString.fromString("***"))); + columnMasking.put("name", nameMaskTransform); + restCatalogServer.setColumnMaskingAuth(identifier, columnMasking); + Predicate deptPredicate = + LeafPredicate.of( + new FieldTransform(new FieldRef(4, "department", DataTypes.STRING())), + Equal.INSTANCE, + Collections.singletonList(BinaryString.fromString("IT"))); + restCatalogServer.setRowFilterAuth(identifier, Collections.singletonList(deptPredicate)); + + readBuilder = table.newReadBuilder(); + splits = readBuilder.newScan().plan().splits(); + read = readBuilder.newRead(); + reader = read.createReader(splits); + + rows = new ArrayList<>(); + reader.forEachRemaining(rows::add); + assertThat(rows).hasSize(2); + assertThat(rows.get(0).getString(1).toString()).isEqualTo("***"); // name masked + assertThat(rows.get(0).getString(2).toString()).isEqualTo("***"); // salary masked + assertThat(rows.get(0).getString(4).toString()).isEqualTo("IT"); // department not masked + + // Test complex scenario: row filter + column masking combined + Predicate combinedPredicate = PredicateBuilder.and(ageGe30Predicate, deptPredicate); + restCatalogServer.setRowFilterAuth( + identifier, Collections.singletonList(combinedPredicate)); + + readBuilder = table.newReadBuilder(); + splits = readBuilder.newScan().plan().splits(); + read = readBuilder.newRead(); + reader = read.createReader(splits); + + rows = new ArrayList<>(); + reader.forEachRemaining(rows::add); + assertThat(rows).hasSize(1); + assertThat(rows.get(0).getInt(0)).isEqualTo(3); // id + assertThat(rows.get(0).getString(1).toString()).isEqualTo("***"); // name masked + assertThat(rows.get(0).getString(2).toString()).isEqualTo("***"); // salary masked + assertThat(rows.get(0).getInt(3)).isEqualTo(35); // age not masked + + // Clear both column masking and row filter + restCatalogServer.setColumnMaskingAuth(identifier, new HashMap<>()); + restCatalogServer.setRowFilterAuth(identifier, null); + + readBuilder = table.newReadBuilder(); + splits = readBuilder.newScan().plan().splits(); + read = readBuilder.newRead(); + reader = read.createReader(splits); + + rows = new ArrayList<>(); + reader.forEachRemaining(rows::add); + assertThat(rows).hasSize(4); + assertThat(rows.get(0).getString(1).toString()).isIn("Alice", "Bob", "Charlie", "David"); + } + private void checkHeader(String headerName, String headerValue) { // Verify that the header were included in the requests List> receivedHeaders = restCatalogServer.getReceivedHeaders(); diff --git a/paimon-core/src/test/java/org/apache/paimon/rest/RESTCatalogServer.java b/paimon-core/src/test/java/org/apache/paimon/rest/RESTCatalogServer.java index 8716d4ea7adf..6420bc956d74 100644 --- a/paimon-core/src/test/java/org/apache/paimon/rest/RESTCatalogServer.java +++ b/paimon-core/src/test/java/org/apache/paimon/rest/RESTCatalogServer.java @@ -41,6 +41,8 @@ import org.apache.paimon.partition.Partition; import org.apache.paimon.partition.PartitionStatistics; import org.apache.paimon.partition.PartitionUtils; +import org.apache.paimon.predicate.Predicate; +import org.apache.paimon.predicate.Transform; import org.apache.paimon.rest.auth.AuthProvider; import org.apache.paimon.rest.auth.RESTAuthParameter; import org.apache.paimon.rest.requests.AlterDatabaseRequest; @@ -97,6 +99,7 @@ import org.apache.paimon.table.object.ObjectTable; import org.apache.paimon.tag.Tag; import org.apache.paimon.utils.BranchManager; +import org.apache.paimon.utils.JsonSerdeUtil; import org.apache.paimon.utils.LazyField; import org.apache.paimon.utils.Pair; import org.apache.paimon.utils.SnapshotManager; @@ -185,8 +188,8 @@ public class RESTCatalogServer { private final List noPermissionTables = new ArrayList<>(); private final Map functionStore = new HashMap<>(); private final Map> columnAuthHandler = new HashMap<>(); - private final Map tableQueryAuthResponseHandler = - new HashMap<>(); + private final Map> rowFilterAuthHandler = new HashMap<>(); + private final Map> columnMaskingAuthHandler = new HashMap<>(); public final ConfigResponse configResponse; public final String warehouse; @@ -268,8 +271,12 @@ public void addTableColumnAuth(Identifier identifier, List select) { columnAuthHandler.put(identifier.getFullName(), select); } - public void setTableQueryAuthResponse(Identifier identifier, AuthTableQueryResponse response) { - tableQueryAuthResponseHandler.put(identifier.getFullName(), response); + public void setRowFilterAuth(Identifier identifier, List rowFilters) { + rowFilterAuthHandler.put(identifier.getFullName(), rowFilters); + } + + public void setColumnMaskingAuth(Identifier identifier, Map columnMasking) { + columnMaskingAuthHandler.put(identifier.getFullName(), columnMasking); } public RESTToken getDataToken(Identifier identifier) { @@ -835,8 +842,30 @@ private MockResponse authTable(Identifier identifier, String data) throws Except } }); } + List rowFilters = rowFilterAuthHandler.get(identifier.getFullName()); + Map columnMasking = + columnMaskingAuthHandler.get(identifier.getFullName()); + + // Convert Predicate list to JSON string list + List filterJsonList = null; + if (rowFilters != null) { + filterJsonList = + rowFilters.stream().map(JsonSerdeUtil::toFlatJson).collect(Collectors.toList()); + } + + // Convert Transform map to JSON string map + Map columnMaskingJsonMap = null; + if (columnMasking != null) { + columnMaskingJsonMap = + columnMasking.entrySet().stream() + .collect( + Collectors.toMap( + Map.Entry::getKey, + entry -> JsonSerdeUtil.toFlatJson(entry.getValue()))); + } + AuthTableQueryResponse response = - tableQueryAuthResponseHandler.get(identifier.getFullName()); + new AuthTableQueryResponse(filterJsonList, columnMaskingJsonMap); if (response == null) { response = new AuthTableQueryResponse(Collections.emptyList(), ImmutableMap.of()); } diff --git a/paimon-core/src/test/java/org/apache/paimon/rest/RESTCatalogTest.java b/paimon-core/src/test/java/org/apache/paimon/rest/RESTCatalogTest.java index 288c925ceb23..bc0a74bf720c 100644 --- a/paimon-core/src/test/java/org/apache/paimon/rest/RESTCatalogTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/rest/RESTCatalogTest.java @@ -3030,11 +3030,27 @@ protected List batchRead(Table table) throws IOException { TableRead read = readBuilder.newRead(); RecordReader reader = read.createReader(splits); List result = new ArrayList<>(); + + // Create field getters for each column + InternalRow.FieldGetter[] fieldGetters = + new InternalRow.FieldGetter[table.rowType().getFieldCount()]; + for (int i = 0; i < table.rowType().getFieldCount(); i++) { + fieldGetters[i] = InternalRow.createFieldGetter(table.rowType().getTypeAt(i), i); + } + reader.forEachRemaining( row -> { - String rowStr = - String.format("%s[%d]", row.getRowKind().shortString(), row.getInt(0)); - result.add(rowStr); + StringBuilder sb = new StringBuilder(); + sb.append(row.getRowKind().shortString()).append("["); + for (int i = 0; i < row.getFieldCount(); i++) { + if (i > 0) { + sb.append(", "); + } + Object value = fieldGetters[i].getFieldOrNull(row); + sb.append(value); + } + sb.append("]"); + result.add(sb.toString()); }); return result; } diff --git a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/RESTCatalogITCase.java b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/RESTCatalogITCase.java index 034816d6a0fc..48fefa8e245f 100644 --- a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/RESTCatalogITCase.java +++ b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/RESTCatalogITCase.java @@ -19,8 +19,23 @@ package org.apache.paimon.flink; import org.apache.paimon.catalog.Identifier; +import org.apache.paimon.data.BinaryString; import org.apache.paimon.partition.Partition; +import org.apache.paimon.predicate.ConcatTransform; +import org.apache.paimon.predicate.ConcatWsTransform; +import org.apache.paimon.predicate.Equal; +import org.apache.paimon.predicate.FieldRef; +import org.apache.paimon.predicate.FieldTransform; +import org.apache.paimon.predicate.GreaterOrEqual; +import org.apache.paimon.predicate.GreaterThan; +import org.apache.paimon.predicate.LeafPredicate; +import org.apache.paimon.predicate.LessThan; +import org.apache.paimon.predicate.Predicate; +import org.apache.paimon.predicate.PredicateBuilder; +import org.apache.paimon.predicate.Transform; +import org.apache.paimon.predicate.UpperTransform; import org.apache.paimon.rest.RESTToken; +import org.apache.paimon.types.DataTypes; import org.apache.paimon.shade.guava30.com.google.common.collect.ImmutableList; import org.apache.paimon.shade.guava30.com.google.common.collect.ImmutableMap; @@ -35,7 +50,11 @@ import org.apache.flink.types.Row; import org.junit.jupiter.api.Test; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.UUID; import static org.assertj.core.api.Assertions.assertThat; @@ -82,7 +101,6 @@ public void testWriteAndRead() { .containsExactlyInAnyOrder(Row.of("1", 11.0D), Row.of("2", 22.0D)); } - @Test public void testExpiredDataToken() { Identifier identifier = Identifier.create(DATABASE_NAME, TABLE_NAME); RESTToken expiredDataToken = @@ -199,4 +217,362 @@ private void validateTotalBuckets( assertThat(partitions.get(0).totalBuckets()).isEqualTo(expectedTotalBuckets); } } + + public void testColumnMasking() { + String maskingTable = "column_masking_table"; + batchSql( + String.format( + "CREATE TABLE %s.%s (id INT, secret STRING, email STRING, phone STRING, salary STRING) WITH ('query-auth.enabled' = 'true')", + DATABASE_NAME, maskingTable)); + batchSql( + String.format( + "INSERT INTO %s.%s VALUES (1, 's1', 'user1@example.com', '12345678901', '50000.0'), (2, 's2', 'user2@example.com', '12345678902', '60000.0')", + DATABASE_NAME, maskingTable)); + + // Test single column masking + Transform maskTransform = + new ConcatTransform(Collections.singletonList(BinaryString.fromString("****"))); + Map columnMasking = new HashMap<>(); + columnMasking.put("secret", maskTransform); + restCatalogServer.setColumnMaskingAuth( + Identifier.create(DATABASE_NAME, maskingTable), columnMasking); + + assertThat(batchSql(String.format("SELECT secret FROM %s.%s", DATABASE_NAME, maskingTable))) + .containsExactlyInAnyOrder(Row.of("****"), Row.of("****")); + assertThat(batchSql(String.format("SELECT id FROM %s.%s", DATABASE_NAME, maskingTable))) + .containsExactlyInAnyOrder(Row.of(1), Row.of(2)); + + // Test multiple columns masking + Transform emailMaskTransform = + new ConcatTransform( + Collections.singletonList(BinaryString.fromString("***@***.com"))); + Transform phoneMaskTransform = + new ConcatTransform( + Collections.singletonList(BinaryString.fromString("***********"))); + Transform salaryMaskTransform = + new ConcatTransform(Collections.singletonList(BinaryString.fromString("0.0"))); + + columnMasking.put("email", emailMaskTransform); + columnMasking.put("phone", phoneMaskTransform); + columnMasking.put("salary", salaryMaskTransform); + restCatalogServer.setColumnMaskingAuth( + Identifier.create(DATABASE_NAME, maskingTable), columnMasking); + + assertThat(batchSql(String.format("SELECT email FROM %s.%s", DATABASE_NAME, maskingTable))) + .containsExactlyInAnyOrder(Row.of("***@***.com"), Row.of("***@***.com")); + assertThat(batchSql(String.format("SELECT phone FROM %s.%s", DATABASE_NAME, maskingTable))) + .containsExactlyInAnyOrder(Row.of("***********"), Row.of("***********")); + assertThat(batchSql(String.format("SELECT salary FROM %s.%s", DATABASE_NAME, maskingTable))) + .containsExactlyInAnyOrder(Row.of("0.0"), Row.of("0.0")); + + // Test SELECT * with column masking + List allRows = + batchSql( + String.format( + "SELECT * FROM %s.%s ORDER BY id", DATABASE_NAME, maskingTable)); + assertThat(allRows.size()).isEqualTo(2); + assertThat(allRows.get(0).getField(1)).isEqualTo("****"); + assertThat(allRows.get(0).getField(2)).isEqualTo("***@***.com"); + assertThat(allRows.get(0).getField(3)).isEqualTo("***********"); + assertThat(allRows.get(0).getField(4)).isEqualTo("0.0"); + + // Test WHERE clause with masked column + assertThat( + batchSql( + String.format( + "SELECT id FROM %s.%s WHERE id = 1", + DATABASE_NAME, maskingTable))) + .containsExactlyInAnyOrder(Row.of(1)); + + // Test aggregation with masked columns + assertThat( + batchSql( + String.format( + "SELECT COUNT(*) FROM %s.%s", DATABASE_NAME, maskingTable))) + .containsExactlyInAnyOrder(Row.of(2L)); + + // Test JOIN with masked columns + String joinTable = "join_table"; + batchSql( + String.format( + "CREATE TABLE %s.%s (id INT, name STRING)", DATABASE_NAME, joinTable)); + batchSql( + String.format( + "INSERT INTO %s.%s VALUES (1, 'Alice'), (2, 'Bob')", + DATABASE_NAME, joinTable)); + + List joinResult = + batchSql( + String.format( + "SELECT t1.id, t1.secret, t2.name FROM %s.%s t1 JOIN %s.%s t2 ON t1.id = t2.id ORDER BY t1.id", + DATABASE_NAME, maskingTable, DATABASE_NAME, joinTable)); + assertThat(joinResult.size()).isEqualTo(2); + assertThat(joinResult.get(0).getField(1)).isEqualTo("****"); + assertThat(joinResult.get(0).getField(2)).isEqualTo("Alice"); + + // Test UpperTransform + Transform upperTransform = + new UpperTransform( + Collections.singletonList(new FieldRef(1, "secret", DataTypes.STRING()))); + columnMasking.clear(); + columnMasking.put("secret", upperTransform); + restCatalogServer.setColumnMaskingAuth( + Identifier.create(DATABASE_NAME, maskingTable), columnMasking); + + assertThat( + batchSql( + String.format( + "SELECT secret FROM %s.%s ORDER BY id", + DATABASE_NAME, maskingTable))) + .containsExactlyInAnyOrder(Row.of("S1"), Row.of("S2")); + + // Test ConcatWsTransform + Transform concatWsTransform = + new ConcatWsTransform( + Arrays.asList( + BinaryString.fromString("-"), + new FieldRef(1, "secret", DataTypes.STRING()), + BinaryString.fromString("masked"))); + columnMasking.clear(); + columnMasking.put("secret", concatWsTransform); + restCatalogServer.setColumnMaskingAuth( + Identifier.create(DATABASE_NAME, maskingTable), columnMasking); + + assertThat( + batchSql( + String.format( + "SELECT secret FROM %s.%s ORDER BY id", + DATABASE_NAME, maskingTable))) + .containsExactlyInAnyOrder(Row.of("s1-masked"), Row.of("s2-masked")); + + // Clear masking and verify original data + restCatalogServer.setColumnMaskingAuth( + Identifier.create(DATABASE_NAME, maskingTable), new HashMap<>()); + assertThat( + batchSql( + String.format( + "SELECT secret FROM %s.%s ORDER BY id", + DATABASE_NAME, maskingTable))) + .containsExactlyInAnyOrder(Row.of("s1"), Row.of("s2")); + assertThat( + batchSql( + String.format( + "SELECT email FROM %s.%s ORDER BY id", + DATABASE_NAME, maskingTable))) + .containsExactlyInAnyOrder( + Row.of("user1@example.com"), Row.of("user2@example.com")); + } + + @Test + public void testRowFilter() { + String filterTable = "row_filter_table"; + batchSql( + String.format( + "CREATE TABLE %s.%s (id INT, name STRING, age INT, department STRING) WITH ('query-auth.enabled' = 'true')", + DATABASE_NAME, filterTable)); + batchSql( + String.format( + "INSERT INTO %s.%s VALUES (1, 'Alice', 25, 'IT'), (2, 'Bob', 30, 'HR'), (3, 'Charlie', 35, 'IT'), (4, 'David', 28, 'Finance')", + DATABASE_NAME, filterTable)); + + // Test single condition row filter (age > 28) + Predicate agePredicate = + LeafPredicate.of( + new FieldTransform(new FieldRef(2, "age", DataTypes.INT())), + GreaterThan.INSTANCE, + Collections.singletonList(28)); + restCatalogServer.setRowFilterAuth( + Identifier.create(DATABASE_NAME, filterTable), + Collections.singletonList(agePredicate)); + + assertThat( + batchSql( + String.format( + "SELECT * FROM %s.%s ORDER BY id", + DATABASE_NAME, filterTable))) + .containsExactlyInAnyOrder( + Row.of(2, "Bob", 30, "HR"), Row.of(3, "Charlie", 35, "IT")); + + // Test string condition row filter (department = 'IT') + Predicate deptPredicate = + LeafPredicate.of( + new FieldTransform(new FieldRef(3, "department", DataTypes.STRING())), + Equal.INSTANCE, + Collections.singletonList(BinaryString.fromString("IT"))); + restCatalogServer.setRowFilterAuth( + Identifier.create(DATABASE_NAME, filterTable), + Collections.singletonList(deptPredicate)); + + assertThat( + batchSql( + String.format( + "SELECT * FROM %s.%s ORDER BY id", + DATABASE_NAME, filterTable))) + .containsExactlyInAnyOrder( + Row.of(1, "Alice", 25, "IT"), Row.of(3, "Charlie", 35, "IT")); + + // Test combined conditions (age >= 30 AND department = 'IT') + Predicate ageGePredicate = + LeafPredicate.of( + new FieldTransform(new FieldRef(2, "age", DataTypes.INT())), + GreaterOrEqual.INSTANCE, + Collections.singletonList(30)); + Predicate combinedPredicate = PredicateBuilder.and(ageGePredicate, deptPredicate); + restCatalogServer.setRowFilterAuth( + Identifier.create(DATABASE_NAME, filterTable), + Collections.singletonList(combinedPredicate)); + + assertThat( + batchSql( + String.format( + "SELECT * FROM %s.%s ORDER BY id", + DATABASE_NAME, filterTable))) + .containsExactlyInAnyOrder(Row.of(3, "Charlie", 35, "IT")); + + // Test OR condition (age < 27 OR department = 'Finance') + Predicate ageLtPredicate = + LeafPredicate.of( + new FieldTransform(new FieldRef(2, "age", DataTypes.INT())), + LessThan.INSTANCE, + Collections.singletonList(27)); + Predicate financePredicate = + LeafPredicate.of( + new FieldTransform(new FieldRef(3, "department", DataTypes.STRING())), + Equal.INSTANCE, + Collections.singletonList(BinaryString.fromString("Finance"))); + Predicate orPredicate = PredicateBuilder.or(ageLtPredicate, financePredicate); + restCatalogServer.setRowFilterAuth( + Identifier.create(DATABASE_NAME, filterTable), + Collections.singletonList(orPredicate)); + + assertThat( + batchSql( + String.format( + "SELECT * FROM %s.%s ORDER BY id", + DATABASE_NAME, filterTable))) + .containsExactlyInAnyOrder( + Row.of(1, "Alice", 25, "IT"), Row.of(4, "David", 28, "Finance")); + + // Test WHERE clause combined with row filter + Predicate ageGt25Predicate = + LeafPredicate.of( + new FieldTransform(new FieldRef(2, "age", DataTypes.INT())), + GreaterThan.INSTANCE, + Collections.singletonList(25)); + restCatalogServer.setRowFilterAuth( + Identifier.create(DATABASE_NAME, filterTable), + Collections.singletonList(ageGt25Predicate)); + + assertThat( + batchSql( + String.format( + "SELECT * FROM %s.%s WHERE department = 'IT' ORDER BY id", + DATABASE_NAME, filterTable))) + .containsExactlyInAnyOrder(Row.of(3, "Charlie", 35, "IT")); + + // Test JOIN with row filter + String joinTable = "join_table"; + batchSql( + String.format( + "CREATE TABLE %s.%s (id INT, salary DOUBLE)", DATABASE_NAME, joinTable)); + batchSql( + String.format( + "INSERT INTO %s.%s VALUES (1, 50000.0), (2, 60000.0), (3, 70000.0), (4, 55000.0)", + DATABASE_NAME, joinTable)); + + Predicate ageGe30Predicate = + LeafPredicate.of( + new FieldTransform(new FieldRef(2, "age", DataTypes.INT())), + GreaterOrEqual.INSTANCE, + Collections.singletonList(30)); + restCatalogServer.setRowFilterAuth( + Identifier.create(DATABASE_NAME, filterTable), + Collections.singletonList(ageGe30Predicate)); + + List joinResult = + batchSql( + String.format( + "SELECT t1.id, t1.name, t1.age, t2.salary FROM %s.%s t1 JOIN %s.%s t2 ON t1.id = t2.id ORDER BY t1.id", + DATABASE_NAME, filterTable, DATABASE_NAME, joinTable)); + assertThat(joinResult.size()).isEqualTo(2); + assertThat(joinResult.get(0)).isEqualTo(Row.of(2, "Bob", 30, 60000.0)); + assertThat(joinResult.get(1)).isEqualTo(Row.of(3, "Charlie", 35, 70000.0)); + + // Clear row filter and verify original data + restCatalogServer.setRowFilterAuth(Identifier.create(DATABASE_NAME, filterTable), null); + + assertThat( + batchSql( + String.format( + "SELECT COUNT(*) FROM %s.%s", DATABASE_NAME, filterTable))) + .containsExactlyInAnyOrder(Row.of(4L)); + } + + @Test + public void testColumnMaskingAndRowFilter() { + String combinedTable = "combined_auth_table"; + batchSql( + String.format( + "CREATE TABLE %s.%s (id INT, name STRING, salary STRING, age INT, department STRING) WITH ('query-auth.enabled' = 'true')", + DATABASE_NAME, combinedTable)); + batchSql( + String.format( + "INSERT INTO %s.%s VALUES (1, 'Alice', '50000.0', 25, 'IT'), (2, 'Bob', '60000.0', 30, 'HR'), (3, 'Charlie', '70000.0', 35, 'IT'), (4, 'David', '55000.0', 28, 'Finance')", + DATABASE_NAME, combinedTable)); + Transform salaryMaskTransform = + new ConcatTransform(Collections.singletonList(BinaryString.fromString("***"))); + Map columnMasking = new HashMap<>(); + columnMasking.put("salary", salaryMaskTransform); + Transform nameMaskTransform = + new ConcatTransform(Collections.singletonList(BinaryString.fromString("***"))); + columnMasking.put("name", nameMaskTransform); + Predicate deptPredicate = + LeafPredicate.of( + new FieldTransform(new FieldRef(4, "department", DataTypes.STRING())), + Equal.INSTANCE, + Collections.singletonList(BinaryString.fromString("IT"))); + restCatalogServer.setColumnMaskingAuth( + Identifier.create(DATABASE_NAME, combinedTable), columnMasking); + restCatalogServer.setRowFilterAuth( + Identifier.create(DATABASE_NAME, combinedTable), + Collections.singletonList(deptPredicate)); + + // Test both column masking and row filter together + List combinedResult = + batchSql( + String.format( + "SELECT * FROM %s.%s ORDER BY id", DATABASE_NAME, combinedTable)); + assertThat(combinedResult.size()).isEqualTo(2); + assertThat(combinedResult.get(0).getField(0)).isEqualTo(1); // id + assertThat(combinedResult.get(0).getField(1)).isEqualTo("***"); // name masked + assertThat(combinedResult.get(0).getField(2)).isEqualTo("***"); // salary masked + assertThat(combinedResult.get(0).getField(3)).isEqualTo(25); // age not masked + assertThat(combinedResult.get(0).getField(4)).isEqualTo("IT"); // department not masked + + // Test WHERE clause with both features + assertThat( + batchSql( + String.format( + "SELECT id, name FROM %s.%s WHERE age > 30 ORDER BY id", + DATABASE_NAME, combinedTable))) + .containsExactlyInAnyOrder(Row.of(3, "***")); + + // Clear both column masking and row filter + restCatalogServer.setColumnMaskingAuth( + Identifier.create(DATABASE_NAME, combinedTable), new HashMap<>()); + restCatalogServer.setRowFilterAuth(Identifier.create(DATABASE_NAME, combinedTable), null); + + assertThat( + batchSql( + String.format( + "SELECT COUNT(*) FROM %s.%s", + DATABASE_NAME, combinedTable))) + .containsExactlyInAnyOrder(Row.of(4L)); + assertThat( + batchSql( + String.format( + "SELECT name FROM %s.%s WHERE id = 1", + DATABASE_NAME, combinedTable))) + .containsExactlyInAnyOrder(Row.of("Alice")); + } } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/scan/BinPackingSplits.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/scan/BinPackingSplits.scala index 679234bc7492..8e635d404cf6 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/scan/BinPackingSplits.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/scan/BinPackingSplits.scala @@ -25,7 +25,7 @@ import org.apache.paimon.spark.PaimonInputPartition import org.apache.paimon.spark.util.SplitUtils import org.apache.paimon.table.FallbackReadFileStoreTable.FallbackSplit import org.apache.paimon.table.format.FormatDataSplit -import org.apache.paimon.table.source.{DataSplit, DeletionFile, Split} +import org.apache.paimon.table.source.{DataSplit, DeletionFile, QueryAuthSplit, Split} import org.apache.spark.internal.Logging import org.apache.spark.sql.PaimonSparkSession @@ -160,19 +160,32 @@ case class BinPackingSplits(coreOptions: CoreOptions, readRowSizeRatio: Double = split: DataSplit, dataFiles: Seq[DataFileMeta], deletionFiles: Seq[DeletionFile]): DataSplit = { + val (actualSplit, authResult) = split match { + case queryAuthSplit: QueryAuthSplit => + (queryAuthSplit.dataSplit(), queryAuthSplit.authResult()) + case _ => + (split, null) + } + val builder = DataSplit .builder() - .withSnapshot(split.snapshotId()) - .withPartition(split.partition()) - .withBucket(split.bucket()) - .withTotalBuckets(split.totalBuckets()) + .withSnapshot(actualSplit.snapshotId()) + .withPartition(actualSplit.partition()) + .withBucket(actualSplit.bucket()) + .withTotalBuckets(actualSplit.totalBuckets()) .withDataFiles(dataFiles.toList.asJava) - .rawConvertible(split.rawConvertible) - .withBucketPath(split.bucketPath) + .rawConvertible(actualSplit.rawConvertible) + .withBucketPath(actualSplit.bucketPath) if (deletionVectors) { builder.withDataDeletionFiles(deletionFiles.toList.asJava) } - builder.build() + val newDataSplit = builder.build() + + if (authResult != null) { + new QueryAuthSplit(newDataSplit, authResult) + } else { + newDataSplit + } } private def withSamePartitionAndBucket(split1: DataSplit, split2: DataSplit): Boolean = { diff --git a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkCatalogWithRestTest.java b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkCatalogWithRestTest.java index ee8978c68767..bbae01af9500 100644 --- a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkCatalogWithRestTest.java +++ b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkCatalogWithRestTest.java @@ -20,11 +20,25 @@ import org.apache.paimon.catalog.Catalog; import org.apache.paimon.catalog.Identifier; +import org.apache.paimon.data.BinaryString; import org.apache.paimon.function.Function; import org.apache.paimon.function.FunctionChange; import org.apache.paimon.function.FunctionDefinition; import org.apache.paimon.function.FunctionImpl; import org.apache.paimon.options.CatalogOptions; +import org.apache.paimon.predicate.ConcatTransform; +import org.apache.paimon.predicate.ConcatWsTransform; +import org.apache.paimon.predicate.Equal; +import org.apache.paimon.predicate.FieldRef; +import org.apache.paimon.predicate.FieldTransform; +import org.apache.paimon.predicate.GreaterOrEqual; +import org.apache.paimon.predicate.GreaterThan; +import org.apache.paimon.predicate.LeafPredicate; +import org.apache.paimon.predicate.LessThan; +import org.apache.paimon.predicate.Predicate; +import org.apache.paimon.predicate.PredicateBuilder; +import org.apache.paimon.predicate.Transform; +import org.apache.paimon.predicate.UpperTransform; import org.apache.paimon.rest.RESTCatalogInternalOptions; import org.apache.paimon.rest.RESTCatalogServer; import org.apache.paimon.rest.auth.AuthProvider; @@ -38,6 +52,7 @@ import org.apache.paimon.shade.guava30.com.google.common.collect.ImmutableList; import org.apache.paimon.shade.guava30.com.google.common.collect.ImmutableMap; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.connector.catalog.CatalogManager; import org.junit.jupiter.api.AfterEach; @@ -48,7 +63,11 @@ import java.io.File; import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.UUID; import static org.assertj.core.api.Assertions.assertThat; @@ -237,6 +256,279 @@ public void testMapFunction() throws Exception { cleanFunction(functionName); } + @Test + public void testColumnMasking() { + spark.sql( + "CREATE TABLE t_column_masking (id INT, secret STRING, email STRING, phone STRING) TBLPROPERTIES" + + " ('query-auth.enabled'='true')"); + spark.sql( + "INSERT INTO t_column_masking VALUES (1, 's1', 'user1@example.com', '12345678901'), (2, 's2', 'user2@example.com', '12345678902')"); + + // Test single column masking + Transform maskTransform = + new ConcatTransform(Collections.singletonList(BinaryString.fromString("****"))); + + Map columnMasking = new HashMap<>(); + columnMasking.put("secret", maskTransform); + restCatalogServer.setColumnMaskingAuth( + Identifier.create("db2", "t_column_masking"), columnMasking); + + assertThat(spark.sql("SELECT secret FROM t_column_masking").collectAsList().toString()) + .isEqualTo("[[****], [****]]"); + assertThat(spark.sql("SELECT id FROM t_column_masking").collectAsList().toString()) + .isEqualTo("[[1], [2]]"); + + // Test multiple columns masking + Transform emailMaskTransform = + new ConcatTransform( + Collections.singletonList(BinaryString.fromString("***@***.com"))); + Transform phoneMaskTransform = + new ConcatTransform( + Collections.singletonList(BinaryString.fromString("***********"))); + + columnMasking.put("email", emailMaskTransform); + columnMasking.put("phone", phoneMaskTransform); + restCatalogServer.setColumnMaskingAuth( + Identifier.create("db2", "t_column_masking"), columnMasking); + + assertThat(spark.sql("SELECT email FROM t_column_masking").collectAsList().toString()) + .isEqualTo("[[***@***.com], [***@***.com]]"); + assertThat(spark.sql("SELECT phone FROM t_column_masking").collectAsList().toString()) + .isEqualTo("[[***********], [***********]]"); + + // Test SELECT * with column masking + List allRows = spark.sql("SELECT * FROM t_column_masking").collectAsList(); + assertThat(allRows.size()).isEqualTo(2); + assertThat(allRows.get(0).getString(1)).isEqualTo("****"); + assertThat(allRows.get(0).getString(2)).isEqualTo("***@***.com"); + assertThat(allRows.get(0).getString(3)).isEqualTo("***********"); + + // Test WHERE clause with masked column + assertThat( + spark.sql("SELECT id FROM t_column_masking WHERE id = 1") + .collectAsList() + .toString()) + .isEqualTo("[[1]]"); + + // Test aggregation with masked columns + assertThat(spark.sql("SELECT COUNT(*) FROM t_column_masking").collectAsList().toString()) + .isEqualTo("[[2]]"); + + // Test UpperTransform + Transform upperTransform = + new UpperTransform( + Collections.singletonList(new FieldRef(1, "secret", DataTypes.STRING()))); + columnMasking.clear(); + columnMasking.put("secret", upperTransform); + restCatalogServer.setColumnMaskingAuth( + Identifier.create("db2", "t_column_masking"), columnMasking); + + assertThat( + spark.sql("SELECT secret FROM t_column_masking ORDER BY id") + .collectAsList() + .toString()) + .isEqualTo("[[S1], [S2]]"); + + // Test ConcatWsTransform + Transform concatWsTransform = + new ConcatWsTransform( + Arrays.asList( + BinaryString.fromString("-"), + new FieldRef(1, "secret", DataTypes.STRING()), + BinaryString.fromString("masked"))); + columnMasking.clear(); + columnMasking.put("secret", concatWsTransform); + restCatalogServer.setColumnMaskingAuth( + Identifier.create("db2", "t_column_masking"), columnMasking); + + assertThat( + spark.sql("SELECT secret FROM t_column_masking ORDER BY id") + .collectAsList() + .toString()) + .isEqualTo("[[s1-masked], [s2-masked]]"); + + // Clear masking and verify original data + restCatalogServer.setColumnMaskingAuth( + Identifier.create("db2", "t_column_masking"), new HashMap<>()); + assertThat( + spark.sql("SELECT secret FROM t_column_masking ORDER BY id") + .collectAsList() + .toString()) + .isEqualTo("[[s1], [s2]]"); + assertThat( + spark.sql("SELECT email FROM t_column_masking ORDER BY id") + .collectAsList() + .toString()) + .isEqualTo("[[user1@example.com], [user2@example.com]]"); + } + + @Test + public void testRowFilter() { + spark.sql( + "CREATE TABLE t_row_filter (id INT, name STRING, age INT, department STRING) TBLPROPERTIES" + + " ('query-auth.enabled'='true')"); + spark.sql( + "INSERT INTO t_row_filter VALUES (1, 'Alice', 25, 'IT'), (2, 'Bob', 30, 'HR'), (3, 'Charlie', 35, 'IT'), (4, 'David', 28, 'Finance')"); + + // Test single condition row filter (age > 28) + Predicate agePredicate = + LeafPredicate.of( + new FieldTransform(new FieldRef(2, "age", DataTypes.INT())), + GreaterThan.INSTANCE, + Collections.singletonList(28)); + restCatalogServer.setRowFilterAuth( + Identifier.create("db2", "t_row_filter"), Collections.singletonList(agePredicate)); + + assertThat(spark.sql("SELECT * FROM t_row_filter ORDER BY id").collectAsList().toString()) + .isEqualTo("[[2,Bob,30,HR], [3,Charlie,35,IT]]"); + + // Test string condition row filter (department = 'IT') + Predicate deptPredicate = + LeafPredicate.of( + new FieldTransform(new FieldRef(3, "department", DataTypes.STRING())), + Equal.INSTANCE, + Collections.singletonList(BinaryString.fromString("IT"))); + restCatalogServer.setRowFilterAuth( + Identifier.create("db2", "t_row_filter"), Collections.singletonList(deptPredicate)); + + assertThat(spark.sql("SELECT * FROM t_row_filter ORDER BY id").collectAsList().toString()) + .isEqualTo("[[1,Alice,25,IT], [3,Charlie,35,IT]]"); + + // Test combined conditions (age >= 30 AND department = 'IT') + Predicate ageGePredicate = + LeafPredicate.of( + new FieldTransform(new FieldRef(2, "age", DataTypes.INT())), + GreaterOrEqual.INSTANCE, + Collections.singletonList(30)); + Predicate combinedPredicate = PredicateBuilder.and(ageGePredicate, deptPredicate); + restCatalogServer.setRowFilterAuth( + Identifier.create("db2", "t_row_filter"), + Collections.singletonList(combinedPredicate)); + + assertThat(spark.sql("SELECT * FROM t_row_filter ORDER BY id").collectAsList().toString()) + .isEqualTo("[[3,Charlie,35,IT]]"); + + // Test OR condition (age < 27 OR department = 'Finance') + Predicate ageLtPredicate = + LeafPredicate.of( + new FieldTransform(new FieldRef(2, "age", DataTypes.INT())), + LessThan.INSTANCE, + Collections.singletonList(27)); + Predicate financePredicate = + LeafPredicate.of( + new FieldTransform(new FieldRef(3, "department", DataTypes.STRING())), + Equal.INSTANCE, + Collections.singletonList(BinaryString.fromString("Finance"))); + Predicate orPredicate = PredicateBuilder.or(ageLtPredicate, financePredicate); + restCatalogServer.setRowFilterAuth( + Identifier.create("db2", "t_row_filter"), Collections.singletonList(orPredicate)); + + assertThat(spark.sql("SELECT * FROM t_row_filter ORDER BY id").collectAsList().toString()) + .isEqualTo("[[1,Alice,25,IT], [4,David,28,Finance]]"); + + // Test WHERE clause combined with row filter + Predicate ageGt25Predicate = + LeafPredicate.of( + new FieldTransform(new FieldRef(2, "age", DataTypes.INT())), + GreaterThan.INSTANCE, + Collections.singletonList(25)); + restCatalogServer.setRowFilterAuth( + Identifier.create("db2", "t_row_filter"), + Collections.singletonList(ageGt25Predicate)); + + assertThat( + spark.sql("SELECT * FROM t_row_filter WHERE department = 'IT' ORDER BY id") + .collectAsList() + .toString()) + .isEqualTo("[[3,Charlie,35,IT]]"); + + // Test JOIN with row filter + spark.sql("CREATE TABLE t_join2 (id INT, salary DOUBLE)"); + spark.sql( + "INSERT INTO t_join2 VALUES (1, 50000.0), (2, 60000.0), (3, 70000.0), (4, 55000.0)"); + + Predicate ageGe30Predicate = + LeafPredicate.of( + new FieldTransform(new FieldRef(2, "age", DataTypes.INT())), + GreaterOrEqual.INSTANCE, + Collections.singletonList(30)); + restCatalogServer.setRowFilterAuth( + Identifier.create("db2", "t_row_filter"), + Collections.singletonList(ageGe30Predicate)); + + List joinResult = + spark.sql( + "SELECT t1.id, t1.name, t1.age, t2.salary FROM t_row_filter t1 JOIN t_join2 t2 ON t1.id = t2.id ORDER BY t1.id") + .collectAsList(); + assertThat(joinResult.size()).isEqualTo(2); + assertThat(joinResult.get(0).toString()).isEqualTo("[2,Bob,30,60000.0]"); + assertThat(joinResult.get(1).toString()).isEqualTo("[3,Charlie,35,70000.0]"); + + // Clear row filter and verify original data + restCatalogServer.setRowFilterAuth(Identifier.create("db2", "t_row_filter"), null); + + assertThat(spark.sql("SELECT COUNT(*) FROM t_row_filter").collectAsList().toString()) + .isEqualTo("[[4]]"); + } + + @Test + public void testColumnMaskingAndRowFilter() { + spark.sql( + "CREATE TABLE t_combined (id INT, name STRING, salary STRING, age INT, department STRING) TBLPROPERTIES" + + " ('query-auth.enabled'='true')"); + spark.sql( + "INSERT INTO t_combined VALUES (1, 'Alice', '50000.0', 25, 'IT'), (2, 'Bob', '60000.0', 30, 'HR'), (3, 'Charlie', '70000.0', 35, 'IT'), (4, 'David', '55000.0', 28, 'Finance')"); + + Transform salaryMaskTransform = + new ConcatTransform(Collections.singletonList(BinaryString.fromString("***"))); + Map columnMasking = new HashMap<>(); + Predicate ageGe30Predicate = + LeafPredicate.of( + new FieldTransform(new FieldRef(3, "age", DataTypes.INT())), + GreaterOrEqual.INSTANCE, + Collections.singletonList(30)); + + // Test both column masking and row filter together + columnMasking.put("salary", salaryMaskTransform); + Transform nameMaskTransform = + new ConcatTransform(Collections.singletonList(BinaryString.fromString("***"))); + columnMasking.put("name", nameMaskTransform); + restCatalogServer.setColumnMaskingAuth( + Identifier.create("db2", "t_combined"), columnMasking); + Predicate deptPredicate = + LeafPredicate.of( + new FieldTransform(new FieldRef(4, "department", DataTypes.STRING())), + Equal.INSTANCE, + Collections.singletonList(BinaryString.fromString("IT"))); + restCatalogServer.setRowFilterAuth( + Identifier.create("db2", "t_combined"), Collections.singletonList(deptPredicate)); + + List combinedResult = + spark.sql("SELECT * FROM t_combined ORDER BY id").collectAsList(); + assertThat(combinedResult.size()).isEqualTo(2); + assertThat(combinedResult.get(0).getString(1)).isEqualTo("***"); // name masked + assertThat(combinedResult.get(0).getString(2)).isEqualTo("***"); // salary masked + assertThat(combinedResult.get(0).getInt(3)).isEqualTo(25); // age not masked + assertThat(combinedResult.get(0).getString(4)).isEqualTo("IT"); // department not masked + + // Test WHERE clause with both features + assertThat( + spark.sql("SELECT id, name FROM t_combined WHERE age > 30 ORDER BY id") + .collectAsList() + .toString()) + .isEqualTo("[[3,***]]"); + + // Clear both column masking and row filter + restCatalogServer.setColumnMaskingAuth( + Identifier.create("db2", "t_combined"), new HashMap<>()); + restCatalogServer.setRowFilterAuth(Identifier.create("db2", "t_combined"), null); + + assertThat(spark.sql("SELECT COUNT(*) FROM t_combined").collectAsList().toString()) + .isEqualTo("[[4]]"); + assertThat(spark.sql("SELECT name FROM t_combined WHERE id = 1").collectAsList().toString()) + .isEqualTo("[[Alice]]"); + } + private Catalog getPaimonCatalog() { CatalogManager catalogManager = spark.sessionState().catalogManager(); WithPaimonCatalog withPaimonCatalog = (WithPaimonCatalog) catalogManager.currentCatalog();