Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions basic_demo/cli_demo_sat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@
from utils.utils import chat, llama2_tokenizer, llama2_text_processor_inference, get_image_processor
from utils.models import CogAgentModel, CogVLMModel


def _get_device():
if torch.cuda.is_available():
return 'cuda'
if torch.backends.mps.is_available():
return 'mps'
return 'cpu'

def main():
parser = argparse.ArgumentParser()
parser.add_argument("--max_length", type=int, default=2048, help='max length of the total sequence')
Expand Down Expand Up @@ -43,8 +51,8 @@ def main():
model_parallel_size=world_size,
mode='inference',
skip_init=True,
use_gpu_initialization=True if (torch.cuda.is_available() and args.quant is None) else False,
device='cpu' if args.quant else 'cuda',
use_gpu_initialization=True if (torch.cuda.is_available() and args.quant is None) else False, # Note: MPS init handled via device param
device='cpu' if args.quant else _get_device(),
**vars(args)
), overwrite_args={'model_parallel_size': world_size} if world_size != 1 else {})
model = model.eval()
Expand All @@ -59,8 +67,9 @@ def main():

if args.quant:
quantize(model, args.quant)
if torch.cuda.is_available():
model = model.cuda()
device = _get_device()
if device != 'cpu':
model = model.to(device)


model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
Expand Down
16 changes: 13 additions & 3 deletions basic_demo/web_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@




def _get_device():
if torch.cuda.is_available():
return 'cuda'
if torch.backends.mps.is_available():
return 'mps'
return 'cpu'


DESCRIPTION = '''<h1 style='text-align: center'> <a href="https://github.com/THUDM/CogVLM">CogVLM / CogAgent</a> </h1>'''

NOTES = '<h3> This app is adapted from <a href="https://github.com/THUDM/CogVLM">https://github.com/THUDM/CogVLM</a>. It would be recommended to check out the repo if you want to see the detail of our model, CogVLM & CogAgent. </h3>'
Expand Down Expand Up @@ -75,7 +84,7 @@ def load_model(args):
bf16=args.bf16,
skip_init=True,
use_gpu_initialization=True if (torch.cuda.is_available() and args.quant is None) else False,
device='cpu' if args.quant else 'cuda'),
device='cpu' if args.quant else _get_device()),
overwrite_args={'model_parallel_size': world_size} if world_size != 1 else {}
)
model = model.eval()
Expand All @@ -88,8 +97,9 @@ def load_model(args):

if args.quant:
quantize(model, args.quant)
if torch.cuda.is_available():
model = model.cuda()
device = _get_device()
if device != 'cpu':
model = model.to(device)
model.add_mixin('auto-regressive', CachedAutoregressiveMixin())

text_processor_infer = llama2_text_processor_inference(tokenizer, args.max_length, model.image_length)
Expand Down