Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 54 additions & 72 deletions rounds/3_dna/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,80 +5,62 @@
own faster implementation.
"""

from __future__ import annotations

import os
from concurrent.futures import ThreadPoolExecutor
from mmap import mmap, ACCESS_READ, MADV_RANDOM, MADV_WILLNEED
from concurrent.futures import ThreadPoolExecutor, wait

def _subsearch(raw, pattern, record_id_start: int, data_start: int, data_end: int):
plen = len(pattern)
data = bytes(raw[data_start : data_end - 1]).replace(b"\n", b"")
locations = []
loc = data.find(pattern)
while loc != -1:
locations.append(loc)
loc = data.find(pattern, loc + plen)

_NL = 0x0A # b"\n"
if not locations:
return None

record_id = raw[record_id_start : data_start - 1].decode("ascii")
return (record_id, locations)

def find_matches(fasta_path: str, pattern: bytes) -> list[tuple[str, list[int]]]:
with open(fasta_path, "rb") as f:
data = f.read()

# Step 1: locate every record start. A record starts with ``>`` either at
# offset 0 or immediately after a ``\n``.
starts: list[int] = []
i = 0
while True:
p = data.find(b">", i)
if p == -1:
break
if p == 0 or data[p - 1] == _NL:
starts.append(p)
i = p + 1
starts.append(len(data)) # sentinel marking the end of the last record.

num_records = len(starts) - 1
if num_records <= 0:
return []

# Step 2: parallel scan. Choose enough batches to keep workers balanced
# even when record sizes vary.
n_workers = max(1, os.cpu_count() or 1)
batches = max(1, n_workers * 4)
batch_size = max(1, (num_records + batches - 1) // batches)

def scan_batch(start_idx: int, end_idx: int) -> list[tuple[int, str, list[int]]]:
out: list[tuple[int, str, list[int]]] = []
for j in range(start_idx, end_idx):
rec_start = starts[j]
rec_end = starts[j + 1]

# Locate the end of the header line within this record's slice.
nl = data.find(b"\n", rec_start, rec_end)
if nl <= rec_start:
continue # Malformed or header-only.

record_id = data[rec_start + 1 : nl].decode("ascii").strip()

# Contiguous sequence: drop the newlines so matches that straddle
# line breaks are still found by ``bytes.find``.
sequence = data[nl + 1 : rec_end].replace(b"\n", b"")

positions: list[int] = []
s = 0
while True:
p = sequence.find(pattern, s)
if p == -1:
break
positions.append(p)
s = p + 1

if positions:
out.append((j, record_id, positions))
return out

with ThreadPoolExecutor(max_workers=n_workers) as pool:
futures = [
pool.submit(scan_batch, lo, min(lo + batch_size, num_records))
for lo in range(0, num_records, batch_size)
]
chunks = [f.result() for f in futures]

# Step 3: flatten and restore file order (record index is monotonic per
# batch, but batches finish in arbitrary order).
flat = [item for chunk in chunks for item in chunk]
flat.sort(key=lambda triple: triple[0])
return [(rid, positions) for _, rid, positions in flat]
"""Find every FASTA record whose sequence contains ``pattern``.

Returns ``[(record_id, [positions...]), ...]`` in file order.
"""
source = open(fasta_path, "rb")
data = mmap(source.fileno(), 0, access=ACCESS_READ)
data.madvise(MADV_RANDOM | MADV_WILLNEED)

last = -1

data_end = len(data) - 1
while data[data_end] == b"\n":
data_end -= 1

n_workers = os.cpu_count() or 1

with ThreadPoolExecutor(max_workers=n_workers) as executor:
records = []
while data_end > 0:
gt_pos = data.rfind(b">", 0, data_end)
if gt_pos == -1:
raise Exception("expected greater than")

record_id_start = gt_pos + 1

nl_pos = data.find(b"\n", record_id_start)
if nl_pos == -1:
raise Exception("expected new line")

data_start = nl_pos + 1

records.append(
executor.submit(_subsearch, data, pattern, record_id_start, data_start, data_end)
)
data_end = gt_pos

results = [d.result() for d in records]
results.reverse()
return [r for r in results if r is not None]
Loading