Skip to content

Commit 111afa1

Browse files
committed
Simplify logic to extract model info
1 parent 5effad4 commit 111afa1

1 file changed

Lines changed: 47 additions & 62 deletions

File tree

apps/analysis.py

Lines changed: 47 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -378,22 +378,38 @@ def _extract_model_details(run_config: dict) -> dict[str, str]:
378378
model_cfg = run_config.get("pipeline") or run_config.get("model") or {}
379379
details: dict[str, str] = {}
380380

381-
# LLM / realtime model
382-
llm = model_cfg.get("realtime_model") or model_cfg.get("llm_model") or model_cfg.get("llm") or ""
383-
if llm:
384-
details["LLM"] = llm
385-
386-
# STT provider + specific model
387-
stt = model_cfg.get("stt_model") or model_cfg.get("stt") or ""
388-
if stt:
389-
stt_model = (model_cfg.get("stt_params") or {}).get("model")
390-
details["STT"] = f"{stt} ({stt_model})" if stt_model else stt
391-
392-
# TTS provider + specific model
393-
tts = model_cfg.get("tts_model") or model_cfg.get("tts") or ""
394-
if tts:
395-
tts_model = (model_cfg.get("tts_params") or {}).get("model")
396-
details["TTS"] = f"{tts} ({tts_model})" if tts_model else tts
381+
# Speech-to-speech
382+
s2s = model_cfg.get("s2s") or model_cfg.get("realtime_model") or ""
383+
if s2s:
384+
s2s_params = model_cfg.get("s2s_params") or {}
385+
label = s2s_params.get("alias") or s2s_params.get("model") or s2s
386+
details["S2S"] = label
387+
else:
388+
# Audio LLM
389+
audio_llm = model_cfg.get("audio_llm") or ""
390+
if audio_llm:
391+
audio_llm_params = model_cfg.get("audio_llm_params") or {}
392+
details["Audio LLM"] = audio_llm_params.get("alias") or audio_llm_params.get("model") or audio_llm
393+
else:
394+
# Cascade: LLM
395+
llm = model_cfg.get("llm") or model_cfg.get("llm_model") or ""
396+
if llm:
397+
details["LLM"] = llm
398+
399+
# STT (cascade only, not S2S/AudioLLM)
400+
if not audio_llm:
401+
stt = model_cfg.get("stt") or model_cfg.get("stt_model") or ""
402+
if stt:
403+
stt_params = model_cfg.get("stt_params") or {}
404+
label = stt_params.get("alias") or stt_params.get("model") or stt
405+
details["STT"] = f"{stt} ({label})" if label != stt else stt
406+
407+
# TTS (cascade and AudioLLM)
408+
tts = model_cfg.get("tts") or model_cfg.get("tts_model") or ""
409+
if tts:
410+
tts_params = model_cfg.get("tts_params") or {}
411+
label = tts_params.get("alias") or tts_params.get("model") or tts
412+
details["TTS"] = f"{tts} ({label})" if label != tts else tts
397413

398414
# Turn strategy
399415
turn_strategy = model_cfg.get("turn_strategy")
@@ -619,49 +635,25 @@ def _classify_pipeline_type(run_config: dict) -> str:
619635

620636
def _extract_llm_model_name(run_config: dict) -> str:
621637
"""Extract the primary model name from config."""
622-
model_cfg = run_config.get("pipeline") or run_config.get("model") or {}
623-
audio_llm_params = model_cfg.get("audio_llm_params") or {}
624-
return (
625-
model_cfg.get("s2s")
626-
or model_cfg.get("realtime_model")
627-
or audio_llm_params.get("alias")
628-
or audio_llm_params.get("model")
629-
or model_cfg.get("llm")
630-
or model_cfg.get("llm_model")
631-
or "unknown"
638+
details = _extract_model_details(run_config)
639+
return next(
640+
(details[k] for k in ("S2S", "Audio LLM", "LLM") if k in details),
641+
"unknown",
632642
)
633643

634644

635645
def _extract_all_models(run_config: dict) -> set[str]:
636646
"""Extract all model names with their role (LLM/STT/TTS) from config."""
637-
model_cfg = run_config.get("pipeline") or run_config.get("model") or {}
647+
details = _extract_model_details(run_config)
638648
models: set[str] = set()
639-
640-
# LLM / S2S / Audio LLM
641-
audio_llm_params = model_cfg.get("audio_llm_params") or {}
642-
llm = (
643-
model_cfg.get("s2s")
644-
or model_cfg.get("realtime_model")
645-
or audio_llm_params.get("alias")
646-
or audio_llm_params.get("model")
647-
or model_cfg.get("llm")
648-
or model_cfg.get("llm_model")
649-
)
650-
if llm:
651-
models.add(f"{llm} (LLM)")
652-
653-
# STT
654-
stt_params = model_cfg.get("stt_params") or {}
655-
stt = stt_params.get("alias") or stt_params.get("model") or model_cfg.get("stt_model") or model_cfg.get("stt") or ""
656-
if stt:
657-
models.add(f"{stt} (STT)")
658-
659-
# TTS
660-
tts_params = model_cfg.get("tts_params") or {}
661-
tts = tts_params.get("alias") or tts_params.get("model") or model_cfg.get("tts_model") or model_cfg.get("tts") or ""
662-
if tts:
663-
models.add(f"{tts} (TTS)")
664-
649+
for role in ("S2S", "Audio LLM", "LLM"):
650+
if role in details:
651+
models.add(f"{details[role]} (LLM)")
652+
break
653+
if "STT" in details:
654+
models.add(f"{details['STT']} (STT)")
655+
if "TTS" in details:
656+
models.add(f"{details['TTS']} (TTS)")
665657
return models or {"unknown"}
666658

667659

@@ -1770,16 +1762,9 @@ def render_conversation_trace_tab(metrics: Optional[RecordMetrics], record_dir:
17701762

17711763
def _render_sidebar_run_metadata(run_name: str, run_config: dict):
17721764
"""Render run metadata in the sidebar."""
1773-
pipeline = run_config.get("pipeline") or run_config.get("model") or {}
17741765
metadata_parts = [f"**Run:** {run_name}"]
1775-
if pipeline.get("realtime_model"):
1776-
metadata_parts.append(f"**Realtime Model:** {pipeline['realtime_model']}")
1777-
elif pipeline.get("llm_model"):
1778-
metadata_parts.append(f"**LLM Model:** {pipeline['llm_model']}")
1779-
if pipeline.get("stt_model"):
1780-
metadata_parts.append(f"**STT:** {pipeline['stt_model']}")
1781-
if pipeline.get("tts_model"):
1782-
metadata_parts.append(f"**TTS:** {pipeline['tts_model']}")
1766+
for label, value in _extract_model_details(run_config).items():
1767+
metadata_parts.append(f"**{label}:** {value}")
17831768
if run_config.get("num_trials"):
17841769
metadata_parts.append(f"**Trials:** {run_config['num_trials']}")
17851770
provenance = run_config.get("provenance", {})

0 commit comments

Comments
 (0)