diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 28eb579..8f4f5da 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,7 +12,7 @@ jobs: - name: Setup python uses: actions/setup-python@v1 with: - python-version: '3.9.22' + python-version: '3.10.19' architecture: x64 - name: Install dependencies run: pip install -r dev-requirements.txt @@ -22,7 +22,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python: [3.9.22, 3.10.17 ] + python: [3.10.19 ] os: [ubuntu-20.04] name: Test on Python ${{ matrix.python }} steps: diff --git a/.gitignore b/.gitignore index 9f4d7d9..c4caba4 100644 --- a/.gitignore +++ b/.gitignore @@ -65,6 +65,7 @@ instance/ # Mac stuff: .DS_Store +*.swp # Sphinx documentation docs/_build/ @@ -114,4 +115,4 @@ venv.bak/ /.idea/ # Temp files -/scratch/ \ No newline at end of file +/scratch/ diff --git a/dev-requirements.txt b/dev-requirements.txt index e49181d..4ffda64 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,5 +1,5 @@ bumpversion==0.6.0 -coverage==5.2.1 +coverage==7.6.1 flake8==3.8.2 pytest==8.2.2 pytest-cov==5.0.0 @@ -8,5 +8,5 @@ sphinx>=3.3.1 sphinx-autoapi>=1.5.1 sphinx_rtd_theme>=0.5.0 twine>=2.0.0 -wheel==0.38.1 +wheel==0.46.2 yapf==0.30.0 diff --git a/kb_python/bins/linux/kallisto/kallisto b/kb_python/bins/linux/kallisto/kallisto index 82a141e..468279d 100755 Binary files a/kb_python/bins/linux/kallisto/kallisto and b/kb_python/bins/linux/kallisto/kallisto differ diff --git a/kb_python/bins/linux/kallisto/kallisto_k64 b/kb_python/bins/linux/kallisto/kallisto_k64 index f0c85ec..6941fff 100755 Binary files a/kb_python/bins/linux/kallisto/kallisto_k64 and b/kb_python/bins/linux/kallisto/kallisto_k64 differ diff --git a/kb_python/bins/linux/kallisto/kallisto_optoff b/kb_python/bins/linux/kallisto/kallisto_optoff index 01233e6..94da56b 100755 Binary files a/kb_python/bins/linux/kallisto/kallisto_optoff and b/kb_python/bins/linux/kallisto/kallisto_optoff differ diff --git a/kb_python/bins/linux/kallisto/kallisto_optoff_k64 b/kb_python/bins/linux/kallisto/kallisto_optoff_k64 index 05ccebf..6941fff 100755 Binary files a/kb_python/bins/linux/kallisto/kallisto_optoff_k64 and b/kb_python/bins/linux/kallisto/kallisto_optoff_k64 differ diff --git a/kb_python/bins/linux/kallisto/license.txt b/kb_python/bins/linux/kallisto/license.txt old mode 100755 new mode 100644 diff --git a/kb_python/count.py b/kb_python/count.py index 2ab9797..a6d8268 100755 --- a/kb_python/count.py +++ b/kb_python/count.py @@ -658,8 +658,12 @@ def bustools_whitelist( def matrix_to_cellranger( - matrix_path: str, barcodes_path: str, genes_path: str, t2g_path: str, - out_dir: str, gzip: bool = False + matrix_path: str, + barcodes_path: str, + genes_path: str, + t2g_path: str, + out_dir: str, + gzip: bool = False ) -> Dict[str, str]: """Convert bustools count matrix to cellranger-format matrix. @@ -1065,8 +1069,10 @@ def filter_with_bustools( if cellranger: if not tcc: cr_result = matrix_to_cellranger( - count_result['mtx'], count_result['barcodes'], - count_result['genes'], t2g_path, + count_result['mtx'], + count_result['barcodes'], + count_result['genes'], + t2g_path, os.path.join(counts_dir, CELLRANGER_DIR), gzip=gzip ) @@ -1290,7 +1296,7 @@ def count( by_name: Aggregate counts by name instead of ID. cellranger: Whether to convert the final count matrix into a cellranger-compatible matrix, defaults to `False` - gzip: Whether to gzip compress cellranger output matrices, + gzip: Whether to gzip compress cellranger output matrices, defaults to `False` delete_bus: Whether to delete intermediate BUS files after successful count, defaults to `False` @@ -1649,8 +1655,10 @@ def update_results_with_suffix(current_results, new_results, suffix): final_result = quant_result if quant else count_result if cellranger: cr_result = matrix_to_cellranger( - count_result['mtx'], count_result['barcodes'], - count_result['genes'], t2g_path, + count_result['mtx'], + count_result['barcodes'], + count_result['genes'], + t2g_path, os.path.join(counts_dir, f'{CELLRANGER_DIR}{suffix}'), gzip=gzip ) @@ -1760,24 +1768,26 @@ def update_results_with_suffix(current_results, new_results, suffix): if delete_bus: logger.info('Deleting intermediate BUS files to save disk space') bus_files_to_delete = [] - + # Collect all .bus files from results if 'bus' in unfiltered_results: bus_files_to_delete.append(unfiltered_results['bus']) if 'bus_scs' in unfiltered_results: bus_files_to_delete.append(unfiltered_results['bus_scs']) - + # For smartseq3, delete suffix versions too for suffix in ['', INTERNAL_SUFFIX, UMI_SUFFIX]: if f'bus{suffix}' in unfiltered_results: bus_files_to_delete.append(unfiltered_results[f'bus{suffix}']) if f'bus_scs{suffix}' in unfiltered_results: - bus_files_to_delete.append(unfiltered_results[f'bus_scs{suffix}']) - + bus_files_to_delete.append( + unfiltered_results[f'bus_scs{suffix}'] + ) + # Delete filtered bus if exists if 'filtered' in results and 'bus_scs' in results['filtered']: bus_files_to_delete.append(results['filtered']['bus_scs']) - + # Delete each BUS file for bus_file in bus_files_to_delete: if bus_file and os.path.exists(bus_file): @@ -1875,7 +1885,7 @@ def count_nac( by_name: Aggregate counts by name instead of ID. cellranger: Whether to convert the final count matrix into a cellranger-compatible matrix, defaults to `False` - gzip: Whether to gzip compress cellranger output matrices, + gzip: Whether to gzip compress cellranger output matrices, defaults to `False` cellranger_style: Whether to organize output in CellRanger-style directories (spliced/ and unspliced/ subdirectories), defaults to `False` @@ -2181,13 +2191,19 @@ def update_results_with_suffix(current_results, new_results, suffix): elif i == 1: # unprocessed/unspliced cr_dir = os.path.join(counts_dir, 'unspliced') else: # ambiguous - cr_dir = os.path.join(counts_dir, f'{CELLRANGER_DIR}_{prefix}{suffix}') + cr_dir = os.path.join( + counts_dir, f'{CELLRANGER_DIR}_{prefix}{suffix}' + ) else: - cr_dir = os.path.join(counts_dir, f'{CELLRANGER_DIR}_{prefix}{suffix}') - + cr_dir = os.path.join( + counts_dir, f'{CELLRANGER_DIR}_{prefix}{suffix}' + ) + cr_result = matrix_to_cellranger( - count_result[i]['mtx'], count_result[i]['barcodes'], - count_result[i]['genes'], t2g_path, + count_result[i]['mtx'], + count_result[i]['barcodes'], + count_result[i]['genes'], + t2g_path, cr_dir, gzip=gzip ) @@ -2225,7 +2241,10 @@ def update_results_with_suffix(current_results, new_results, suffix): update_results_with_suffix(prefix_results, res, suffix) if cellranger: cr_result = matrix_to_cellranger( - res['mtx'], res['barcodes'], res['genes'], t2g_path, + res['mtx'], + res['barcodes'], + res['genes'], + t2g_path, os.path.join( counts_dir, f'{CELLRANGER_DIR}_{prefix}{suffix}' ), @@ -2352,17 +2371,28 @@ def update_results_with_suffix(current_results, new_results, suffix): if cellranger_style: # Create spliced/unspliced subdirectories for CellRanger style if i == 0: # processed/spliced - cr_dir = os.path.join(filtered_counts_dir, 'spliced') + cr_dir = os.path.join( + filtered_counts_dir, 'spliced' + ) elif i == 1: # unprocessed/unspliced - cr_dir = os.path.join(filtered_counts_dir, 'unspliced') + cr_dir = os.path.join( + filtered_counts_dir, 'unspliced' + ) else: # ambiguous - cr_dir = os.path.join(filtered_counts_dir, f'{CELLRANGER_DIR}_{prefix}') + cr_dir = os.path.join( + filtered_counts_dir, + f'{CELLRANGER_DIR}_{prefix}' + ) else: - cr_dir = os.path.join(filtered_counts_dir, f'{CELLRANGER_DIR}_{prefix}') - + cr_dir = os.path.join( + filtered_counts_dir, f'{CELLRANGER_DIR}_{prefix}' + ) + cr_result = matrix_to_cellranger( - count_result[i]['mtx'], count_result[i]['barcodes'], - count_result[i]['genes'], t2g_path, + count_result[i]['mtx'], + count_result[i]['barcodes'], + count_result[i]['genes'], + t2g_path, cr_dir, gzip=gzip ) @@ -2396,7 +2426,10 @@ def update_results_with_suffix(current_results, new_results, suffix): filtered_results[prefix] = {} if cellranger: cr_result = matrix_to_cellranger( - res['mtx'], res['barcodes'], res['genes'], t2g_path, + res['mtx'], + res['barcodes'], + res['genes'], + t2g_path, os.path.join( filtered_counts_dir, f'{CELLRANGER_DIR}_{prefix}' @@ -2488,19 +2521,21 @@ def update_results_with_suffix(current_results, new_results, suffix): if delete_bus: logger.info('Deleting intermediate BUS files to save disk space') bus_files_to_delete = [] - + # Collect all .bus files from results prefixes = ['processed', 'unprocessed', 'ambiguous'] for prefix in prefixes: if prefix in unfiltered_results: for suffix in ['', INTERNAL_SUFFIX, UMI_SUFFIX]: if f'bus{suffix}' in unfiltered_results[prefix]: - bus_files_to_delete.append(unfiltered_results[prefix][f'bus{suffix}']) - + bus_files_to_delete.append( + unfiltered_results[prefix][f'bus{suffix}'] + ) + # Delete filtered bus files if they exist if 'filtered' in results and 'bus_scs' in results['filtered']: bus_files_to_delete.append(results['filtered']['bus_scs']) - + # Delete each BUS file for bus_file in bus_files_to_delete: if bus_file and os.path.exists(bus_file): diff --git a/kb_python/main.py b/kb_python/main.py index 30d749f..9620034 100755 --- a/kb_python/main.py +++ b/kb_python/main.py @@ -408,7 +408,7 @@ def parse_count( 'Plots for TCC matrices have not yet been implemented. ' 'The HTML report will not contain any plots.' ) - # Note: We are currently not supporting --genomebam + if args.genomebam: parser.error('--genomebam is not currently supported') if args.genomebam and not args.gtf: @@ -591,11 +591,11 @@ def parse_count( parser.error( f'Option `--aa` cannot be used with workflow {args.workflow}.' ) - + # Auto-enable gzip and cellranger-style when --cellranger is used use_gzip = args.cellranger and not args.no_gzip or args.gzip use_cellranger_style = args.cellranger - + from .count import count_nac count_nac( args.i, @@ -1462,7 +1462,10 @@ def setup_count_args( ) parser_count.add_argument( '--gzip', - help='Gzip compress output matrices (matrix.mtx.gz, barcodes.tsv.gz, genes.tsv.gz). Automatically enabled with --cellranger', + help=( + 'Gzip compress output matrices (matrix.mtx.gz, barcodes.tsv.gz, genes.tsv.gz). ' + 'Automatically enabled with --cellranger. ' + ), action='store_true' ) parser_count.add_argument( @@ -1472,7 +1475,9 @@ def setup_count_args( ) parser_count.add_argument( '--delete-bus', - help='Delete intermediate BUS files after successful count to save disk space', + help=( + 'Delete intermediate BUS files after successful count to save disk space' + ), action='store_true' ) parser_count.add_argument( diff --git a/kb_python/ref.py b/kb_python/ref.py index 1a747d6..7004d04 100755 --- a/kb_python/ref.py +++ b/kb_python/ref.py @@ -1,7 +1,7 @@ import glob -import itertools import os import tarfile +from collections import defaultdict from typing import Callable, Dict, List, Optional, Tuple, Union import ngs_tools as ngs @@ -71,6 +71,10 @@ def generate_mismatches(name, sequence): lengths = set() features = {} variants = {} + + # Store all original sequences to check for collisions with variants + original_sequences = set() + # Generate all feature barcode variations before saving to check for collisions. for i, row in df_features.iterrows(): # Check that the first column contains the sequence @@ -83,6 +87,8 @@ def generate_mismatches(name, sequence): lengths.add(len(row.sequence)) features[row['name']] = row.sequence + original_sequences.add(row.sequence) + variants[row['name']] = { name: seq for name, seq in generate_mismatches(row['name'], row.sequence) @@ -103,45 +109,36 @@ def generate_mismatches(name, sequence): ','.join(str(l) for l in lengths) # noqa ) ) - # Find & remove collisions between barcode and variants - for feature in variants.keys(): - _variants = variants[feature] - collisions = set(_variants.values()) & set(features.values()) - if collisions: - # Remove collisions + + # Invert variants: sequence -> list of (feature_name, variant_name) + seq_to_variants = defaultdict(list) + for feature_name, feature_variants in variants.items(): + for variant_name, seq in feature_variants.items(): + seq_to_variants[seq].append((feature_name, variant_name)) + + # Process collisions + for seq, variant_list in seq_to_variants.items(): + # 1. Check collision with original barcodes + if seq in original_sequences: logger.warning( - f'Colision detected between variants of feature barcode {feature} ' - 'and feature barcode(s). These variants will be removed.' + f'Collision detected between variants of feature barcode(s) {",".join(set(v[0] for v in variant_list))}' + f' and original feature barcode {seq}. These variants will be removed.' ) - variants[feature] = { - name: seq - for name, seq in _variants.items() - if seq not in collisions - } - - # Find & remove collisions between variants - for f1, f2 in itertools.combinations(variants.keys(), 2): - v1 = variants[f1] - v2 = variants[f2] - - collisions = set(v1.values()) & set(v2.values()) - if collisions: + for feature_name, variant_name in variant_list: + if variant_name in variants[feature_name]: + del variants[feature_name][variant_name] + continue + + # 2. Check collision between variants of DIFFERENT features + features_involved = set(v[0] for v in variant_list) + if len(features_involved) > 1: logger.warning( - f'Collision(s) detected between variants of feature barcodes {f1} and {f2}: ' - f'{",".join(collisions)}. These variants will be removed.' + f'Collision(s) detected between variants of feature barcodes {",".join(features_involved)}: ' + f'{seq}. These variants will be removed.' ) - - # Remove collisions - variants[f1] = { - name: seq - for name, seq in v1.items() - if seq not in collisions - } - variants[f2] = { - name: seq - for name, seq in v2.items() - if seq not in collisions - } + for feature_name, variant_name in variant_list: + if variant_name in variants[feature_name]: + del variants[feature_name][variant_name] # Write FASTA with ngs.fasta.Fasta(out_path, 'w') as f: diff --git a/kb_python/utils.py b/kb_python/utils.py index 70a644a..b91a29d 100755 --- a/kb_python/utils.py +++ b/kb_python/utils.py @@ -11,6 +11,7 @@ from urllib.request import urlretrieve import anndata +import numpy as np import ngs_tools as ngs import pandas as pd import scipy.io @@ -171,8 +172,8 @@ def reader(pipe, qu, stop_event, name): stderr_reader.start() while p.poll() is None: - while not out_queue.empty(): - name, line = out_queue.get() + try: + name, line = out_queue.get(timeout=0.1) if stream and not quiet: logger.debug(line) out.append(line) @@ -180,8 +181,8 @@ def reader(pipe, qu, stop_event, name): stdout += f'{line}\n' elif name == 'stderr': stderr += f'{line}\n' - else: - time.sleep(0.1) + except queue.Empty: + pass # Stop readers & flush queue stop_event.set() @@ -520,33 +521,31 @@ def collapse_anndata( if not any(adata.var.index.duplicated()): return adata - var_indices = {} - for i, index in enumerate(adata.var.index): - var_indices.setdefault(index, []).append(i) + # Optimized implementation using matrix multiplication + codes, uniques = pd.factorize(adata.var.index) + n_old = len(codes) + n_new = len(uniques) - # Convert all original matrices to csc for fast column operations - X = sparse.csc_matrix(adata.X) - layers = { - layer: sparse.csc_matrix(adata.layers[layer]) - for layer in adata.layers - } - new_index = [] - # lil_matrix is efficient for row-by-row construction - new_X = sparse.lil_matrix((len(var_indices), adata.shape[0])) - new_layers = {layer: new_X.copy() for layer in adata.layers} - for i, (index, indices) in enumerate(var_indices.items()): - new_index.append(index) - new_X[i] = X[:, indices].sum(axis=1).flatten() - for layer in layers.keys(): - new_layers[layer][i] = layers[layer][:, - indices].sum(axis=1).flatten() + row_indices = np.arange(n_old) + col_indices = codes + data = np.ones(n_old) + + # S maps from old columns to new columns (summing duplicates) + S = sparse.coo_matrix((data, (row_indices, col_indices)), + shape=(n_old, n_new)).tocsr() + + X = sparse.csr_matrix(adata.X) + new_X = X @ S + + new_layers = {} + for layer, mat in adata.layers.items(): + new_layers[layer] = sparse.csr_matrix(mat) @ S return anndata.AnnData( - X=new_X.T.tocsr(), - layers={layer: new_layers[layer].T.tocsr() - for layer in new_layers}, + X=new_X, + layers=new_layers, obs=adata.obs.copy(), - var=pd.DataFrame(index=pd.Series(new_index, name=adata.var.index.name)), + var=pd.DataFrame(index=pd.Series(uniques, name=adata.var.index.name)), ) @@ -766,146 +765,146 @@ def do_sum_matrices( ) -> str: """Sums up two matrices given two matrix files. + This implementation uses a 1-pass streaming merge to minimize I/O + and keep memory usage constant (O(1)), allowing it to handle matrices + larger than available RAM. + Args: mtx1_path: First matrix file path mtx2_path: Second matrix file path out_path: Output file path mm: Whether to allow multimapping (i.e. decimals) - header_line: The header line if we have it + header_line: The header line if we have it (Used for recursion) Returns: Output file path """ logger.info('Summing matrices into {}'.format(out_path)) + + if not os.path.exists(mtx1_path) or not os.path.exists(mtx2_path): + raise Exception("Input matrix files do not exist.") + n = 0 - header = [] - with open_as_text(mtx1_path, - 'r') as f1, open_as_text(mtx2_path, - 'r') as f2, open(out_path, - 'w') as out: - eof1 = eof2 = pause1 = pause2 = False - nums = [0, 0, 0] - nums1 = nums2 = to_write = None - if header_line: - out.write("%%MatrixMarket matrix coordinate real general\n%\n") - while not eof1 or not eof2: - s1 = f1.readline() if not eof1 and not pause1 else '%' - s2 = f2.readline() if not eof2 and not pause2 else '%' - if not s1: - pause1 = eof1 = True - if not s2: - pause2 = eof2 = True - _nums1 = _nums2 = [] - if not eof1 and s1[0] != '%': - _nums1 = s1.split() - if not mm: - _nums1[0] = int(_nums1[0]) - _nums1[1] = int(_nums1[1]) - _nums1[2] = int(float(_nums1[2])) - else: - _nums1[0] = int(_nums1[0]) - _nums1[1] = int(_nums1[1]) - _nums1[2] = float(_nums1[2]) - if not eof2 and s2[0] != '%': - _nums2 = s2.split() - if not mm: - _nums2[0] = int(_nums2[0]) - _nums2[1] = int(_nums2[1]) - _nums2[2] = int(float(_nums2[2])) + header = None + # We use a temporary file to store the body while we count n (nnz) + temp_dir = os.path.dirname(out_path) + tmp_body_path = None + + try: + tmp_body_path = get_temporary_filename(temp_dir) + with open_as_text(mtx1_path, 'r') as f1, \ + open_as_text(mtx2_path, 'r') as f2, \ + open(tmp_body_path, 'w') as tmp_body: + + eof1 = eof2 = pause1 = pause2 = False + nums1 = nums2 = to_write = None + + while not eof1 or not eof2: + s1 = f1.readline() if not eof1 and not pause1 else '%' + s2 = f2.readline() if not eof2 and not pause2 else '%' + if not s1: + pause1 = eof1 = True + if not s2: + pause2 = eof2 = True + + _nums1 = _nums2 = [] + if not eof1 and s1[0] != '%': + tokens1 = s1.split() + if not mm: + _nums1 = [ + int(tokens1[0]), + int(tokens1[1]), + int(float(tokens1[2])) + ] + else: + _nums1 = [ + int(tokens1[0]), + int(tokens1[1]), + float(tokens1[2]) + ] + if not eof2 and s2[0] != '%': + tokens2 = s2.split() + if not mm: + _nums2 = [ + int(tokens2[0]), + int(tokens2[1]), + int(float(tokens2[2])) + ] + else: + _nums2 = [ + int(tokens2[0]), + int(tokens2[1]), + float(tokens2[2]) + ] + + if nums1 is not None: + _nums1, nums1 = nums1, None + if nums2 is not None: + _nums2, nums2 = nums2, None + + if eof1 and eof2: + break + elif eof1: + nums, pause2 = _nums2, False + elif eof2: + nums, pause1 = _nums1, False + elif not _nums1 or not _nums2: + # Skip header comments + continue + elif not header: + if (_nums1[0] != _nums2[0] or _nums1[1] != _nums2[1]): + raise Exception( + "Summing up two matrix files failed: Headers incompatible" + ) + header = [_nums1[0], _nums1[1]] + continue + elif (_nums1[0] > _nums2[0] + or (_nums1[0] == _nums2[0] and _nums1[1] > _nums2[1])): + nums, pause1, pause2, nums1, nums2 = _nums2, True, False, _nums1, None + elif (_nums2[0] > _nums1[0] + or (_nums2[0] == _nums1[0] and _nums2[1] > _nums1[1])): + nums, pause2, pause1, nums2, nums1 = _nums1, True, False, _nums2, None + elif _nums1[0] == _nums2[0] and _nums1[1] == _nums2[1]: + nums, pause1, pause2, nums1, nums2 = _nums1, False, False, None, None + nums[2] += _nums2[2] else: - _nums2[0] = int(_nums2[0]) - _nums2[1] = int(_nums2[1]) - _nums2[2] = float(_nums2[2]) - if nums1 is not None: - _nums1 = nums1 - nums1 = None - if nums2 is not None: - _nums2 = nums2 - nums2 = None - if eof1 and eof2: - # Both mtxs are done - break - elif eof1: - # mtx1 is done - nums = _nums2 - pause2 = False - elif eof2: - # mtx2 is done - nums = _nums1 - pause1 = False - elif eof1 and eof2: - # Both mtxs are done - break - # elif (len(_nums1) != len(_nums2)): - # # We have a problem - # raise Exception("Summing up two matrix files failed") - elif not _nums1 or not _nums2: - # We have something other than a matrix line - continue - elif not header: - # We are at the header line and need to read it in - if (_nums1[0] != _nums2[0] or _nums1[1] != _nums2[1]): raise Exception( - "Summing up two matrix files failed: Headers incompatible" + "Summing up two matrix files failed: Assertion failed" ) + + if (to_write and to_write[0] == nums[0] + and to_write[1] == nums[1]): + to_write[2] += nums[2] else: - header = [_nums1[0], _nums1[1]] - if header_line: - out.write(header_line) - continue - elif (_nums1[0] > _nums2[0] - or (_nums1[0] == _nums2[0] and _nums1[1] > _nums2[1])): - # If we're further in mtx1 than mtx2 - nums = _nums2 - pause1 = True - pause2 = False - nums1 = _nums1 - nums2 = None - elif (_nums2[0] > _nums1[0] - or (_nums2[0] == _nums1[0] and _nums2[1] > _nums1[1])): - # If we're further in mtx2 than mtx1 - nums = _nums1 - pause2 = True - pause1 = False - nums2 = _nums2 - nums1 = None - elif _nums1[0] == _nums2[0] and _nums1[1] == _nums2[1]: - # If we're at the same location in mtx1 and mtx2 - nums = _nums1 - nums[2] += _nums2[2] - pause1 = pause2 = False - nums1 = nums2 = None - else: - # Shouldn't happen - raise Exception( - "Summing up two matrix files failed: Assertion failed" - ) - # Write out a line - _nums_prev = to_write - if (_nums_prev and _nums_prev[0] == nums[0] - and _nums_prev[1] == nums[1]): - nums[2] += _nums_prev[2] - pause1 = pause2 = False - to_write = [nums[0], nums[1], nums[2]] - else: - if to_write: - if header_line: - if mm and to_write[2].is_integer(): - to_write[2] = int(to_write[2]) - out.write( - f'{to_write[0]} {to_write[1]} {to_write[2]}\n' - ) - n += 1 - to_write = [nums[0], nums[1], nums[2]] - if to_write: - if header_line: - if mm and to_write[2].is_integer(): - to_write[2] = int(to_write[2]) - out.write(f'{to_write[0]} {to_write[1]} {to_write[2]}\n') - n += 1 - if not header_line: - header_line = f'{header[0]} {header[1]} {n}\n' - do_sum_matrices(mtx1_path, mtx2_path, out_path, mm, header_line) + if to_write: + val = to_write[2] + if not mm: + val = int(val) + tmp_body.write(f'{to_write[0]} {to_write[1]} {val}\n') + n += 1 + to_write = [nums[0], nums[1], nums[2]] + + if to_write: + val = to_write[2] + if not mm: + val = int(val) + tmp_body.write(f'{to_write[0]} {to_write[1]} {val}\n') + n += 1 + + if header is None: + raise Exception( + f"Summing up two matrix files failed: Missing header in {mtx1_path} or {mtx2_path}" + ) + + # Final assembly: Prepend header and copy body + with open(out_path, 'w') as out, open(tmp_body_path, 'r') as body: + out.write("%%MatrixMarket matrix coordinate real general\n%\n") + out.write(f"{header[0]} {header[1]} {n}\n") + shutil.copyfileobj(body, out) + finally: + if tmp_body_path and os.path.exists(tmp_body_path): + os.remove(tmp_body_path) + return out_path diff --git a/kb_python/validate.py b/kb_python/validate.py index ecca731..fe631c8 100755 --- a/kb_python/validate.py +++ b/kb_python/validate.py @@ -52,7 +52,7 @@ def validate_mtx(path: str): ValidateError: If the file failed verification """ try: - scipy.io.mmread(path) + scipy.io.mminfo(path) except ValueError: raise ValidateError(f'{path} is not a valid matrix market file') diff --git a/tests/test_count.py b/tests/test_count.py index d37215e..d211ee3 100755 --- a/tests/test_count.py +++ b/tests/test_count.py @@ -947,6 +947,7 @@ def test_filter_with_bustools_cellranger(self): '{}.genes.txt'.format(counts_prefix), t2g_path, cellranger_dir, + gzip=False ) def test_stream_fastqs_local(self): @@ -1642,7 +1643,7 @@ def test_count_cellranger(self): '{}.mtx'.format(counts_prefix), '{}.barcodes.txt'.format(counts_prefix), '{}.genes.txt'.format(counts_prefix), self.t2g_path, - cellranger_dir + cellranger_dir, gzip=False ) def test_count_filter(self): @@ -1809,16 +1810,17 @@ def test_count_filter(self): out_dir, FILTERED_COUNTS_DIR, COUNTS_PREFIX ), kite=False, + tcc=False, temp_dir=temp_dir, threads=threads, memory=memory, loom=False, + loom_names=['barcode', 'target_name'], h5ad=False, by_name=False, - tcc=False, + gzip=False, umi_gene=True, - em=False, - loom_names=['barcode', 'target_name'], + em=False ) convert_matrix.assert_not_called() @@ -2253,16 +2255,17 @@ def test_count_kite_filter(self): out_dir, FILTERED_COUNTS_DIR, FEATURE_PREFIX ), kite=True, + tcc=False, temp_dir=temp_dir, threads=threads, memory=memory, loom=False, + loom_names=['barcode', 'target_name'], h5ad=False, by_name=False, - tcc=False, + gzip=False, umi_gene=True, - em=False, - loom_names=['barcode', 'target_name'], + em=False ) convert_matrix.assert_not_called() diff --git a/tests/test_utils.py b/tests/test_utils.py index d09f4d2..11f1529 100755 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -340,3 +340,66 @@ def test_create_10x_feature_barcode_map(self): self.assertTrue(os.path.exists(map_path)) with open(map_path, 'r') as f: self.assertIn('\t', f.readline()) + + def test_do_sum_matrices(self): + import scipy.io + import scipy.sparse + + m1 = scipy.sparse.csr_matrix([[1, 2], [3, 4]]) + m2 = scipy.sparse.csr_matrix([[5, 6], [7, 8]]) + + m1_path = os.path.join(self.temp_dir, 'm1.mtx') + m2_path = os.path.join(self.temp_dir, 'm2.mtx') + out_path = os.path.join(self.temp_dir, 'sum.mtx') + + scipy.io.mmwrite(m1_path, m1) + scipy.io.mmwrite(m2_path, m2) + + utils.do_sum_matrices(m1_path, m2_path, out_path) + + m_sum = scipy.io.mmread(out_path) + expected = np.array([[6, 8], [10, 12]]) + + np.testing.assert_array_equal(m_sum.toarray(), expected) + + # Verify robust integer formatting in output + with open(out_path, 'r') as f: + for line in f: + if line.startswith('%'): + continue + parts = line.split() + if len(parts) == 3: + self.assertTrue(parts[0].isdigit()) + self.assertTrue(parts[1].isdigit()) + # Value might be negative, though not in this test + self.assertTrue(parts[2].lstrip('-').isdigit()) + + def test_do_sum_matrices_complex(self): + import scipy.io + import scipy.sparse + + # Test case with: + # - Overlapping coordinates (1,1) + # - Unique coordinates in m1 (1,2) + # - Unique coordinates in m2 (2,1) + # - Sparse structure + m1 = scipy.sparse.coo_matrix(([1, 2], ([0, 0], [0, 1])), shape=(2, 2)) + m2 = scipy.sparse.coo_matrix(([3, 4], ([0, 1], [0, 0])), shape=(2, 2)) + + # m1: [[1, 2], [0, 0]] + # m2: [[3, 0], [4, 0]] + # sum: [[4, 2], [4, 0]] + + m1_path = os.path.join(self.temp_dir, 'm1_complex.mtx') + m2_path = os.path.join(self.temp_dir, 'm2_complex.mtx') + out_path = os.path.join(self.temp_dir, 'sum_complex.mtx') + + scipy.io.mmwrite(m1_path, m1) + scipy.io.mmwrite(m2_path, m2) + + utils.do_sum_matrices(m1_path, m2_path, out_path) + + m_sum = scipy.io.mmread(out_path) + expected = np.array([[4, 2], [4, 0]]) + + np.testing.assert_array_equal(m_sum.toarray(), expected) diff --git a/tests/test_validate.py b/tests/test_validate.py index ed1cf1a..23bcf19 100755 --- a/tests/test_validate.py +++ b/tests/test_validate.py @@ -47,8 +47,8 @@ def test_validate_mtx(self): validate.validate_mtx(self.matrix_path) def test_validate_mtx_raises_on_error(self): - with mock.patch('kb_python.validate.scipy.io.mmread') as mmread: - mmread.side_effect = ValueError('test') + with mock.patch('kb_python.validate.scipy.io.mminfo') as mminfo: + mminfo.side_effect = ValueError('test') with self.assertRaises(validate.ValidateError): validate.validate_mtx('path')