diff --git a/paimon-python/dev/lint-python.sh b/paimon-python/dev/lint-python.sh index 44be2871493e..4ebbe1acf87f 100755 --- a/paimon-python/dev/lint-python.sh +++ b/paimon-python/dev/lint-python.sh @@ -107,7 +107,7 @@ function collect_checks() { function get_all_supported_checks() { _OLD_IFS=$IFS IFS=$'\n' - SUPPORT_CHECKS=("flake8_check" "pytest_torch_check" "pytest_check" "mixed_check") # control the calling sequence + SUPPORT_CHECKS=("flake8_check" "pytest_check" "pytest_torch_check" "mixed_check") # control the calling sequence for fun in $(declare -F); do if [[ `regexp_match "$fun" "_check$"` = true ]]; then check_name="${fun:11}" diff --git a/paimon-python/pypaimon/read/reader/field_bunch.py b/paimon-python/pypaimon/read/reader/field_bunch.py index 4ba82bd80e39..2cd309cae62c 100644 --- a/paimon-python/pypaimon/read/reader/field_bunch.py +++ b/paimon-python/pypaimon/read/reader/field_bunch.py @@ -82,11 +82,6 @@ def add(self, file: DataFileMeta) -> None: "Blob file with overlapping row id should have decreasing sequence number." ) return - elif first_row_id > self.expected_next_first_row_id: - raise ValueError( - f"Blob file first row id should be continuous, expect " - f"{self.expected_next_first_row_id} but got {first_row_id}" - ) if file.schema_id != self._files[0].schema_id: raise ValueError( diff --git a/paimon-python/pypaimon/read/reader/sample_batch_reader.py b/paimon-python/pypaimon/read/reader/sample_batch_reader.py new file mode 100644 index 000000000000..584caae9742d --- /dev/null +++ b/paimon-python/pypaimon/read/reader/sample_batch_reader.py @@ -0,0 +1,99 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Optional + +from pyarrow import RecordBatch + +from pypaimon.read.reader.format_blob_reader import FormatBlobReader +from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader + + +class SampleBatchReader(RecordBatchReader): + """ + A reader that reads a subset of rows from a data file based on specified sample positions. + + This reader wraps another RecordBatchReader and only returns rows at the specified + sample positions, enabling efficient random sampling of data without reading all rows. + + The reader supports two modes: + 1. For blob readers: Directly reads specific rows by index + 2. For other readers: Reads batches sequentially and extracts only the sampled rows + + Attributes: + reader: The underlying RecordBatchReader to read data from + sample_positions: A sorted list of row indices to sample (0-based) + sample_idx: Current index in the sample_positions list + current_pos: Current absolute row position in the data file + """ + + def __init__(self, reader, sample_positions): + """ + Initialize the SampleBatchReader. + + Args: + reader: The underlying RecordBatchReader to read data from + sample_positions: A bitmap of row indices to sample (0-based). + Must be sorted in ascending order for correct behavior. + """ + self.reader = reader + self.sample_positions = sample_positions + self.sample_idx = 0 + self.current_pos = 0 + + def read_arrow_batch(self) -> Optional[RecordBatch]: + """ + Read the next batch containing sampled rows. + + This method reads data from the underlying reader and returns only the rows + at the specified sample positions. The behavior differs based on reader type: + + - For FormatBlobReader: Directly reads individual rows by index + - For other readers: Reads batches sequentially and extracts sampled rows + using PyArrow's take() method + """ + if self.sample_idx >= len(self.sample_positions): + return None + if isinstance(self.reader.format_reader, FormatBlobReader): + # For blob reader, pass begin_idx and end_idx parameters + batch = self.reader.read_arrow_batch(start_idx=self.sample_positions[self.sample_idx], + end_idx=self.sample_positions[self.sample_idx] + 1) + self.sample_idx += 1 + return batch + else: + while True: + batch = self.reader.read_arrow_batch() + if batch is None: + return None + + batch_begin = self.current_pos + self.current_pos += batch.num_rows + take_idxes = [] + + sample_pos = self.sample_positions[self.sample_idx] + while batch_begin <= sample_pos < self.current_pos: + take_idxes.append(sample_pos - batch_begin) + self.sample_idx += 1 + if self.sample_idx >= len(self.sample_positions): + break + sample_pos = self.sample_positions[self.sample_idx] + + if take_idxes: + return batch.take(take_idxes) + # batch is outside the desired range, continue to next batch + + def close(self): + self.reader.close() diff --git a/paimon-python/pypaimon/read/sampled_split.py b/paimon-python/pypaimon/read/sampled_split.py new file mode 100644 index 000000000000..79d557d67980 --- /dev/null +++ b/paimon-python/pypaimon/read/sampled_split.py @@ -0,0 +1,103 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Dict, List + +from pyroaring import BitMap + +from pypaimon.read.split import Split + + +class SampledSplit(Split): + """ + A Split wrapper that contains sampled row indexes for each file. + + This class wraps a data split and maintains a mapping from file names to + lists of sampled row indexes. It is used for random sampling scenarios where + only specific rows from each file need to be read. + + Attributes: + _data_split: The underlying data split being wrapped. + _sampled_file_idx_map: A dictionary mapping file names to lists of + sampled row indexes within each file. + """ + + def __init__( + self, + data_split: 'Split', + sampled_file_idx_map: Dict[str, BitMap] + ): + self._data_split = data_split + self._sampled_file_idx_map = sampled_file_idx_map + + def data_split(self) -> 'Split': + return self._data_split + + def sampled_file_idx_map(self) -> Dict[str, BitMap]: + return self._sampled_file_idx_map + + @property + def files(self) -> List['DataFileMeta']: + return self._data_split.files + + @property + def partition(self) -> 'GenericRow': + return self._data_split.partition + + @property + def bucket(self) -> int: + return self._data_split.bucket + + @property + def row_count(self) -> int: + if not self._sampled_file_idx_map: + return self._data_split.row_count + + total_rows = 0 + for file in self._data_split.files: + positions = self._sampled_file_idx_map[file.file_name] + total_rows += len(positions) + + return total_rows + + @property + def file_paths(self): + return self._data_split.file_paths + + @property + def file_size(self): + return self._data_split.file_size + + @property + def raw_convertible(self): + return self._data_split.raw_convertible + + @property + def data_deletion_files(self): + return self._data_split.data_deletion_files + + def __eq__(self, other): + if not isinstance(other, SampledSplit): + return False + return (self._data_split == other._data_split and + self._sampled_file_idx_map == other._sampled_file_idx_map) + + def __hash__(self): + return hash((id(self._data_split), tuple(sorted(self._sampled_file_idx_map.items())))) + + def __repr__(self): + return (f"SampledSplit(data_split={self._data_split}, " + f"sampled_file_idx_map={self._sampled_file_idx_map})") diff --git a/paimon-python/pypaimon/read/scanner/append_table_split_generator.py b/paimon-python/pypaimon/read/scanner/append_table_split_generator.py index 775771eed1c7..b7081c4a2c61 100644 --- a/paimon-python/pypaimon/read/scanner/append_table_split_generator.py +++ b/paimon-python/pypaimon/read/scanner/append_table_split_generator.py @@ -15,11 +15,15 @@ See the License for the specific language governing permissions and limitations under the License. """ +import random from collections import defaultdict from typing import List, Dict, Tuple +from pyroaring import BitMap + from pypaimon.manifest.schema.data_file_meta import DataFileMeta from pypaimon.manifest.schema.manifest_entry import ManifestEntry +from pypaimon.read.sampled_split import SampledSplit from pypaimon.read.scanner.split_generator import AbstractSplitGenerator from pypaimon.read.split import Split from pypaimon.read.sliced_split import SlicedSplit @@ -41,13 +45,15 @@ def create_splits(self, file_entries: List[ManifestEntry]) -> List[Split]: if self.start_pos_of_this_subtask is not None: # shard data range: [plan_start_pos, plan_end_pos) partitioned_files, plan_start_pos, plan_end_pos = \ - self.__filter_by_slice( + self._filter_by_slice( partitioned_files, self.start_pos_of_this_subtask, self.end_pos_of_this_subtask ) elif self.idx_of_this_subtask is not None: partitioned_files, plan_start_pos, plan_end_pos = self._filter_by_shard(partitioned_files) + elif self.sample_num_rows is not None: + partitioned_files, file_positions = self._filter_by_sample(partitioned_files) def weight_func(f: DataFileMeta) -> int: return max(f.file_size, self.open_file_cost) @@ -68,6 +74,8 @@ def weight_func(f: DataFileMeta) -> int: if self.start_pos_of_this_subtask is not None or self.idx_of_this_subtask is not None: splits = self._wrap_to_sliced_splits(splits, plan_start_pos, plan_end_pos) + elif self.sample_num_rows is not None: + splits = self._wrap_to_sampled_splits(splits, file_positions) return splits @@ -76,12 +84,12 @@ def _wrap_to_sliced_splits(self, splits: List[Split], plan_start_pos: int, plan_ file_end_pos = 0 # end row position of current file in all splits data for split in splits: - shard_file_idx_map = self.__compute_split_file_idx_map( + shard_file_idx_map = self._compute_split_shard_file_idx_map( plan_start_pos, plan_end_pos, split, file_end_pos ) file_end_pos = shard_file_idx_map[self.NEXT_POS_KEY] del shard_file_idx_map[self.NEXT_POS_KEY] - + if shard_file_idx_map: sliced_splits.append(SlicedSplit(split, shard_file_idx_map)) else: @@ -90,10 +98,21 @@ def _wrap_to_sliced_splits(self, splits: List[Split], plan_start_pos: int, plan_ return sliced_splits @staticmethod - def __filter_by_slice( - partitioned_files: defaultdict, - start_pos: int, - end_pos: int + def _wrap_to_sampled_splits(splits: List[Split], file_positions: Dict[str, BitMap]) -> List[Split]: + # Set sample file positions for each split + sampled_splits = [] + for split in splits: + sampled_file_idx_map = {} + for file in split.files: + sampled_file_idx_map[file.file_name] = file_positions[file.file_name] + sampled_splits.append(SampledSplit(split, sampled_file_idx_map)) + return sampled_splits + + @staticmethod + def _filter_by_slice( + partitioned_files: defaultdict, + start_pos: int, + end_pos: int ) -> tuple: plan_start_pos = 0 plan_end_pos = 0 @@ -142,21 +161,45 @@ def _filter_by_shard(self, partitioned_files: defaultdict) -> tuple: # Calculate shard range using shared helper start_pos, end_pos = self._compute_shard_range(total_row) - return self.__filter_by_slice(partitioned_files, start_pos, end_pos) + return self._filter_by_slice(partitioned_files, start_pos, end_pos) + + def _filter_by_sample(self, partitioned_files) -> (defaultdict, Dict[str, List[int]]): + """ + Randomly sample num_rows data from partitioned_files: + 1. First use random to generate num_rows indexes + 2. Iterate through partitioned_files, find the file entries where corresponding indexes are located, + add them to filtered_partitioned_files, and for each entry, add indexes to the list + """ + # Calculate total number of rows + total_rows = 0 + for key, file_entries in partitioned_files.items(): + for entry in file_entries: + total_rows += entry.file.row_count + + # Generate random sample indexes + sample_indexes = sorted(random.sample(range(total_rows), self.sample_num_rows)) + + # Map each sample index to its corresponding file and local index + filtered_partitioned_files = defaultdict(list) + file_positions = {} # {file_name: BitMap of local_indexes} + self._compute_file_sample_idx_map(partitioned_files, filtered_partitioned_files, + file_positions, + sample_indexes, is_blob=False) + return filtered_partitioned_files, file_positions @staticmethod - def __compute_split_file_idx_map( - plan_start_pos: int, - plan_end_pos: int, - split: Split, - file_end_pos: int + def _compute_split_shard_file_idx_map( + plan_start_pos: int, + plan_end_pos: int, + split: Split, + file_end_pos: int ) -> Dict[str, Tuple[int, int]]: """ Compute file index map for a split, determining which rows to read from each file. """ shard_file_idx_map = {} - + for file in split.files: file_begin_pos = file_end_pos # Starting row position of current file in all data file_end_pos += file.row_count # Update to row position after current file @@ -165,7 +208,7 @@ def __compute_split_file_idx_map( file_range = AppendTableSplitGenerator._compute_file_range( plan_start_pos, plan_end_pos, file_begin_pos, file.row_count ) - + if file_range is not None: shard_file_idx_map[file.file_name] = file_range diff --git a/paimon-python/pypaimon/read/scanner/data_evolution_split_generator.py b/paimon-python/pypaimon/read/scanner/data_evolution_split_generator.py index 10847b12b202..931184d7d392 100644 --- a/paimon-python/pypaimon/read/scanner/data_evolution_split_generator.py +++ b/paimon-python/pypaimon/read/scanner/data_evolution_split_generator.py @@ -15,13 +15,17 @@ See the License for the specific language governing permissions and limitations under the License. """ +import random from collections import defaultdict from typing import List, Optional, Dict, Tuple +from pyroaring import BitMap + from pypaimon.globalindex.indexed_split import IndexedSplit from pypaimon.globalindex.range import Range from pypaimon.manifest.schema.data_file_meta import DataFileMeta from pypaimon.manifest.schema.manifest_entry import ManifestEntry +from pypaimon.read.sampled_split import SampledSplit from pypaimon.read.scanner.split_generator import AbstractSplitGenerator from pypaimon.read.split import Split from pypaimon.read.sliced_split import SlicedSplit @@ -33,15 +37,14 @@ class DataEvolutionSplitGenerator(AbstractSplitGenerator): """ def __init__( - self, - table, - target_split_size: int, - open_file_cost: int, - deletion_files_map=None, - row_ranges: Optional[List] = None, - score_getter=None + self, + table, + target_split_size: int, + open_file_cost: int, + row_ranges: Optional[List] = None, + score_getter=None ): - super().__init__(table, target_split_size, open_file_cost, deletion_files_map) + super().__init__(table, target_split_size, open_file_cost) self.row_ranges = row_ranges self.score_getter = score_getter @@ -49,6 +52,7 @@ def create_splits(self, file_entries: List[ManifestEntry]) -> List[Split]: """ Create splits for data evolution tables. """ + def sort_key(manifest_entry: ManifestEntry) -> tuple: first_row_id = ( manifest_entry.file.first_row_id @@ -59,10 +63,8 @@ def sort_key(manifest_entry: ManifestEntry) -> tuple: max_seq = manifest_entry.file.max_sequence_number return first_row_id, is_blob, -max_seq - sorted_entries = sorted(file_entries, key=sort_key) - partitioned_files = defaultdict(list) - for entry in sorted_entries: + for entry in file_entries: partitioned_files[(tuple(entry.partition.values), entry.bucket)].append(entry) plan_start_pos = 0 @@ -71,7 +73,7 @@ def sort_key(manifest_entry: ManifestEntry) -> tuple: if self.start_pos_of_this_subtask is not None: # shard data range: [plan_start_pos, plan_end_pos) partitioned_files, plan_start_pos, plan_end_pos = \ - self._filter_by_row_range( + self._filter_by_slice( partitioned_files, self.start_pos_of_this_subtask, self.end_pos_of_this_subtask @@ -79,6 +81,8 @@ def sort_key(manifest_entry: ManifestEntry) -> tuple: elif self.idx_of_this_subtask is not None: # shard data range: [plan_start_pos, plan_end_pos) partitioned_files, plan_start_pos, plan_end_pos = self._filter_by_shard(partitioned_files) + elif self.sample_num_rows is not None: + partitioned_files, file_positions = self._filter_by_sample(partitioned_files) def weight_func(file_list: List[DataFileMeta]) -> int: return max(sum(f.file_size for f in file_list), self.open_file_cost) @@ -87,7 +91,7 @@ def weight_func(file_list: List[DataFileMeta]) -> int: for key, sorted_entries_list in partitioned_files.items(): if not sorted_entries_list: continue - + sorted_entries_list = sorted(sorted_entries_list, key=sort_key) data_files: List[DataFileMeta] = [e.file for e in sorted_entries_list] # Split files by firstRowId for data evolution @@ -110,6 +114,8 @@ def weight_func(file_list: List[DataFileMeta]) -> int: if self.start_pos_of_this_subtask is not None or self.idx_of_this_subtask is not None: splits = self._wrap_to_sliced_splits(splits, plan_start_pos, plan_end_pos) + elif self.sample_num_rows is not None: + splits = self._wrap_to_sampled_splits(splits, file_positions) # Wrap splits with IndexedSplit if row_ranges is provided if self.row_ranges: @@ -127,12 +133,12 @@ def _wrap_to_sliced_splits(self, splits: List[Split], plan_start_pos: int, plan_ for split in splits: # Compute file index map for both data and blob files # Blob files share the same row position tracking as data files - shard_file_idx_map = self._compute_split_file_idx_map( + shard_file_idx_map = self._compute_split_shard_file_idx_map( plan_start_pos, plan_end_pos, split, file_end_pos ) file_end_pos = shard_file_idx_map[self.NEXT_POS_KEY] del shard_file_idx_map[self.NEXT_POS_KEY] - + if shard_file_idx_map: sliced_splits.append(SlicedSplit(split, shard_file_idx_map)) else: @@ -140,11 +146,21 @@ def _wrap_to_sliced_splits(self, splits: List[Split], plan_start_pos: int, plan_ return sliced_splits - def _filter_by_row_range( - self, - partitioned_files: defaultdict, - start_pos: int, - end_pos: int + def _wrap_to_sampled_splits(self, splits: List[Split], file_positions: Dict[str, BitMap]) -> List[Split]: + # Set sample file positions for each split + sampled_splits = [] + for split in splits: + sampled_file_idx_map = {} + for file in split.files: + sampled_file_idx_map[file.file_name] = file_positions[file.file_name] + sampled_splits.append(SampledSplit(split, sampled_file_idx_map)) + return sampled_splits + + def _filter_by_slice( + self, + partitioned_files: defaultdict, + start_pos: int, + end_pos: int ) -> tuple: """ Filter file entries by row range for data evolution tables. @@ -203,7 +219,33 @@ def _filter_by_shard(self, partitioned_files: defaultdict) -> tuple: # Calculate shard range using shared helper start_pos, end_pos = self._compute_shard_range(total_row) - return self._filter_by_row_range(partitioned_files, start_pos, end_pos) + return self._filter_by_slice(partitioned_files, start_pos, end_pos) + + def _filter_by_sample(self, partitioned_files) -> (defaultdict, Dict[str, List[int]]): + """ + Randomly sample num_rows data from partitioned_files: + 1. First use random to generate num_rows indexes + 2. Iterate through partitioned_files, find the file entries where corresponding indexes are located, + add them to filtered_partitioned_files, and for each entry, add indexes to the list + """ + # Calculate total number of rows + total_rows = 0 + for key, file_entries in partitioned_files.items(): + for entry in file_entries: + if not self._is_blob_file(entry.file.file_name): + total_rows += entry.file.row_count + # Generate random sample indexes + sample_indexes = sorted(random.sample(range(total_rows), self.sample_num_rows)) + + # Map each sample index to its corresponding file and local index + filtered_partitioned_files = defaultdict(list) + file_positions = {} # {file_name: BitMap of local_indexes} + self._compute_file_sample_idx_map(partitioned_files, filtered_partitioned_files, file_positions, + sample_indexes, is_blob=False) + self._compute_file_sample_idx_map(partitioned_files, filtered_partitioned_files, file_positions, + sample_indexes, is_blob=True) + + return filtered_partitioned_files, file_positions def _split_by_row_id(self, files: List[DataFileMeta]) -> List[List[DataFileMeta]]: """ @@ -249,12 +291,12 @@ def _split_by_row_id(self, files: List[DataFileMeta]) -> List[List[DataFileMeta] return split_by_row_id - def _compute_split_file_idx_map( - self, - plan_start_pos: int, - plan_end_pos: int, - split: Split, - file_end_pos: int + def _compute_split_shard_file_idx_map( + self, + plan_start_pos: int, + plan_end_pos: int, + split: Split, + file_end_pos: int ) -> Dict[str, Tuple[int, int]]: """ Compute file index map for a split, determining which rows to read from each file. @@ -262,14 +304,14 @@ def _compute_split_file_idx_map( For blob files (which may be rolled), the range is calculated based on each file's first_row_id. """ shard_file_idx_map = {} - + # Find the first non-blob file to determine the row range for this split data_file = None for file in split.files: if not self._is_blob_file(file.file_name): data_file = file break - + if data_file is None: # No data file, skip this split shard_file_idx_map[self.NEXT_POS_KEY] = file_end_pos @@ -284,7 +326,7 @@ def _compute_split_file_idx_map( data_file_range = self._compute_file_range( plan_start_pos, plan_end_pos, file_begin_pos, data_file.row_count ) - + # Apply ranges to each file in the split for file in split.files: if self._is_blob_file(file.file_name): @@ -301,15 +343,15 @@ def _compute_split_file_idx_map( # Blob's position relative to data file start blob_rel_start = blob_first_row_id - data_file_first_row_id blob_rel_end = blob_rel_start + file.row_count - + # Shard range relative to data file start shard_start = data_file_range[0] shard_end = data_file_range[1] - + # Intersect blob's range with shard range intersect_start = max(blob_rel_start, shard_start) intersect_end = min(blob_rel_end, shard_end) - + if intersect_start >= intersect_end: # Blob file is completely outside shard range shard_file_idx_map[file.file_name] = (-1, -1) diff --git a/paimon-python/pypaimon/read/scanner/full_starting_scanner.py b/paimon-python/pypaimon/read/scanner/full_starting_scanner.py index f415ab61ef8d..a5b0ccde79bc 100755 --- a/paimon-python/pypaimon/read/scanner/full_starting_scanner.py +++ b/paimon-python/pypaimon/read/scanner/full_starting_scanner.py @@ -67,11 +67,6 @@ def __init__( self.target_split_size = options.source_split_target_size() self.open_file_cost = options.source_split_open_file_cost() - self.idx_of_this_subtask = None - self.number_of_para_subtasks = None - self.start_pos_of_this_subtask = None - self.end_pos_of_this_subtask = None - self.only_read_real_buckets = options.bucket() == BucketMode.POSTPONE_BUCKET.value self.data_evolution = options.data_evolution_enabled() self.deletion_vectors_enabled = options.deletion_vectors_enabled() @@ -84,28 +79,12 @@ def schema_fields_func(schema_id: int): self.table.table_schema.id ) - def scan(self) -> Plan: - file_entries = self.plan_files() - if not file_entries: - return Plan([]) - # Get deletion files map if deletion vectors are enabled. - # {partition-bucket -> {filename -> DeletionFile}} - deletion_files_map: dict[tuple, dict[str, DeletionFile]] = {} - if self.deletion_vectors_enabled: - latest_snapshot = self.snapshot_manager.get_latest_snapshot() - # Extract unique partition-bucket pairs from file entries - buckets = set() - for entry in file_entries: - buckets.add((tuple(entry.partition.values), entry.bucket)) - deletion_files_map = self._scan_dv_index(latest_snapshot, buckets) - # Create appropriate split generator based on table type if self.table.is_primary_key_table: - split_generator = PrimaryKeyTableSplitGenerator( + self.split_generator = PrimaryKeyTableSplitGenerator( self.table, self.target_split_size, self.open_file_cost, - deletion_files_map ) elif self.data_evolution: global_index_result = self._eval_global_index() @@ -115,30 +94,37 @@ def scan(self) -> Plan: row_ranges = global_index_result.results().to_range_list() if isinstance(global_index_result, VectorSearchGlobalIndexResult): score_getter = global_index_result.score_getter() - split_generator = DataEvolutionSplitGenerator( + self.split_generator = DataEvolutionSplitGenerator( self.table, self.target_split_size, self.open_file_cost, - deletion_files_map, row_ranges, score_getter ) else: - split_generator = AppendTableSplitGenerator( + self.split_generator = AppendTableSplitGenerator( self.table, self.target_split_size, self.open_file_cost, - deletion_files_map ) - # Configure sharding if needed - if self.idx_of_this_subtask is not None: - split_generator.with_shard(self.idx_of_this_subtask, self.number_of_para_subtasks) - elif self.start_pos_of_this_subtask is not None: - split_generator.with_slice(self.start_pos_of_this_subtask, self.end_pos_of_this_subtask) + def scan(self) -> Plan: + file_entries = self.plan_files() + if not file_entries: + return Plan([]) + # Get deletion files map if deletion vectors are enabled. + # {partition-bucket -> {filename -> DeletionFile}} + if self.deletion_vectors_enabled: + latest_snapshot = self.snapshot_manager.get_latest_snapshot() + # Extract unique partition-bucket pairs from file entries + buckets = set() + for entry in file_entries: + buckets.add((tuple(entry.partition.values), entry.bucket)) + deletion_files_map = self._scan_dv_index(latest_snapshot, buckets) + self.split_generator.deletion_files_map = deletion_files_map # Generate splits - splits = split_generator.create_splits(file_entries) + splits = self.split_generator.create_splits(file_entries) splits = self._apply_push_down_limit(splits) return Plan(splits) @@ -223,21 +209,15 @@ def read_manifest_entries(self, manifest_files: List[ManifestFileMeta]) -> List[ ) def with_shard(self, idx_of_this_subtask: int, number_of_para_subtasks: int) -> 'FullStartingScanner': - if idx_of_this_subtask >= number_of_para_subtasks: - raise ValueError("idx_of_this_subtask must be less than number_of_para_subtasks") - if self.start_pos_of_this_subtask is not None: - raise Exception("with_shard and with_slice cannot be used simultaneously") - self.idx_of_this_subtask = idx_of_this_subtask - self.number_of_para_subtasks = number_of_para_subtasks + self.split_generator.with_shard(idx_of_this_subtask, number_of_para_subtasks) return self def with_slice(self, start_pos: int, end_pos: int) -> 'FullStartingScanner': - if start_pos >= end_pos: - raise ValueError("start_pos must be less than end_pos") - if self.idx_of_this_subtask is not None: - raise Exception("with_slice and with_shard cannot be used simultaneously") - self.start_pos_of_this_subtask = start_pos - self.end_pos_of_this_subtask = end_pos + self.split_generator.with_slice(start_pos, end_pos) + return self + + def with_sample(self, num_rows: int) -> 'FullStartingScanner': + self.split_generator.with_sample(num_rows) return self def _apply_push_down_limit(self, splits: List[DataSplit]) -> List[DataSplit]: diff --git a/paimon-python/pypaimon/read/scanner/primary_key_table_split_generator.py b/paimon-python/pypaimon/read/scanner/primary_key_table_split_generator.py index 5955b6aa9478..970092824d5c 100644 --- a/paimon-python/pypaimon/read/scanner/primary_key_table_split_generator.py +++ b/paimon-python/pypaimon/read/scanner/primary_key_table_split_generator.py @@ -36,9 +36,8 @@ def __init__( table, target_split_size: int, open_file_cost: int, - deletion_files_map=None ): - super().__init__(table, target_split_size, open_file_cost, deletion_files_map) + super().__init__(table, target_split_size, open_file_cost) self.deletion_vectors_enabled = table.options.deletion_vectors_enabled() self.merge_engine = table.options.merge_engine() diff --git a/paimon-python/pypaimon/read/scanner/split_generator.py b/paimon-python/pypaimon/read/scanner/split_generator.py index 6dab4fc12aa3..bac85797f031 100644 --- a/paimon-python/pypaimon/read/scanner/split_generator.py +++ b/paimon-python/pypaimon/read/scanner/split_generator.py @@ -16,7 +16,9 @@ limitations under the License. """ from abc import ABC, abstractmethod -from typing import Callable, List, Optional, Dict, Tuple +from typing import Callable, List, Optional, Tuple + +from pyroaring import BitMap from pypaimon.manifest.schema.data_file_meta import DataFileMeta from pypaimon.manifest.schema.manifest_entry import ManifestEntry @@ -30,28 +32,29 @@ class AbstractSplitGenerator(ABC): """ Abstract base class for generating splits. """ - + # Special key for tracking file end position in split file index map NEXT_POS_KEY = '_next_pos' def __init__( - self, - table, - target_split_size: int, - open_file_cost: int, - deletion_files_map: Optional[Dict] = None + self, + table, + target_split_size: int, + open_file_cost: int, ): self.table = table self.target_split_size = target_split_size self.open_file_cost = open_file_cost - self.deletion_files_map = deletion_files_map or {} - + self.deletion_files_map = {} + # Shard configuration self.idx_of_this_subtask = None self.number_of_para_subtasks = None self.start_pos_of_this_subtask = None self.end_pos_of_this_subtask = None + self.sample_num_rows = None + def with_shard(self, idx_of_this_subtask: int, number_of_para_subtasks: int): """Configure sharding for parallel processing.""" if idx_of_this_subtask >= number_of_para_subtasks: @@ -72,6 +75,14 @@ def with_slice(self, start_pos: int, end_pos: int): self.end_pos_of_this_subtask = end_pos return self + def with_sample(self, num_rows: int): + if self.idx_of_this_subtask is not None: + raise ValueError("with_sample and with_shard cannot be used simultaneously now") + if self.start_pos_of_this_subtask is not None: + raise ValueError("with_sample and with_slice cannot be used simultaneously now") + self.sample_num_rows = num_rows + return self + @abstractmethod def create_splits(self, file_entries: List[ManifestEntry]) -> List[Split]: """ @@ -80,11 +91,11 @@ def create_splits(self, file_entries: List[ManifestEntry]) -> List[Split]: pass def _build_split_from_pack( - self, - packed_files: List[List[DataFileMeta]], - file_entries: List[ManifestEntry], - for_primary_key_split: bool, - use_optimized_path: bool = False + self, + packed_files: List[List[DataFileMeta]], + file_entries: List[ManifestEntry], + for_primary_key_split: bool, + use_optimized_path: bool = False ) -> List[Split]: """ Build splits from packed files. @@ -136,10 +147,10 @@ def _build_split_from_pack( return splits def _get_deletion_files_for_split( - self, - data_files: List[DataFileMeta], - partition: GenericRow, - bucket: int + self, + data_files: List[DataFileMeta], + partition: GenericRow, + bucket: int ) -> Optional[List[DeletionFile]]: """Get deletion files for the given data files in a split.""" if not self.deletion_files_map: @@ -170,9 +181,9 @@ def _without_delete_row(data_file_meta: DataFileMeta) -> bool: @staticmethod def _pack_for_ordered( - items: List, - weight_func: Callable, - target_weight: int + items: List, + weight_func: Callable, + target_weight: int ) -> List[List]: """Pack items into groups based on target weight.""" packed = [] @@ -216,12 +227,54 @@ def _compute_shard_range(self, total_row: int) -> Tuple[int, int]: end_pos = start_pos + num_row return start_pos, end_pos + def _compute_file_sample_idx_map(self, partitioned_files, filtered_partitioned_files, file_positions, + sample_indexes, is_blob): + current_row = 0 + sample_idx = 0 + + for key, file_entries in partitioned_files.items(): + filtered_entries = [] + for entry in file_entries: + if not is_blob and self._is_blob_file(entry.file.file_name): + continue + if is_blob and not self._is_blob_file(entry.file.file_name): + continue + file_start_row = current_row + file_end_row = current_row + entry.file.row_count + + # Find all sample indexes that fall within this file + local_indexes = BitMap() + while sample_idx < len(sample_indexes) and sample_indexes[sample_idx] < file_end_row: + if sample_indexes[sample_idx] >= file_start_row: + # Convert global index to local index within this file + local_index = sample_indexes[sample_idx] - file_start_row + local_indexes.add(local_index) + sample_idx += 1 + + # If this file contains any sampled rows, include it + if len(local_indexes) > 0: + filtered_entries.append(entry) + file_positions[entry.file.file_name] = local_indexes + + current_row = file_end_row + + # Early exit if we've processed all sample indexes + if sample_idx >= len(sample_indexes): + break + + if filtered_entries: + filtered_partitioned_files[key] = filtered_partitioned_files.get(key, []) + filtered_entries + + # Early exit if we've processed all sample indexes + if sample_idx >= len(sample_indexes): + break + @staticmethod def _compute_file_range( - plan_start_pos: int, - plan_end_pos: int, - file_begin_pos: int, - file_row_count: int + plan_start_pos: int, + plan_end_pos: int, + file_begin_pos: int, + file_row_count: int ) -> Optional[Tuple[int, int]]: """ Compute the row range to read from a file given shard range and file position. diff --git a/paimon-python/pypaimon/read/split.py b/paimon-python/pypaimon/read/split.py index 3f2d2f83294c..24ea8dcc5193 100644 --- a/paimon-python/pypaimon/read/split.py +++ b/paimon-python/pypaimon/read/split.py @@ -57,11 +57,6 @@ def bucket(self) -> int: class DataSplit(Split): - """ - Implementation of Split for native Python reading. - - This is equivalent to Java's DataSplit. - """ def __init__( self, diff --git a/paimon-python/pypaimon/read/split_read.py b/paimon-python/pypaimon/read/split_read.py index 47edf63d9a8f..09f4b46d8fa9 100644 --- a/paimon-python/pypaimon/read/split_read.py +++ b/paimon-python/pypaimon/read/split_read.py @@ -48,8 +48,10 @@ from pypaimon.read.reader.key_value_unwrap_reader import \ KeyValueUnwrapRecordReader from pypaimon.read.reader.key_value_wrap_reader import KeyValueWrapReader +from pypaimon.read.reader.sample_batch_reader import SampleBatchReader from pypaimon.read.reader.shard_batch_reader import ShardBatchReader from pypaimon.read.reader.sort_merge_reader import SortMergeReaderWithMinHeap +from pypaimon.read.sampled_split import SampledSplit from pypaimon.read.split import Split from pypaimon.read.sliced_split import SlicedSplit from pypaimon.schema.data_types import DataField @@ -329,31 +331,53 @@ def _genarate_deletion_file_readers(self): self.deletion_file_readers[data_file.file_name] = lambda df=deletion_file: DeletionVector.read( self.table.file_io, df) - -class RawFileSplitRead(SplitRead): - def raw_reader_supplier(self, file: DataFileMeta, dv_factory: Optional[Callable] = None) -> Optional[RecordReader]: - read_fields = self._get_final_read_data_fields() - # If the current file needs to be further divided for reading, use ShardBatchReader - # Check if this is a SlicedSplit to get shard_file_idx_map - shard_file_idx_map = ( - self.split.shard_file_idx_map() if isinstance(self.split, SlicedSplit) else {} - ) - if file.file_name in shard_file_idx_map: - (start_pos, end_pos) = shard_file_idx_map[file.file_name] - if (start_pos, end_pos) == (-1, -1): - return None + def _create_file_reader( + self, + file: DataFileMeta, + read_fields: List[str], + row_tracking_enabled: bool = True + ) -> Optional[RecordBatchReader]: + """Create a file reader that handles SlicedSplit, SampledSplit, and regular Split.""" + if isinstance(self.split, SlicedSplit): + shard_file_idx_map = self.split.shard_file_idx_map() + if file.file_name in shard_file_idx_map: + (begin_pos, end_pos) = shard_file_idx_map[file.file_name] + if (begin_pos, end_pos) == (-1, -1): + return None + else: + return ShardBatchReader(self.file_reader_supplier( + file=file, + for_merge_read=False, + read_fields=read_fields, + row_tracking_enabled=row_tracking_enabled), begin_pos, end_pos) else: - file_batch_reader = ShardBatchReader(self.file_reader_supplier( + return self.file_reader_supplier( file=file, for_merge_read=False, read_fields=read_fields, - row_tracking_enabled=True), start_pos, end_pos) + row_tracking_enabled=row_tracking_enabled) + elif isinstance(self.split, SampledSplit): + sampled_file_idx_map = self.split.sampled_file_idx_map() + sample_positions = sampled_file_idx_map[file.file_name] + return SampleBatchReader(self.file_reader_supplier( + file=file, + for_merge_read=False, + read_fields=read_fields, + row_tracking_enabled=row_tracking_enabled), sample_positions) else: - file_batch_reader = self.file_reader_supplier( + return self.file_reader_supplier( file=file, for_merge_read=False, read_fields=read_fields, - row_tracking_enabled=True) + row_tracking_enabled=row_tracking_enabled) + + +class RawFileSplitRead(SplitRead): + def raw_reader_supplier(self, file: DataFileMeta, dv_factory: Optional[Callable] = None) -> Optional[RecordReader]: + read_fields = self._get_final_read_data_fields() + file_batch_reader = self._create_file_reader(file, read_fields) + if file_batch_reader is None: + return None dv = dv_factory() if dv_factory else None if dv: return ApplyDeletionVectorReader(RowPositionReader(file_batch_reader), dv) @@ -496,18 +520,6 @@ def _create_union_reader(self, need_merge_files: List[DataFileMeta]) -> RecordRe # Split field bunches fields_files = self._split_field_bunches(need_merge_files) - # Validate row counts and first row IDs - row_count = fields_files[0].row_count() - first_row_id = fields_files[0].files()[0].first_row_id - - for bunch in fields_files: - if bunch.row_count() != row_count: - raise ValueError("All files in a field merge split should have the same row count.") - if bunch.files()[0].first_row_id != first_row_id: - raise ValueError( - "All files in a field merge split should have the same first row id and could not be null." - ) - # Create the union reader all_read_fields = self.read_fields file_record_readers = [None] * len(fields_files) @@ -569,30 +581,6 @@ def _create_union_reader(self, need_merge_files: List[DataFileMeta]) -> RecordRe return DataEvolutionMergeReader(row_offsets, field_offsets, file_record_readers) - def _create_file_reader(self, file: DataFileMeta, read_fields: [str]) -> Optional[RecordReader]: - """Create a file reader for a single file.""" - # If the current file needs to be further divided for reading, use ShardBatchReader - # Check if this is a SlicedSplit to get shard_file_idx_map - shard_file_idx_map = ( - self.split.shard_file_idx_map() if isinstance(self.split, SlicedSplit) else {} - ) - if file.file_name in shard_file_idx_map: - (begin_pos, end_pos) = shard_file_idx_map[file.file_name] - if (begin_pos, end_pos) == (-1, -1): - return None - else: - return ShardBatchReader(self.file_reader_supplier( - file=file, - for_merge_read=False, - read_fields=read_fields, - row_tracking_enabled=True), begin_pos, end_pos) - else: - return self.file_reader_supplier( - file=file, - for_merge_read=False, - read_fields=read_fields, - row_tracking_enabled=True) - def _split_field_bunches(self, need_merge_files: List[DataFileMeta]) -> List[FieldBunch]: """Split files into field bunches.""" diff --git a/paimon-python/pypaimon/read/table_scan.py b/paimon-python/pypaimon/read/table_scan.py index 8276163450a0..7e59542b5e15 100755 --- a/paimon-python/pypaimon/read/table_scan.py +++ b/paimon-python/pypaimon/read/table_scan.py @@ -90,3 +90,12 @@ def with_shard(self, idx_of_this_subtask, number_of_para_subtasks) -> 'TableScan def with_slice(self, start_pos, end_pos) -> 'TableScan': self.starting_scanner.with_slice(start_pos, end_pos) return self + + def with_sample(self, num_rows: int) -> 'TableScan': + """Sample the table with the given number of rows. + + params: + num_rows: The number of rows to sample. + """ + self.starting_scanner.with_sample(num_rows) + return self diff --git a/paimon-python/pypaimon/tests/blob_table_test.py b/paimon-python/pypaimon/tests/blob_table_test.py index 9925e21be54d..a8fbb7f58392 100755 --- a/paimon-python/pypaimon/tests/blob_table_test.py +++ b/paimon-python/pypaimon/tests/blob_table_test.py @@ -2567,6 +2567,207 @@ def test_blob_write_read_large_data_volume_rolling_with_shard(self): self.assertEqual(actual, expected) + def test_data_blob_writer_with_sample(self): + """Test DataBlobWriter with mixed data types in blob column.""" + + # Create schema with blob column + pa_schema = pa.schema([ + ('id', pa.int32()), + ('type', pa.string()), + ('data', pa.large_binary()), + ]) + + schema = Schema.from_pyarrow_schema( + pa_schema, + options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true' + } + ) + self.catalog.create_table('test_db.with_sample', schema, False) + table = self.catalog.get_table('test_db.with_sample') + + # Use proper table API to create writer + write_builder = table.new_batch_write_builder() + blob_writer = write_builder.new_write() + + # Test data with different types of blob content + test_data = pa.Table.from_pydict({ + 'id': [1, 2, 3, 4, 5], + 'type': ['text', 'json', 'binary', 'image', 'pdf'], + 'data': [ + b'This is text content', + b'{"key": "value", "number": 42}', + b'\x00\x01\x02\x03\xff\xfe\xfd', + b'PNG_IMAGE_DATA_PLACEHOLDER', + b'%PDF-1.4\nPDF_CONTENT_PLACEHOLDER' + ] + }, schema=pa_schema) + + # Write mixed data + total_rows = 0 + for batch in test_data.to_batches(): + blob_writer.write_arrow_batch(batch) + total_rows += batch.num_rows + + # Test prepare commit + commit_messages = blob_writer.prepare_commit() + # Create commit and commit the data + commit = write_builder.new_commit() + commit.commit(commit_messages) + blob_writer.close() + + # Read data back using table API + read_builder = table.new_read_builder() + table_scan = read_builder.new_scan().with_sample(2) + table_read = read_builder.new_read() + splits = table_scan.plan().splits() + actual = table_read.to_arrow(splits) + expected_ids = set(test_data['id'].to_pylist()) + actual_ids = set(actual['id'].to_pylist()) + self.assertEqual(2, actual.num_rows, "Should have 2 rows") + self.assertTrue(actual_ids.issubset(expected_ids), + f"Actual user_ids {actual_ids} should be subset of written ids {expected_ids}") + + def test_blob_write_read_large_data_with_rolling_with_sample(self): + """ + Test writing and reading large blob data with file rolling and sample. + + Test workflow: + - Creates a table with blob column and 10MB target file size + - Writes 4 batches of 40 records each (160 total records) + - Each record contains a 3MB blob + - Random sample 12 records + - Verifies blob data integrity and size + """ + + # Create schema with blob column + pa_schema = pa.schema([ + ('id', pa.int32()), + ('record_id_of_batch', pa.int32()), + ('metadata', pa.string()), + ('large_blob', pa.large_binary()), + ]) + + schema = Schema.from_pyarrow_schema( + pa_schema, + options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + 'blob.target-file-size': '10MB' + } + ) + self.catalog.create_table('test_db.with_rolling_with_sample', schema, False) + table = self.catalog.get_table('test_db.with_rolling_with_sample') + + # Create large blob data + large_blob_size = 3 * 1024 * 1024 + blob_pattern = b'LARGE_BLOB_PATTERN_' + b'X' * 1024 # ~1KB pattern + pattern_size = len(blob_pattern) + repetitions = large_blob_size // pattern_size + large_blob_data = blob_pattern * repetitions + + actual_size = len(large_blob_data) + print(f"Created blob data: {actual_size:,} bytes ({actual_size / (1024 * 1024):.2f} MB)") + # Write 4 batches of 40 records + for i in range(4): + write_builder = table.new_batch_write_builder() + writer = write_builder.new_write() + # Write 40 records + for record_id in range(40): + test_data = pa.Table.from_pydict({ + 'id': [i * 40 + record_id + 1], # Unique ID for each row + 'record_id_of_batch': [record_id], + 'metadata': [f'Large blob batch {record_id + 1}'], + 'large_blob': [large_blob_data] + }, schema=pa_schema) + writer.write_arrow(test_data) + + commit_messages = writer.prepare_commit() + commit = write_builder.new_commit() + commit.commit(commit_messages) + writer.close() + + # Read data back + read_builder = table.new_read_builder() + table_scan = read_builder.new_scan().with_sample(12) + table_read = read_builder.new_read() + result = table_read.to_arrow(table_scan.plan().splits()) + + # Verify the data + self.assertEqual(result.num_rows, 12, "Should have 12 rows") + self.assertEqual(result.num_columns, 4, "Should have 4 columns") + + # Verify blob data integrity + blob_data = result.column('large_blob').to_pylist() + self.assertEqual(len(blob_data), 12, "Should have 54 blob records") + + def test_blob_write_read_large_data_volums_with_rolling_with_sample(self): + """ + Test writing and reading large blob data with file rolling and sample. + Test workflow: + - Creates a table with blob column and 10MB target file size + - Writes 10000 records of 5KB blob data each + - Random sample 1000 records + - Verifies data size + """ + # Create schema with blob column + pa_schema = pa.schema([ + ('id', pa.int32()), + ('batch_id', pa.int32()), + ('metadata', pa.string()), + ('large_blob', pa.large_binary()), + ]) + + schema = Schema.from_pyarrow_schema( + pa_schema, + options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + 'blob.target-file-size': '10MB' + } + ) + self.catalog.create_table('test_db.large_rolling_with_sample', schema, False) + table = self.catalog.get_table('test_db.large_rolling_with_sample') + + # Create large blob data + large_blob_size = 5 * 1024 # + blob_pattern = b'LARGE_BLOB_PATTERN_' + b'X' * 1024 # ~1KB pattern + pattern_size = len(blob_pattern) + repetitions = large_blob_size // pattern_size + large_blob_data = blob_pattern * repetitions + + actual_size = len(large_blob_data) + print(f"Created blob data: {actual_size:,} bytes ({actual_size / (1024 * 1024):.2f} MB)") + # Write 10000 records of data + num_row = 10000 + expected = pa.Table.from_pydict({ + 'id': [i for i in range(1, num_row + 1)], + 'batch_id': [11] * num_row, + 'metadata': [f'Large blob batch {11}'] * num_row, + 'large_blob': [i.to_bytes(2, byteorder='little') + large_blob_data for i in range(num_row)] + }, schema=pa_schema) + write_builder = table.new_batch_write_builder() + writer = write_builder.new_write() + writer.write_arrow(expected) + + commit_messages = writer.prepare_commit() + commit = write_builder.new_commit() + commit.commit(commit_messages) + writer.close() + + # Read data back + read_builder = table.new_read_builder() + table_scan = read_builder.new_scan().with_sample(1000) + table_read = read_builder.new_read() + actual = table_read.to_arrow(table_scan.plan().splits()) + + # Verify the data + actual_ids = set(actual['id'].to_pylist()) + expected_ids = set(list(range(1, num_row + 1))) + self.assertEqual(1000, actual.num_rows, "Should have 1000 rows") + self.assertTrue(actual_ids.issubset(expected_ids)) + def test_concurrent_blob_writes_with_retry(self): """Test concurrent blob writes to verify retry mechanism works correctly.""" import threading diff --git a/paimon-python/pypaimon/tests/ray_data_test.py b/paimon-python/pypaimon/tests/ray_data_test.py index e931a4c7dcc1..1d962b58a3de 100644 --- a/paimon-python/pypaimon/tests/ray_data_test.py +++ b/paimon-python/pypaimon/tests/ray_data_test.py @@ -230,6 +230,47 @@ def test_ray_data_with_predicate(self): self.assertEqual(set(df['category'].tolist()), {'A'}, "All rows should have category='A'") self.assertEqual(set(df['id'].tolist()), {1, 3}, "Should have IDs 1 and 3") + def test_ray_data_with_sample(self): + """Test Ray Data read with sample filtering.""" + # Create schema + pa_schema = pa.schema([ + ('id', pa.int32()), + ('category', pa.string()), + ('amount', pa.int64()), + ]) + + schema = Schema.from_pyarrow_schema(pa_schema) + self.catalog.create_table('default.test_ray_sample', schema, False) + table = self.catalog.get_table('default.test_ray_sample') + + # Write test data + test_data = pa.Table.from_pydict({ + 'id': [1, 2, 3, 4, 5], + 'category': ['A', 'B', 'A', 'C', 'B'], + 'amount': [100, 200, 150, 300, 250], + }, schema=pa_schema) + + write_builder = table.new_batch_write_builder() + writer = write_builder.new_write() + writer.write_arrow(test_data) + commit_messages = writer.prepare_commit() + commit = write_builder.new_commit() + commit.commit(commit_messages) + writer.close() + + # Read with predicate + read_builder = table.new_read_builder() + table_read = read_builder.new_read() + table_scan = read_builder.new_scan().with_sample(3) + splits = table_scan.plan().splits() + + ray_dataset = table_read.to_ray(splits, override_num_blocks=2) + + # Verify filtered results + df = ray_dataset.to_pandas() + self.assertTrue(set(df['id'].tolist()).issubset(set(test_data['id'].to_pylist()))) + self.assertEqual(3, len(df), "Should have 3 rows after sampling") + def test_ray_data_with_projection(self): """Test Ray Data read with column projection.""" # Create schema @@ -697,5 +738,6 @@ def test_ray_data_invalid_parallelism(self): table_read.to_ray(splits, override_num_blocks=-10) self.assertIn("override_num_blocks must be at least 1", str(context.exception)) + if __name__ == '__main__': unittest.main() diff --git a/paimon-python/pypaimon/tests/rest/rest_simple_test.py b/paimon-python/pypaimon/tests/rest/rest_simple_test.py index 2adce404961d..93fe87d5e578 100644 --- a/paimon-python/pypaimon/tests/rest/rest_simple_test.py +++ b/paimon-python/pypaimon/tests/rest/rest_simple_test.py @@ -660,6 +660,231 @@ def test_with_shard_uniform_division(self): expected = self._read_test_table(read_builder).sort_by('user_id') self.assertEqual(expected, actual) + def test_with_sample(self): + schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']) + self.rest_catalog.create_table('default.test_with_sample', schema, False) + table = self.rest_catalog.get_table('default.test_with_sample') + write_builder = table.new_batch_write_builder() + # first write + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data1 = { + 'user_id': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], + 'item_id': [1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014], + 'behavior': ['a', 'b', 'c', None, 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm'], + 'dt': ['p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1'], + } + pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + # second write + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data2 = { + 'user_id': [5, 6, 7, 8, 18], + 'item_id': [1005, 1006, 1007, 1008, 1018], + 'behavior': ['e', 'f', 'g', 'h', 'z'], + 'dt': ['p2', 'p1', 'p2', 'p2', 'p1'], + } + pa_table = pa.Table.from_pydict(data2, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + read_builder = table.new_read_builder() + table_read = read_builder.new_read() + splits = read_builder.new_scan().with_sample(3).plan().splits() + actual = table_read.to_arrow(splits).sort_by('user_id') + expected_user_ids = set(data1['user_id'] + data2['user_id']) + actual_user_ids = set(actual['user_id'].to_pylist()) + self.assertEqual(3, len(actual)) + self.assertTrue(actual_user_ids.issubset(expected_user_ids), + f"Actual user_ids {actual_user_ids} should be subset of written user_ids {expected_user_ids}") + + splits = read_builder.new_scan().with_sample(0).plan().splits() + actual = table_read.to_arrow(splits).sort_by('user_id') + self.assertEqual(0, len(actual)) + + def test_with_sample_all_data(self): + schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']) + self.rest_catalog.create_table('default.test_with_sample_all_data', schema, False) + table = self.rest_catalog.get_table('default.test_with_sample_all_data') + write_builder = table.new_batch_write_builder() + # first write + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data1 = { + 'user_id': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], + 'item_id': [1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014], + 'behavior': ['a', 'b', 'c', None, 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm'], + 'dt': ['p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1'], + } + pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + # second write + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data2 = { + 'user_id': [5, 6, 7, 8, 18], + 'item_id': [1005, 1006, 1007, 1008, 1018], + 'behavior': ['e', 'f', 'g', 'h', 'z'], + 'dt': ['p2', 'p1', 'p2', 'p2', 'p1'], + } + pa_table = pa.Table.from_pydict(data2, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + read_builder = table.new_read_builder() + table_read = read_builder.new_read() + splits = read_builder.new_scan().with_sample(19).plan().splits() + actual = table_read.to_arrow(splits).sort_by('user_id') + expected_user_ids = set(data1['user_id'] + data2['user_id']) + actual_user_ids = set(actual['user_id'].to_pylist()) + self.assertEqual(19, len(actual)) + self.assertTrue(actual_user_ids.issubset(expected_user_ids), + f"Actual user_ids {actual_user_ids} should be subset of written user_ids {expected_user_ids}") + + def test_with_sample_larger_than_population(self): + """Test that sampling with num_rows larger than total rows raises ValueError""" + schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']) + self.rest_catalog.create_table('default.test_with_sample_larger_than_population', schema, False) + table = self.rest_catalog.get_table('default.test_with_sample_larger_than_population') + write_builder = table.new_batch_write_builder() + + # Write only 5 rows + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data = { + 'user_id': [1, 2, 3, 4, 5], + 'item_id': [1001, 1002, 1003, 1004, 1005], + 'behavior': ['a', 'b', 'c', 'd', 'e'], + 'dt': ['p1', 'p1', 'p2', 'p1', 'p2'], + } + pa_table = pa.Table.from_pydict(data, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + # Try to sample 100 rows from a table with only 5 rows + read_builder = table.new_read_builder() + table_read = read_builder.new_read() + + # This should raise ValueError: Sample larger than population or is negative + with self.assertRaises(ValueError) as context: + splits = read_builder.new_scan().with_sample(100).plan().splits() + table_read.to_arrow(splits) + + self.assertIn("Sample larger than population or is negative", str(context.exception)) + + def test_with_sample_projection(self): + """Test sampling with projection""" + schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']) + self.rest_catalog.create_table('default.test_with_sample_projection', schema, False) + table = self.rest_catalog.get_table('default.test_with_sample_projection') + write_builder = table.new_batch_write_builder() + # first write + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data1 = { + 'user_id': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], + 'item_id': [1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014], + 'behavior': ['a', 'b', 'c', None, 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm'], + 'dt': ['p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1'], + } + pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + # second write + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data2 = { + 'user_id': [5, 6, 7, 8, 18], + 'item_id': [1005, 1006, 1007, 1008, 1018], + 'behavior': ['e', 'f', 'g', 'h', 'z'], + 'dt': ['p2', 'p1', 'p2', 'p2', 'p1'], + } + pa_table = pa.Table.from_pydict(data2, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + read_builder = table.new_read_builder().with_projection(['user_id', 'dt']) + splits = read_builder.new_scan().with_sample(3).plan().splits() + table_read = read_builder.new_read() + actual = table_read.to_arrow(splits).sort_by('user_id') + expected_user_ids = set(data1['user_id'] + data2['user_id']) + actual_user_ids = set(actual['user_id'].to_pylist()) + self.assertEqual(2, len(actual.columns)) + self.assertEqual(3, len(actual)) + self.assertTrue(actual_user_ids.issubset(expected_user_ids), + f"Actual user_ids {actual_user_ids} should be subset of written user_ids {expected_user_ids}") + + def test_with_sample_large_data(self): + """Test sampling with large data volume""" + schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']) + self.rest_catalog.create_table('default.test_with_sample_large', schema, False) + table = self.rest_catalog.get_table('default.test_with_sample_large') + write_builder = table.new_batch_write_builder() + + # Write large volume of data in multiple batches + all_user_ids = [] + batch_size = 1000 + num_batches = 10 + + for batch_idx in range(num_batches): + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + + start_id = batch_idx * batch_size + 1 + end_id = (batch_idx + 1) * batch_size + 1 + user_ids = list(range(start_id, end_id)) + all_user_ids.extend(user_ids) + + data = { + 'user_id': user_ids, + 'item_id': [1000 + uid for uid in user_ids], + 'behavior': [chr(97 + (uid % 26)) for uid in user_ids], # 'a' to 'z' cycling + 'dt': ['p1' if uid % 2 == 0 else 'p2' for uid in user_ids], + } + pa_table = pa.Table.from_pydict(data, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + # Test sampling with different sample sizes + sample_sizes = [10, 50, 100] + for sample_size in sample_sizes: + read_builder = table.new_read_builder() + table_read = read_builder.new_read() + splits = read_builder.new_scan().with_sample(sample_size).plan().splits() + actual = table_read.to_arrow(splits) + + # Verify sample size + self.assertEqual(sample_size, len(actual), + f"Sample size should be {sample_size}, but got {len(actual)}") + # Verify sampled data is from written data + expected_user_ids = set(all_user_ids) + actual_user_ids = set(actual['user_id'].to_pylist()) + self.assertTrue(actual_user_ids.issubset(expected_user_ids), + "Sampled user_ids should be subset of all written user_ids") + + # Verify no duplicate rows in sample + self.assertEqual(len(actual_user_ids), len(actual), + "Sample should not contain duplicate rows") + def test_create_drop_database_table(self): # test create database self.rest_catalog.create_database("db1", False)