33
44Unified compute-backend dispatcher.
55
6- Chronos dispatches across six backend names:
6+ Chronos dispatches across these backend names:
77
88 cpu — PyTorch CPU (always available)
99 cuda — PyTorch CUDA (NVIDIA GPU)
1010 mps — PyTorch Metal Performance Shaders (Apple Silicon via torch)
11- mlx — Apple MLX (Apple Silicon, unified memory; non-torch)
11+ mlx — Apple MLX (Apple Silicon, non-torch inference path in this repo)
12+ xpu — PyTorch Intel XPU
1213 vulkan — PyTorch Vulkan (only if torch was built with USE_VULKAN=ON)
1314 opencl — third-party extension hook (no upstream backend; plug-in only)
1415
15- Training support: cpu, cuda, mps, mlx
16- Inference support (stock torch): cpu, cuda, mps, mlx, (vulkan if built-in)
17- Inference via ext plugin: opencl (requires chronos.backend.ext.opencl
18- implementation; stub returns False)
16+ Training support in the current repo: cpu, cuda, mps, xpu
17+ Inference support (stock paths): cpu, cuda, mps, mlx, xpu, (vulkan if built-in)
18+ Inference via ext plugin: opencl
1919
20- The Vulkan and OpenCL hooks exist so that someone running a custom PyTorch
21- build (or a custom kernel) can plug in without modifying core Chronos code.
22- On stock pip-installed PyTorch, they report "not available" and fall back.
20+ Important: although MLX is available as an inference backend, this repository
21+ does not currently implement a full MLX-native training stack comparable to
22+ ``chronos.trainer.*``. Training resolvers therefore exclude ``mlx`` until a
23+ real MLX trainer exists.
2324"""
2425from __future__ import annotations
2526
3233
3334BACKENDS = ("cuda" , "mlx" , "mps" , "xpu" , "vulkan" , "opencl" , "cpu" )
3435
35- # Auto-detect priority. mlx > cuda > mps > xpu > vulkan > opencl > cpu.
36- # The intent: prefer whichever backend has the most optimized training path
37- # on the host. On Apple Silicon, mlx has unified memory advantages that
38- # mps (via torch) cannot match; on NVIDIA, cuda is always best.
36+ # General runtime auto-detect priority. This covers inference and legacy
37+ # backend selection, where MLX is meaningful on Apple Silicon.
3938AUTO_PRIORITY = ("mlx" , "cuda" , "xpu" , "mps" , "vulkan" , "opencl" , "cpu" )
4039
40+ # Training must only pick backends with an actual training implementation in
41+ # this repository. Keep this separate from AUTO_PRIORITY so inference can
42+ # still prefer MLX while training stays honest.
43+ TRAINING_AUTO_PRIORITY = ("cuda" , "xpu" , "mps" , "cpu" )
44+ TRAINING_BACKENDS = ("cuda" , "xpu" , "mps" , "cpu" )
45+
4146
4247@dataclass
4348class BackendInfo :
@@ -101,9 +106,9 @@ def _probe_mlx() -> BackendInfo:
101106 except Exception :
102107 avail = False
103108 return BackendInfo (
104- name = "mlx" , available = avail , supports_training = avail ,
109+ name = "mlx" , available = avail , supports_training = False ,
105110 supports_amp = avail , torch_device = None ,
106- notes = "non-torch backend; Chronos uses chronos.mlx.* paths" ,
111+ notes = "non-torch backend; Chronos uses chronos.mlx.* inference paths" ,
107112 )
108113
109114
@@ -167,6 +172,13 @@ def available(self) -> List[str]:
167172 """Returns names in priority order, filtered to those actually usable."""
168173 return [n for n in AUTO_PRIORITY if self .info (n ).available ]
169174
175+ def training_available (self ) -> List [str ]:
176+ """Returns trainable backend names in training-priority order."""
177+ return [
178+ n for n in TRAINING_AUTO_PRIORITY
179+ if self .info (n ).available and self .info (n ).supports_training
180+ ]
181+
170182 def select (self , prefer : Optional [str ] = None ) -> str :
171183 """Resolve a concrete backend name.
172184
@@ -189,10 +201,81 @@ def select(self, prefer: Optional[str] = None) -> str:
189201 return n
190202 return "cpu"
191203
204+ def select_training (self , prefer : Optional [str ] = None ) -> str :
205+ """Resolve a concrete training backend name.
206+
207+ Resolution order:
208+ 1. ``CHRONOS_TRAIN_BACKEND`` env var if set.
209+ 2. ``prefer`` argument when available and trainable.
210+ 3. First available backend in training priority order.
211+ 4. ``cpu``.
212+
213+ ``prefer`` may be ``None`` or ``"auto"`` to request automatic
214+ selection. Non-trainable values such as ``mlx`` are ignored here.
215+ """
216+ env = os .environ .get ("CHRONOS_TRAIN_BACKEND" )
217+ if env :
218+ env = env .strip ().lower ()
219+ if env == "auto" :
220+ env = ""
221+ elif env in BACKENDS and self .info (env ).available and self .info (env ).supports_training :
222+ return env
223+ elif env :
224+ print (f"[chronos.backend] CHRONOS_TRAIN_BACKEND={ env } not available for training; auto-selecting." )
225+
226+ prefer = (prefer or "" ).strip ().lower ()
227+ if prefer == "auto" :
228+ prefer = ""
229+ if prefer and prefer in BACKENDS :
230+ info = self .info (prefer )
231+ if info .available and info .supports_training :
232+ return prefer
233+
234+ for n in TRAINING_AUTO_PRIORITY :
235+ info = self .info (n )
236+ if info .available and info .supports_training :
237+ return n
238+ return "cpu"
239+
192240 def device_str (self , name : str ) -> Optional [str ]:
193241 """PyTorch device string for a backend, or None for non-torch backends."""
194242 return self .info (name ).torch_device
195243
244+ def training_device_str (self , name : Optional [str ] = None ) -> str :
245+ """PyTorch device string for a training backend."""
246+ backend = self .select_training (name )
247+ return self .info (backend ).torch_device or "cpu"
248+
249+ def resolve_training_device (self , prefer : Optional [str ] = None ) -> tuple [str , str ]:
250+ """Resolve ``(backend, torch_device)`` for training.
251+
252+ Accepts backend-level requests such as ``auto`` / ``cuda`` / ``mps``
253+ as well as explicit torch device strings like ``cuda:0``.
254+ """
255+ raw = (prefer or "" ).strip ()
256+ name = raw .lower ()
257+
258+ explicit_map = {
259+ "cuda:" : "cuda" ,
260+ "xpu:" : "xpu" ,
261+ }
262+ for prefix , backend in explicit_map .items ():
263+ if name .startswith (prefix ):
264+ info = self .info (backend )
265+ if info .available and info .supports_training :
266+ return backend , raw
267+ print (f"[chronos.backend] requested training device { raw } is not available; auto-selecting." )
268+ chosen = self .select_training ()
269+ return chosen , self .info (chosen ).torch_device or "cpu"
270+
271+ if name in {"cpu" , "mps" , "cuda" , "xpu" }:
272+ chosen = self .select_training (name )
273+ if chosen == name :
274+ return chosen , self .info (chosen ).torch_device or "cpu"
275+
276+ chosen = self .select_training (name or None )
277+ return chosen , self .info (chosen ).torch_device or "cpu"
278+
196279 def supports_training (self , name : str ) -> bool :
197280 return self .info (name ).supports_training
198281
@@ -213,6 +296,16 @@ def describe(self) -> str:
213296 lines .append (f" { marker } { n :<8} { tr :<10} dev={ i .torch_device or '-' :<6} { extra } " )
214297 return "\n " .join (lines )
215298
299+ def describe_training (self ) -> str :
300+ """Human-readable summary for trainable backends only."""
301+ lines = ["Chronos training backends:" ]
302+ for n in TRAINING_AUTO_PRIORITY :
303+ i = self .info (n )
304+ marker = "✓" if (i .available and i .supports_training ) else "·"
305+ extra = f" — { i .notes } " if i .notes else ""
306+ lines .append (f" { marker } { n :<8} train dev={ i .torch_device or '-' :<6} { extra } " )
307+ return "\n " .join (lines )
308+
216309
217310# Module-level convenience singleton
218311_default = BackendDispatcher ()
@@ -226,9 +319,29 @@ def select(prefer: Optional[str] = None) -> str:
226319 return _default .select (prefer )
227320
228321
322+ def training_available () -> List [str ]:
323+ return _default .training_available ()
324+
325+
326+ def select_training (prefer : Optional [str ] = None ) -> str :
327+ return _default .select_training (prefer )
328+
329+
229330def device_str (name : Optional [str ] = None ) -> Optional [str ]:
230331 return _default .device_str (name or select ())
231332
232333
334+ def training_device_str (name : Optional [str ] = None ) -> str :
335+ return _default .training_device_str (name )
336+
337+
338+ def resolve_training_device (prefer : Optional [str ] = None ) -> tuple [str , str ]:
339+ return _default .resolve_training_device (prefer )
340+
341+
233342def describe () -> str :
234343 return _default .describe ()
344+
345+
346+ def describe_training () -> str :
347+ return _default .describe_training ()
0 commit comments