|
31 | 31 |
|
32 | 32 | logger = logging.getLogger(__name__) |
33 | 33 |
|
| 34 | +_ENV_ATP_READONLY_POOL_MAX_SIZE = "ATP_READONLY_POOL_MAX_SIZE" |
| 35 | + |
| 36 | + |
| 37 | +def _read_required_positive_int_env(var_name: str) -> int: |
| 38 | + """Read and validate a required positive integer environment variable.""" |
| 39 | + raw = os.getenv(var_name) |
| 40 | + if raw is None: |
| 41 | + raise RuntimeError( |
| 42 | + f"Missing required environment variable '{var_name}'. " |
| 43 | + "Configure it before calling Query.iter()." |
| 44 | + ) |
| 45 | + try: |
| 46 | + value = int(raw) |
| 47 | + except ValueError as exc: |
| 48 | + raise ValueError( |
| 49 | + f"Environment variable '{var_name}' must be an integer, got: {raw!r}" |
| 50 | + ) from exc |
| 51 | + if value <= 0: |
| 52 | + raise ValueError( |
| 53 | + f"Environment variable '{var_name}' must be > 0, got: {value}" |
| 54 | + ) |
| 55 | + return value |
| 56 | + |
34 | 57 | if TYPE_CHECKING: |
35 | 58 | from .kuzu_session import KuzuSession |
36 | 59 |
|
@@ -599,8 +622,8 @@ def iter(self, page_size: int = 10, prefetch_pages: int = 1) -> Iterator[Union[M |
599 | 622 | ) |
600 | 623 |
|
601 | 624 | # Check if parallel execution is available and beneficial |
602 | | - pool_size = int(os.environ["ATP_READONLY_POOL_MAX_SIZE"]) |
603 | | - parallel_threshold = int(os.environ["ATP_READONLY_POOL_MAX_SIZE"]) |
| 625 | + pool_size = _read_required_positive_int_env(_ENV_ATP_READONLY_POOL_MAX_SIZE) |
| 626 | + parallel_threshold = pool_size |
604 | 627 | use_parallel = pool_size > 1 and parallel_threshold > 0 |
605 | 628 |
|
606 | 629 | def fetch_page(offset: int) -> List[Union[ModelType, Dict[str, Any]]]: |
@@ -629,6 +652,37 @@ def fetch_page(offset: int) -> List[Union[ModelType, Dict[str, Any]]]: |
629 | 652 |
|
630 | 653 | return mapped |
631 | 654 |
|
| 655 | + def fetch_page_with_lookahead(offset: int) -> Tuple[List[Union[ModelType, Dict[str, Any]]], bool]: |
| 656 | + """Fetch a page with one-row lookahead to avoid terminal out-of-range SKIP queries.""" |
| 657 | + q = self.offset(offset).limit(ps + 1) |
| 658 | + t0 = time.perf_counter() |
| 659 | + raw = q._execute() |
| 660 | + t1 = time.perf_counter() |
| 661 | + mapped = q._map_results(raw) |
| 662 | + t2 = time.perf_counter() |
| 663 | + |
| 664 | + has_more = len(mapped) > ps |
| 665 | + page_data = mapped[:ps] if has_more else mapped |
| 666 | + |
| 667 | + if getattr(self._session, "_debug_timing", False) or ((t2 - t0) >= 0.25): |
| 668 | + raw_rows = len(raw) if isinstance(raw, list) else None |
| 669 | + mapped_rows = len(page_data) if isinstance(page_data, list) else None |
| 670 | + logger.info( |
| 671 | + "kuzu.query.page.lookahead rel=%s offset=%d page_size=%d raw_rows=%s mapped_rows=%s has_more=%s exec_seconds=%.6f map_seconds=%.6f total_seconds=%.6f pairs_subset=%s", |
| 672 | + model_name, |
| 673 | + int(offset), |
| 674 | + int(ps), |
| 675 | + raw_rows, |
| 676 | + mapped_rows, |
| 677 | + has_more, |
| 678 | + (t1 - t0), |
| 679 | + (t2 - t1), |
| 680 | + (t2 - t0), |
| 681 | + pairs_subset_meta, |
| 682 | + ) |
| 683 | + |
| 684 | + return page_data, has_more |
| 685 | + |
632 | 686 | def fetch_pages_parallel(offsets: List[int]) -> List[List[Union[ModelType, Dict[str, Any]]]]: |
633 | 687 | """Fetch multiple pages in parallel using Rust rayon via ATP pipeline.""" |
634 | 688 | if not offsets: |
@@ -672,63 +726,81 @@ def fetch_pages_parallel(offsets: List[int]) -> List[List[Union[ModelType, Dict[ |
672 | 726 | mapped_pages.append(mapped) |
673 | 727 | return mapped_pages |
674 | 728 |
|
675 | | - # Fetch first page to determine if more pages exist |
676 | | - offset = 0 |
677 | | - page = fetch_page(offset) |
678 | | - offset += ps |
| 729 | + # If parallel execution is enabled, preserve existing count-bounded parallel strategy. |
| 730 | + if use_parallel: |
| 731 | + offset = 0 |
| 732 | + page = fetch_page(offset) |
| 733 | + offset += ps |
| 734 | + |
| 735 | + # If first page is not full, result set fits in one page. |
| 736 | + if len(page) < ps: |
| 737 | + for item in page: |
| 738 | + yield item |
| 739 | + return |
| 740 | + |
| 741 | + total_rows = self.count_results() |
| 742 | + remaining_rows = max(total_rows - ps, 0) |
679 | 743 |
|
680 | | - # If parallel execution is enabled and first page is full, try parallel fetching |
681 | | - if use_parallel and len(page) == ps: |
682 | 744 | # Yield first page items |
683 | 745 | for item in page: |
684 | 746 | yield item |
685 | | - |
| 747 | + |
| 748 | + if remaining_rows == 0: |
| 749 | + return |
| 750 | + |
686 | 751 | # Parallel batch fetching |
687 | 752 | batch_size = min(pool_size, parallel_threshold) |
688 | | - while True: |
689 | | - # Build batch of offsets |
690 | | - batch_offsets = [offset + i * ps for i in range(batch_size)] |
691 | | - |
| 753 | + while remaining_rows > 0: |
| 754 | + pages_in_batch = min(batch_size, (remaining_rows + ps - 1) // ps) |
| 755 | + batch_offsets = [offset + i * ps for i in range(pages_in_batch)] |
| 756 | + |
692 | 757 | # Fetch batch in parallel |
693 | 758 | batch_pages = fetch_pages_parallel(batch_offsets) |
694 | | - |
695 | | - # Yield results and track if we got a partial page |
696 | | - last_page_full = True |
697 | | - for page_idx, page_data in enumerate(batch_pages): |
| 759 | + |
| 760 | + # Yield results in requested page order |
| 761 | + for page_data in batch_pages: |
698 | 762 | for item in page_data: |
699 | 763 | yield item |
700 | | - if len(page_data) < ps: |
701 | | - last_page_full = False |
702 | | - break |
703 | | - |
704 | | - if not last_page_full: |
705 | | - break |
706 | | - |
707 | | - offset += batch_size * ps |
708 | | - elif pf > 0: |
| 764 | + |
| 765 | + advanced_rows = pages_in_batch * ps |
| 766 | + offset += advanced_rows |
| 767 | + remaining_rows = max(remaining_rows - advanced_rows, 0) |
| 768 | + |
| 769 | + return |
| 770 | + |
| 771 | + # Sequential modes: use +1 lookahead to avoid issuing a terminal out-of-range page. |
| 772 | + offset = 0 |
| 773 | + page, has_more = fetch_page_with_lookahead(offset) |
| 774 | + offset += ps |
| 775 | + |
| 776 | + if pf > 0: |
709 | 777 | # Sequential with prefetch (original behavior) |
710 | 778 | with ThreadPoolExecutor(max_workers=1) as executor: |
711 | | - next_future = executor.submit(fetch_page, offset) if len(page) == ps else None |
| 779 | + next_future = executor.submit(fetch_page_with_lookahead, offset) if has_more else None |
712 | 780 | while True: |
713 | 781 | for item in page: |
714 | 782 | yield item |
715 | | - if len(page) < ps: |
| 783 | + if not has_more: |
716 | 784 | break |
717 | | - next_page = next_future.result() if next_future is not None else fetch_page(offset) |
| 785 | + if next_future is not None: |
| 786 | + next_page, next_has_more = next_future.result() |
| 787 | + else: |
| 788 | + next_page, next_has_more = fetch_page_with_lookahead(offset) |
718 | 789 | offset += ps |
719 | | - if len(next_page) == ps and pf > 0: |
720 | | - next_future = executor.submit(fetch_page, offset) |
| 790 | + if next_has_more and pf > 0: |
| 791 | + next_future = executor.submit(fetch_page_with_lookahead, offset) |
721 | 792 | else: |
722 | 793 | next_future = None |
723 | 794 | page = next_page |
| 795 | + has_more = next_has_more |
724 | 796 | else: |
725 | 797 | # Pure sequential (no prefetch) |
726 | 798 | while True: |
727 | 799 | for item in page: |
728 | 800 | yield item |
729 | | - if len(page) < ps: |
| 801 | + if not has_more: |
730 | 802 | break |
731 | | - page = fetch_page(offset) |
| 803 | + page, has_more = fetch_page_with_lookahead(offset) |
732 | 804 | offset += ps |
733 | 805 |
|
734 | 806 | def all(self, as_iterator: bool = False, page_size: Optional[int] = None, prefetch_pages: int = 1) -> Union[List[ModelType], List[Dict[str, Any]], Iterator[Union[ModelType, Dict[str, Any]]]]: |
@@ -784,7 +856,9 @@ def exists(self) -> bool: |
784 | 856 |
|
785 | 857 | def count_results(self) -> int: |
786 | 858 | """Count the number of results.""" |
787 | | - count_query = self.count() |
| 859 | + # ORDER BY columns are not valid after scalar COUNT aggregation in Kuzu. |
| 860 | + # Keep all filters/joins while stripping ORDER BY for the COUNT query only. |
| 861 | + count_query = self._copy_with_state(order_by=[]).count() |
788 | 862 | result = count_query._execute() |
789 | 863 | if type(result) is not list: |
790 | 864 | logger.error("Count query returned non-list result type: %r", type(result)) |
|
0 commit comments