From 3eec255b95df7eebbfa5fb389b29fb62d1f4d60d Mon Sep 17 00:00:00 2001 From: Mr-Neutr0n <64578610+Mr-Neutr0n@users.noreply.github.com> Date: Thu, 12 Feb 2026 00:00:13 +0530 Subject: [PATCH] fix: replace hardcoded .cuda() with configurable device --- basic_demo/cli_demo_sat.py | 17 +++++++++++++---- basic_demo/web_demo.py | 16 +++++++++++++--- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/basic_demo/cli_demo_sat.py b/basic_demo/cli_demo_sat.py index 5f7b1e12..88362e56 100644 --- a/basic_demo/cli_demo_sat.py +++ b/basic_demo/cli_demo_sat.py @@ -12,6 +12,14 @@ from utils.utils import chat, llama2_tokenizer, llama2_text_processor_inference, get_image_processor from utils.models import CogAgentModel, CogVLMModel + +def _get_device(): + if torch.cuda.is_available(): + return 'cuda' + if torch.backends.mps.is_available(): + return 'mps' + return 'cpu' + def main(): parser = argparse.ArgumentParser() parser.add_argument("--max_length", type=int, default=2048, help='max length of the total sequence') @@ -43,8 +51,8 @@ def main(): model_parallel_size=world_size, mode='inference', skip_init=True, - use_gpu_initialization=True if (torch.cuda.is_available() and args.quant is None) else False, - device='cpu' if args.quant else 'cuda', + use_gpu_initialization=True if (torch.cuda.is_available() and args.quant is None) else False, # Note: MPS init handled via device param + device='cpu' if args.quant else _get_device(), **vars(args) ), overwrite_args={'model_parallel_size': world_size} if world_size != 1 else {}) model = model.eval() @@ -59,8 +67,9 @@ def main(): if args.quant: quantize(model, args.quant) - if torch.cuda.is_available(): - model = model.cuda() + device = _get_device() + if device != 'cpu': + model = model.to(device) model.add_mixin('auto-regressive', CachedAutoregressiveMixin()) diff --git a/basic_demo/web_demo.py b/basic_demo/web_demo.py index 7426171d..a12bc3d4 100644 --- a/basic_demo/web_demo.py +++ b/basic_demo/web_demo.py @@ -30,6 +30,15 @@ + +def _get_device(): + if torch.cuda.is_available(): + return 'cuda' + if torch.backends.mps.is_available(): + return 'mps' + return 'cpu' + + DESCRIPTION = '''