Skip to content

Add MPS (Apple Silicon) auto-detection and fallback in trainer #43

@SahilKumar75

Description

@SahilKumar75

Context

Issues #26 and #34 have tracked MPS support. This issue focuses specifically on the trainer-side auto-detection as a self-contained, beginner-friendly task.

Problem

trainer/finetune.py currently determines the compute device like this (paraphrased):

device = "cuda" if torch.cuda.is_available() else "cpu"

On Apple Silicon Macs, torch.backends.mps.is_available() returns True, but the trainer never checks for it, so training always falls back to CPU — which is 10–30× slower.

Fix

Update the device detection logic to:

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

And pass the device through to TrainingArguments / Trainer appropriately.

Caveats

  • bitsandbytes quantization (QLoRA) does not support MPS — fall back to plain LoRA automatically when MPS is detected and the user selected QLoRA, and surface a warning in the UI.
  • device_map="auto" from accelerate does not always handle MPS correctly; set device_map={"":"mps"} explicitly.

Acceptance criteria

  • trainer/finetune.py detects MPS and sets the device correctly
  • QLoRA → LoRA fallback with warning when on MPS
  • Unit test added for the device detection helper
  • Manual test: training completes on an M1/M2/M3 Mac in ~1 epoch

Metadata

Metadata

Assignees

No one assigned

    Labels

    area:fine-tuningLoRA, QLoRA, training configuration, and tuning workflowsenhancementNew feature or requestgood first issueGood for newcomers

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions