diff --git a/imajin/src/imajin/main.py b/imajin/src/imajin/main.py index ddf84183..b1a88163 100644 --- a/imajin/src/imajin/main.py +++ b/imajin/src/imajin/main.py @@ -15,9 +15,11 @@ from pydantic import BaseModel, Field, ValidationError from .settings import get_settings from .contracts import ( validate_diffusion_response, - validate_prompt_response, + validate_classify_response, + validate_generate_prompt_response, DiffusionGenerateRequest, - PromptAnalyzeRequest, + ClassifyRequest, + GeneratePromptRequest, ) @@ -90,9 +92,13 @@ class ServiceClients: def __init__(self) -> None: settings = get_settings() - self.prompt_client = httpx.AsyncClient( - base_url=settings.imajin_prompt_url, - timeout=settings.prompt_timeout_ms / 1000, + self.classifier_client = httpx.AsyncClient( + base_url=settings.imajin_classifier_url, + timeout=settings.classifier_timeout_ms / 1000, + ) + self.prompt_generator_client = httpx.AsyncClient( + base_url=settings.imajin_prompt_generator_url, + timeout=settings.prompt_generator_timeout_ms / 1000, ) self.diffusion_client = httpx.AsyncClient( base_url=settings.imajin_diffusion_url, @@ -104,7 +110,8 @@ class ServiceClients: ) async def close(self) -> None: - await self.prompt_client.aclose() + await self.classifier_client.aclose() + await self.prompt_generator_client.aclose() await self.diffusion_client.aclose() await self.processing_client.aclose() diff --git a/imajin/src/imajin/settings.py b/imajin/src/imajin/settings.py index 676f52d2..328d8b61 100644 --- a/imajin/src/imajin/settings.py +++ b/imajin/src/imajin/settings.py @@ -74,7 +74,8 @@ class ImajinSettings(BaseSettings): api_port: int = 8080 # Timeouts (ms) - prompt_timeout_ms: int = 120000 # 2 min for LLM + classifier_timeout_ms: int = 30000 # 30s for Stage 1 classification + prompt_generator_timeout_ms: int = 90000 # 90s for Stage 2 prompt generation diffusion_timeout_ms: int = 300000 # 5 min for image gen processing_timeout_ms: int = 60000 # 1 min for post-processing @@ -89,12 +90,16 @@ class ImajinSettings(BaseSettings): port = _get_imajin_port("imajin", "diffusion", 8002) self.imajin_diffusion_url = f"http://127.0.0.1:{port}" - if not self.imajin_prompt_url: - port = _get_imajin_port("imajin", "prompt", 8003) - self.imajin_prompt_url = f"http://127.0.0.1:{port}" + if not self.imajin_classifier_url: + port = _get_imajin_port("imajin", "classifier", 8005) + self.imajin_classifier_url = f"http://127.0.0.1:{port}" + + if not self.imajin_prompt_generator_url: + port = _get_imajin_port("imajin", "prompt-generator", 8006) + self.imajin_prompt_generator_url = f"http://127.0.0.1:{port}" if not self.imajin_processing_url: - port = _get_imajin_port("imajin", "processing", 8005) + port = _get_imajin_port("imajin", "processing", 8004) self.imajin_processing_url = f"http://127.0.0.1:{port}" if not self.api_port: