Skip to content

ValueError: at site "data_target", invalid log_prob shape #416

@nextgenius-ai

Description

@nextgenius-ai

version info:
PyTorch: 2.6.0+cu124
Pyro: 1.9.1
Cell2location: 0.1.4
CUDA available: True
CUDA version: 12.4
GPU name: Tesla V100-PCIE-32GB
Lightning: 2.5.3
Pyro: 1.9.1

ref_adata

AnnData object with n_obs × n_vars = 29851 × 26255
obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'percent.mito', 'sample', 'percent.ribo', 'RNA_snn_res.0.6', 'seurat_clusters', 'condition', 'sample2', 'cell_type', 'cell_type2', 'cell_type3', 'cell_type4'
var: 'vst.mean', 'vst.variance', 'vst.variance.expected', 'vst.variance.standardized', 'vst.variable'
obsm: 'X_pca', 'X_tsne', 'X_umap'

ref_adata.X = ref_adata.X.astype("int32")

from cell2location.utils.filtering import filter_genes
selected = filter_genes(ref_adata, cell_count_cutoff=5, cell_percentage_cutoff2=0.03, nonz_mean_cutoff=1.12)
ref_adata = ref_adata[:, selected].copy()

cell2location.models.RegressionModel.setup_anndata(adata=ref_adata,
                        batch_key='sample',
                        labels_key='cell_type'
                       )

from cell2location.models import RegressionModel
mod = RegressionModel(ref_adata)
mod.view_anndata_setup()

mod.train(
    max_epochs=250,
    batch_size=128,
    accelerator="gpu"
)

