Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 11 additions & 13 deletions rounds/1_histogram/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,22 @@
own faster implementation.
"""

import numpy as np


def compute_histogram(path: str) -> dict[bytes, int]:
"""Frequency of every 2-byte bigram in the file at ``path``."""
# Step 1: read the whole file into memory as a single bytes object.
with open(path, "rb") as f:
data = f.read()

# Create a 2D matrix to count bigrams
counts = [[0] * 256 for _ in range(256)]
arr = np.frombuffer(data, dtype=np.uint8)

# Vectorised bigram index: first_byte * 256 + second_byte
bigram_indices = arr[:-1].astype(np.uint16) * 256 + arr[1:]

for i in range(len(data) - 1):
# Increment the count in each cell
counts[data[i]][data[i + 1]] += 1
# Count every bigram in a single pass (C-level loop inside numpy)
counts = np.bincount(bigram_indices, minlength=65536)

# Convert the matrix to the original format
output = {}
for i in range(256):
for j in range(256):
if counts[i][j] > 0:
output[bytes([i, j])] = counts[i][j]
return output
# Build the result dict from non-zero entries only
nonzero = np.flatnonzero(counts)
return {int(idx).to_bytes(2, "big"): int(counts[idx]) for idx in nonzero}
26 changes: 12 additions & 14 deletions rounds/3_dna/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,18 @@
import os
from concurrent.futures import ThreadPoolExecutor

_NL = 0x0A # b"\n"
_NL = 0x0A


def find_matches(fasta_path: str, pattern: bytes) -> list[tuple[str, list[int]]]:
"""Find every FASTA record whose sequence contains ``pattern``.

Returns ``[(record_id, [positions...]), ...]`` in file order.
"""
with open(fasta_path, "rb") as f:
data = f.read()

# Step 1: locate every record start. A record starts with ``>`` either at
# offset 0 or immediately after a ``\n``.
# Step 1: locate every record start ('>' at offset 0 or after '\n').
starts: list[int] = []
i = 0
while True:
Expand All @@ -28,14 +31,13 @@ def find_matches(fasta_path: str, pattern: bytes) -> list[tuple[str, list[int]]]
if p == 0 or data[p - 1] == _NL:
starts.append(p)
i = p + 1
starts.append(len(data)) # sentinel marking the end of the last record.
starts.append(len(data)) # sentinel

num_records = len(starts) - 1
if num_records <= 0:
return []

# Step 2: parallel scan. Choose enough batches to keep workers balanced
# even when record sizes vary.
# Step 2: parallel scan with batched work units.
n_workers = max(1, os.cpu_count() or 1)
batches = max(1, n_workers * 4)
batch_size = max(1, (num_records + batches - 1) // batches)
Expand All @@ -46,15 +48,11 @@ def scan_batch(start_idx: int, end_idx: int) -> list[tuple[int, str, list[int]]]
rec_start = starts[j]
rec_end = starts[j + 1]

# Locate the end of the header line within this record's slice.
nl = data.find(b"\n", rec_start, rec_end)
if nl <= rec_start:
continue # Malformed or header-only.

record_id = data[rec_start + 1 : nl].decode("ascii").strip()
continue

# Contiguous sequence: drop the newlines so matches that straddle
# line breaks are still found by ``bytes.find``.
# Strip newlines so cross-line matches are found.
sequence = data[nl + 1 : rec_end].replace(b"\n", b"")

positions: list[int] = []
Expand All @@ -67,6 +65,7 @@ def scan_batch(start_idx: int, end_idx: int) -> list[tuple[int, str, list[int]]]
s = p + 1

if positions:
record_id = data[rec_start + 1 : nl].decode("ascii").strip()
out.append((j, record_id, positions))
return out

Expand All @@ -77,8 +76,7 @@ def scan_batch(start_idx: int, end_idx: int) -> list[tuple[int, str, list[int]]]
]
chunks = [f.result() for f in futures]

# Step 3: flatten and restore file order (record index is monotonic per
# batch, but batches finish in arbitrary order).
# Step 3: flatten and restore file order.
flat = [item for chunk in chunks for item in chunk]
flat.sort(key=lambda triple: triple[0])
return [(rid, positions) for _, rid, positions in flat]
Loading