diff --git a/rounds/3_dna/solution.py b/rounds/3_dna/solution.py index 4507c2d..bc5bfbf 100644 --- a/rounds/3_dna/solution.py +++ b/rounds/3_dna/solution.py @@ -5,80 +5,62 @@ own faster implementation. """ -from __future__ import annotations - import os -from concurrent.futures import ThreadPoolExecutor +from mmap import mmap, ACCESS_READ, MADV_RANDOM, MADV_WILLNEED +from concurrent.futures import ThreadPoolExecutor, wait + +def _subsearch(raw, pattern, record_id_start: int, data_start: int, data_end: int): + plen = len(pattern) + data = bytes(raw[data_start : data_end - 1]).replace(b"\n", b"") + locations = [] + loc = data.find(pattern) + while loc != -1: + locations.append(loc) + loc = data.find(pattern, loc + plen) -_NL = 0x0A # b"\n" + if not locations: + return None + record_id = raw[record_id_start : data_start - 1].decode("ascii") + return (record_id, locations) def find_matches(fasta_path: str, pattern: bytes) -> list[tuple[str, list[int]]]: - with open(fasta_path, "rb") as f: - data = f.read() - - # Step 1: locate every record start. A record starts with ``>`` either at - # offset 0 or immediately after a ``\n``. - starts: list[int] = [] - i = 0 - while True: - p = data.find(b">", i) - if p == -1: - break - if p == 0 or data[p - 1] == _NL: - starts.append(p) - i = p + 1 - starts.append(len(data)) # sentinel marking the end of the last record. - - num_records = len(starts) - 1 - if num_records <= 0: - return [] - - # Step 2: parallel scan. Choose enough batches to keep workers balanced - # even when record sizes vary. - n_workers = max(1, os.cpu_count() or 1) - batches = max(1, n_workers * 4) - batch_size = max(1, (num_records + batches - 1) // batches) - - def scan_batch(start_idx: int, end_idx: int) -> list[tuple[int, str, list[int]]]: - out: list[tuple[int, str, list[int]]] = [] - for j in range(start_idx, end_idx): - rec_start = starts[j] - rec_end = starts[j + 1] - - # Locate the end of the header line within this record's slice. - nl = data.find(b"\n", rec_start, rec_end) - if nl <= rec_start: - continue # Malformed or header-only. - - record_id = data[rec_start + 1 : nl].decode("ascii").strip() - - # Contiguous sequence: drop the newlines so matches that straddle - # line breaks are still found by ``bytes.find``. - sequence = data[nl + 1 : rec_end].replace(b"\n", b"") - - positions: list[int] = [] - s = 0 - while True: - p = sequence.find(pattern, s) - if p == -1: - break - positions.append(p) - s = p + 1 - - if positions: - out.append((j, record_id, positions)) - return out - - with ThreadPoolExecutor(max_workers=n_workers) as pool: - futures = [ - pool.submit(scan_batch, lo, min(lo + batch_size, num_records)) - for lo in range(0, num_records, batch_size) - ] - chunks = [f.result() for f in futures] - - # Step 3: flatten and restore file order (record index is monotonic per - # batch, but batches finish in arbitrary order). - flat = [item for chunk in chunks for item in chunk] - flat.sort(key=lambda triple: triple[0]) - return [(rid, positions) for _, rid, positions in flat] + """Find every FASTA record whose sequence contains ``pattern``. + + Returns ``[(record_id, [positions...]), ...]`` in file order. + """ + source = open(fasta_path, "rb") + data = mmap(source.fileno(), 0, access=ACCESS_READ) + data.madvise(MADV_RANDOM | MADV_WILLNEED) + + last = -1 + + data_end = len(data) - 1 + while data[data_end] == b"\n": + data_end -= 1 + + n_workers = os.cpu_count() or 1 + + with ThreadPoolExecutor(max_workers=n_workers) as executor: + records = [] + while data_end > 0: + gt_pos = data.rfind(b">", 0, data_end) + if gt_pos == -1: + raise Exception("expected greater than") + + record_id_start = gt_pos + 1 + + nl_pos = data.find(b"\n", record_id_start) + if nl_pos == -1: + raise Exception("expected new line") + + data_start = nl_pos + 1 + + records.append( + executor.submit(_subsearch, data, pattern, record_id_start, data_start, data_end) + ) + data_end = gt_pos + + results = [d.result() for d in records] + results.reverse() + return [r for r in results if r is not None]