diff --git a/graphkir/kir_typing.py b/graphkir/kir_typing.py index 80b02a8..9ed035c 100644 --- a/graphkir/kir_typing.py +++ b/graphkir/kir_typing.py @@ -8,7 +8,7 @@ from .utils import NumpyEncoder, logger from .msa2hisat import Variant from .hisat2 import loadReadsAndVariantsData, removeMultipleMapped, PairRead -from .typing_mulit_allele import AlleleTyping, AlleleTypingExonFirst +from .typing_mulit_allele import AlleleTyping, AlleleTypingExonFirst, isHetrozygous from .typing_em import preprocessHisatReads, hisat2TypingPerGene, printHisatTyping @@ -103,38 +103,32 @@ def __init__( def typingPerGene(self, gene: str, cn: int) -> tuple[list[str], int]: """Select reads belonged to the gene and typing it""" logger.debug(f"[Allele] {gene=} {cn=}") + force_homo = False if isHetrozygous(gene) else None + if not self._exon_first and not self._exon_only: typ = AlleleTyping( self._gene_reads[gene], self._gene_variants[gene], - gene, - cn, + force_homo=force_homo, top_n=self._top_n, variant_correction=self._variant_correction, ) - if not typ.homo: - res = typ.typing(cn) - self._result[gene] = typ.result - alleles = res.selectBest() - # KIR2DL1*BACKBONE -> KIR2DL1 - alleles = [i if i != "fail" else f"{pure_gene}*" for i in alleles] - else: - alleles = [f"{typ.typingHomo()}" for i in range(cn)] else: typ = AlleleTypingExonFirst( self._gene_reads[gene], self._gene_variants[gene], + force_homo=force_homo, top_n=self._top_n, exon_only=self._exon_only, candidate_set_threshold=self._exon_candidate_threshold, ) - res = typ.typing(cn) - self._result[gene] = typ.result - # return res.selectBest(filter_minor=True) - alleles = res.selectBest() - # KIR2DL1*BACKBONE -> KIR2DL1 - pure_gene = gene.split("*")[0] - alleles = [i if i != "fail" else f"{pure_gene}*" for i in alleles] + res = typ.typing(cn) + self._result[gene] = typ.result + # return res.selectBest(filter_minor=True) + alleles = res.selectBest() + # KIR2DL1*BACKBONE -> KIR2DL1 + pure_gene = gene.split("*")[0] + alleles = [i if i != "fail" else f"{pure_gene}*" for i in alleles] return alleles, typ.getReadsNum() def getAllPossibleTyping(self) -> list[dict[Any, Any]]: diff --git a/graphkir/typing_mulit_allele.py b/graphkir/typing_mulit_allele.py index 1239e6d..e938622 100644 --- a/graphkir/typing_mulit_allele.py +++ b/graphkir/typing_mulit_allele.py @@ -3,7 +3,6 @@ """ from __future__ import annotations import io -import sys import copy from typing import Optional, Iterable from itertools import chain @@ -231,8 +230,7 @@ def __init__( self, reads: list[PairRead], variants: list[Variant], - gene: str, - cn: int, + force_homo: bool | None = None, top_n: int = 300, no_empty: bool = True, variant_correction: bool = True, @@ -247,6 +245,7 @@ def __init__( """ self.top_n = top_n self._no_empty = no_empty + self.force_homo: bool | None = force_homo allele_names = self.collectAlleleNames(variants) # create variant_id -> variant @@ -255,17 +254,11 @@ def __init__( self.id_to_allele: dict[int, str] = dict(enumerate(sorted(allele_names))) self.allele_to_id: dict[str, int] = {j: i for i, j in self.id_to_allele.items()} - if "2DL1S1" in gene or "2DL5" in gene: - self.homo = False - elif cn > 1: - self.homo = self.readToHomoHetero(reads, gene, cn) - else: - self.homo = False - if variant_correction: # var_errcorr reads = self.errorCorrection(reads) if self._no_empty: # reserve read position reads = self.removeEmptyReads(reads) + self.reads = reads self.probs = self.reads2AlleleProb(reads) self.log_probs = np.log10(self.probs) @@ -278,53 +271,6 @@ def __init__( def getReadsNum(self) -> int: return len(self.probs) - def typingHomo(self) -> str: - return self.id_to_allele[np.argmax(np.prod(self.probs, axis=0))] - - def readToHomoHetero(self, reads: list[PairRead], gene: str, cn: int) -> bool: - homo = False - v_record = defaultdict(lambda: defaultdict(int)) - hit_score = 0 - - # variants call by read -> dict - # note: pv record the variant on read, nv record the variant not on read - for read in reads: - for i in read.lpv: - v = self.variants[i] - if v.typ != "deletion": - v_record[v.pos][v.val] += 1 - for i in read.rpv: - v = self.variants[i] - if v.typ != "deletion": - v_record[v.pos][v.val] += 1 - for i in read.lnv: - v = self.variants[i] - if v.typ != "deletion": - v_record[v.pos][f"*{v.val}"] += 1 - for i in read.rnv: - v = self.variants[i] - if v.typ != "deletion": - v_record[v.pos][f"*{v.val}"] += 1 - # find heterozygous variant - for val in v_record.values(): - if len(val) > 1: - if all('*' in key for key in val): - continue - counts = sorted(list(val.values()), reverse=True) - counts = [c for c in counts if c > 3] # low coverage variant -> sequencing error - if len(counts) <= 1: - continue - sum_counts = sum(counts) - hetero_percentage = [c/sum_counts for c in counts if c/sum_counts > 0.1] # filter very minor variant - if len(hetero_percentage) == 1: - continue - if sum_counts < 20: # not processing low coverage possition - pass - elif hetero_percentage[1] > (1/(cn*2)): - hit_score += 1 - # break - return hit_score == 0 - @staticmethod def removeEmptyReads(reads: list[PairRead]) -> list[PairRead]: """ @@ -442,17 +388,71 @@ def typing(self, cn: int) -> TypingResult: The top-n allele-set are best fit the reads (with maximum probility). Each set has CN alleles. """ + if cn < 1: + raise ValueError(f"CN should be >= 1, got {cn}") + + # decide homo/hetero if not forced + if self.force_homo is None: + homo = isHomozygous(self.reads, self.variants, cn) + else: + homo = self.force_homo + self.result = [] - for _ in range(cn): + if homo: self.addCandidate() + self.addHomoResultForCn(cn) + else: + for _ in range(cn): + self.addCandidate() # self.result[-1].print() + self.result[-1].print() return self.result[-1] + def addHomoResultForCn(self, cn: int) -> None: + """Add homozygous result for copy number > 1.""" + if cn > 1: + homo_result = self.createHomoResult(self.result[0], cn) + assert homo_result is not None + self.result.append(homo_result) + def mapAlleleIDs(self, list_ids: IdArray) -> list[list[str]]: """id (m x n np array) -> name (m list x n list of str)""" return [[self.id_to_allele[id] for id in ids] for ids in list_ids] + @staticmethod + def createHomoResult(cn1_result: TypingResult, cn: int) -> TypingResult: + """ + Generate homozygous typing result with CN copies from CN=1 result. + Used when sample is homozygous to replicate the same allele n times. + + Parameters: + cn1_result: Typing result with CN=1 (single allele) + cn: Target copy number (> 1) + Returns: + TypingResult with cn copies of the allele + Raises: + ValueError: If cn <= 1 + """ + if cn <= 1: + raise ValueError(f"CN should be > 1, got {cn}") + + # Replicate the single allele cn times + new_allele_id = np.repeat(cn1_result.allele_id, cn, axis=1) + new_allele_name = [[name[0]] * cn for name in cn1_result.allele_name] + + result = TypingResult( + n=cn, + value=cn1_result.value * cn, # Scale likelihood by CN + value_sum_indv=np.repeat(cn1_result.value_sum_indv, cn, axis=1), + allele_id=new_allele_id, + allele_name=new_allele_name, + allele_prob=cn1_result.allele_prob, + fraction=np.ones((len(cn1_result.value), cn)) / cn, + fraction_uniq=np.ones((len(cn1_result.value), cn)) / cn, + ) + return result + @staticmethod def uniqueAllele(data: IdArray) -> BoolArray: """ @@ -632,6 +632,7 @@ def __init__( exon_only: bool = False, candidate_set_threshold: float = 1., variant_correction: bool = True, + force_homo: bool | None = None, ): """Extracting exon alleles""" # extract exon variants @@ -660,7 +661,7 @@ def __init__( # pprint(self.allele_group) # same as before - super().__init__(exon_reads, exon_variants, top_n=top_n) + super().__init__(exon_reads, exon_variants, force_homo=force_homo, top_n=top_n) self.candidate_set_threshold = candidate_set_threshold """ self.first_set_only = candidate_set == "first_score" @@ -676,6 +677,7 @@ def __init__( self.full_model: AlleleTyping | None = AlleleTyping( reads, variants, + force_homo=force_homo, top_n = top_n // 5, # TODO: default = 30 variant_correction=variant_correction, # variant_correction=False, # no_intron_corr @@ -793,3 +795,63 @@ def typing(self, cn: int) -> TypingResult: logger.debug("[Allele] Typing intron + exon") result.print() return result + + +def isHetrozygous(gene: str) -> bool: + """Decide whether the gene is heterozygous by gene name only""" + if "2DL1S1" in gene or "2DL5" in gene: + return True + return False + + +def isHomozygous(reads: list[PairRead], variants_map: dict[str, Variant], cn: int) -> bool: + """ + Determine whether the sample is homozygous for a given gene based on read variants. + + Heuristics: + - If copy number <= 1, treat as heterozygous (return False). + - Otherwise, inspect per-position variant support; if we observe convincing + bi-allelic evidence at any position (second-most allele passes thresholds), + we treat as heterozygous; else homozygous. + """ + if cn <= 1: + return False + + v_record: dict[int, dict[str, int]] = defaultdict(lambda: defaultdict(int)) + hit_score = 0 + + # Tally variant observations per genomic position + for read in reads: + for i in chain.from_iterable([read.lpv, read.rpv]): + v = variants_map[i] + if v.typ != "deletion": + v_record[v.pos][str(v.val)] += 1 + for i in chain.from_iterable([read.lnv, read.rnv]): + v = variants_map[i] + if v.typ != "deletion": + v_record[v.pos][f"*{v.val}"] += 1 + + # Identify heterozygous positions + for val in v_record.values(): + if len(val) <= 1: + continue + # Ignore positions where all are negative evidence + if all("*" in key for key in val): + continue + counts = sorted(list(val.values()), reverse=True) + + # filter low coverage to avoid sequencing error + counts = [c for c in counts if c > 3] + sum_counts = sum(counts) + # too low coverage: skip decision for this site + if sum_counts < 20: + continue + # filter very minor variant + hetero_percentage = [c / sum_counts for c in counts if c / sum_counts > 0.1] + if len(hetero_percentage) == 1: + continue + # heterozygous if runner-up fraction is above 1/(2*cn) + if hetero_percentage[1] > (1 / (cn * 2)): + hit_score += 1 + + return hit_score == 0