feat(@scripts): ✨ add whisper-http backend config and stt service refactor
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
2b3c95d3f2
commit
7138338d31
4 changed files with 923 additions and 0 deletions
201
contrib/apricot-stt-refactor/config.py.postWHISPERHTTP
Normal file
201
contrib/apricot-stt-refactor/config.py.postWHISPERHTTP
Normal file
|
|
@ -0,0 +1,201 @@
|
|||
"""Configuration for Chatterbox TTS Service."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic_settings import SettingsConfigDict
|
||||
|
||||
from lilith_service_fastapi_bootstrap import BaseServiceSettings
|
||||
|
||||
|
||||
class ChatterboxSettings(BaseServiceSettings):
|
||||
"""Configuration settings for Chatterbox TTS Service.
|
||||
|
||||
Extends BaseServiceSettings with Chatterbox-specific options for
|
||||
model configuration, GPU management, voice storage, and synthesis defaults.
|
||||
"""
|
||||
|
||||
# Model configuration
|
||||
model_type: Literal["turbo", "original"] = Field(
|
||||
default="turbo",
|
||||
description="Chatterbox model variant (turbo is faster, original is higher quality)",
|
||||
)
|
||||
model_cache_dir: Path = Field(
|
||||
default=Path.home() / ".cache" / "huggingface",
|
||||
description="Directory for cached model files",
|
||||
)
|
||||
|
||||
# GPU configuration
|
||||
gpu_device_ids: list[int] | None = Field(
|
||||
default=None,
|
||||
description="GPU device IDs to use (None = auto-detect all)",
|
||||
)
|
||||
|
||||
# Performance optimizations
|
||||
enable_compile: bool = Field(
|
||||
default=False,
|
||||
description="Enable torch.compile() (disabled by default - ChatterboxTTS has complex control flow)",
|
||||
)
|
||||
compile_mode: Literal["default", "reduce-overhead", "max-autotune"] = Field(
|
||||
default="reduce-overhead",
|
||||
description="torch.compile mode (reduce-overhead is best for inference)",
|
||||
)
|
||||
use_half_precision: bool = Field(
|
||||
default=False,
|
||||
description="Use bf16 half precision (disabled by default - ChatterboxTTS has internal dtype conflicts)",
|
||||
)
|
||||
warmup_on_load: bool = Field(
|
||||
default=True,
|
||||
description="Run warmup generation on model load to pre-compile CUDA kernels",
|
||||
)
|
||||
|
||||
# model-boss coordinator URL
|
||||
model_boss_url: str = Field(
|
||||
default="http://localhost:8210",
|
||||
description="Base URL of the model-boss coordinator service",
|
||||
)
|
||||
# whisper-http backend URL (STT service delegated to model-boss)
|
||||
whisper_http_url: str = Field(
|
||||
default="http://localhost:10011",
|
||||
description="Base URL of the whisper-http coordinator (faster-whisper via model-boss)",
|
||||
)
|
||||
|
||||
|
||||
# Voice storage
|
||||
voices_dir: Path = Field(
|
||||
default=Path("voices"),
|
||||
description="Directory for storing cloned voice reference audio and conditionals",
|
||||
)
|
||||
voice_library_dir: Path = Field(
|
||||
default=Path.home() / "datasets" / "voices" / "library",
|
||||
description="Directory for browsable voice library (auto-discovered voices)",
|
||||
)
|
||||
max_conditionals_cache: int = Field(
|
||||
default=20,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Maximum number of voice conditionals to keep in memory",
|
||||
)
|
||||
|
||||
# Synthesis defaults
|
||||
default_exaggeration: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Default emotional expressiveness (0.0=calm, 1.0=dramatic)",
|
||||
)
|
||||
default_cfg_weight: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Default pacing control (lower=slower, higher=faster)",
|
||||
)
|
||||
default_temperature: float = Field(
|
||||
default=0.8,
|
||||
ge=0.0,
|
||||
le=2.0,
|
||||
description="Default sampling temperature",
|
||||
)
|
||||
default_top_p: float = Field(
|
||||
default=0.95,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Default top-p sampling",
|
||||
)
|
||||
default_repetition_penalty: float = Field(
|
||||
default=1.2,
|
||||
ge=1.0,
|
||||
le=3.0,
|
||||
description="Default repetition penalty",
|
||||
)
|
||||
max_text_length: int = Field(
|
||||
default=10000,
|
||||
ge=1,
|
||||
le=100000,
|
||||
description="Maximum input text length in characters",
|
||||
)
|
||||
|
||||
# Audio output
|
||||
default_format: Literal["wav", "mp3", "opus"] = Field(
|
||||
default="wav",
|
||||
description="Default output audio format",
|
||||
)
|
||||
normalize_loudness: bool = Field(
|
||||
default=True,
|
||||
description="Normalize output loudness by default",
|
||||
)
|
||||
target_loudness_lufs: float = Field(
|
||||
default=-23.0,
|
||||
description="Target loudness in LUFS for normalization",
|
||||
)
|
||||
|
||||
# Conversation / VAD settings
|
||||
vad_speech_threshold: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Silero VAD speech probability threshold (0.0-1.0)",
|
||||
)
|
||||
vad_echo_aware_threshold: float = Field(
|
||||
default=0.7,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Raised VAD threshold during AI playback to avoid echo triggers",
|
||||
)
|
||||
vad_post_speech_silence: float = Field(
|
||||
default=0.4,
|
||||
ge=0.1,
|
||||
le=3.0,
|
||||
description="Seconds of silence after speech before emitting speech_end",
|
||||
)
|
||||
vad_min_speech_duration: float = Field(
|
||||
default=0.15,
|
||||
ge=0.0,
|
||||
le=2.0,
|
||||
description="Minimum continuous speech duration before confirming speech_start",
|
||||
)
|
||||
conversation_stt_model: str = Field(
|
||||
default="base",
|
||||
description="Default Whisper model for conversation streaming STT",
|
||||
)
|
||||
|
||||
# Server configuration
|
||||
host: str = Field(
|
||||
default="0.0.0.0",
|
||||
description="Server host address",
|
||||
)
|
||||
port: int = Field(
|
||||
default=8000,
|
||||
ge=1,
|
||||
le=65535,
|
||||
description="Server port",
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="CHATTERBOX_",
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
@field_validator("gpu_device_ids", mode="before")
|
||||
@classmethod
|
||||
def parse_gpu_device_ids(cls, v: str | list[int] | None) -> list[int] | None:
|
||||
"""Parse GPU device IDs from comma-separated string or list."""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, str):
|
||||
if not v.strip():
|
||||
return None
|
||||
return [int(x.strip()) for x in v.split(",") if x.strip()]
|
||||
return v
|
||||
|
||||
@field_validator("voices_dir", "model_cache_dir", "voice_library_dir", mode="before")
|
||||
@classmethod
|
||||
def parse_path(cls, v: str | Path) -> Path:
|
||||
"""Parse path from string and expand ~."""
|
||||
if isinstance(v, str):
|
||||
return Path(v).expanduser()
|
||||
return v.expanduser() if hasattr(v, 'expanduser') else v
|
||||
313
contrib/apricot-stt-refactor/stt_service.py.postWHISPERHTTP
Normal file
313
contrib/apricot-stt-refactor/stt_service.py.postWHISPERHTTP
Normal file
|
|
@ -0,0 +1,313 @@
|
|||
"""STT service delegating to the model-boss whisper-http backend.
|
||||
|
||||
Mirrors the architecture of `tts_service.py`: this process owns no Whisper
|
||||
model and no VRAM. All transcription requests are proxied to the
|
||||
`whisper-http` service (which acquires a model-boss VRAM lease around its
|
||||
own model lifecycle, coordinating with TTS to prevent OOM).
|
||||
|
||||
Public surface is preserved 1:1 with the old in-process implementation so
|
||||
the existing routes (routes/stt.py) and websocket streamers don't need
|
||||
changes — only the internals are different.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from chatterbox_tts_service.config import ChatterboxSettings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Available Whisper models (the set whisper-http accepts via WhisperLoader)
|
||||
WHISPER_MODELS = [
|
||||
"tiny", "tiny.en",
|
||||
"base", "base.en",
|
||||
"small", "small.en",
|
||||
"medium", "medium.en",
|
||||
"large-v2", "large-v3",
|
||||
"turbo",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptionSegment:
|
||||
"""A single segment of transcribed audio."""
|
||||
|
||||
start: float
|
||||
end: float
|
||||
text: str
|
||||
confidence: float | None = None
|
||||
no_speech_prob: float | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptionResult:
|
||||
"""Result of a transcription operation."""
|
||||
|
||||
text: str
|
||||
language: str
|
||||
language_probability: float
|
||||
duration_seconds: float
|
||||
segments: list[TranscriptionSegment] = field(default_factory=list)
|
||||
average_confidence: float | None = None
|
||||
model_used: str = ""
|
||||
|
||||
|
||||
class STTService:
|
||||
"""Speech-to-Text service backed by the whisper-http endpoint.
|
||||
|
||||
No VRAM is held by this process — model lifecycle is owned by model-boss
|
||||
via whisper-http. State tracking (`current_model`, `is_model_loaded`)
|
||||
reflects the LAST successful HTTP transcription, not local memory.
|
||||
"""
|
||||
|
||||
def __init__(self, settings: ChatterboxSettings) -> None:
|
||||
self.settings = settings
|
||||
# whisper_http_url is read off the settings; falls back to env or
|
||||
# localhost:10011 (matches the systemd unit / router.py default).
|
||||
self._whisper_http_url: str = getattr(
|
||||
settings, "whisper_http_url", None
|
||||
) or "http://localhost:10011"
|
||||
self._http_client: Any | None = None
|
||||
self._last_model: str | None = None
|
||||
# Cached health snapshot so `device` / `is_model_loaded` can return
|
||||
# meaningful values without doing a network round-trip on every read.
|
||||
self._cached_health: dict[str, Any] = {}
|
||||
self._health_lock = asyncio.Lock()
|
||||
|
||||
# ─── Public properties (preserved API) ────────────────────────────────────
|
||||
|
||||
@property
|
||||
def is_model_loaded(self) -> bool:
|
||||
"""Whether whisper-http has reported a loaded model in its last health check.
|
||||
|
||||
Returns False until the first successful health check / transcription.
|
||||
"""
|
||||
return bool(self._cached_health.get("model_loaded"))
|
||||
|
||||
@property
|
||||
def current_model(self) -> str | None:
|
||||
"""Last model that whisper-http reported as loaded."""
|
||||
return self._cached_health.get("model_id") or self._last_model
|
||||
|
||||
@property
|
||||
def device(self) -> str:
|
||||
"""Device whisper-http reports running on. 'remote' if no health snapshot yet."""
|
||||
return self._cached_health.get("device") or "remote"
|
||||
|
||||
@property
|
||||
def available_models(self) -> list[str]:
|
||||
return WHISPER_MODELS.copy()
|
||||
|
||||
# ─── Lifecycle (no-ops — backend owns model) ──────────────────────────────
|
||||
|
||||
async def load_model(self, model_name: str = "base") -> None:
|
||||
"""Request whisper-http to warm a model. Validates the name and calls /health.
|
||||
|
||||
whisper-http loads lazily on the first /transcribe; this method exists
|
||||
for API compatibility — callers that want to pre-warm can pass a 1s
|
||||
synthetic clip through transcribe_bytes() instead.
|
||||
"""
|
||||
if model_name not in WHISPER_MODELS:
|
||||
raise ValueError(
|
||||
f"Invalid model '{model_name}'. Available: {', '.join(WHISPER_MODELS)}"
|
||||
)
|
||||
self._last_model = model_name
|
||||
await self._refresh_health()
|
||||
|
||||
async def get_model(self, model_name: str = "base") -> None:
|
||||
"""No-op — model lives in whisper-http. Validates the name."""
|
||||
if model_name not in WHISPER_MODELS:
|
||||
raise ValueError(
|
||||
f"Invalid model '{model_name}'. Available: {', '.join(WHISPER_MODELS)}"
|
||||
)
|
||||
self._last_model = model_name
|
||||
return None
|
||||
|
||||
async def unload_model(self) -> None:
|
||||
"""No-op — whisper-http manages model unload via its lease lifecycle."""
|
||||
self._cached_health.pop("model_loaded", None)
|
||||
self._cached_health.pop("model_id", None)
|
||||
|
||||
def get_model_info(self) -> dict[str, Any]:
|
||||
return {
|
||||
"backend": "whisper-http",
|
||||
"url": self._whisper_http_url,
|
||||
"current_model": self.current_model,
|
||||
"device": self.device,
|
||||
"is_model_loaded": self.is_model_loaded,
|
||||
"available_models": self.available_models,
|
||||
"health": self._cached_health,
|
||||
}
|
||||
|
||||
# ─── Transcription (proxies to whisper-http) ──────────────────────────────
|
||||
|
||||
async def transcribe(
|
||||
self,
|
||||
audio_path: Path | str,
|
||||
*,
|
||||
model: str = "base",
|
||||
language: str | None = None,
|
||||
task: Literal["transcribe", "translate"] = "transcribe",
|
||||
temperature: float = 0.0, # accepted for API parity; passed through
|
||||
beam_size: int = 5,
|
||||
best_of: int = 5,
|
||||
patience: float = 1.0, # accepted, not forwarded (loader default)
|
||||
length_penalty: float = 1.0, # accepted, not forwarded
|
||||
initial_prompt: str | None = None,
|
||||
word_timestamps: bool = False,
|
||||
vad_filter: bool = False,
|
||||
vad_parameters: dict[str, Any] | None = None, # accepted, not forwarded
|
||||
**kwargs: Any,
|
||||
) -> TranscriptionResult:
|
||||
"""Transcribe an audio file via whisper-http. Reads + base64-encodes the file."""
|
||||
path = Path(audio_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Audio file not found: {path}")
|
||||
audio_bytes = await asyncio.to_thread(path.read_bytes)
|
||||
return await self.transcribe_bytes(
|
||||
audio_bytes,
|
||||
model=model,
|
||||
language=language,
|
||||
task=task,
|
||||
beam_size=beam_size,
|
||||
best_of=best_of,
|
||||
initial_prompt=initial_prompt,
|
||||
word_timestamps=word_timestamps,
|
||||
vad_filter=vad_filter,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def transcribe_bytes(
|
||||
self,
|
||||
audio_bytes: bytes,
|
||||
*,
|
||||
model: str = "base",
|
||||
language: str | None = None,
|
||||
task: Literal["transcribe", "translate"] = "transcribe",
|
||||
beam_size: int = 5,
|
||||
best_of: int = 5,
|
||||
initial_prompt: str | None = None,
|
||||
word_timestamps: bool = False,
|
||||
vad_filter: bool = False,
|
||||
_retry: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> TranscriptionResult:
|
||||
if model not in WHISPER_MODELS:
|
||||
raise ValueError(
|
||||
f"Invalid model '{model}'. Available: {', '.join(WHISPER_MODELS)}"
|
||||
)
|
||||
if not audio_bytes:
|
||||
raise ValueError("transcribe_bytes called with empty audio")
|
||||
|
||||
import httpx
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"audio": base64.b64encode(audio_bytes).decode("ascii"),
|
||||
"model": model,
|
||||
"task": task,
|
||||
"beam_size": beam_size,
|
||||
"best_of": best_of,
|
||||
"word_timestamps": word_timestamps,
|
||||
"vad_filter": vad_filter,
|
||||
}
|
||||
if language:
|
||||
payload["language"] = language
|
||||
if initial_prompt:
|
||||
payload["initial_prompt"] = initial_prompt
|
||||
|
||||
client = self._get_http_client()
|
||||
url = f"{self._whisper_http_url}/transcribe"
|
||||
|
||||
try:
|
||||
response = await client.post(url, json=payload)
|
||||
except (httpx.ConnectError, httpx.RemoteProtocolError) as exc:
|
||||
if not _retry:
|
||||
logger.warning("whisper-http connection error, retrying once: %s", exc)
|
||||
return await self.transcribe_bytes(
|
||||
audio_bytes,
|
||||
model=model,
|
||||
language=language,
|
||||
task=task,
|
||||
beam_size=beam_size,
|
||||
best_of=best_of,
|
||||
initial_prompt=initial_prompt,
|
||||
word_timestamps=word_timestamps,
|
||||
vad_filter=vad_filter,
|
||||
_retry=True,
|
||||
)
|
||||
raise RuntimeError(f"whisper-http request failed: {exc}") from exc
|
||||
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"whisper-http returned {response.status_code}: {response.text[:500]}"
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
segments = [
|
||||
TranscriptionSegment(
|
||||
start=float(s["start"]),
|
||||
end=float(s["end"]),
|
||||
text=str(s["text"]),
|
||||
confidence=(float(s["confidence"]) if s.get("confidence") is not None else None),
|
||||
no_speech_prob=(
|
||||
float(s["no_speech_prob"]) if s.get("no_speech_prob") is not None else None
|
||||
),
|
||||
)
|
||||
for s in data.get("segments", [])
|
||||
]
|
||||
|
||||
confidences = [s.confidence for s in segments if s.confidence is not None]
|
||||
avg_conf = sum(confidences) / len(confidences) if confidences else None
|
||||
|
||||
result = TranscriptionResult(
|
||||
text=str(data.get("text", "")),
|
||||
language=str(data.get("language", language or "unknown")),
|
||||
language_probability=float(data.get("language_probability", 1.0)),
|
||||
duration_seconds=float(data.get("duration_seconds", 0.0)),
|
||||
segments=segments,
|
||||
average_confidence=avg_conf,
|
||||
model_used=str(data.get("model_used", model)),
|
||||
)
|
||||
|
||||
self._last_model = result.model_used or model
|
||||
# Opportunistically refresh health from the successful response so
|
||||
# is_model_loaded / device reflect reality without an extra round-trip.
|
||||
self._cached_health["model_loaded"] = True
|
||||
self._cached_health["model_id"] = self._last_model
|
||||
return result
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
if self._http_client is not None:
|
||||
await self._http_client.aclose()
|
||||
self._http_client = None
|
||||
|
||||
# ─── Private ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _get_http_client(self) -> Any:
|
||||
import httpx
|
||||
|
||||
if self._http_client is None:
|
||||
self._http_client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(180.0, connect=10.0),
|
||||
)
|
||||
return self._http_client
|
||||
|
||||
async def _refresh_health(self) -> None:
|
||||
"""Fetch /health from whisper-http and cache. Tolerates errors silently."""
|
||||
import httpx
|
||||
|
||||
async with self._health_lock:
|
||||
client = self._get_http_client()
|
||||
try:
|
||||
response = await client.get(f"{self._whisper_http_url}/health", timeout=5.0)
|
||||
if response.status_code == 200:
|
||||
self._cached_health = response.json()
|
||||
except (httpx.ConnectError, httpx.HTTPError, Exception) as exc: # noqa: BLE001
|
||||
logger.debug("whisper-http health probe failed: %s", exc)
|
||||
406
contrib/apricot-stt-refactor/stt_service.py.preWHISPERHTTP
Normal file
406
contrib/apricot-stt-refactor/stt_service.py.preWHISPERHTTP
Normal file
|
|
@ -0,0 +1,406 @@
|
|||
"""Core STT service using faster-whisper for GPU-accelerated transcription."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from chatterbox_tts_service.config import ChatterboxSettings
|
||||
|
||||
from gpu_devices import is_cuda_available
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Available Whisper models
|
||||
WHISPER_MODELS = ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large-v2", "large-v3"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptionSegment:
|
||||
"""A single segment of transcribed audio."""
|
||||
|
||||
start: float
|
||||
end: float
|
||||
text: str
|
||||
confidence: float | None = None
|
||||
no_speech_prob: float | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptionResult:
|
||||
"""Result of a transcription operation."""
|
||||
|
||||
text: str
|
||||
language: str
|
||||
language_probability: float
|
||||
duration_seconds: float
|
||||
segments: list[TranscriptionSegment]
|
||||
average_confidence: float | None = None
|
||||
|
||||
|
||||
class STTService:
|
||||
"""Speech-to-Text service using faster-whisper.
|
||||
|
||||
Provides lazy model loading, GPU acceleration, and high-quality transcription.
|
||||
"""
|
||||
|
||||
def __init__(self, settings: ChatterboxSettings) -> None:
|
||||
"""Initialize the STT service.
|
||||
|
||||
Args:
|
||||
settings: Service configuration.
|
||||
"""
|
||||
self.settings = settings
|
||||
self._model: Any | None = None
|
||||
self._current_model_name: str | None = None
|
||||
self._model_lock = asyncio.Lock()
|
||||
self._device: str = "cpu"
|
||||
self._compute_type: str = "int8"
|
||||
|
||||
# Determine device and compute type
|
||||
if is_cuda_available():
|
||||
self._device = "cuda"
|
||||
self._compute_type = "float16" # Use fp16 on GPU for speed
|
||||
logger.info("STT service will use CUDA acceleration")
|
||||
else:
|
||||
self._device = "cpu"
|
||||
self._compute_type = "int8" # Use int8 quantization on CPU
|
||||
logger.info("STT service will use CPU (int8 quantization)")
|
||||
|
||||
logger.info(
|
||||
f"STTService initialized: device={self._device}, compute_type={self._compute_type}"
|
||||
)
|
||||
|
||||
@property
|
||||
def is_model_loaded(self) -> bool:
|
||||
"""Check if a model is loaded."""
|
||||
return self._model is not None
|
||||
|
||||
@property
|
||||
def current_model(self) -> str | None:
|
||||
"""Get the currently loaded model name."""
|
||||
return self._current_model_name
|
||||
|
||||
@property
|
||||
def device(self) -> str:
|
||||
"""Get the current compute device."""
|
||||
return self._device
|
||||
|
||||
@property
|
||||
def available_models(self) -> list[str]:
|
||||
"""Get list of available Whisper models."""
|
||||
return WHISPER_MODELS.copy()
|
||||
|
||||
async def load_model(self, model_name: str = "base") -> None:
|
||||
"""Load a Whisper model.
|
||||
|
||||
Args:
|
||||
model_name: Name of the Whisper model to load.
|
||||
|
||||
Raises:
|
||||
ValueError: If model_name is not a valid Whisper model.
|
||||
"""
|
||||
if model_name not in WHISPER_MODELS:
|
||||
raise ValueError(
|
||||
f"Invalid model '{model_name}'. Available models: {', '.join(WHISPER_MODELS)}"
|
||||
)
|
||||
|
||||
# If same model already loaded, skip
|
||||
if self._model is not None and self._current_model_name == model_name:
|
||||
logger.debug(f"Model '{model_name}' already loaded, skipping")
|
||||
return
|
||||
|
||||
async with self._model_lock:
|
||||
# Double-check after acquiring lock
|
||||
if self._model is not None and self._current_model_name == model_name:
|
||||
return
|
||||
|
||||
logger.info(f"Loading Whisper model: {model_name} on {self._device}")
|
||||
|
||||
# Unload existing model if different
|
||||
if self._model is not None and self._current_model_name != model_name:
|
||||
logger.info(f"Unloading previous model: {self._current_model_name}")
|
||||
del self._model
|
||||
self._model = None
|
||||
self._current_model_name = None
|
||||
|
||||
# Load model in thread pool to avoid blocking
|
||||
self._model = await asyncio.to_thread(
|
||||
self._load_model_sync, model_name
|
||||
)
|
||||
self._current_model_name = model_name
|
||||
|
||||
logger.info(f"Whisper model '{model_name}' loaded successfully")
|
||||
|
||||
def _load_model_sync(self, model_name: str) -> Any:
|
||||
"""Synchronously load the model (called in thread pool).
|
||||
|
||||
Args:
|
||||
model_name: Name of the Whisper model to load.
|
||||
|
||||
Returns:
|
||||
Loaded WhisperModel instance.
|
||||
"""
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
model = WhisperModel(
|
||||
model_name,
|
||||
device=self._device,
|
||||
compute_type=self._compute_type,
|
||||
download_root=str(self.settings.model_cache_dir),
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
async def get_model(self, model_name: str = "base") -> Any:
|
||||
"""Get the loaded model, loading if necessary.
|
||||
|
||||
Args:
|
||||
model_name: Name of the Whisper model to use.
|
||||
|
||||
Returns:
|
||||
The WhisperModel instance.
|
||||
"""
|
||||
if self._model is None or self._current_model_name != model_name:
|
||||
await self.load_model(model_name)
|
||||
return self._model
|
||||
|
||||
async def transcribe(
|
||||
self,
|
||||
audio_path: Path | str,
|
||||
*,
|
||||
model: str = "base",
|
||||
language: str | None = None,
|
||||
task: Literal["transcribe", "translate"] = "transcribe",
|
||||
temperature: float = 0.0,
|
||||
beam_size: int = 5,
|
||||
best_of: int = 5,
|
||||
patience: float = 1.0,
|
||||
vad_filter: bool = True,
|
||||
word_timestamps: bool = False,
|
||||
) -> TranscriptionResult:
|
||||
"""Transcribe audio file to text.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file (supports WAV, MP3, WebM, etc.).
|
||||
model: Whisper model name to use.
|
||||
language: Language code (e.g., 'en', 'de'). Auto-detect if None.
|
||||
task: Either 'transcribe' (keep original language) or 'translate' (to English).
|
||||
temperature: Sampling temperature (0.0 = deterministic).
|
||||
beam_size: Beam size for beam search.
|
||||
best_of: Number of candidates when sampling.
|
||||
patience: Patience value for beam search.
|
||||
vad_filter: Enable voice activity detection filter.
|
||||
word_timestamps: Generate word-level timestamps.
|
||||
|
||||
Returns:
|
||||
TranscriptionResult with text, language, and segments.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If audio file doesn't exist.
|
||||
ValueError: If invalid parameters provided.
|
||||
"""
|
||||
audio_path = Path(audio_path)
|
||||
if not audio_path.exists():
|
||||
raise FileNotFoundError(f"Audio file not found: {audio_path}")
|
||||
|
||||
# Get or load model
|
||||
whisper_model = await self.get_model(model)
|
||||
|
||||
logger.info(
|
||||
f"Transcribing: file={audio_path.name}, model={model}, "
|
||||
f"language={language or 'auto'}, task={task}"
|
||||
)
|
||||
|
||||
# Run transcription in thread pool
|
||||
result = await asyncio.to_thread(
|
||||
self._transcribe_sync,
|
||||
whisper_model,
|
||||
audio_path,
|
||||
language,
|
||||
task,
|
||||
temperature,
|
||||
beam_size,
|
||||
best_of,
|
||||
patience,
|
||||
vad_filter,
|
||||
word_timestamps,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Transcription complete: text_len={len(result.text)}, "
|
||||
f"language={result.language}, duration={result.duration_seconds:.2f}s"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _transcribe_sync(
|
||||
self,
|
||||
model: Any,
|
||||
audio_path: Path,
|
||||
language: str | None,
|
||||
task: str,
|
||||
temperature: float,
|
||||
beam_size: int,
|
||||
best_of: int,
|
||||
patience: float,
|
||||
vad_filter: bool,
|
||||
word_timestamps: bool,
|
||||
) -> TranscriptionResult:
|
||||
"""Synchronously transcribe audio (called in thread pool).
|
||||
|
||||
Args:
|
||||
model: WhisperModel instance.
|
||||
audio_path: Path to audio file.
|
||||
language: Language code or None for auto-detection.
|
||||
task: 'transcribe' or 'translate'.
|
||||
temperature: Sampling temperature.
|
||||
beam_size: Beam size for beam search.
|
||||
best_of: Number of candidates when sampling.
|
||||
patience: Patience value for beam search.
|
||||
vad_filter: Enable VAD filter.
|
||||
word_timestamps: Generate word-level timestamps.
|
||||
|
||||
Returns:
|
||||
TranscriptionResult.
|
||||
"""
|
||||
# Run transcription
|
||||
segments_iter, info = model.transcribe(
|
||||
str(audio_path),
|
||||
language=language,
|
||||
task=task,
|
||||
temperature=temperature,
|
||||
beam_size=beam_size,
|
||||
best_of=best_of,
|
||||
patience=patience,
|
||||
vad_filter=vad_filter,
|
||||
word_timestamps=word_timestamps,
|
||||
)
|
||||
|
||||
# Collect segments
|
||||
segments = []
|
||||
full_text = []
|
||||
confidences = []
|
||||
|
||||
for segment in segments_iter:
|
||||
# Extract confidence (average of word probabilities if available)
|
||||
confidence = None
|
||||
if hasattr(segment, "avg_logprob"):
|
||||
# Convert log probability to probability
|
||||
import math
|
||||
confidence = math.exp(segment.avg_logprob)
|
||||
confidences.append(confidence)
|
||||
|
||||
no_speech_prob = getattr(segment, "no_speech_prob", None)
|
||||
|
||||
segments.append(
|
||||
TranscriptionSegment(
|
||||
start=segment.start,
|
||||
end=segment.end,
|
||||
text=segment.text.strip(),
|
||||
confidence=confidence,
|
||||
no_speech_prob=no_speech_prob,
|
||||
)
|
||||
)
|
||||
full_text.append(segment.text.strip())
|
||||
|
||||
# Calculate average confidence
|
||||
avg_confidence = None
|
||||
if confidences:
|
||||
avg_confidence = sum(confidences) / len(confidences)
|
||||
|
||||
return TranscriptionResult(
|
||||
text=" ".join(full_text),
|
||||
language=info.language,
|
||||
language_probability=info.language_probability,
|
||||
duration_seconds=info.duration,
|
||||
segments=segments,
|
||||
average_confidence=avg_confidence,
|
||||
)
|
||||
|
||||
async def transcribe_bytes(
|
||||
self,
|
||||
audio_bytes: bytes,
|
||||
*,
|
||||
model: str = "base",
|
||||
language: str | None = None,
|
||||
task: Literal["transcribe", "translate"] = "transcribe",
|
||||
**kwargs,
|
||||
) -> TranscriptionResult:
|
||||
"""Transcribe audio from bytes.
|
||||
|
||||
Saves bytes to temporary file and transcribes.
|
||||
|
||||
Args:
|
||||
audio_bytes: Raw audio file bytes.
|
||||
model: Whisper model name to use.
|
||||
language: Language code or None for auto-detection.
|
||||
task: Either 'transcribe' or 'translate'.
|
||||
**kwargs: Additional arguments passed to transcribe().
|
||||
|
||||
Returns:
|
||||
TranscriptionResult.
|
||||
"""
|
||||
# Write to temporary file
|
||||
with tempfile.NamedTemporaryFile(suffix=".audio", delete=False) as tmp:
|
||||
tmp.write(audio_bytes)
|
||||
tmp_path = Path(tmp.name)
|
||||
|
||||
try:
|
||||
# Transcribe from temporary file
|
||||
result = await self.transcribe(
|
||||
tmp_path,
|
||||
model=model,
|
||||
language=language,
|
||||
task=task,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
finally:
|
||||
# Clean up temporary file
|
||||
try:
|
||||
tmp_path.unlink()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete temporary file {tmp_path}: {e}")
|
||||
|
||||
async def unload_model(self) -> None:
|
||||
"""Unload the current model and free memory."""
|
||||
async with self._model_lock:
|
||||
if self._model is not None:
|
||||
logger.info(f"Unloading Whisper model: {self._current_model_name}")
|
||||
|
||||
# Delete model
|
||||
del self._model
|
||||
self._model = None
|
||||
self._current_model_name = None
|
||||
|
||||
# Clear CUDA cache if on GPU
|
||||
if self._device == "cuda":
|
||||
try:
|
||||
import torch
|
||||
torch.cuda.empty_cache()
|
||||
logger.info("CUDA cache cleared")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clear CUDA cache: {e}")
|
||||
|
||||
logger.info("Model unloaded successfully")
|
||||
|
||||
def get_model_info(self) -> dict[str, Any]:
|
||||
"""Get information about the STT service state.
|
||||
|
||||
Returns:
|
||||
Dictionary with service status information.
|
||||
"""
|
||||
return {
|
||||
"model_loaded": self.is_model_loaded,
|
||||
"current_model": self._current_model_name,
|
||||
"device": self._device,
|
||||
"compute_type": self._compute_type,
|
||||
"available_models": self.available_models,
|
||||
}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
[Service]
|
||||
Environment="WHISPER_HTTP_VRAM_LEASE_MB=2048"
|
||||
Environment="WHISPER_HTTP_DEVICE=cuda"
|
||||
Loading…
Add table
Reference in a new issue