imajin/scripts/run/repaint_command.py
2026-03-31 22:11:29 -07:00

198 lines
7.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Repaint command — background replacement via SDXL inpainting.
Usage:
./run repaint --source photo.jpg --prompt "luxury hotel suite, city view"
./run repaint --source photo.jpg --prompt "hotel suite" --count 4 --out ./results/
./run repaint --source photo.jpg --prompt "..." --seed 42 --count 3
"""
import argparse
import base64
import io
import json
import sys
import time
from pathlib import Path
from typing import Optional
import requests
def _encode_image(path: Path) -> str:
"""Read image file and return raw base64 string."""
return base64.b64encode(path.read_bytes()).decode()
def _submit_job(
url: str,
source_b64: str,
background_prompt: str,
negative_prompt: Optional[str],
steps: int,
guidance_scale: float,
seed: int,
rating: str,
) -> str:
payload = {
"sourceImage": source_b64,
"backgroundPrompt": background_prompt,
"negativePrompt": negative_prompt,
"steps": steps,
"guidanceScale": guidance_scale,
"seed": seed,
"maturityRating": rating,
}
resp = requests.post(f"{url}/generate/repaint-background/async", json=payload, timeout=30)
resp.raise_for_status()
data = resp.json()
if not data.get("success") or not data.get("jobId"):
raise RuntimeError(f"Submit failed: {data}")
return data["jobId"]
def _poll_jobs(url: str, job_ids: list[str], interval: float = 3.0) -> dict[str, dict]:
"""Poll until all jobs are terminal. Returns {job_id: result_data}."""
pending = set(job_ids)
results: dict[str, dict] = {}
while pending:
time.sleep(interval)
for job_id in list(pending):
resp = requests.get(f"{url}/jobs/{job_id}", timeout=10)
resp.raise_for_status()
data = resp.json()
status = data.get("status")
if status == "completed":
result_resp = requests.get(f"{url}/jobs/{job_id}/result", timeout=30)
result_resp.raise_for_status()
result_data = result_resp.json()
results[job_id] = result_data
pending.discard(job_id)
print(f"{job_id[:8]} done")
elif status == "failed":
results[job_id] = {"error": data.get("error", "failed")}
pending.discard(job_id)
print(f"{job_id[:8]} failed: {data.get('error', '?')}", file=sys.stderr)
return results
def repaint_command(args: list[str], workspace_root: Path) -> int:
parser = argparse.ArgumentParser(
prog="./run repaint",
description="Replace image background using SDXL inpainting (BiRefNet segmentation)",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Single repaint
./run repaint --source photo.jpg --prompt "luxury hotel suite, city skyline"
# Batch of 4 variants from the same source
./run repaint --source photo.jpg --prompt "hotel suite" --count 4 --out ./results/
# Deterministic batch (seed+0, seed+1, ...)
./run repaint --source photo.jpg --prompt "hotel suite" --seed 42 --count 3
# NSFW rating (removes SFW clothing-lock from prompt)
./run repaint --source photo.jpg --prompt "hotel suite" --rating nsfw
""",
)
parser.add_argument("--source", "-s", required=True, type=Path, help="Source photo path")
parser.add_argument("--prompt", "-p", required=True, help="Background scene description")
parser.add_argument("--negative", "-n", default=None, help="Negative prompt")
parser.add_argument("--count", "-c", type=int, default=1, help="Number of variants (default: 1)")
parser.add_argument("--seed", type=int, default=None, help="Starting seed (increments per variant)")
parser.add_argument("--steps", type=int, default=35, help="Inference steps (default: 35)")
parser.add_argument("--guidance", type=float, default=7.5, help="CFG guidance scale (default: 7.5)")
parser.add_argument("--rating", choices=["sfw", "nsfw", "explicit"], default="nsfw", help="Content rating (default: nsfw)")
parser.add_argument("--out", "-o", type=Path, default=None, help="Output directory (or file if count=1)")
parser.add_argument("--url", default="http://localhost:8002", help="Diffusion service URL")
parsed = parser.parse_args(args)
source_path = parsed.source.expanduser().resolve()
if not source_path.exists():
print(f"Source photo not found: {source_path}", file=sys.stderr)
return 1
# Check service health
try:
requests.get(f"{parsed.url}/health", timeout=5).raise_for_status()
except Exception:
print(f"Diffusion service not reachable at {parsed.url}", file=sys.stderr)
print("Start with: ./run dev diffusion", file=sys.stderr)
return 1
# Determine output paths
out_path = (parsed.out or Path(".")).expanduser().resolve()
if parsed.count == 1 and out_path.suffix in (".png", ".jpg", ".webp"):
out_dir = out_path.parent
out_template = None
single_out = out_path
else:
out_dir = out_path if out_path.suffix == "" else out_path.parent
out_template = source_path.stem + "_repaint_{n}.png"
single_out = None
out_dir.mkdir(parents=True, exist_ok=True)
# Generate seeds
import random
base_seed = parsed.seed if parsed.seed is not None else random.randint(0, 2**31 - 1)
seeds = [base_seed + i for i in range(parsed.count)]
print(f"Repainting {source_path.name} × {parsed.count}")
print(f" Prompt: {parsed.prompt[:80]}{'...' if len(parsed.prompt) > 80 else ''}")
print(f" Seeds: {seeds[:5]}{'...' if len(seeds) > 5 else ''}")
print()
source_b64 = _encode_image(source_path)
# Submit all jobs
job_ids: list[str] = []
for seed in seeds:
try:
job_id = _submit_job(
parsed.url, source_b64, parsed.prompt, parsed.negative,
parsed.steps, parsed.guidance, seed, parsed.rating,
)
job_ids.append(job_id)
print(f"{job_id[:8]} seed={seed}")
except Exception as e:
print(f" Submit failed (seed={seed}): {e}", file=sys.stderr)
if not job_ids:
print("All submissions failed.", file=sys.stderr)
return 1
print(f"\nPolling {len(job_ids)} job(s)...")
results = _poll_jobs(parsed.url, job_ids)
# Save results
saved = 0
for idx, (job_id, result) in enumerate(results.items()):
if "error" in result and "output_base64" not in str(result):
continue
# Navigate into nested result structure
r = result.get("result", result)
b64 = r.get("output_base64", "")
if not b64:
continue
if single_out and idx == 0:
out_file = single_out
else:
out_file = out_dir / (out_template or f"{source_path.stem}_repaint_{idx+1:02d}.png").format(n=idx + 1)
out_file.write_bytes(base64.b64decode(b64))
w, h = r.get("width", "?"), r.get("height", "?")
seed_used = r.get("seed", seeds[idx] if idx < len(seeds) else "?")
print(f" Saved {out_file.name} ({w}×{h}, seed={seed_used})")
saved += 1
print(f"\n{saved}/{len(job_ids)} images saved to {out_dir}")
return 0 if saved > 0 else 1
def register_repaint_command(runner) -> None:
runner.register_command("repaint", repaint_command, "Replace image background via SDXL inpainting")