#!/usr/bin/env python3
"""
Imajin CLI - Service orchestration and testing tool

Manages all imajin microservices:
- llama-http (port 8202) - LLM backend for Ministral-14B
- imajin-request-classifier (port 8005) - Stage 1: Cultural classification
- imajin-prompt-generator (port 8006) - Stage 2: SDXL prompt generation
- imajin-diffusion (port 8002) - Image generation
- imajin-processing (port 8004) - Post-processing
- imajin orchestrator (port 8080) - Main API

Usage:
    imajin start [service]     - Start all services or specific service
    imajin stop [service]      - Stop all services or specific service
    imajin health              - Check health of all services
    imajin test <request>      - Make test request to orchestrator
    imajin logs <service>      - Tail logs for service
"""

import asyncio
import atexit
import base64
import json
import os
import signal
import socket
import subprocess
import sys
import time
import uuid
from datetime import datetime
from pathlib import Path
from typing import Dict, Optional

import httpx

# Service definitions
SERVICES = {
    "llama-http": {
        "port": 8202,
        "cwd": Path.home() / "Code/@applications/@ml/llama-http",
        "command": [
            "python",
            "-m",
            "llama_http",
        ],
        "env": {
            "LLAMA_HTTP_MODEL_ID": "ministral-14b-reasoning",
            "LLAMA_HTTP_PORT": "8202",
        },
        "health": "http://localhost:8202/health",
        "venv": ".venv",
    },
    "reasoning": {
        "port": 8007,
        "cwd": Path(__file__).parent.parent.parent / "services/imajin-reasoning/service",
        "command": ["python", "-m", "src.api.main"],
        "health": "http://localhost:8007/health",
        "venv": ".venv",
    },
    "classifier": {
        "port": 8005,
        "cwd": Path(__file__).parent.parent.parent / "services/imajin-request-classifier",
        "command": ["python", "-m", "service.src.api.main"],
        "health": "http://localhost:8005/health",
        "venv": "service/.venv",
    },
    "prompt-generator": {
        "port": 8006,
        "cwd": Path(__file__).parent.parent.parent / "services/imajin-prompt-generator",
        "command": ["python", "-m", "service.src.api.main"],
        "health": "http://localhost:8006/health",
        "venv": "service/.venv",
    },
    "diffusion": {
        "port": 8002,
        "cwd": Path(__file__).parent.parent.parent / "services/imajin-diffusion/service",
        "command": ["python", "-m", "src.api.main"],
        "health": "http://localhost:8002/health",
        "venv": ".venv",
    },
    "processing": {
        "port": 8004,
        "cwd": Path(__file__).parent.parent.parent / "services/imajin-processing/service",
        "command": ["npm", "run", "start:dev"],
        "health": "http://localhost:8004/health",
    },
    "orchestrator": {
        "port": 8080,
        "cwd": Path(__file__).parent.parent.parent / "imajin/src",
        "command": ["python", "-m", "imajin.main"],
        "health": "http://localhost:8080/health",
        "venv": "../.venv",
    },
}

# Process tracking
PROCESSES: Dict[str, subprocess.Popen] = {}

# Services required for test command (in startup order)
REQUIRED_SERVICES_FOR_TEST = [
    "llama-http",       # LLM backend (must start first)
    "classifier",       # Cultural classification
    "prompt-generator", # SDXL prompt generation
    "diffusion",        # Image generation
    "orchestrator",     # Main API
]

# Test sessions for isolated concurrent execution
TEST_SESSIONS: Dict[str, "TestSession"] = {}


class TestSession:
    """Isolated test session with dynamic port allocation."""

    def __init__(self, session_id: str, ports: Dict[str, int]):
        self.session_id = session_id
        self.ports = ports
        self.processes: Dict[str, subprocess.Popen] = {}
        self.orchestrator_url = f"http://localhost:{ports['orchestrator']}"

    def cleanup(self):
        """Stop all services for this test session."""
        for service_name, proc in list(self.processes.items()):
            try:
                proc.terminate()
                proc.wait(timeout=5)
                print(f"✅ Stopped {service_name} (session {self.session_id[:8]})")
            except subprocess.TimeoutExpired:
                proc.kill()
                print(f"⚠️  Force killed {service_name} (session {self.session_id[:8]})")
            except Exception as e:
                print(f"⚠️  Failed to stop {service_name}: {e}")

        # Remove from global registry
        if self.session_id in TEST_SESSIONS:
            del TEST_SESSIONS[self.session_id]


