Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions src/eva/assistant/pipeline/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
SessionProperties,
)
from pipecat.services.openai.stt import OpenAISTTService
from pipecat.services.openai.tts import VALID_VOICES, OpenAITTSService
from pipecat.services.openai.tts import OpenAITTSService
from pipecat.services.stt_service import STTService
from pipecat.services.tts_service import TTSService
from pipecat.transcriptions.language import Language
Expand Down Expand Up @@ -134,6 +134,16 @@ def create_stt_service(
),
)

elif model_lower == "cohere":
logger.info(f"Using Cohere STT: {params['model']}")
return OpenAISTTService(
api_key=api_key,
base_url=url,
model=params["model"],
language=Language.EN,
sample_rate=SAMPLE_RATE,
)

elif model_lower.startswith("deepgram"):
# Check if using Flux model
if "flux" in model_lower:
Expand Down Expand Up @@ -207,7 +217,7 @@ def create_stt_service(

else:
raise ValueError(
f"Unknown STT model: {model}. Available: assemblyai, cartesia, deepgram, deepgram-flux, elevenlabs, nvidia, nvidia-baseten, openai"
f"Unknown STT model: {model}. Available: assemblyai, cartesia, cohere, deepgram, deepgram-flux, elevenlabs, nvidia, nvidia-baseten, openai"
)


Expand Down Expand Up @@ -341,6 +351,18 @@ def create_tts_service(

return openai_tts

elif model_lower == "voxtral":
logger.info(f"Using Voxtral TTS: {params['model']}")
voxtral_tts = OpenAITTSService(
api_key=api_key,
model=params["model"],
voice=params.get("voice", "neutral_female"),
base_url=url,
)
OpenAITTSService.run_tts = override_run_tts
voxtral_tts._settings.language = language_code
return voxtral_tts

elif model_lower == "xtts":
logger.info(f"Using XTTS TTS: {params['model']}")
xtts_tts = OpenAITTSService(
Expand Down Expand Up @@ -535,7 +557,7 @@ async def override_run_tts(self, text: str, context_id: str) -> AsyncGenerator[F
create_params = {
"input": text,
"model": self._settings.model,
"voice": VALID_VOICES[self._settings.voice],
"voice": self._settings.voice,
"response_format": "pcm",
"extra_body": {
"streaming_quality": "fast",
Expand Down