From 08413d3814e0ad40aa8629e5a249a4a11b312ad5 Mon Sep 17 00:00:00 2001 From: Sam Wilson Date: Wed, 13 May 2026 09:28:12 -0700 Subject: [PATCH 1/3] pull request From 1094bc578d93f2e69d87cedca77de21f367a0a41 Mon Sep 17 00:00:00 2001 From: Sam Wilson Date: Wed, 13 May 2026 10:04:18 -0700 Subject: [PATCH 2/3] switch to array, use mmap skip an array access dna attempt --- rounds/1_histogram/solution.py | 33 ++++++----- rounds/3_dna/solution.py | 104 +++++++++++++-------------------- 2 files changed, 61 insertions(+), 76 deletions(-) diff --git a/rounds/1_histogram/solution.py b/rounds/1_histogram/solution.py index 63d7aa7..1709975 100644 --- a/rounds/1_histogram/solution.py +++ b/rounds/1_histogram/solution.py @@ -4,25 +4,32 @@ passes out of the box. Replace the body of ``compute_histogram`` with your own faster implementation. """ +from collections import defaultdict +from mmap import mmap, ACCESS_READ +def b2i(low: int, high: int) -> int: + return high + (low << 8) + +def i2b(x: int) -> bytes: + return bytes([(x & 0xFF00) >> 8, x & 0xFF]) def compute_histogram(path: str) -> dict[bytes, int]: """Frequency of every 2-byte bigram in the file at ``path``.""" # Step 1: read the whole file into memory as a single bytes object. - with open(path, "rb") as f: - data = f.read() + counts = [0 for _ in range(2**16)] - # Create a 2D matrix to count bigrams - counts = [[0] * 256 for _ in range(256)] + source = open(path, "rb", buffering=0) + data = mmap(source.fileno(), 0, access=ACCESS_READ) + # Step 2: slide a 2-byte window across the buffer. For ``b"ABCD"`` the + # iterations produce ``b"AB"``, ``b"BC"``, then ``b"CD"``. For each window, + # bump the matching bucket in a ``dict`` keyed by the bigram itself. + previous = data[0] for i in range(len(data) - 1): - # Increment the count in each cell - counts[data[i]][data[i + 1]] += 1 + current = data[i + 1] + counts[current + (previous << 8)] += 1 + previous = current - # Convert the matrix to the original format - output = {} - for i in range(256): - for j in range(256): - if counts[i][j] > 0: - output[bytes([i, j])] = counts[i][j] - return output + return { + i2b(idx): value for idx, value in enumerate(counts) if value != 0 + } diff --git a/rounds/3_dna/solution.py b/rounds/3_dna/solution.py index 4507c2d..2481444 100644 --- a/rounds/3_dna/solution.py +++ b/rounds/3_dna/solution.py @@ -5,80 +5,58 @@ own faster implementation. """ -from __future__ import annotations +from mmap import mmap, ACCESS_READ +from concurrent.futures import ThreadPoolExecutor, wait -import os -from concurrent.futures import ThreadPoolExecutor +def _subsearch(raw, record_id_start: int, data_start: int, data_end: int, pattern: bytes): + plen = len(pattern) + data = bytes(raw[data_start : data_end - 1]).replace(b"\n", b"") + locations = [] + loc = data.find(pattern) + while loc != -1: + locations.append(loc) + loc = data.find(pattern, loc + plen) -_NL = 0x0A # b"\n" + if not locations: + return None + record_id = raw[record_id_start : data_start - 1].decode("ascii") + return (record_id, locations) def find_matches(fasta_path: str, pattern: bytes) -> list[tuple[str, list[int]]]: - with open(fasta_path, "rb") as f: - data = f.read() + """Find every FASTA record whose sequence contains ``pattern``. - # Step 1: locate every record start. A record starts with ``>`` either at - # offset 0 or immediately after a ``\n``. - starts: list[int] = [] - i = 0 - while True: - p = data.find(b">", i) - if p == -1: - break - if p == 0 or data[p - 1] == _NL: - starts.append(p) - i = p + 1 - starts.append(len(data)) # sentinel marking the end of the last record. + Returns ``[(record_id, [positions...]), ...]`` in file order. + """ + source = open(fasta_path, "rb") + data = mmap(source.fileno(), 0, access=ACCESS_READ) - num_records = len(starts) - 1 - if num_records <= 0: - return [] + last = -1 - # Step 2: parallel scan. Choose enough batches to keep workers balanced - # even when record sizes vary. - n_workers = max(1, os.cpu_count() or 1) - batches = max(1, n_workers * 4) - batch_size = max(1, (num_records + batches - 1) // batches) + data_end = len(data) - 1 + while data[data_end] == b"\n": + data_end -= 1 - def scan_batch(start_idx: int, end_idx: int) -> list[tuple[int, str, list[int]]]: - out: list[tuple[int, str, list[int]]] = [] - for j in range(start_idx, end_idx): - rec_start = starts[j] - rec_end = starts[j + 1] + with ThreadPoolExecutor(max_workers=16) as executor: + records = [] + while data_end > 0: + gt_pos = data.rfind(b">", 0, data_end) + if gt_pos == -1: + raise Exception("expected greater than") - # Locate the end of the header line within this record's slice. - nl = data.find(b"\n", rec_start, rec_end) - if nl <= rec_start: - continue # Malformed or header-only. + record_id_start = gt_pos + 1 - record_id = data[rec_start + 1 : nl].decode("ascii").strip() + nl_pos = data.find(b"\n", record_id_start) + if nl_pos == -1: + raise Exception("expected new line") - # Contiguous sequence: drop the newlines so matches that straddle - # line breaks are still found by ``bytes.find``. - sequence = data[nl + 1 : rec_end].replace(b"\n", b"") + data_start = nl_pos + 1 - positions: list[int] = [] - s = 0 - while True: - p = sequence.find(pattern, s) - if p == -1: - break - positions.append(p) - s = p + 1 + records.append( + executor.submit(_subsearch, data, record_id_start, data_start, data_end, pattern) + ) + data_end = gt_pos - if positions: - out.append((j, record_id, positions)) - return out - - with ThreadPoolExecutor(max_workers=n_workers) as pool: - futures = [ - pool.submit(scan_batch, lo, min(lo + batch_size, num_records)) - for lo in range(0, num_records, batch_size) - ] - chunks = [f.result() for f in futures] - - # Step 3: flatten and restore file order (record index is monotonic per - # batch, but batches finish in arbitrary order). - flat = [item for chunk in chunks for item in chunk] - flat.sort(key=lambda triple: triple[0]) - return [(rid, positions) for _, rid, positions in flat] + results = [d.result() for d in records if d.result() is not None] + results.reverse() + return results From b0dd57f58464e065d74a246239b2640ee8eae85b Mon Sep 17 00:00:00 2001 From: "codspeed-hq[bot]" <117304815+codspeed-hq[bot]@users.noreply.github.com> Date: Wed, 13 May 2026 21:40:20 +0000 Subject: [PATCH 3/3] Fix performance regressions in histogram and DNA solutions --- rounds/1_histogram/solution.py | 39 ++++++-------- rounds/3_dna/solution.py | 96 ++++++++++++++++++++-------------- 2 files changed, 73 insertions(+), 62 deletions(-) diff --git a/rounds/1_histogram/solution.py b/rounds/1_histogram/solution.py index 1709975..5eca05f 100644 --- a/rounds/1_histogram/solution.py +++ b/rounds/1_histogram/solution.py @@ -4,32 +4,23 @@ passes out of the box. Replace the body of ``compute_histogram`` with your own faster implementation. """ -from collections import defaultdict -from mmap import mmap, ACCESS_READ -def b2i(low: int, high: int) -> int: - return high + (low << 8) +import numpy as np -def i2b(x: int) -> bytes: - return bytes([(x & 0xFF00) >> 8, x & 0xFF]) def compute_histogram(path: str) -> dict[bytes, int]: """Frequency of every 2-byte bigram in the file at ``path``.""" - # Step 1: read the whole file into memory as a single bytes object. - counts = [0 for _ in range(2**16)] - - source = open(path, "rb", buffering=0) - data = mmap(source.fileno(), 0, access=ACCESS_READ) - - # Step 2: slide a 2-byte window across the buffer. For ``b"ABCD"`` the - # iterations produce ``b"AB"``, ``b"BC"``, then ``b"CD"``. For each window, - # bump the matching bucket in a ``dict`` keyed by the bigram itself. - previous = data[0] - for i in range(len(data) - 1): - current = data[i + 1] - counts[current + (previous << 8)] += 1 - previous = current - - return { - i2b(idx): value for idx, value in enumerate(counts) if value != 0 - } + with open(path, "rb") as f: + data = f.read() + + arr = np.frombuffer(data, dtype=np.uint8) + + # Vectorised bigram index: first_byte * 256 + second_byte + bigram_indices = arr[:-1].astype(np.uint16) * 256 + arr[1:] + + # Count every bigram in a single pass (C-level loop inside numpy) + counts = np.bincount(bigram_indices, minlength=65536) + + # Build the result dict from non-zero entries only + nonzero = np.flatnonzero(counts) + return {int(idx).to_bytes(2, "big"): int(counts[idx]) for idx in nonzero} diff --git a/rounds/3_dna/solution.py b/rounds/3_dna/solution.py index 2481444..0756188 100644 --- a/rounds/3_dna/solution.py +++ b/rounds/3_dna/solution.py @@ -5,58 +5,78 @@ own faster implementation. """ -from mmap import mmap, ACCESS_READ -from concurrent.futures import ThreadPoolExecutor, wait +from __future__ import annotations -def _subsearch(raw, record_id_start: int, data_start: int, data_end: int, pattern: bytes): - plen = len(pattern) - data = bytes(raw[data_start : data_end - 1]).replace(b"\n", b"") - locations = [] - loc = data.find(pattern) - while loc != -1: - locations.append(loc) - loc = data.find(pattern, loc + plen) +import os +from concurrent.futures import ThreadPoolExecutor - if not locations: - return None +_NL = 0x0A - record_id = raw[record_id_start : data_start - 1].decode("ascii") - return (record_id, locations) def find_matches(fasta_path: str, pattern: bytes) -> list[tuple[str, list[int]]]: """Find every FASTA record whose sequence contains ``pattern``. Returns ``[(record_id, [positions...]), ...]`` in file order. """ - source = open(fasta_path, "rb") - data = mmap(source.fileno(), 0, access=ACCESS_READ) + with open(fasta_path, "rb") as f: + data = f.read() - last = -1 + # Step 1: locate every record start ('>' at offset 0 or after '\n'). + starts: list[int] = [] + i = 0 + while True: + p = data.find(b">", i) + if p == -1: + break + if p == 0 or data[p - 1] == _NL: + starts.append(p) + i = p + 1 + starts.append(len(data)) # sentinel - data_end = len(data) - 1 - while data[data_end] == b"\n": - data_end -= 1 + num_records = len(starts) - 1 + if num_records <= 0: + return [] - with ThreadPoolExecutor(max_workers=16) as executor: - records = [] - while data_end > 0: - gt_pos = data.rfind(b">", 0, data_end) - if gt_pos == -1: - raise Exception("expected greater than") + # Step 2: parallel scan with batched work units. + n_workers = max(1, os.cpu_count() or 1) + batches = max(1, n_workers * 4) + batch_size = max(1, (num_records + batches - 1) // batches) - record_id_start = gt_pos + 1 + def scan_batch(start_idx: int, end_idx: int) -> list[tuple[int, str, list[int]]]: + out: list[tuple[int, str, list[int]]] = [] + for j in range(start_idx, end_idx): + rec_start = starts[j] + rec_end = starts[j + 1] - nl_pos = data.find(b"\n", record_id_start) - if nl_pos == -1: - raise Exception("expected new line") + nl = data.find(b"\n", rec_start, rec_end) + if nl <= rec_start: + continue - data_start = nl_pos + 1 + # Strip newlines so cross-line matches are found. + sequence = data[nl + 1 : rec_end].replace(b"\n", b"") - records.append( - executor.submit(_subsearch, data, record_id_start, data_start, data_end, pattern) - ) - data_end = gt_pos + positions: list[int] = [] + s = 0 + while True: + p = sequence.find(pattern, s) + if p == -1: + break + positions.append(p) + s = p + 1 - results = [d.result() for d in records if d.result() is not None] - results.reverse() - return results + if positions: + record_id = data[rec_start + 1 : nl].decode("ascii").strip() + out.append((j, record_id, positions)) + return out + + with ThreadPoolExecutor(max_workers=n_workers) as pool: + futures = [ + pool.submit(scan_batch, lo, min(lo + batch_size, num_records)) + for lo in range(0, num_records, batch_size) + ] + chunks = [f.result() for f in futures] + + # Step 3: flatten and restore file order. + flat = [item for chunk in chunks for item in chunk] + flat.sort(key=lambda triple: triple[0]) + return [(rid, positions) for _, rid, positions in flat]