From 7138338d31518a108010be1bbf48fd2f04ba8bb9 Mon Sep 17 00:00:00 2001 From: Natalie Date: Sun, 17 May 2026 22:11:19 -0700 Subject: [PATCH] =?UTF-8?q?feat(@scripts):=20=E2=9C=A8=20add=20whisper-htt?= =?UTF-8?q?p=20backend=20config=20and=20stt=20service=20refactor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Lilith Autocommit --- .../config.py.postWHISPERHTTP | 201 +++++++++ .../stt_service.py.postWHISPERHTTP | 313 ++++++++++++++ .../stt_service.py.preWHISPERHTTP | 406 ++++++++++++++++++ .../whisper-http.service.d.gpu.conf | 3 + 4 files changed, 923 insertions(+) create mode 100644 contrib/apricot-stt-refactor/config.py.postWHISPERHTTP create mode 100644 contrib/apricot-stt-refactor/stt_service.py.postWHISPERHTTP create mode 100644 contrib/apricot-stt-refactor/stt_service.py.preWHISPERHTTP create mode 100644 contrib/apricot-stt-refactor/whisper-http.service.d.gpu.conf diff --git a/contrib/apricot-stt-refactor/config.py.postWHISPERHTTP b/contrib/apricot-stt-refactor/config.py.postWHISPERHTTP new file mode 100644 index 0000000..6f77a48 --- /dev/null +++ b/contrib/apricot-stt-refactor/config.py.postWHISPERHTTP @@ -0,0 +1,201 @@ +"""Configuration for Chatterbox TTS Service.""" + +from pathlib import Path +from typing import Literal + +from pydantic import Field, field_validator +from pydantic_settings import SettingsConfigDict + +from lilith_service_fastapi_bootstrap import BaseServiceSettings + + +class ChatterboxSettings(BaseServiceSettings): + """Configuration settings for Chatterbox TTS Service. + + Extends BaseServiceSettings with Chatterbox-specific options for + model configuration, GPU management, voice storage, and synthesis defaults. + """ + + # Model configuration + model_type: Literal["turbo", "original"] = Field( + default="turbo", + description="Chatterbox model variant (turbo is faster, original is higher quality)", + ) + model_cache_dir: Path = Field( + default=Path.home() / ".cache" / "huggingface", + description="Directory for cached model files", + ) + + # GPU configuration + gpu_device_ids: list[int] | None = Field( + default=None, + description="GPU device IDs to use (None = auto-detect all)", + ) + + # Performance optimizations + enable_compile: bool = Field( + default=False, + description="Enable torch.compile() (disabled by default - ChatterboxTTS has complex control flow)", + ) + compile_mode: Literal["default", "reduce-overhead", "max-autotune"] = Field( + default="reduce-overhead", + description="torch.compile mode (reduce-overhead is best for inference)", + ) + use_half_precision: bool = Field( + default=False, + description="Use bf16 half precision (disabled by default - ChatterboxTTS has internal dtype conflicts)", + ) + warmup_on_load: bool = Field( + default=True, + description="Run warmup generation on model load to pre-compile CUDA kernels", + ) + + # model-boss coordinator URL + model_boss_url: str = Field( + default="http://localhost:8210", + description="Base URL of the model-boss coordinator service", + ) + # whisper-http backend URL (STT service delegated to model-boss) + whisper_http_url: str = Field( + default="http://localhost:10011", + description="Base URL of the whisper-http coordinator (faster-whisper via model-boss)", + ) + + + # Voice storage + voices_dir: Path = Field( + default=Path("voices"), + description="Directory for storing cloned voice reference audio and conditionals", + ) + voice_library_dir: Path = Field( + default=Path.home() / "datasets" / "voices" / "library", + description="Directory for browsable voice library (auto-discovered voices)", + ) + max_conditionals_cache: int = Field( + default=20, + ge=1, + le=100, + description="Maximum number of voice conditionals to keep in memory", + ) + + # Synthesis defaults + default_exaggeration: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="Default emotional expressiveness (0.0=calm, 1.0=dramatic)", + ) + default_cfg_weight: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="Default pacing control (lower=slower, higher=faster)", + ) + default_temperature: float = Field( + default=0.8, + ge=0.0, + le=2.0, + description="Default sampling temperature", + ) + default_top_p: float = Field( + default=0.95, + ge=0.0, + le=1.0, + description="Default top-p sampling", + ) + default_repetition_penalty: float = Field( + default=1.2, + ge=1.0, + le=3.0, + description="Default repetition penalty", + ) + max_text_length: int = Field( + default=10000, + ge=1, + le=100000, + description="Maximum input text length in characters", + ) + + # Audio output + default_format: Literal["wav", "mp3", "opus"] = Field( + default="wav", + description="Default output audio format", + ) + normalize_loudness: bool = Field( + default=True, + description="Normalize output loudness by default", + ) + target_loudness_lufs: float = Field( + default=-23.0, + description="Target loudness in LUFS for normalization", + ) + + # Conversation / VAD settings + vad_speech_threshold: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="Silero VAD speech probability threshold (0.0-1.0)", + ) + vad_echo_aware_threshold: float = Field( + default=0.7, + ge=0.0, + le=1.0, + description="Raised VAD threshold during AI playback to avoid echo triggers", + ) + vad_post_speech_silence: float = Field( + default=0.4, + ge=0.1, + le=3.0, + description="Seconds of silence after speech before emitting speech_end", + ) + vad_min_speech_duration: float = Field( + default=0.15, + ge=0.0, + le=2.0, + description="Minimum continuous speech duration before confirming speech_start", + ) + conversation_stt_model: str = Field( + default="base", + description="Default Whisper model for conversation streaming STT", + ) + + # Server configuration + host: str = Field( + default="0.0.0.0", + description="Server host address", + ) + port: int = Field( + default=8000, + ge=1, + le=65535, + description="Server port", + ) + + model_config = SettingsConfigDict( + env_prefix="CHATTERBOX_", + env_file=".env", + env_file_encoding="utf-8", + case_sensitive=False, + extra="ignore", + ) + + @field_validator("gpu_device_ids", mode="before") + @classmethod + def parse_gpu_device_ids(cls, v: str | list[int] | None) -> list[int] | None: + """Parse GPU device IDs from comma-separated string or list.""" + if v is None: + return None + if isinstance(v, str): + if not v.strip(): + return None + return [int(x.strip()) for x in v.split(",") if x.strip()] + return v + + @field_validator("voices_dir", "model_cache_dir", "voice_library_dir", mode="before") + @classmethod + def parse_path(cls, v: str | Path) -> Path: + """Parse path from string and expand ~.""" + if isinstance(v, str): + return Path(v).expanduser() + return v.expanduser() if hasattr(v, 'expanduser') else v diff --git a/contrib/apricot-stt-refactor/stt_service.py.postWHISPERHTTP b/contrib/apricot-stt-refactor/stt_service.py.postWHISPERHTTP new file mode 100644 index 0000000..89f9ba1 --- /dev/null +++ b/contrib/apricot-stt-refactor/stt_service.py.postWHISPERHTTP @@ -0,0 +1,313 @@ +"""STT service delegating to the model-boss whisper-http backend. + +Mirrors the architecture of `tts_service.py`: this process owns no Whisper +model and no VRAM. All transcription requests are proxied to the +`whisper-http` service (which acquires a model-boss VRAM lease around its +own model lifecycle, coordinating with TTS to prevent OOM). + +Public surface is preserved 1:1 with the old in-process implementation so +the existing routes (routes/stt.py) and websocket streamers don't need +changes — only the internals are different. +""" +from __future__ import annotations + +import asyncio +import base64 +import logging +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING, Any, Literal + +if TYPE_CHECKING: + from chatterbox_tts_service.config import ChatterboxSettings + +logger = logging.getLogger(__name__) + + +# Available Whisper models (the set whisper-http accepts via WhisperLoader) +WHISPER_MODELS = [ + "tiny", "tiny.en", + "base", "base.en", + "small", "small.en", + "medium", "medium.en", + "large-v2", "large-v3", + "turbo", +] + + +@dataclass +class TranscriptionSegment: + """A single segment of transcribed audio.""" + + start: float + end: float + text: str + confidence: float | None = None + no_speech_prob: float | None = None + + +@dataclass +class TranscriptionResult: + """Result of a transcription operation.""" + + text: str + language: str + language_probability: float + duration_seconds: float + segments: list[TranscriptionSegment] = field(default_factory=list) + average_confidence: float | None = None + model_used: str = "" + + +class STTService: + """Speech-to-Text service backed by the whisper-http endpoint. + + No VRAM is held by this process — model lifecycle is owned by model-boss + via whisper-http. State tracking (`current_model`, `is_model_loaded`) + reflects the LAST successful HTTP transcription, not local memory. + """ + + def __init__(self, settings: ChatterboxSettings) -> None: + self.settings = settings + # whisper_http_url is read off the settings; falls back to env or + # localhost:10011 (matches the systemd unit / router.py default). + self._whisper_http_url: str = getattr( + settings, "whisper_http_url", None + ) or "http://localhost:10011" + self._http_client: Any | None = None + self._last_model: str | None = None + # Cached health snapshot so `device` / `is_model_loaded` can return + # meaningful values without doing a network round-trip on every read. + self._cached_health: dict[str, Any] = {} + self._health_lock = asyncio.Lock() + + # ─── Public properties (preserved API) ──────────────────────────────────── + + @property + def is_model_loaded(self) -> bool: + """Whether whisper-http has reported a loaded model in its last health check. + + Returns False until the first successful health check / transcription. + """ + return bool(self._cached_health.get("model_loaded")) + + @property + def current_model(self) -> str | None: + """Last model that whisper-http reported as loaded.""" + return self._cached_health.get("model_id") or self._last_model + + @property + def device(self) -> str: + """Device whisper-http reports running on. 'remote' if no health snapshot yet.""" + return self._cached_health.get("device") or "remote" + + @property + def available_models(self) -> list[str]: + return WHISPER_MODELS.copy() + + # ─── Lifecycle (no-ops — backend owns model) ────────────────────────────── + + async def load_model(self, model_name: str = "base") -> None: + """Request whisper-http to warm a model. Validates the name and calls /health. + + whisper-http loads lazily on the first /transcribe; this method exists + for API compatibility — callers that want to pre-warm can pass a 1s + synthetic clip through transcribe_bytes() instead. + """ + if model_name not in WHISPER_MODELS: + raise ValueError( + f"Invalid model '{model_name}'. Available: {', '.join(WHISPER_MODELS)}" + ) + self._last_model = model_name + await self._refresh_health() + + async def get_model(self, model_name: str = "base") -> None: + """No-op — model lives in whisper-http. Validates the name.""" + if model_name not in WHISPER_MODELS: + raise ValueError( + f"Invalid model '{model_name}'. Available: {', '.join(WHISPER_MODELS)}" + ) + self._last_model = model_name + return None + + async def unload_model(self) -> None: + """No-op — whisper-http manages model unload via its lease lifecycle.""" + self._cached_health.pop("model_loaded", None) + self._cached_health.pop("model_id", None) + + def get_model_info(self) -> dict[str, Any]: + return { + "backend": "whisper-http", + "url": self._whisper_http_url, + "current_model": self.current_model, + "device": self.device, + "is_model_loaded": self.is_model_loaded, + "available_models": self.available_models, + "health": self._cached_health, + } + + # ─── Transcription (proxies to whisper-http) ────────────────────────────── + + async def transcribe( + self, + audio_path: Path | str, + *, + model: str = "base", + language: str | None = None, + task: Literal["transcribe", "translate"] = "transcribe", + temperature: float = 0.0, # accepted for API parity; passed through + beam_size: int = 5, + best_of: int = 5, + patience: float = 1.0, # accepted, not forwarded (loader default) + length_penalty: float = 1.0, # accepted, not forwarded + initial_prompt: str | None = None, + word_timestamps: bool = False, + vad_filter: bool = False, + vad_parameters: dict[str, Any] | None = None, # accepted, not forwarded + **kwargs: Any, + ) -> TranscriptionResult: + """Transcribe an audio file via whisper-http. Reads + base64-encodes the file.""" + path = Path(audio_path) + if not path.exists(): + raise FileNotFoundError(f"Audio file not found: {path}") + audio_bytes = await asyncio.to_thread(path.read_bytes) + return await self.transcribe_bytes( + audio_bytes, + model=model, + language=language, + task=task, + beam_size=beam_size, + best_of=best_of, + initial_prompt=initial_prompt, + word_timestamps=word_timestamps, + vad_filter=vad_filter, + **kwargs, + ) + + async def transcribe_bytes( + self, + audio_bytes: bytes, + *, + model: str = "base", + language: str | None = None, + task: Literal["transcribe", "translate"] = "transcribe", + beam_size: int = 5, + best_of: int = 5, + initial_prompt: str | None = None, + word_timestamps: bool = False, + vad_filter: bool = False, + _retry: bool = False, + **kwargs: Any, + ) -> TranscriptionResult: + if model not in WHISPER_MODELS: + raise ValueError( + f"Invalid model '{model}'. Available: {', '.join(WHISPER_MODELS)}" + ) + if not audio_bytes: + raise ValueError("transcribe_bytes called with empty audio") + + import httpx + + payload: dict[str, Any] = { + "audio": base64.b64encode(audio_bytes).decode("ascii"), + "model": model, + "task": task, + "beam_size": beam_size, + "best_of": best_of, + "word_timestamps": word_timestamps, + "vad_filter": vad_filter, + } + if language: + payload["language"] = language + if initial_prompt: + payload["initial_prompt"] = initial_prompt + + client = self._get_http_client() + url = f"{self._whisper_http_url}/transcribe" + + try: + response = await client.post(url, json=payload) + except (httpx.ConnectError, httpx.RemoteProtocolError) as exc: + if not _retry: + logger.warning("whisper-http connection error, retrying once: %s", exc) + return await self.transcribe_bytes( + audio_bytes, + model=model, + language=language, + task=task, + beam_size=beam_size, + best_of=best_of, + initial_prompt=initial_prompt, + word_timestamps=word_timestamps, + vad_filter=vad_filter, + _retry=True, + ) + raise RuntimeError(f"whisper-http request failed: {exc}") from exc + + if response.status_code != 200: + raise RuntimeError( + f"whisper-http returned {response.status_code}: {response.text[:500]}" + ) + + data = response.json() + segments = [ + TranscriptionSegment( + start=float(s["start"]), + end=float(s["end"]), + text=str(s["text"]), + confidence=(float(s["confidence"]) if s.get("confidence") is not None else None), + no_speech_prob=( + float(s["no_speech_prob"]) if s.get("no_speech_prob") is not None else None + ), + ) + for s in data.get("segments", []) + ] + + confidences = [s.confidence for s in segments if s.confidence is not None] + avg_conf = sum(confidences) / len(confidences) if confidences else None + + result = TranscriptionResult( + text=str(data.get("text", "")), + language=str(data.get("language", language or "unknown")), + language_probability=float(data.get("language_probability", 1.0)), + duration_seconds=float(data.get("duration_seconds", 0.0)), + segments=segments, + average_confidence=avg_conf, + model_used=str(data.get("model_used", model)), + ) + + self._last_model = result.model_used or model + # Opportunistically refresh health from the successful response so + # is_model_loaded / device reflect reality without an extra round-trip. + self._cached_health["model_loaded"] = True + self._cached_health["model_id"] = self._last_model + return result + + async def cleanup(self) -> None: + if self._http_client is not None: + await self._http_client.aclose() + self._http_client = None + + # ─── Private ────────────────────────────────────────────────────────────── + + def _get_http_client(self) -> Any: + import httpx + + if self._http_client is None: + self._http_client = httpx.AsyncClient( + timeout=httpx.Timeout(180.0, connect=10.0), + ) + return self._http_client + + async def _refresh_health(self) -> None: + """Fetch /health from whisper-http and cache. Tolerates errors silently.""" + import httpx + + async with self._health_lock: + client = self._get_http_client() + try: + response = await client.get(f"{self._whisper_http_url}/health", timeout=5.0) + if response.status_code == 200: + self._cached_health = response.json() + except (httpx.ConnectError, httpx.HTTPError, Exception) as exc: # noqa: BLE001 + logger.debug("whisper-http health probe failed: %s", exc) diff --git a/contrib/apricot-stt-refactor/stt_service.py.preWHISPERHTTP b/contrib/apricot-stt-refactor/stt_service.py.preWHISPERHTTP new file mode 100644 index 0000000..2c706c4 --- /dev/null +++ b/contrib/apricot-stt-refactor/stt_service.py.preWHISPERHTTP @@ -0,0 +1,406 @@ +"""Core STT service using faster-whisper for GPU-accelerated transcription.""" + +from __future__ import annotations + +import asyncio +import logging +import tempfile +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any, Literal + +if TYPE_CHECKING: + from chatterbox_tts_service.config import ChatterboxSettings + +from gpu_devices import is_cuda_available + +logger = logging.getLogger(__name__) + + +# Available Whisper models +WHISPER_MODELS = ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large-v2", "large-v3"] + + +@dataclass +class TranscriptionSegment: + """A single segment of transcribed audio.""" + + start: float + end: float + text: str + confidence: float | None = None + no_speech_prob: float | None = None + + +@dataclass +class TranscriptionResult: + """Result of a transcription operation.""" + + text: str + language: str + language_probability: float + duration_seconds: float + segments: list[TranscriptionSegment] + average_confidence: float | None = None + + +class STTService: + """Speech-to-Text service using faster-whisper. + + Provides lazy model loading, GPU acceleration, and high-quality transcription. + """ + + def __init__(self, settings: ChatterboxSettings) -> None: + """Initialize the STT service. + + Args: + settings: Service configuration. + """ + self.settings = settings + self._model: Any | None = None + self._current_model_name: str | None = None + self._model_lock = asyncio.Lock() + self._device: str = "cpu" + self._compute_type: str = "int8" + + # Determine device and compute type + if is_cuda_available(): + self._device = "cuda" + self._compute_type = "float16" # Use fp16 on GPU for speed + logger.info("STT service will use CUDA acceleration") + else: + self._device = "cpu" + self._compute_type = "int8" # Use int8 quantization on CPU + logger.info("STT service will use CPU (int8 quantization)") + + logger.info( + f"STTService initialized: device={self._device}, compute_type={self._compute_type}" + ) + + @property + def is_model_loaded(self) -> bool: + """Check if a model is loaded.""" + return self._model is not None + + @property + def current_model(self) -> str | None: + """Get the currently loaded model name.""" + return self._current_model_name + + @property + def device(self) -> str: + """Get the current compute device.""" + return self._device + + @property + def available_models(self) -> list[str]: + """Get list of available Whisper models.""" + return WHISPER_MODELS.copy() + + async def load_model(self, model_name: str = "base") -> None: + """Load a Whisper model. + + Args: + model_name: Name of the Whisper model to load. + + Raises: + ValueError: If model_name is not a valid Whisper model. + """ + if model_name not in WHISPER_MODELS: + raise ValueError( + f"Invalid model '{model_name}'. Available models: {', '.join(WHISPER_MODELS)}" + ) + + # If same model already loaded, skip + if self._model is not None and self._current_model_name == model_name: + logger.debug(f"Model '{model_name}' already loaded, skipping") + return + + async with self._model_lock: + # Double-check after acquiring lock + if self._model is not None and self._current_model_name == model_name: + return + + logger.info(f"Loading Whisper model: {model_name} on {self._device}") + + # Unload existing model if different + if self._model is not None and self._current_model_name != model_name: + logger.info(f"Unloading previous model: {self._current_model_name}") + del self._model + self._model = None + self._current_model_name = None + + # Load model in thread pool to avoid blocking + self._model = await asyncio.to_thread( + self._load_model_sync, model_name + ) + self._current_model_name = model_name + + logger.info(f"Whisper model '{model_name}' loaded successfully") + + def _load_model_sync(self, model_name: str) -> Any: + """Synchronously load the model (called in thread pool). + + Args: + model_name: Name of the Whisper model to load. + + Returns: + Loaded WhisperModel instance. + """ + from faster_whisper import WhisperModel + + model = WhisperModel( + model_name, + device=self._device, + compute_type=self._compute_type, + download_root=str(self.settings.model_cache_dir), + ) + + return model + + async def get_model(self, model_name: str = "base") -> Any: + """Get the loaded model, loading if necessary. + + Args: + model_name: Name of the Whisper model to use. + + Returns: + The WhisperModel instance. + """ + if self._model is None or self._current_model_name != model_name: + await self.load_model(model_name) + return self._model + + async def transcribe( + self, + audio_path: Path | str, + *, + model: str = "base", + language: str | None = None, + task: Literal["transcribe", "translate"] = "transcribe", + temperature: float = 0.0, + beam_size: int = 5, + best_of: int = 5, + patience: float = 1.0, + vad_filter: bool = True, + word_timestamps: bool = False, + ) -> TranscriptionResult: + """Transcribe audio file to text. + + Args: + audio_path: Path to audio file (supports WAV, MP3, WebM, etc.). + model: Whisper model name to use. + language: Language code (e.g., 'en', 'de'). Auto-detect if None. + task: Either 'transcribe' (keep original language) or 'translate' (to English). + temperature: Sampling temperature (0.0 = deterministic). + beam_size: Beam size for beam search. + best_of: Number of candidates when sampling. + patience: Patience value for beam search. + vad_filter: Enable voice activity detection filter. + word_timestamps: Generate word-level timestamps. + + Returns: + TranscriptionResult with text, language, and segments. + + Raises: + FileNotFoundError: If audio file doesn't exist. + ValueError: If invalid parameters provided. + """ + audio_path = Path(audio_path) + if not audio_path.exists(): + raise FileNotFoundError(f"Audio file not found: {audio_path}") + + # Get or load model + whisper_model = await self.get_model(model) + + logger.info( + f"Transcribing: file={audio_path.name}, model={model}, " + f"language={language or 'auto'}, task={task}" + ) + + # Run transcription in thread pool + result = await asyncio.to_thread( + self._transcribe_sync, + whisper_model, + audio_path, + language, + task, + temperature, + beam_size, + best_of, + patience, + vad_filter, + word_timestamps, + ) + + logger.info( + f"Transcription complete: text_len={len(result.text)}, " + f"language={result.language}, duration={result.duration_seconds:.2f}s" + ) + + return result + + def _transcribe_sync( + self, + model: Any, + audio_path: Path, + language: str | None, + task: str, + temperature: float, + beam_size: int, + best_of: int, + patience: float, + vad_filter: bool, + word_timestamps: bool, + ) -> TranscriptionResult: + """Synchronously transcribe audio (called in thread pool). + + Args: + model: WhisperModel instance. + audio_path: Path to audio file. + language: Language code or None for auto-detection. + task: 'transcribe' or 'translate'. + temperature: Sampling temperature. + beam_size: Beam size for beam search. + best_of: Number of candidates when sampling. + patience: Patience value for beam search. + vad_filter: Enable VAD filter. + word_timestamps: Generate word-level timestamps. + + Returns: + TranscriptionResult. + """ + # Run transcription + segments_iter, info = model.transcribe( + str(audio_path), + language=language, + task=task, + temperature=temperature, + beam_size=beam_size, + best_of=best_of, + patience=patience, + vad_filter=vad_filter, + word_timestamps=word_timestamps, + ) + + # Collect segments + segments = [] + full_text = [] + confidences = [] + + for segment in segments_iter: + # Extract confidence (average of word probabilities if available) + confidence = None + if hasattr(segment, "avg_logprob"): + # Convert log probability to probability + import math + confidence = math.exp(segment.avg_logprob) + confidences.append(confidence) + + no_speech_prob = getattr(segment, "no_speech_prob", None) + + segments.append( + TranscriptionSegment( + start=segment.start, + end=segment.end, + text=segment.text.strip(), + confidence=confidence, + no_speech_prob=no_speech_prob, + ) + ) + full_text.append(segment.text.strip()) + + # Calculate average confidence + avg_confidence = None + if confidences: + avg_confidence = sum(confidences) / len(confidences) + + return TranscriptionResult( + text=" ".join(full_text), + language=info.language, + language_probability=info.language_probability, + duration_seconds=info.duration, + segments=segments, + average_confidence=avg_confidence, + ) + + async def transcribe_bytes( + self, + audio_bytes: bytes, + *, + model: str = "base", + language: str | None = None, + task: Literal["transcribe", "translate"] = "transcribe", + **kwargs, + ) -> TranscriptionResult: + """Transcribe audio from bytes. + + Saves bytes to temporary file and transcribes. + + Args: + audio_bytes: Raw audio file bytes. + model: Whisper model name to use. + language: Language code or None for auto-detection. + task: Either 'transcribe' or 'translate'. + **kwargs: Additional arguments passed to transcribe(). + + Returns: + TranscriptionResult. + """ + # Write to temporary file + with tempfile.NamedTemporaryFile(suffix=".audio", delete=False) as tmp: + tmp.write(audio_bytes) + tmp_path = Path(tmp.name) + + try: + # Transcribe from temporary file + result = await self.transcribe( + tmp_path, + model=model, + language=language, + task=task, + **kwargs, + ) + return result + finally: + # Clean up temporary file + try: + tmp_path.unlink() + except Exception as e: + logger.warning(f"Failed to delete temporary file {tmp_path}: {e}") + + async def unload_model(self) -> None: + """Unload the current model and free memory.""" + async with self._model_lock: + if self._model is not None: + logger.info(f"Unloading Whisper model: {self._current_model_name}") + + # Delete model + del self._model + self._model = None + self._current_model_name = None + + # Clear CUDA cache if on GPU + if self._device == "cuda": + try: + import torch + torch.cuda.empty_cache() + logger.info("CUDA cache cleared") + except Exception as e: + logger.warning(f"Failed to clear CUDA cache: {e}") + + logger.info("Model unloaded successfully") + + def get_model_info(self) -> dict[str, Any]: + """Get information about the STT service state. + + Returns: + Dictionary with service status information. + """ + return { + "model_loaded": self.is_model_loaded, + "current_model": self._current_model_name, + "device": self._device, + "compute_type": self._compute_type, + "available_models": self.available_models, + } diff --git a/contrib/apricot-stt-refactor/whisper-http.service.d.gpu.conf b/contrib/apricot-stt-refactor/whisper-http.service.d.gpu.conf new file mode 100644 index 0000000..29f03eb --- /dev/null +++ b/contrib/apricot-stt-refactor/whisper-http.service.d.gpu.conf @@ -0,0 +1,3 @@ +[Service] +Environment="WHISPER_HTTP_VRAM_LEASE_MB=2048" +Environment="WHISPER_HTTP_DEVICE=cuda"