diff --git a/deepcell_types/predict.py b/deepcell_types/predict.py index b01bd51..dfa48e5 100644 --- a/deepcell_types/predict.py +++ b/deepcell_types/predict.py @@ -33,7 +33,7 @@ def get_result(self): return cell_type_str_pred, top_probs, cell_index -def predict(raw, mask, channel_names, mpp, model_name, device_num, batch_size=256, num_workers=24, tissue_exclude=None): +def predict(raw, mask, channel_names, mpp, model_name, device_num, batch_size=256, num_workers=24, tissue=None): device = torch.device(device_num) embedding_model_name = "deepseek-r1-70b-llama-distill-q4_K_M" @@ -95,8 +95,8 @@ def predict(raw, mask, channel_names, mpp, model_name, device_num, batch_size=25 with torch.no_grad(): for sample, ch_idx, attn_mask, cell_index in tqdm(data_loader, desc=f"(inference)"): ct_exclude = None - if tissue_exclude: - ct_exclude = [[i for i in range(len(ct_embeddings)) if i not in [dct_config.ct2idx[i] for i in tct[tissue_exclude]]] for _ in range(len(sample))] + if tissue: + ct_exclude = [[i for i in range(len(ct_embeddings)) if i not in [dct_config.ct2idx[i] for i in tct[tissue]]] for _ in range(len(sample))] _, _, _, _, probs, _ = model( sample.to(device), ch_idx.to(device),