Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 40 additions & 49 deletions bin/handle_spike_ins.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import json
import os
from collections import defaultdict
import mappy as mp
import gzip


def parse_spike_in_refs(ref_file, ref_dir):
Expand Down Expand Up @@ -50,6 +50,35 @@ def expand_spike_in_input(list_spike_ins, spike_in_dict):
return spike_taxids, spike_refs


def parse_idxstats(idxstats_file):
counts = defaultdict(int)
total_unmapped = 0
with open(idxstats_file, "r") as f:
for line in f:
parts = line.strip().split("\t")
name, _, mapped, unmapped = parts
mapped = int(mapped)
unmapped = int(unmapped)
if name == "*":
total_unmapped = unmapped
continue
counts[name] = mapped
counts["total"] = sum(counts.values()) + total_unmapped
return counts


def read_reference_headers(spike_refs):
ref_headers = {}
for ref_path in spike_refs:
headers = []
with gzip.open(ref_path, "rt") as f:
for line in f:
if line.startswith(">"):
headers.append(line[1:].strip().split()[0])
ref_headers[ref_path] = headers
return ref_headers


def parse_depth(name):
parse_name = name.split(" ")
depth = 0
Expand Down Expand Up @@ -162,43 +191,16 @@ def parse_report_file(report_file, spike_in, save_json):
return spike_entries


def map_to_refs(query, reference, counts, preset):
a = mp.Aligner(reference, best_n=1, preset=preset) # load or build index
if not a:
raise Exception(f"ERROR: failed to load/build index for {reference}")

read_count = 0
for name, seq, qual in mp.fastx_read(query): # read a fasta/q sequence
read_count += 1
for hit in a.map(seq): # traverse alignments
counts[hit.ctg] += 1
# print("{}\t{}\t{}\t{}\t{}".format(name, hit.ctg, hit.r_st, hit.r_en, hit.cigar_str))
break
# if read_count % 1000000 == 0:
# break
counts["total"] = read_count
return a.seq_names


def identify_spike_map_counts(query, spike_refs, preset):
map_counts = defaultdict(int)
map_ids = defaultdict(list)
for reference in spike_refs:
map_ids[reference] = map_to_refs(query, reference, map_counts, preset)
return map_counts, map_ids


def combine_report_and_map_counts(
list_spike_ins, spike_in_dict, report_entries, map_counts, map_ids
list_spike_ins, spike_in_dict, report_entries, map_counts, ref_headers
):
spike_summary = defaultdict(lambda: {})
for spike in list_spike_ins:
spike_dict = defaultdict(lambda: {})
if spike in spike_in_dict:
spike_name = spike

if spike_in_dict[spike]["ref"]:
for long_name in map_ids[spike_in_dict[spike]["ref"]]:
for long_name in ref_headers[spike_in_dict[spike]["ref"]]:
name, taxid, taxon_name = long_name.split("|")
taxon_name = taxon_name.replace("_", " ")

Expand Down Expand Up @@ -226,7 +228,7 @@ def combine_report_and_map_counts(
spike_dict[spike].update(entry)
elif spike.endswith("f*a") or spike.endswith("f*a.gz"):
spike_name = spike.split("/")[-1].split(".")[0]
for long_name in map_ids[spike]:
for long_name in ref_headers[spike]:
name, taxid, taxon_name = long_name.split("|")
taxon_name = taxon_name.replace("_", " ")

Expand Down Expand Up @@ -325,10 +327,10 @@ def main():
help="Kraken or Bracken file of taxon relationships and quantities",
)
parser.add_argument(
"-i",
dest="fastq_file",
"--idxstats",
dest="idxstats_file",
required=True,
help="Read file",
help="samtools idxstats output generated during spike removal",
)
parser.add_argument(
"--spike_ins",
Expand All @@ -355,22 +357,12 @@ def main():
required=False,
help="Save the kraken report in JSON format",
)
parser.add_argument(
"--illumina",
action="store_true",
required=False,
help="Use the short read minimap preset",
)

args = parser.parse_args()
spike_ins = []
for spike in args.spike_ins:
spike_ins.extend(spike.split(","))

