feat(@scripts): add whisper-http backend config and stt service refactor

Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
Natalie 2026-05-17 22:11:19 -07:00
parent 2b3c95d3f2
commit 7138338d31
4 changed files with 923 additions and 0 deletions

View file

@ -0,0 +1,201 @@
"""Configuration for Chatterbox TTS Service."""
from pathlib import Path
from typing import Literal
from pydantic import Field, field_validator
from pydantic_settings import SettingsConfigDict
from lilith_service_fastapi_bootstrap import BaseServiceSettings
class ChatterboxSettings(BaseServiceSettings):
"""Configuration settings for Chatterbox TTS Service.
Extends BaseServiceSettings with Chatterbox-specific options for
model configuration, GPU management, voice storage, and synthesis defaults.
"""
# Model configuration
model_type: Literal["turbo", "original"] = Field(
default="turbo",
description="Chatterbox model variant (turbo is faster, original is higher quality)",
)
model_cache_dir: Path = Field(
default=Path.home() / ".cache" / "huggingface",
description="Directory for cached model files",
)
# GPU configuration
gpu_device_ids: list[int] | None = Field(
default=None,
description="GPU device IDs to use (None = auto-detect all)",
)
# Performance optimizations
enable_compile: bool = Field(
default=False,
description="Enable torch.compile() (disabled by default - ChatterboxTTS has complex control flow)",
)
compile_mode: Literal["default", "reduce-overhead", "max-autotune"] = Field(
default="reduce-overhead",
description="torch.compile mode (reduce-overhead is best for inference)",
)
use_half_precision: bool = Field(
default=False,
description="Use bf16 half precision (disabled by default - ChatterboxTTS has internal dtype conflicts)",
)
warmup_on_load: bool = Field(
default=True,
description="Run warmup generation on model load to pre-compile CUDA kernels",
)
# model-boss coordinator URL
model_boss_url: str = Field(
default="http://localhost:8210",
description="Base URL of the model-boss coordinator service",
)
# whisper-http backend URL (STT service delegated to model-boss)
whisper_http_url: str = Field(
default="http://localhost:10011",
description="Base URL of the whisper-http coordinator (faster-whisper via model-boss)",
)
# Voice storage
voices_dir: Path = Field(
default=Path("voices"),
description="Directory for storing cloned voice reference audio and conditionals",
)
voice_library_dir: Path = Field(
default=Path.home() / "datasets" / "voices" / "library",
description="Directory for browsable voice library (auto-discovered voices)",
)
max_conditionals_cache: int = Field(
default=20,
ge=1,
le=100,
description="Maximum number of voice conditionals to keep in memory",
)
# Synthesis defaults
default_exaggeration: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Default emotional expressiveness (0.0=calm, 1.0=dramatic)",
)
default_cfg_weight: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Default pacing control (lower=slower, higher=faster)",
)
default_temperature: float = Field(
default=0.8,
ge=0.0,
le=2.0,
description="Default sampling temperature",
)
default_top_p: float = Field(
default=0.95,
ge=0.0,
le=1.0,
description="Default top-p sampling",
)
default_repetition_penalty: float = Field(
default=1.2,
ge=1.0,
le=3.0,
description="Default repetition penalty",
)
max_text_length: int = Field(
default=10000,
ge=1,
le=100000,
description="Maximum input text length in characters",
)
# Audio output
default_format: Literal["wav", "mp3", "opus"] = Field(
default="wav",
description="Default output audio format",
)
normalize_loudness: bool = Field(
default=True,
description="Normalize output loudness by default",
)
target_loudness_lufs: float = Field(
default=-23.0,
description="Target loudness in LUFS for normalization",
)
# Conversation / VAD settings
vad_speech_threshold: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Silero VAD speech probability threshold (0.0-1.0)",
)
vad_echo_aware_threshold: float = Field(
default=0.7,
ge=0.0,
le=1.0,
description="Raised VAD threshold during AI playback to avoid echo triggers",
)
vad_post_speech_silence: float = Field(
default=0.4,
ge=0.1,
le=3.0,
description="Seconds of silence after speech before emitting speech_end",
)
vad_min_speech_duration: float = Field(
default=0.15,
ge=0.0,
le=2.0,
description="Minimum continuous speech duration before confirming speech_start",
)
conversation_stt_model: str = Field(
default="base",
description="Default Whisper model for conversation streaming STT",
)
# Server configuration
host: str = Field(
default="0.0.0.0",
description="Server host address",
)
port: int = Field(
default=8000,
ge=1,
le=65535,
description="Server port",
)
model_config = SettingsConfigDict(
env_prefix="CHATTERBOX_",
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=False,
extra="ignore",
)
@field_validator("gpu_device_ids", mode="before")
@classmethod
def parse_gpu_device_ids(cls, v: str | list[int] | None) -> list[int] | None:
"""Parse GPU device IDs from comma-separated string or list."""
if v is None:
return None
if isinstance(v, str):
if not v.strip():
return None
return [int(x.strip()) for x in v.split(",") if x.strip()]
return v
@field_validator("voices_dir", "model_cache_dir", "voice_library_dir", mode="before")
@classmethod
def parse_path(cls, v: str | Path) -> Path:
"""Parse path from string and expand ~."""
if isinstance(v, str):
return Path(v).expanduser()
return v.expanduser() if hasattr(v, 'expanduser') else v

