diff --git a/rounds/1_histogram/solution.py b/rounds/1_histogram/solution.py index 63d7aa7..5eca05f 100644 --- a/rounds/1_histogram/solution.py +++ b/rounds/1_histogram/solution.py @@ -5,24 +5,22 @@ own faster implementation. """ +import numpy as np + def compute_histogram(path: str) -> dict[bytes, int]: """Frequency of every 2-byte bigram in the file at ``path``.""" - # Step 1: read the whole file into memory as a single bytes object. with open(path, "rb") as f: data = f.read() - # Create a 2D matrix to count bigrams - counts = [[0] * 256 for _ in range(256)] + arr = np.frombuffer(data, dtype=np.uint8) + + # Vectorised bigram index: first_byte * 256 + second_byte + bigram_indices = arr[:-1].astype(np.uint16) * 256 + arr[1:] - for i in range(len(data) - 1): - # Increment the count in each cell - counts[data[i]][data[i + 1]] += 1 + # Count every bigram in a single pass (C-level loop inside numpy) + counts = np.bincount(bigram_indices, minlength=65536) - # Convert the matrix to the original format - output = {} - for i in range(256): - for j in range(256): - if counts[i][j] > 0: - output[bytes([i, j])] = counts[i][j] - return output + # Build the result dict from non-zero entries only + nonzero = np.flatnonzero(counts) + return {int(idx).to_bytes(2, "big"): int(counts[idx]) for idx in nonzero} diff --git a/rounds/3_dna/solution.py b/rounds/3_dna/solution.py index 4507c2d..0756188 100644 --- a/rounds/3_dna/solution.py +++ b/rounds/3_dna/solution.py @@ -10,15 +10,18 @@ import os from concurrent.futures import ThreadPoolExecutor -_NL = 0x0A # b"\n" +_NL = 0x0A def find_matches(fasta_path: str, pattern: bytes) -> list[tuple[str, list[int]]]: + """Find every FASTA record whose sequence contains ``pattern``. + + Returns ``[(record_id, [positions...]), ...]`` in file order. + """ with open(fasta_path, "rb") as f: data = f.read() - # Step 1: locate every record start. A record starts with ``>`` either at - # offset 0 or immediately after a ``\n``. + # Step 1: locate every record start ('>' at offset 0 or after '\n'). starts: list[int] = [] i = 0 while True: @@ -28,14 +31,13 @@ def find_matches(fasta_path: str, pattern: bytes) -> list[tuple[str, list[int]]] if p == 0 or data[p - 1] == _NL: starts.append(p) i = p + 1 - starts.append(len(data)) # sentinel marking the end of the last record. + starts.append(len(data)) # sentinel num_records = len(starts) - 1 if num_records <= 0: return [] - # Step 2: parallel scan. Choose enough batches to keep workers balanced - # even when record sizes vary. + # Step 2: parallel scan with batched work units. n_workers = max(1, os.cpu_count() or 1) batches = max(1, n_workers * 4) batch_size = max(1, (num_records + batches - 1) // batches) @@ -46,15 +48,11 @@ def scan_batch(start_idx: int, end_idx: int) -> list[tuple[int, str, list[int]]] rec_start = starts[j] rec_end = starts[j + 1] - # Locate the end of the header line within this record's slice. nl = data.find(b"\n", rec_start, rec_end) if nl <= rec_start: - continue # Malformed or header-only. - - record_id = data[rec_start + 1 : nl].decode("ascii").strip() + continue - # Contiguous sequence: drop the newlines so matches that straddle - # line breaks are still found by ``bytes.find``. + # Strip newlines so cross-line matches are found. sequence = data[nl + 1 : rec_end].replace(b"\n", b"") positions: list[int] = [] @@ -67,6 +65,7 @@ def scan_batch(start_idx: int, end_idx: int) -> list[tuple[int, str, list[int]]] s = p + 1 if positions: + record_id = data[rec_start + 1 : nl].decode("ascii").strip() out.append((j, record_id, positions)) return out @@ -77,8 +76,7 @@ def scan_batch(start_idx: int, end_idx: int) -> list[tuple[int, str, list[int]]] ] chunks = [f.result() for f in futures] - # Step 3: flatten and restore file order (record index is monotonic per - # batch, but batches finish in arbitrary order). + # Step 3: flatten and restore file order. flat = [item for chunk in chunks for item in chunk] flat.sort(key=lambda triple: triple[0]) return [(rid, positions) for _, rid, positions in flat]