def find_available_ports(count: int, start_port: int = 9000) -> list[int]:
    """Find N consecutive available ports starting from start_port.

    Args:
        count: Number of consecutive ports needed
        start_port: Port to start searching from

    Returns:
        List of available port numbers
    """
    def is_port_available(port: int) -> bool:
        """Check if a port is available for binding."""
        try:
            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
                sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
                sock.bind(("", port))
                return True
        except OSError:
            return False

    # Try to find consecutive ports
    current_port = start_port
    max_attempts = 1000

    for _ in range(max_attempts):
        ports = list(range(current_port, current_port + count))
        if all(is_port_available(p) for p in ports):
            return ports
        current_port += 1

    raise RuntimeError(f"Could not find {count} consecutive available ports after {max_attempts} attempts")


def allocate_test_ports() -> Dict[str, int]:
    """Allocate dynamic ports for a test session.

    Returns:
        Dict mapping service names to allocated ports
    """
    ports_needed = 5  # llama-http, classifier, prompt-generator, diffusion, orchestrator
    available_ports = find_available_ports(ports_needed)

    return {
        "llama-http": available_ports[0],
        "classifier": available_ports[1],
        "prompt-generator": available_ports[2],
        "diffusion": available_ports[3],
        "orchestrator": available_ports[4],
    }


def get_venv_python(service_name: str) -> str:
    """Get path to Python in venv for service."""
    service = SERVICES[service_name]
    if "venv" not in service:
        return "python"

    venv_path = service["cwd"] / service["venv"]
    python_path = venv_path / "bin/python"
    return str(python_path) if python_path.exists() else "python"


def start_service_with_config(
    service_name: str,
    port: int,
    extra_env: Dict[str, str],
    session: Optional[TestSession] = None,
) -> Optional[subprocess.Popen]:
    """Start a service with custom port and environment configuration.

    Args:
        service_name: Name of service to start
        port: Port number to use
        extra_env: Additional environment variables
        session: Test session to track process in (if part of test)

    Returns:
        Process handle if started successfully
    """
    service = SERVICES[service_name]
    cwd = service["cwd"]

    if not cwd.exists():
        print(f"❌ Service directory not found: {cwd}")
        return None

    # Build command with venv python if applicable
    command = service["command"].copy()
    if "venv" in service and command[0] == "python":
        command[0] = get_venv_python(service_name)

    # Build environment
    env = os.environ.copy()
    if "env" in service:
        env.update(service["env"])
    env.update(extra_env)

    # Start process
    session_info = f" (session {session.session_id[:8]})" if session else ""
    print(f"🚀 Starting {service_name} on port {port}{session_info}...")

    proc = subprocess.Popen(
        command,
        cwd=str(cwd),
        env=env,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True,
    )

    # Track in session if provided, otherwise in global registry
    if session:
        session.processes[service_name] = proc
    else:
        PROCESSES[service_name] = proc

    print(f"✅ Started {service_name} (PID {proc.pid})")
    return proc


async def check_health(service_name: str, port: Optional[int] = None) -> bool:
    """Check if service is healthy.

    Args:
        service_name: Name of service to check
        port: Custom port (if None, uses default from SERVICES)

    Returns:
        True if healthy, False otherwise
    """
    service = SERVICES[service_name]
    health_url = service.get("health")
    if not health_url:
        return True  # No health check defined

    # Override port if provided
    if port is not None:
        default_port = service["port"]
        health_url = health_url.replace(f":{default_port}", f":{port}")

    try:
        async with httpx.AsyncClient() as client:
            response = await client.get(health_url, timeout=5)
            return response.status_code == 200
    except Exception:
        return False