View file

@ -0,0 +1,313 @@
"""STT service delegating to the model-boss whisper-http backend.
Mirrors the architecture of `tts_service.py`: this process owns no Whisper
model and no VRAM. All transcription requests are proxied to the
`whisper-http` service (which acquires a model-boss VRAM lease around its
own model lifecycle, coordinating with TTS to prevent OOM).
Public surface is preserved 1:1 with the old in-process implementation so
the existing routes (routes/stt.py) and websocket streamers don't need
changes — only the internals are different.
"""
from __future__ import annotations
import asyncio
import base64
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
if TYPE_CHECKING:
from chatterbox_tts_service.config import ChatterboxSettings
logger = logging.getLogger(__name__)
# Available Whisper models (the set whisper-http accepts via WhisperLoader)
WHISPER_MODELS = [
"tiny", "tiny.en",
"base", "base.en",
"small", "small.en",
"medium", "medium.en",
"large-v2", "large-v3",
"turbo",
]
@dataclass
class TranscriptionSegment:
"""A single segment of transcribed audio."""
start: float
end: float
text: str
confidence: float | None = None
no_speech_prob: float | None = None
@dataclass
class TranscriptionResult:
"""Result of a transcription operation."""
text: str
language: str
language_probability: float
duration_seconds: float
segments: list[TranscriptionSegment] = field(default_factory=list)
average_confidence: float | None = None
model_used: str = ""
class STTService:
"""Speech-to-Text service backed by the whisper-http endpoint.
No VRAM is held by this process — model lifecycle is owned by model-boss
via whisper-http. State tracking (`current_model`, `is_model_loaded`)
reflects the LAST successful HTTP transcription, not local memory.
"""
def __init__(self, settings: ChatterboxSettings) -> None:
self.settings = settings
# whisper_http_url is read off the settings; falls back to env or
# localhost:10011 (matches the systemd unit / router.py default).
self._whisper_http_url: str = getattr(
settings, "whisper_http_url", None
) or "http://localhost:10011"
self._http_client: Any | None = None
self._last_model: str | None = None
# Cached health snapshot so `device` / `is_model_loaded` can return
# meaningful values without doing a network round-trip on every read.
self._cached_health: dict[str, Any] = {}
self._health_lock = asyncio.Lock()
# ─── Public properties (preserved API) ────────────────────────────────────
@property
def is_model_loaded(self) -> bool:
"""Whether whisper-http has reported a loaded model in its last health check.
Returns False until the first successful health check / transcription.
"""
return bool(self._cached_health.get("model_loaded"))
@property
def current_model(self) -> str | None:
"""Last model that whisper-http reported as loaded."""
return self._cached_health.get("model_id") or self._last_model
@property
def device(self) -> str:
"""Device whisper-http reports running on. 'remote' if no health snapshot yet."""
return self._cached_health.get("device") or "remote"
@property
def available_models(self) -> list[str]:
return WHISPER_MODELS.copy()
# ─── Lifecycle (no-ops — backend owns model) ──────────────────────────────
async def load_model(self, model_name: str = "base") -> None:
"""Request whisper-http to warm a model. Validates the name and calls /health.
whisper-http loads lazily on the first /transcribe; this method exists
for API compatibility — callers that want to pre-warm can pass a 1s
synthetic clip through transcribe_bytes() instead.
"""
if model_name not in WHISPER_MODELS:
raise ValueError(
f"Invalid model '{model_name}'. Available: {', '.join(WHISPER_MODELS)}"
)
self._last_model = model_name
await self._refresh_health()
async def get_model(self, model_name: str = "base") -> None:
"""No-op — model lives in whisper-http. Validates the name."""
if model_name not in WHISPER_MODELS:
raise ValueError(
f"Invalid model '{model_name}'. Available: {', '.join(WHISPER_MODELS)}"
)
self._last_model = model_name
return None
async def unload_model(self) -> None:
"""No-op — whisper-http manages model unload via its lease lifecycle."""
self._cached_health.pop("model_loaded", None)
self._cached_health.pop("model_id", None)
def get_model_info(self) -> dict[str, Any]:
return {
"backend": "whisper-http",
"url": self._whisper_http_url,
"current_model": self.current_model,
"device": self.device,
"is_model_loaded": self.is_model_loaded,
"available_models": self.available_models,
"health": self._cached_health,
}
# ─── Transcription (proxies to whisper-http) ──────────────────────────────
async def transcribe(
self,
audio_path: Path | str,
*,
model: str = "base",
language: str | None = None,
task: Literal["transcribe", "translate"] = "transcribe",
temperature: float = 0.0, # accepted for API parity; passed through
beam_size: int = 5,
best_of: int = 5,
patience: float = 1.0, # accepted, not forwarded (loader default)
length_penalty: float = 1.0, # accepted, not forwarded
initial_prompt: str | None = None,
word_timestamps: bool = False,
vad_filter: bool = False,
vad_parameters: dict[str, Any] | None = None, # accepted, not forwarded
**kwargs: Any,
) -> TranscriptionResult:
"""Transcribe an audio file via whisper-http. Reads + base64-encodes the file."""
path = Path(audio_path)
if not path.exists():
raise FileNotFoundError(f"Audio file not found: {path}")
audio_bytes = await asyncio.to_thread(path.read_bytes)
return await self.transcribe_bytes(
audio_bytes,
model=model,
language=language,
task=task,
beam_size=beam_size,
best_of=best_of,
initial_prompt=initial_prompt,
word_timestamps=word_timestamps,
vad_filter=vad_filter,
**kwargs,
)
async def transcribe_bytes(
self,
audio_bytes: bytes,
*,
model: str = "base",
language: str | None = None,
task: Literal["transcribe", "translate"] = "transcribe",
beam_size: int = 5,
best_of: int = 5,
initial_prompt: str | None = None,
word_timestamps: bool = False,
vad_filter: bool = False,
_retry: bool = False,
**kwargs: Any,
) -> TranscriptionResult:
if model not in WHISPER_MODELS:
raise ValueError(
f"Invalid model '{model}'. Available: {', '.join(WHISPER_MODELS)}"
)
if not audio_bytes:
raise ValueError("transcribe_bytes called with empty audio")
import httpx
payload: dict[str, Any] = {
"audio": base64.b64encode(audio_bytes).decode("ascii"),
"model": model,
"task": task,
"beam_size": beam_size,
"best_of": best_of,
"word_timestamps": word_timestamps,
"vad_filter": vad_filter,
}
if language:
payload["language"] = language
if initial_prompt:
payload["initial_prompt"] = initial_prompt
client = self._get_http_client()
url = f"{self._whisper_http_url}/transcribe"
try:
response = await client.post(url, json=payload)
except (httpx.ConnectError, httpx.RemoteProtocolError) as exc:
if not _retry:
logger.warning("whisper-http connection error, retrying once: %s", exc)
return await self.transcribe_bytes(
audio_bytes,
model=model,
language=language,
task=task,
beam_size=beam_size,
best_of=best_of,
initial_prompt=initial_prompt,
word_timestamps=word_timestamps,
vad_filter=vad_filter,
_retry=True,
)
raise RuntimeError(f"whisper-http request failed: {exc}") from exc
if response.status_code != 200:
raise RuntimeError(
f"whisper-http returned {response.status_code}: {response.text[:500]}"
)
data = response.json()
segments = [
TranscriptionSegment(
start=float(s["start"]),
end=float(s["end"]),
text=str(s["text"]),
confidence=(float(s["confidence"]) if s.get("confidence") is not None else None),
no_speech_prob=(
float(s["no_speech_prob"]) if s.get("no_speech_prob") is not None else None
),
)
for s in data.get("segments", [])
]
confidences = [s.confidence for s in segments if s.confidence is not None]
avg_conf = sum(confidences) / len(confidences) if confidences else None
result = TranscriptionResult(
text=str(data.get("text", "")),
language=str(data.get("language", language or "unknown")),
language_probability=float(data.get("language_probability", 1.0)),
duration_seconds=float(data.get("duration_seconds", 0.0)),
segments=segments,
average_confidence=avg_conf,
model_used=str(data.get("model_used", model)),
)
self._last_model = result.model_used or model
# Opportunistically refresh health from the successful response so
# is_model_loaded / device reflect reality without an extra round-trip.
self._cached_health["model_loaded"] = True
self._cached_health["model_id"] = self._last_model
return result
async def cleanup(self) -> None:
if self._http_client is not None:
await self._http_client.aclose()
self._http_client = None
# ─── Private ──────────────────────────────────────────────────────────────
def _get_http_client(self) -> Any:
import httpx
if self._http_client is None:
self._http_client = httpx.AsyncClient(
timeout=httpx.Timeout(180.0, connect=10.0),
)
return self._http_client
async def _refresh_health(self) -> None:
"""Fetch /health from whisper-http and cache. Tolerates errors silently."""
import httpx
async with self._health_lock:
client = self._get_http_client()
try:
response = await client.get(f"{self._whisper_http_url}/health", timeout=5.0)
if response.status_code == 200:
self._cached_health = response.json()
except (httpx.ConnectError, httpx.HTTPError, Exception) as exc: # noqa: BLE001
logger.debug("whisper-http health probe failed: %s", exc)

