Skip to content

Commit 3e5ff70

Browse files
authored
Add files via upload
1 parent e166e11 commit 3e5ff70

17 files changed

Lines changed: 1840 additions & 190 deletions

chronos/backend/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,15 @@
1010
BackendInfo,
1111
available,
1212
select,
13+
training_available,
14+
select_training,
1315
device_str,
16+
training_device_str,
17+
resolve_training_device,
1418
describe,
19+
describe_training,
1520
AUTO_PRIORITY,
21+
TRAINING_AUTO_PRIORITY,
1622
)
1723

1824
# Back-compat shim for the pre-M5 API that code elsewhere still calls.
@@ -26,6 +32,9 @@ def get_backend():
2632

2733
__all__ = [
2834
"BackendDispatcher", "BackendInfo",
29-
"available", "select", "device_str", "describe", "AUTO_PRIORITY",
35+
"available", "select", "training_available", "select_training",
36+
"device_str", "training_device_str", "resolve_training_device",
37+
"describe", "describe_training",
38+
"AUTO_PRIORITY", "TRAINING_AUTO_PRIORITY",
3039
"get_backend", "build_model", "BackendType",
3140
]

chronos/backend/dispatcher.py

Lines changed: 128 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,24 @@
33
44
Unified compute-backend dispatcher.
55
6-
Chronos dispatches across six backend names:
6+
Chronos dispatches across these backend names:
77
88
cpu — PyTorch CPU (always available)
99
cuda — PyTorch CUDA (NVIDIA GPU)
1010
mps — PyTorch Metal Performance Shaders (Apple Silicon via torch)
11-
mlx — Apple MLX (Apple Silicon, unified memory; non-torch)
11+
mlx — Apple MLX (Apple Silicon, non-torch inference path in this repo)
12+
xpu — PyTorch Intel XPU
1213
vulkan — PyTorch Vulkan (only if torch was built with USE_VULKAN=ON)
1314
opencl — third-party extension hook (no upstream backend; plug-in only)
1415
15-
Training support: cpu, cuda, mps, mlx
16-
Inference support (stock torch): cpu, cuda, mps, mlx, (vulkan if built-in)
17-
Inference via ext plugin: opencl (requires chronos.backend.ext.opencl
18-
implementation; stub returns False)
16+
Training support in the current repo: cpu, cuda, mps, xpu
17+
Inference support (stock paths): cpu, cuda, mps, mlx, xpu, (vulkan if built-in)
18+
Inference via ext plugin: opencl
1919
20-
The Vulkan and OpenCL hooks exist so that someone running a custom PyTorch
21-
build (or a custom kernel) can plug in without modifying core Chronos code.
22-
On stock pip-installed PyTorch, they report "not available" and fall back.
20+
Important: although MLX is available as an inference backend, this repository
21+
does not currently implement a full MLX-native training stack comparable to
22+
``chronos.trainer.*``. Training resolvers therefore exclude ``mlx`` until a
23+
real MLX trainer exists.
2324
"""
2425
from __future__ import annotations
2526

@@ -32,12 +33,16 @@
3233

3334
BACKENDS = ("cuda", "mlx", "mps", "xpu", "vulkan", "opencl", "cpu")
3435

35-
# Auto-detect priority. mlx > cuda > mps > xpu > vulkan > opencl > cpu.
36-
# The intent: prefer whichever backend has the most optimized training path
37-
# on the host. On Apple Silicon, mlx has unified memory advantages that
38-
# mps (via torch) cannot match; on NVIDIA, cuda is always best.
36+
# General runtime auto-detect priority. This covers inference and legacy
37+
# backend selection, where MLX is meaningful on Apple Silicon.
3938
AUTO_PRIORITY = ("mlx", "cuda", "xpu", "mps", "vulkan", "opencl", "cpu")
4039

40+
# Training must only pick backends with an actual training implementation in
41+
# this repository. Keep this separate from AUTO_PRIORITY so inference can
42+
# still prefer MLX while training stays honest.
43+
TRAINING_AUTO_PRIORITY = ("cuda", "xpu", "mps", "cpu")
44+
TRAINING_BACKENDS = ("cuda", "xpu", "mps", "cpu")
45+
4146

4247
@dataclass
4348
class BackendInfo:
@@ -101,9 +106,9 @@ def _probe_mlx() -> BackendInfo:
101106
except Exception:
102107
avail = False
103108
return BackendInfo(
104-
name="mlx", available=avail, supports_training=avail,
109+
name="mlx", available=avail, supports_training=False,
105110
supports_amp=avail, torch_device=None,
106-
notes="non-torch backend; Chronos uses chronos.mlx.* paths",
111+
notes="non-torch backend; Chronos uses chronos.mlx.* inference paths",
107112
)
108113

109114

@@ -167,6 +172,13 @@ def available(self) -> List[str]:
167172
"""Returns names in priority order, filtered to those actually usable."""
168173
return [n for n in AUTO_PRIORITY if self.info(n).available]
169174

175+
def training_available(self) -> List[str]:
176+
"""Returns trainable backend names in training-priority order."""
177+
return [
178+
n for n in TRAINING_AUTO_PRIORITY
179+
if self.info(n).available and self.info(n).supports_training
180+
]
181+
170182
def select(self, prefer: Optional[str] = None) -> str:
171183
"""Resolve a concrete backend name.
172184
@@ -189,10 +201,81 @@ def select(self, prefer: Optional[str] = None) -> str:
189201
return n
190202
return "cpu"
191203