# error:
......
File [~/.conda/envs/cell2loc_env/lib/python3.10/site-packages/pyro/infer/enum.py:80](http://localhost:8880/lab/tree/work/Stereo_seq_AGA/~/.conda/envs/cell2loc_env/lib/python3.10/site-packages/pyro/infer/enum.py#line=79), in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
     78 for site in model_trace.nodes.values():
     79     if site["type"] == "sample":
---> 80         check_site_shape(site, max_plate_nesting)
     81 for site in guide_trace.nodes.values():
     82     if site["type"] == "sample":

File [~/.conda/envs/cell2loc_env/lib/python3.10/site-packages/pyro/util.py:437](http://localhost:8880/lab/tree/work/Stereo_seq_AGA/~/.conda/envs/cell2loc_env/lib/python3.10/site-packages/pyro/util.py#line=436), in check_site_shape(site, max_plate_nesting)
    433 for actual_size, expected_size in zip_longest(
    434     reversed(actual_shape), reversed(expected_shape), fillvalue=1
    435 ):
    436     if expected_size != -1 and expected_size != actual_size:
--> 437         raise ValueError(
    438             "\n  ".join(
    439                 [
    440                     'at site "{}", invalid log_prob shape'.format(site["name"]),
    441                     "Expected {}, actual {}".format(expected_shape, actual_shape),
    442                     "Try one of the following fixes:",
    443                     "- enclose the batched tensor in a with pyro.plate(...): context",
    444                     "- .to_event(...) the distribution being sampled",
    445                     "- .permute() data dimensions",
    446                 ]
    447             )
    448         )
    450 # Check parallel dimensions on the left of max_plate_nesting.
    451 enum_dim = site["infer"].get("_enumerate_dim")

ValueError: at site "data_target", invalid log_prob shape
  Expected [128, -1], actual [128, 128, 15125]
  Try one of the following fixes:
  - enclose the batched tensor in a with pyro.plate(...): context
  - .to_event(...) the distribution being sampled
  - .permute() data dimensions

i'm sure i used the count matrix, and I have set ref_adata.X.astype("int32") or ref_adata.X.astype("int") ,the problem is still exist. moreover , change accelerator="cpu", the problem is still exist.

when run the lymph node tutorial ,the error is still:

adata_ref = sc.read(
    f'./data/sc.h5ad',
    backup_url='https://cell2location.cog.sanger.ac.uk/paper/integrated_lymphoid_organ_scrna/RegressionNBV4Torch_57covariates_73260cells_10237genes/sc.h5ad'
)
adata_ref.var['SYMBOL'] = adata_ref.var.index
adata_ref.var.set_index('GeneID-2', drop=True, inplace=True)
del adata_ref.raw
from cell2location.utils.filtering import filter_genes
selected = filter_genes(adata_ref, cell_count_cutoff=5, cell_percentage_cutoff2=0.03, nonz_mean_cutoff=1.12)
adata_ref = adata_ref[:, selected].copy()
cell2location.models.RegressionModel.setup_anndata(adata=adata_ref,
                        # 10X reaction / sample / batch
                        batch_key='Sample',
                        # cell type, covariate used for constructing signatures
                        labels_key='Subset',
                        # multiplicative technical effects (platform, 3' vs 5', donor effect)
                        categorical_covariate_keys=['Method']
                       )
from cell2location.models import RegressionModel
mod = RegressionModel(adata_ref)
mod.view_anndata_setup()
mod.train(max_epochs=250, batch_size=32,accelerator='gpu')

......
File [~/.conda/envs/cell2loc_env/lib/python3.10/site-packages/pyro/infer/trace_elbo.py:57](http://localhost:8880/lab/tree/work/Stereo_seq_AGA/st_AGA3/~/.conda/envs/cell2loc_env/lib/python3.10/site-packages/pyro/infer/trace_elbo.py#line=56), in Trace_ELBO._get_trace(self, model, guide, args, kwargs)
     52 def _get_trace(self, model, guide, args, kwargs):
     53     """
     54     Returns a single trace from the guide, and the model that is run
     55     against it.
     56     """
---> 57     model_trace, guide_trace = get_importance_trace(
     58         "flat", self.max_plate_nesting, model, guide, args, kwargs
     59     )
     60     if is_validation_enabled():
     61         check_if_enumerated(guide_trace)

File [~/.conda/envs/cell2loc_env/lib/python3.10/site-packages/pyro/infer/enum.py:80](http://localhost:8880/lab/tree/work/Stereo_seq_AGA/st_AGA3/~/.conda/envs/cell2loc_env/lib/python3.10/site-packages/pyro/infer/enum.py#line=79), in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
     78 for site in model_trace.nodes.values():
     79     if site["type"] == "sample":
---> 80         check_site_shape(site, max_plate_nesting)
     81 for site in guide_trace.nodes.values():
     82     if site["type"] == "sample":

File [~/.conda/envs/cell2loc_env/lib/python3.10/site-packages/pyro/util.py:437](http://localhost:8880/lab/tree/work/Stereo_seq_AGA/st_AGA3/~/.conda/envs/cell2loc_env/lib/python3.10/site-packages/pyro/util.py#line=436), in check_site_shape(site, max_plate_nesting)
    433 for actual_size, expected_size in zip_longest(
    434     reversed(actual_shape), reversed(expected_shape), fillvalue=1
    435 ):
    436     if expected_size != -1 and expected_size != actual_size:
--> 437         raise ValueError(
    438             "\n  ".join(
    439                 [
    440                     'at site "{}", invalid log_prob shape'.format(site["name"]),
    441                     "Expected {}, actual {}".format(expected_shape, actual_shape),
    442                     "Try one of the following fixes:",
    443                     "- enclose the batched tensor in a with pyro.plate(...): context",
    444                     "- .to_event(...) the distribution being sampled",
    445                     "- .permute() data dimensions",
    446                 ]
    447             )
    448         )
    450 # Check parallel dimensions on the left of max_plate_nesting.
    451 enum_dim = site["infer"].get("_enumerate_dim")

ValueError: at site "data_target", invalid log_prob shape
  Expected [32, -1], actual [32, 32, 10237]
  Try one of the following fixes:
  - enclose the batched tensor in a with pyro.plate(...): context
  - .to_event(...) the distribution being sampled
  - .permute() data dimensions

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions