Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 12 additions & 18 deletions graphkir/kir_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .utils import NumpyEncoder, logger
from .msa2hisat import Variant
from .hisat2 import loadReadsAndVariantsData, removeMultipleMapped, PairRead
from .typing_mulit_allele import AlleleTyping, AlleleTypingExonFirst
from .typing_mulit_allele import AlleleTyping, AlleleTypingExonFirst, isHetrozygous
from .typing_em import preprocessHisatReads, hisat2TypingPerGene, printHisatTyping


Expand Down Expand Up @@ -103,38 +103,32 @@ def __init__(
def typingPerGene(self, gene: str, cn: int) -> tuple[list[str], int]:
"""Select reads belonged to the gene and typing it"""
logger.debug(f"[Allele] {gene=} {cn=}")
force_homo = False if isHetrozygous(gene) else None

if not self._exon_first and not self._exon_only:
typ = AlleleTyping(
self._gene_reads[gene],
self._gene_variants[gene],
gene,
cn,
force_homo=force_homo,
top_n=self._top_n,
variant_correction=self._variant_correction,
)
if not typ.homo:
res = typ.typing(cn)
self._result[gene] = typ.result
alleles = res.selectBest()
# KIR2DL1*BACKBONE -> KIR2DL1
alleles = [i if i != "fail" else f"{pure_gene}*" for i in alleles]
else:
alleles = [f"{typ.typingHomo()}" for i in range(cn)]
else:
typ = AlleleTypingExonFirst(
self._gene_reads[gene],
self._gene_variants[gene],
force_homo=force_homo,
top_n=self._top_n,
exon_only=self._exon_only,
candidate_set_threshold=self._exon_candidate_threshold,
)
res = typ.typing(cn)
self._result[gene] = typ.result
# return res.selectBest(filter_minor=True)
alleles = res.selectBest()
# KIR2DL1*BACKBONE -> KIR2DL1
pure_gene = gene.split("*")[0]
alleles = [i if i != "fail" else f"{pure_gene}*" for i in alleles]
res = typ.typing(cn)
self._result[gene] = typ.result
# return res.selectBest(filter_minor=True)
alleles = res.selectBest()
# KIR2DL1*BACKBONE -> KIR2DL1
pure_gene = gene.split("*")[0]
alleles = [i if i != "fail" else f"{pure_gene}*" for i in alleles]
return alleles, typ.getReadsNum()

def getAllPossibleTyping(self) -> list[dict[Any, Any]]:
Expand Down
180 changes: 121 additions & 59 deletions graphkir/typing_mulit_allele.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""
from __future__ import annotations
import io
import sys
import copy
from typing import Optional, Iterable
from itertools import chain
Expand Down Expand Up @@ -231,8 +230,7 @@ def __init__(
self,
reads: list[PairRead],
variants: list[Variant],
gene: str,
cn: int,
force_homo: bool | None = None,
top_n: int = 300,
no_empty: bool = True,
variant_correction: bool = True,
Expand All @@ -247,6 +245,7 @@ def __init__(
"""
self.top_n = top_n
self._no_empty = no_empty
self.force_homo: bool | None = force_homo
allele_names = self.collectAlleleNames(variants)

# create variant_id -> variant
Expand All @@ -255,17 +254,11 @@ def __init__(
self.id_to_allele: dict[int, str] = dict(enumerate(sorted(allele_names)))
self.allele_to_id: dict[str, int] = {j: i for i, j in self.id_to_allele.items()}

if "2DL1S1" in gene or "2DL5" in gene:
self.homo = False
elif cn > 1:
self.homo = self.readToHomoHetero(reads, gene, cn)
else:
self.homo = False

if variant_correction: # var_errcorr
reads = self.errorCorrection(reads)
if self._no_empty: # reserve read position
reads = self.removeEmptyReads(reads)
self.reads = reads
self.probs = self.reads2AlleleProb(reads)
self.log_probs = np.log10(self.probs)

Expand All @@ -278,53 +271,6 @@ def __init__(
def getReadsNum(self) -> int:
return len(self.probs)

def typingHomo(self) -> str:
return self.id_to_allele[np.argmax(np.prod(self.probs, axis=0))]

def readToHomoHetero(self, reads: list[PairRead], gene: str, cn: int) -> bool:
homo = False
v_record = defaultdict(lambda: defaultdict(int))
hit_score = 0

# variants call by read -> dict
# note: pv record the variant on read, nv record the variant not on read
for read in reads:
for i in read.lpv:
v = self.variants[i]
if v.typ != "deletion":
v_record[v.pos][v.val] += 1
for i in read.rpv:
v = self.variants[i]
if v.typ != "deletion":
v_record[v.pos][v.val] += 1
for i in read.lnv:
v = self.variants[i]
if v.typ != "deletion":
v_record[v.pos][f"*{v.val}"] += 1
for i in read.rnv:
v = self.variants[i]
if v.typ != "deletion":
v_record[v.pos][f"*{v.val}"] += 1
# find heterozygous variant
for val in v_record.values():
if len(val) > 1:
if all('*' in key for key in val):
continue
counts = sorted(list(val.values()), reverse=True)
counts = [c for c in counts if c > 3] # low coverage variant -> sequencing error
if len(counts) <= 1:
continue
sum_counts = sum(counts)
hetero_percentage = [c/sum_counts for c in counts if c/sum_counts > 0.1] # filter very minor variant
if len(hetero_percentage) == 1:
continue
if sum_counts < 20: # not processing low coverage possition
pass
elif hetero_percentage[1] > (1/(cn*2)):
hit_score += 1
# break
return hit_score == 0

@staticmethod
def removeEmptyReads(reads: list[PairRead]) -> list[PairRead]:
"""
Expand Down Expand Up @@ -442,17 +388,71 @@ def typing(self, cn: int) -> TypingResult:
The top-n allele-set are best fit the reads (with maximum probility).
Each set has CN alleles.
"""
if cn < 1:
raise ValueError(f"CN should be >= 1, got {cn}")

# decide homo/hetero if not forced
if self.force_homo is None:
homo = isHomozygous(self.reads, self.variants, cn)
else:
homo = self.force_homo

self.result = []
for _ in range(cn):
if homo:
self.addCandidate()
self.addHomoResultForCn(cn)
else:
for _ in range(cn):
self.addCandidate()
# self.result[-1].print()

self.result[-1].print()
return self.result[-1]

def addHomoResultForCn(self, cn: int) -> None:
"""Add homozygous result for copy number > 1."""
if cn > 1:
homo_result = self.createHomoResult(self.result[0], cn)
assert homo_result is not None
self.result.append(homo_result)

def mapAlleleIDs(self, list_ids: IdArray) -> list[list[str]]:
"""id (m x n np array) -> name (m list x n list of str)"""
return [[self.id_to_allele[id] for id in ids] for ids in list_ids]

@staticmethod
def createHomoResult(cn1_result: TypingResult, cn: int) -> TypingResult:
"""
Generate homozygous typing result with CN copies from CN=1 result.
Used when sample is homozygous to replicate the same allele n times.

Parameters:
cn1_result: Typing result with CN=1 (single allele)
cn: Target copy number (> 1)
Returns:
TypingResult with cn copies of the allele
Raises:
ValueError: If cn <= 1
"""
if cn <= 1:
raise ValueError(f"CN should be > 1, got {cn}")

# Replicate the single allele cn times
new_allele_id = np.repeat(cn1_result.allele_id, cn, axis=1)
new_allele_name = [[name[0]] * cn for name in cn1_result.allele_name]

result = TypingResult(
n=cn,
value=cn1_result.value * cn, # Scale likelihood by CN
value_sum_indv=np.repeat(cn1_result.value_sum_indv, cn, axis=1),
allele_id=new_allele_id,
allele_name=new_allele_name,
allele_prob=cn1_result.allele_prob,
fraction=np.ones((len(cn1_result.value), cn)) / cn,
fraction_uniq=np.ones((len(cn1_result.value), cn)) / cn,
)
return result

@staticmethod
def uniqueAllele(data: IdArray) -> BoolArray:
"""
Expand Down Expand Up @@ -632,6 +632,7 @@ def __init__(
exon_only: bool = False,
candidate_set_threshold: float = 1.,
variant_correction: bool = True,
force_homo: bool | None = None,
):
"""Extracting exon alleles"""
# extract exon variants
Expand Down Expand Up @@ -660,7 +661,7 @@ def __init__(
# pprint(self.allele_group)

# same as before
super().__init__(exon_reads, exon_variants, top_n=top_n)
super().__init__(exon_reads, exon_variants, force_homo=force_homo, top_n=top_n)
self.candidate_set_threshold = candidate_set_threshold
"""
self.first_set_only = candidate_set == "first_score"
Expand All @@ -676,6 +677,7 @@ def __init__(
self.full_model: AlleleTyping | None = AlleleTyping(
reads,
variants,
force_homo=force_homo,
top_n = top_n // 5, # TODO: default = 30
variant_correction=variant_correction,
# variant_correction=False, # no_intron_corr
Expand Down Expand Up @@ -793,3 +795,63 @@ def typing(self, cn: int) -> TypingResult:
logger.debug("[Allele] Typing intron + exon")
result.print()
return result


def isHetrozygous(gene: str) -> bool:
"""Decide whether the gene is heterozygous by gene name only"""
if "2DL1S1" in gene or "2DL5" in gene:
return True
return False


def isHomozygous(reads: list[PairRead], variants_map: dict[str, Variant], cn: int) -> bool:
"""
Determine whether the sample is homozygous for a given gene based on read variants.

Heuristics:
- If copy number <= 1, treat as heterozygous (return False).
- Otherwise, inspect per-position variant support; if we observe convincing
bi-allelic evidence at any position (second-most allele passes thresholds),
we treat as heterozygous; else homozygous.
"""
if cn <= 1:
return False

v_record: dict[int, dict[str, int]] = defaultdict(lambda: defaultdict(int))
hit_score = 0

# Tally variant observations per genomic position
for read in reads:
for i in chain.from_iterable([read.lpv, read.rpv]):
v = variants_map[i]
if v.typ != "deletion":
v_record[v.pos][str(v.val)] += 1
for i in chain.from_iterable([read.lnv, read.rnv]):
v = variants_map[i]
if v.typ != "deletion":
v_record[v.pos][f"*{v.val}"] += 1

# Identify heterozygous positions
for val in v_record.values():
if len(val) <= 1:
continue
# Ignore positions where all are negative evidence
if all("*" in key for key in val):
continue
counts = sorted(list(val.values()), reverse=True)

# filter low coverage to avoid sequencing error
counts = [c for c in counts if c > 3]
sum_counts = sum(counts)
# too low coverage: skip decision for this site
if sum_counts < 20:
continue
# filter very minor variant
hetero_percentage = [c / sum_counts for c in counts if c / sum_counts > 0.1]
if len(hetero_percentage) == 1:
continue
# heterozygous if runner-up fraction is above 1/(2*cn)
if hetero_percentage[1] > (1 / (cn * 2)):
hit_score += 1

return hit_score == 0