204+
def select_training(self, prefer: Optional[str] = None) -> str:
205+
"""Resolve a concrete training backend name.
206+
207+
Resolution order:
208+
1. ``CHRONOS_TRAIN_BACKEND`` env var if set.
209+
2. ``prefer`` argument when available and trainable.
210+
3. First available backend in training priority order.
211+
4. ``cpu``.
212+
213+
``prefer`` may be ``None`` or ``"auto"`` to request automatic
214+
selection. Non-trainable values such as ``mlx`` are ignored here.
215+
"""
216+
env = os.environ.get("CHRONOS_TRAIN_BACKEND")
217+
if env:
218+
env = env.strip().lower()
219+
if env == "auto":
220+
env = ""
221+
elif env in BACKENDS and self.info(env).available and self.info(env).supports_training:
222+
return env
223+
elif env:
224+
print(f"[chronos.backend] CHRONOS_TRAIN_BACKEND={env} not available for training; auto-selecting.")
225+
226+
prefer = (prefer or "").strip().lower()
227+
if prefer == "auto":
228+
prefer = ""
229+
if prefer and prefer in BACKENDS:
230+
info = self.info(prefer)
231+
if info.available and info.supports_training:
232+
return prefer
233+
234+
for n in TRAINING_AUTO_PRIORITY:
235+
info = self.info(n)
236+
if info.available and info.supports_training:
237+
return n
238+
return "cpu"
239+
192240
def device_str(self, name: str) -> Optional[str]:
193241
"""PyTorch device string for a backend, or None for non-torch backends."""
194242
return self.info(name).torch_device
195243

244+
def training_device_str(self, name: Optional[str] = None) -> str:
245+
"""PyTorch device string for a training backend."""
246+
backend = self.select_training(name)
247+
return self.info(backend).torch_device or "cpu"
248+
249+
def resolve_training_device(self, prefer: Optional[str] = None) -> tuple[str, str]:
250+
"""Resolve ``(backend, torch_device)`` for training.
251+
252+
Accepts backend-level requests such as ``auto`` / ``cuda`` / ``mps``
253+
as well as explicit torch device strings like ``cuda:0``.
254+
"""
255+
raw = (prefer or "").strip()
256+
name = raw.lower()
257+
258+
explicit_map = {
259+
"cuda:": "cuda",
260+
"xpu:": "xpu",
261+
}
262+
for prefix, backend in explicit_map.items():
263+
if name.startswith(prefix):
264+
info = self.info(backend)
265+
if info.available and info.supports_training:
266+
return backend, raw
267+
print(f"[chronos.backend] requested training device {raw} is not available; auto-selecting.")
268+
chosen = self.select_training()
269+
return chosen, self.info(chosen).torch_device or "cpu"
270+
271+
if name in {"cpu", "mps", "cuda", "xpu"}:
272+
chosen = self.select_training(name)
273+
if chosen == name:
274+
return chosen, self.info(chosen).torch_device or "cpu"
275+
276+
chosen = self.select_training(name or None)
277+
return chosen, self.info(chosen).torch_device or "cpu"
278+
196279
def supports_training(self, name: str) -> bool:
197280
return self.info(name).supports_training
198281

@@ -213,6 +296,16 @@ def describe(self) -> str:
213296
lines.append(f" {marker} {n:<8} {tr:<10} dev={i.torch_device or '-':<6}{extra}")
214297
return "\n".join(lines)
215298

299+
def describe_training(self) -> str:
300+
"""Human-readable summary for trainable backends only."""
301+
lines = ["Chronos training backends:"]
302+
for n in TRAINING_AUTO_PRIORITY:
303+
i = self.info(n)
304+
marker = "✓" if (i.available and i.supports_training) else "·"
305+
extra = f" — {i.notes}" if i.notes else ""
306+
lines.append(f" {marker} {n:<8} train dev={i.torch_device or '-':<6}{extra}")
307+
return "\n".join(lines)
308+
216309

217310
# Module-level convenience singleton
218311
_default = BackendDispatcher()
@@ -226,9 +319,29 @@ def select(prefer: Optional[str] = None) -> str:
226319
return _default.select(prefer)
227320

228321

322+
def training_available() -> List[str]:
323+
return _default.training_available()
324+
325+
326+
def select_training(prefer: Optional[str] = None) -> str:
327+
return _default.select_training(prefer)
328+
329+
229330
def device_str(name: Optional[str] = None) -> Optional[str]:
230331
return _default.device_str(name or select())
231332

232333

334+
def training_device_str(name: Optional[str] = None) -> str:
335+
return _default.training_device_str(name)
336+
337+
338+
def resolve_training_device(prefer: Optional[str] = None) -> tuple[str, str]:
339+
return _default.resolve_training_device(prefer)
340+
341+
233342
def describe() -> str:
234343
return _default.describe()
344+
345+
346+
def describe_training() -> str:
347+
return _default.describe_training()

0 commit comments

Comments
 (0)