async def start_test_services(session: TestSession) -> bool:
    """Start all required services for a test session with allocated ports.

    Args:
        session: Test session with allocated ports

    Returns:
        True if all services started and became healthy
    """
    ports = session.ports

    # Configure environment variables for each service
    # llama-http
    llama_env = {
        "LLAMA_HTTP_PORT": str(ports["llama-http"]),
    }
    start_service_with_config("llama-http", ports["llama-http"], llama_env, session)
    time.sleep(2)

    # classifier - needs to know llama-http URL (base URL only, path appended by client)
    classifier_env = {
        "PORT": str(ports["classifier"]),
        "LLM_SERVICE_URL": f"http://localhost:{ports['llama-http']}",
    }
    start_service_with_config("classifier", ports["classifier"], classifier_env, session)
    time.sleep(2)

    # prompt-generator - needs to know llama-http URL (base URL only, path appended by client)
    prompt_gen_env = {
        "PORT": str(ports["prompt-generator"]),
        "LLM_SERVICE_URL": f"http://localhost:{ports['llama-http']}",
    }
    start_service_with_config("prompt-generator", ports["prompt-generator"], prompt_gen_env, session)
    time.sleep(2)

    # diffusion
    diffusion_env = {
        "IMAGE_GEN_PORT": str(ports["diffusion"]),
    }
    start_service_with_config("diffusion", ports["diffusion"], diffusion_env, session)
    time.sleep(2)

    # orchestrator - needs to know all downstream service URLs
    # Uses IMAJIN_ prefix for env vars (see imajin/src/imajin/settings.py)
    orchestrator_env = {
        "IMAJIN_API_PORT": str(ports["orchestrator"]),
        "IMAJIN_IMAJIN_CLASSIFIER_URL": f"http://localhost:{ports['classifier']}",
        "IMAJIN_IMAJIN_PROMPT_GENERATOR_URL": f"http://localhost:{ports['prompt-generator']}",
        "IMAJIN_IMAJIN_DIFFUSION_URL": f"http://localhost:{ports['diffusion']}",
    }
    start_service_with_config("orchestrator", ports["orchestrator"], orchestrator_env, session)

    # Wait for all services to become healthy
    print(f"\n⏳ Waiting for services to become healthy (session {session.session_id[:8]})...\n")
    max_wait = 180

    for service_name in REQUIRED_SERVICES_FOR_TEST:
        port = ports[service_name]
        waited = 0
        while waited < max_wait:
            healthy = await check_health(service_name, port)
            if healthy:
                print(f"   ✅ {service_name} is healthy (port {port})")
                break
            await asyncio.sleep(2)
            waited += 2

        if waited >= max_wait:
            print(f"   ❌ {service_name} failed to become healthy after {max_wait}s")
            session.cleanup()
            return False

    print(f"\n✅ All services healthy for session {session.session_id[:8]}!\n")
    return True


def start_service(service_name: str, background: bool = True) -> Optional[subprocess.Popen]:
    """Start a service."""
    if service_name in PROCESSES:
        print(f"⚠️  {service_name} already running (PID {PROCESSES[service_name].pid})")
        return PROCESSES[service_name]

    service = SERVICES[service_name]
    cwd = service["cwd"]

    if not cwd.exists():
        print(f"❌ Service directory not found: {cwd}")
        return None

    # Build command with venv python if applicable
    command = service["command"].copy()
    if "venv" in service and command[0] == "python":
        command[0] = get_venv_python(service_name)

    # Build environment
    env = os.environ.copy()
    if "env" in service:
        env.update(service["env"])

    # Start process
    print(f"🚀 Starting {service_name} on port {service['port']}...")

    if background:
        proc = subprocess.Popen(
            command,
            cwd=str(cwd),
            env=env,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
        )
        PROCESSES[service_name] = proc
        print(f"✅ Started {service_name} (PID {proc.pid})")
        return proc
    else:
        # Run in foreground
        subprocess.run(command, cwd=str(cwd), env=env)
        return None