preset = None
if args.illumina:
preset = "sr"

# Start Program
now = datetime.now()
time = now.strftime("%m/%d/%Y, %H:%M:%S")
Expand All @@ -383,13 +375,12 @@ def main():
args.report_file, spike_taxids, args.save_json
)

if len(spike_taxids) > 0 or len(spike_refs) > 0:
map_counts, map_ids = identify_spike_map_counts(
args.fastq_file, spike_refs, preset
)
map_counts = parse_idxstats(args.idxstats_file)
ref_headers = read_reference_headers(spike_refs)

if spike_ins:
spike_summary = combine_report_and_map_counts(
spike_ins, spike_in_dict, spike_kraken_entries, map_counts, map_ids
spike_ins, spike_in_dict, spike_kraken_entries, map_counts, ref_headers
)
check_spike_summary(spike_summary)

Expand Down
57 changes: 57 additions & 0 deletions bin/prep_spike_refs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/usr/bin/env python
import argparse
import gzip
import json
from pathlib import Path


def load_spike_in_dict(ref_file):
with open(ref_file) as handle:
return json.load(handle)


def concatenate_references(ref_paths, output_file):
with open(output_file, "w") as f:
for ref_path in ref_paths:
with gzip.open(ref_path, "rt") as gz_file:
for line in gz_file:
f.write(line)


def main():
parser = argparse.ArgumentParser(
description="Combine spike-in reference sequences"
)
parser.add_argument(
"--spike_ins",
required=True,
help="Comma-separated list of spike-in names",
)
parser.add_argument(
"--spike_in_dict",
required=True,
help="JSON file mapping spike-in names to reference files",
)
parser.add_argument(
"--spike_in_ref_dir",
required=True,
help="Directory containing spike-in reference files",
)
parser.add_argument(
"-o",
"--output",
default="combined_spikes.fa",
help="Output FASTA file",
)

args = parser.parse_args()

spike_names = [name.strip() for name in args.spike_ins.split(",") if name.strip()]
spike_map = load_spike_in_dict(args.spike_in_dict)
base_dir = Path(args.spike_in_ref_dir)
ref_paths = [str(base_dir / spike_map[name]["ref"]) for name in spike_names]
concatenate_references(ref_paths, args.output)


if __name__ == "__main__":
main()
19 changes: 7 additions & 12 deletions modules/check_spike_status.nf
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ process check_spike_ins {
publishDir "${params.outdir}/${unique_id}/classifications", mode: "copy", overwrite: true, pattern: "*.json"

input:
tuple val(unique_id), val(database_name), path(kreport), path(reads)
tuple val(unique_id), val(database_name), path(kreport), path(reads), path(spike_mapping_stats)
val spike_ins
path spike_in_dict
path spike_in_ref_dir
Expand All @@ -23,36 +23,31 @@ process check_spike_ins {
tuple val(unique_id), path("${kreport.baseName}*.json"), emit: kreport

script:
preset = ""
if (params.read_type == "illumina") {
preset = "--illumina"
}
else if (params.paired) {
preset = "--illumina"
}
"""
handle_spike_ins.py \
-r ${kreport} \
-i ${reads} \
--spike_ins ${spike_ins} \
--spike_in_dict ${spike_in_dict} \
--spike_in_ref_dir ${spike_in_ref_dir} \
--save_json ${preset}
--idxstats ${spike_mapping_stats} \
--save_json
"""
}

workflow check_spike_status {
take:
kreport_ch
fastq_ch

spike_mapping_stats_ch

main:
spike_ins = "${params.spike_ins}"
println(spike_ins)
spike_in_dict = file("${params.spike_in_dict}", type: "file", checkIfExists: true)
spike_in_ref_dir = file("${params.spike_in_ref_dir}", type: "dir", checkIfExists: true)

kreport_ch.join(fastq_ch).set { input_ch }
kreport_ch.join(fastq_ch).join(spike_mapping_stats_ch) .set { input_ch }

check_spike_ins(input_ch, spike_ins, spike_in_dict, spike_in_ref_dir)

empty_file = file("${baseDir}/resources/empty_file")
Expand Down
Loading