View file

@ -0,0 +1,406 @@
"""Core STT service using faster-whisper for GPU-accelerated transcription."""
from __future__ import annotations
import asyncio
import logging
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
if TYPE_CHECKING:
from chatterbox_tts_service.config import ChatterboxSettings
from gpu_devices import is_cuda_available
logger = logging.getLogger(__name__)
# Available Whisper models
WHISPER_MODELS = ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large-v2", "large-v3"]
@dataclass
class TranscriptionSegment:
"""A single segment of transcribed audio."""
start: float
end: float
text: str
confidence: float | None = None
no_speech_prob: float | None = None
@dataclass
class TranscriptionResult:
"""Result of a transcription operation."""
text: str
language: str
language_probability: float
duration_seconds: float
segments: list[TranscriptionSegment]
average_confidence: float | None = None
class STTService:
"""Speech-to-Text service using faster-whisper.
Provides lazy model loading, GPU acceleration, and high-quality transcription.
"""
def __init__(self, settings: ChatterboxSettings) -> None:
"""Initialize the STT service.
Args:
settings: Service configuration.
"""
self.settings = settings
self._model: Any | None = None
self._current_model_name: str | None = None
self._model_lock = asyncio.Lock()
self._device: str = "cpu"
self._compute_type: str = "int8"
# Determine device and compute type
if is_cuda_available():
self._device = "cuda"
self._compute_type = "float16" # Use fp16 on GPU for speed
logger.info("STT service will use CUDA acceleration")
else:
self._device = "cpu"
self._compute_type = "int8" # Use int8 quantization on CPU
logger.info("STT service will use CPU (int8 quantization)")
logger.info(
f"STTService initialized: device={self._device}, compute_type={self._compute_type}"
)
@property
def is_model_loaded(self) -> bool:
"""Check if a model is loaded."""
return self._model is not None
@property
def current_model(self) -> str | None:
"""Get the currently loaded model name."""
return self._current_model_name
@property
def device(self) -> str:
"""Get the current compute device."""
return self._device
@property
def available_models(self) -> list[str]:
"""Get list of available Whisper models."""
return WHISPER_MODELS.copy()
async def load_model(self, model_name: str = "base") -> None:
"""Load a Whisper model.
Args:
model_name: Name of the Whisper model to load.
Raises:
ValueError: If model_name is not a valid Whisper model.
"""
if model_name not in WHISPER_MODELS:
raise ValueError(
f"Invalid model '{model_name}'. Available models: {', '.join(WHISPER_MODELS)}"
)
# If same model already loaded, skip
if self._model is not None and self._current_model_name == model_name:
logger.debug(f"Model '{model_name}' already loaded, skipping")
return
async with self._model_lock:
# Double-check after acquiring lock
if self._model is not None and self._current_model_name == model_name:
return
logger.info(f"Loading Whisper model: {model_name} on {self._device}")
# Unload existing model if different
if self._model is not None and self._current_model_name != model_name:
logger.info(f"Unloading previous model: {self._current_model_name}")
del self._model
self._model = None
self._current_model_name = None
# Load model in thread pool to avoid blocking
self._model = await asyncio.to_thread(
self._load_model_sync, model_name
)
self._current_model_name = model_name
logger.info(f"Whisper model '{model_name}' loaded successfully")
def _load_model_sync(self, model_name: str) -> Any:
"""Synchronously load the model (called in thread pool).
Args:
model_name: Name of the Whisper model to load.
Returns:
Loaded WhisperModel instance.
"""
from faster_whisper import WhisperModel
model = WhisperModel(
model_name,
device=self._device,
compute_type=self._compute_type,
download_root=str(self.settings.model_cache_dir),
)
return model
async def get_model(self, model_name: str = "base") -> Any:
"""Get the loaded model, loading if necessary.
Args:
model_name: Name of the Whisper model to use.
Returns:
The WhisperModel instance.
"""
if self._model is None or self._current_model_name != model_name:
await self.load_model(model_name)
return self._model
async def transcribe(
self,
audio_path: Path | str,
*,
model: str = "base",
language: str | None = None,
task: Literal["transcribe", "translate"] = "transcribe",
temperature: float = 0.0,
beam_size: int = 5,
best_of: int = 5,
patience: float = 1.0,
vad_filter: bool = True,
word_timestamps: bool = False,
) -> TranscriptionResult:
"""Transcribe audio file to text.
Args:
audio_path: Path to audio file (supports WAV, MP3, WebM, etc.).
model: Whisper model name to use.
language: Language code (e.g., 'en', 'de'). Auto-detect if None.
task: Either 'transcribe' (keep original language) or 'translate' (to English).
temperature: Sampling temperature (0.0 = deterministic).
beam_size: Beam size for beam search.
best_of: Number of candidates when sampling.
patience: Patience value for beam search.
vad_filter: Enable voice activity detection filter.
word_timestamps: Generate word-level timestamps.
Returns:
TranscriptionResult with text, language, and segments.
Raises:
FileNotFoundError: If audio file doesn't exist.
ValueError: If invalid parameters provided.
"""
audio_path = Path(audio_path)
if not audio_path.exists():
raise FileNotFoundError(f"Audio file not found: {audio_path}")
# Get or load model
whisper_model = await self.get_model(model)
logger.info(
f"Transcribing: file={audio_path.name}, model={model}, "
f"language={language or 'auto'}, task={task}"
)
# Run transcription in thread pool
result = await asyncio.to_thread(
self._transcribe_sync,
whisper_model,
audio_path,
language,
task,
temperature,
beam_size,
best_of,
patience,
vad_filter,
word_timestamps,
)
logger.info(
f"Transcription complete: text_len={len(result.text)}, "
f"language={result.language}, duration={result.duration_seconds:.2f}s"
)
return result
def _transcribe_sync(
self,
model: Any,
audio_path: Path,
language: str | None,
task: str,
temperature: float,
beam_size: int,
best_of: int,
patience: float,
vad_filter: bool,
word_timestamps: bool,
) -> TranscriptionResult:
"""Synchronously transcribe audio (called in thread pool).
Args:
model: WhisperModel instance.
audio_path: Path to audio file.
language: Language code or None for auto-detection.
task: 'transcribe' or 'translate'.
temperature: Sampling temperature.
beam_size: Beam size for beam search.
best_of: Number of candidates when sampling.
patience: Patience value for beam search.
vad_filter: Enable VAD filter.
word_timestamps: Generate word-level timestamps.
Returns:
TranscriptionResult.
"""
# Run transcription
segments_iter, info = model.transcribe(
str(audio_path),
language=language,
task=task,
temperature=temperature,
beam_size=beam_size,
best_of=best_of,
patience=patience,
vad_filter=vad_filter,
word_timestamps=word_timestamps,
)
# Collect segments
segments = []
full_text = []
confidences = []
for segment in segments_iter:
# Extract confidence (average of word probabilities if available)
confidence = None
if hasattr(segment, "avg_logprob"):
# Convert log probability to probability
import math
confidence = math.exp(segment.avg_logprob)
confidences.append(confidence)
no_speech_prob = getattr(segment, "no_speech_prob", None)
segments.append(
TranscriptionSegment(
start=segment.start,
end=segment.end,
text=segment.text.strip(),
confidence=confidence,
no_speech_prob=no_speech_prob,
)
)
full_text.append(segment.text.strip())
# Calculate average confidence
avg_confidence = None
if confidences:
avg_confidence = sum(confidences) / len(confidences)
return TranscriptionResult(
text=" ".join(full_text),
language=info.language,
language_probability=info.language_probability,
duration_seconds=info.duration,
segments=segments,
average_confidence=avg_confidence,
)
async def transcribe_bytes(
self,
audio_bytes: bytes,
*,
model: str = "base",
language: str | None = None,
task: Literal["transcribe", "translate"] = "transcribe",
**kwargs,
) -> TranscriptionResult:
"""Transcribe audio from bytes.
Saves bytes to temporary file and transcribes.
Args:
audio_bytes: Raw audio file bytes.
model: Whisper model name to use.
language: Language code or None for auto-detection.
task: Either 'transcribe' or 'translate'.
**kwargs: Additional arguments passed to transcribe().
Returns:
TranscriptionResult.
"""
# Write to temporary file
with tempfile.NamedTemporaryFile(suffix=".audio", delete=False) as tmp:
tmp.write(audio_bytes)
tmp_path = Path(tmp.name)
try:
# Transcribe from temporary file
result = await self.transcribe(
tmp_path,
model=model,
language=language,
task=task,
**kwargs,
)
return result
finally:
# Clean up temporary file
try:
tmp_path.unlink()
except Exception as e:
logger.warning(f"Failed to delete temporary file {tmp_path}: {e}")
async def unload_model(self) -> None:
"""Unload the current model and free memory."""
async with self._model_lock:
if self._model is not None:
logger.info(f"Unloading Whisper model: {self._current_model_name}")
# Delete model
del self._model
self._model = None
self._current_model_name = None
# Clear CUDA cache if on GPU
if self._device == "cuda":
try:
import torch
torch.cuda.empty_cache()
logger.info("CUDA cache cleared")
except Exception as e:
logger.warning(f"Failed to clear CUDA cache: {e}")
logger.info("Model unloaded successfully")
def get_model_info(self) -> dict[str, Any]:
"""Get information about the STT service state.
Returns:
Dictionary with service status information.
"""
return {
"model_loaded": self.is_model_loaded,
"current_model": self._current_model_name,
"device": self._device,
"compute_type": self._compute_type,
"available_models": self.available_models,
}

View file

@ -0,0 +1,3 @@
[Service]
Environment="WHISPER_HTTP_VRAM_LEASE_MB=2048"
Environment="WHISPER_HTTP_DEVICE=cuda"