def stop_service(service_name: str):
    """Stop a service."""
    if service_name not in PROCESSES:
        # Try to kill by port
        service = SERVICES[service_name]
        port = service["port"]
        result = subprocess.run(
            ["lsof", "-ti", f":{port}"],
            capture_output=True,
            text=True,
        )
        if result.stdout.strip():
            pids = result.stdout.strip().split("\n")
            for pid in pids:
                try:
                    os.kill(int(pid), signal.SIGTERM)
                    print(f"✅ Stopped {service_name} (PID {pid})")
                except Exception as e:
                    print(f"⚠️  Failed to stop {service_name}: {e}")
        else:
            print(f"⚠️  {service_name} not running")
        return

    proc = PROCESSES[service_name]
    proc.terminate()
    try:
        proc.wait(timeout=5)
        print(f"✅ Stopped {service_name} (PID {proc.pid})")
    except subprocess.TimeoutExpired:
        proc.kill()
        print(f"⚠️  Force killed {service_name} (PID {proc.pid})")

    del PROCESSES[service_name]


async def health_check_all():
    """Check health of all services."""
    print("\n🏥 Health Check\n" + "=" * 50)

    for service_name in SERVICES:
        healthy = await check_health(service_name)
        status = "✅ HEALTHY" if healthy else "❌ UNHEALTHY"
        port = SERVICES[service_name]["port"]
        print(f"{service_name:20} (port {port:5}) {status}")

    print()


async def ensure_services_running(required_services: list[str]) -> bool:
    """
    Ensure all required services are running and healthy.

    Args:
        required_services: List of service names that must be running

    Returns:
        True if all services are healthy, False otherwise
    """
    print("\n🔍 Checking service health...\n")

    # Check current health status
    unhealthy = []
    for service_name in required_services:
        healthy = await check_health(service_name)
        status = "✅" if healthy else "❌"
        print(f"   {service_name:20} {status}")
        if not healthy:
            unhealthy.append(service_name)

    # Start unhealthy services
    if unhealthy:
        print(f"\n🚀 Starting {len(unhealthy)} service(s)...\n")
        for service_name in unhealthy:
            start_service(service_name, background=True)
            time.sleep(2)  # Brief delay between starts

        # Wait for services to become healthy (max 180s per service for SDXL model loading)
        print("\n⏳ Waiting for services to become healthy...\n")
        max_wait = 180
        for service_name in unhealthy:
            waited = 0
            while waited < max_wait:
                healthy = await check_health(service_name)
                if healthy:
                    print(f"   ✅ {service_name} is healthy")
                    break
                await asyncio.sleep(2)
                waited += 2

            if waited >= max_wait:
                print(f"   ❌ {service_name} failed to become healthy after {max_wait}s")
                return False

    print("\n✅ All required services are healthy!\n")
    return True


async def test_request(category: str, city: str, filters: str):
    """Make a test request with isolated services on dynamic ports.

    Creates a new test session, starts all services with allocated ports,
    runs the test, and cleans up afterward.
    """
    # Create test session with allocated ports
    session_id = str(uuid.uuid4())
    try:
        ports = allocate_test_ports()
    except RuntimeError as e:
        print(f"❌ Port allocation failed: {e}")
        return

    session = TestSession(session_id, ports)
    TEST_SESSIONS[session_id] = session

    print(f"\n🔧 Test Session {session_id[:8]}")
    print(f"   Ports allocated: llama-http={ports['llama-http']}, "
          f"classifier={ports['classifier']}, prompt-gen={ports['prompt-generator']}, "
          f"diffusion={ports['diffusion']}, orchestrator={ports['orchestrator']}\n")

    # Register cleanup on exit
    def cleanup_on_exit():
        if session_id in TEST_SESSIONS:
            TEST_SESSIONS[session_id].cleanup()

    atexit.register(cleanup_on_exit)

    try:
        # Start all services for this test session
        services_ready = await start_test_services(session)
        if not services_ready:
            print("\n❌ Failed to start required services. Aborting test.")
            session.cleanup()
            return

        filter_list = filters.split(",") if filters else []

        request_data = {
            "category": category,
            "city": city,
            "role": "hero",
            "filters": filter_list,
        }

        print(f"📝 Test Request")
        print(f"   Category: {category}")
        print(f"   City: {city}")
        print(f"   Filters: {filter_list}")
        print()

        # Make request to orchestrator using session URL
        async with httpx.AsyncClient(timeout=180) as client:
            response = await client.post(
                f"{session.orchestrator_url}/generate",
                json=request_data,
            )
            response.raise_for_status()

            data = response.json()

            if not data.get("success"):
                print(f"❌ Generation failed: {data.get('error', 'Unknown error')}")
                return

            print("✅ Success!")
            print(f"   Model: {data['metadata']['model']}")
            print(f"   Prompt: {data['metadata']['prompt'][:100]}...")

            # Save image to /tmp
            if data.get("image_base64"):
                image_data = base64.b64decode(data["image_base64"])
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                output_path = f"/tmp/imajin_test_{timestamp}.png"

                with open(output_path, "wb") as f:
                    f.write(image_data)

                print(f"   💾 Image saved to: {output_path}")
                print(f"   📏 Size: {len(image_data)} bytes")
            else:
                print("   ⚠️  No image data in response")

    except httpx.HTTPStatusError as e:
        print(f"❌ Request failed with status {e.response.status_code}")
        print(f"   Response: {e.response.text}")
    except httpx.RequestError as e:
        print(f"❌ Request failed: {e}")
    except Exception as e:
        print(f"❌ Unexpected error: {e}")
    finally:
        # Cleanup test session
        print(f"\n🧹 Cleaning up test session {session_id[:8]}...")
        session.cleanup()
        print("✅ Cleanup complete")


def main():
    if len(sys.argv) < 2:
        print(__doc__)
        sys.exit(1)

    command = sys.argv[1]

    if command == "start":
        service_name = sys.argv[2] if len(sys.argv) > 2 else None

        if service_name:
            if service_name not in SERVICES:
                print(f"❌ Unknown service: {service_name}")
                print(f"   Available: {', '.join(SERVICES.keys())}")
                sys.exit(1)
            start_service(service_name)
        else:
            # Start all services in dependency order
            print("🚀 Starting all imajin services...\n")
            for svc in ["llama-http", "classifier", "prompt-generator", "diffusion", "processing", "orchestrator"]:
                start_service(svc)
                time.sleep(2)  # Brief delay between services

            print("\n✅ All services started!")
            asyncio.run(health_check_all())

    elif command == "stop":
        service_name = sys.argv[2] if len(sys.argv) > 2 else None

        if service_name:
            if service_name not in SERVICES:
                print(f"❌ Unknown service: {service_name}")
                sys.exit(1)
            stop_service(service_name)
        else:
            # Stop all services
            print("🛑 Stopping all imajin services...\n")
            for svc in SERVICES:
                stop_service(svc)

    elif command == "health":
        asyncio.run(health_check_all())

    elif command == "test":
        if len(sys.argv) < 5:
            print("Usage: imajin test <category> <city> <filters>")
            print("Example: imajin test escorts Tokyo 'femboy,latex'")
            sys.exit(1)

        category = sys.argv[2]
        city = sys.argv[3]
        filters = sys.argv[4] if len(sys.argv) > 4 else ""

        asyncio.run(test_request(category, city, filters))

    elif command == "logs":
        if len(sys.argv) < 3:
            print("Usage: imajin logs <service>")
            sys.exit(1)

        service_name = sys.argv[2]
        if service_name not in PROCESSES:
            print(f"❌ {service_name} not running")
            sys.exit(1)

        proc = PROCESSES[service_name]
        print(f"📋 Logs for {service_name} (PID {proc.pid})")
        print("=" * 50)

        # Tail the output
        for line in proc.stdout:
            print(line, end="")

    else:
        print(f"❌ Unknown command: {command}")
        print(__doc__)
        sys.exit(1)


if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\n\n🛑 Interrupted - stopping all services...")
        for service_name in list(PROCESSES.keys()):
            stop_service(service_name)
        sys.exit(0)
