Skip to content

Include MPS as accelerator for Apple ARM machines #14

@agdiaz

Description

@agdiaz

Dear code maintainers,
Would it be possible to include mps as device for running parrot as fast as using CUDA device on Macbook laptops?

Here is a snippet where mps can be included:

# Device configuration
if forceCPU:
    device = 'cpu'
elif gpu_id:
    device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else 'cpu')
    print(f"You've specified to run this network on cuda:{gpu_id}. Running on {device=}")
else:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Available on https://github.com/idptools/parrot/blob/6e09567afdc3a59d0c03f0802cf4d2fe9c973feb/scripts/parrot-train#L135C1-L142C5

This code could help to integrate MPS:

has_mps = torch.backends.mps.is_built()
device = "mps" if has_mps else "cuda" if torch.cuda.is_available() else "cpu"

Thanks in advance

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions