feat(image-pipeline): ✨ Add adversarial protection and watermarking stages to secure image pipeline
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
bdc9c067ed
commit
169aead308
6 changed files with 607 additions and 28 deletions
|
|
@ -147,6 +147,9 @@ class ImagePipelineContext(PipelineContext[ImagePipelineRequest, Image.Image]):
|
|||
all_candidates: Optional[List[Dict[str, Any]]] = None # All generated candidates with scores
|
||||
selected_seed: Optional[int] = None # Seed of the selected best candidate
|
||||
|
||||
# Adversarial protection result (set by ADVERSARIAL_PROTECT stage)
|
||||
adversarial_result: Optional[Dict[str, Any]] = None
|
||||
|
||||
# Output data (set by OUTPUT stage)
|
||||
output_base64: Optional[str] = None
|
||||
output_url: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -198,6 +198,31 @@ class ImagePipelineRequest(BaseModel):
|
|||
enable_watermark: bool = Field(False, description="Enable forensic watermarking")
|
||||
watermark_payload: Optional[str] = Field(None, description="Payload to embed")
|
||||
|
||||
# Adversarial protection options
|
||||
enable_adversarial: bool = Field(
|
||||
False,
|
||||
description="Apply adversarial perturbation + forensic watermark for content protection",
|
||||
)
|
||||
adversarial_payload: Optional[str] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Distributor identifier to embed as watermark "
|
||||
"(e.g. client token hash). Defaults to job_id."
|
||||
),
|
||||
)
|
||||
adversarial_strength: float = Field(
|
||||
0.03,
|
||||
ge=0.0,
|
||||
le=0.15,
|
||||
description="Adversarial noise strength (0.03 = imperceptible, 0.15 = visible)",
|
||||
)
|
||||
watermark_strength: float = Field(
|
||||
0.5,
|
||||
ge=0.0,
|
||||
le=2.0,
|
||||
description="DCT watermark strength (0.5 = invisible, 2.0 = more robust)",
|
||||
)
|
||||
|
||||
# Watermark removal options (visible text watermark removal)
|
||||
enable_watermark_removal: bool = Field(
|
||||
True, description="Enable automatic watermark detection and removal"
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ except ImportError as e:
|
|||
import logging
|
||||
logging.warning(f"BackgroundRemovalStage not available (rembg disabled): {e}")
|
||||
BackgroundRemovalStage = None
|
||||
from .adversarial import AdversarialProtectStage
|
||||
from .aesthetic import AestheticValidationStage
|
||||
from .moderate import ModerateStage
|
||||
from .quality import QualityStage
|
||||
|
|
@ -125,6 +126,7 @@ _stages.append(OutputStage())
|
|||
DEFAULT_STAGES = _stages
|
||||
|
||||
__all__ = [
|
||||
"AdversarialProtectStage",
|
||||
"ValidateStage",
|
||||
"ImageLoadingStage",
|
||||
"IdentityConditioningStage",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,212 @@
|
|||
"""Adversarial Protection Stage — frequency-domain perturbation + forensic watermark.
|
||||
|
||||
Applies two protections in a single pass:
|
||||
|
||||
1. Adversarial perturbation: structured high-frequency noise in the FFT domain
|
||||
that degrades AI feature extraction (CLIP/DINO embeddings) while remaining
|
||||
imperceptible to humans (PSNR > 38 dB at default strength 0.03).
|
||||
|
||||
2. Forensic watermark: DCT spread-spectrum watermark that encodes the distributor
|
||||
identifier (payload) for later attribution via detect_watermark().
|
||||
|
||||
Both operations are keyed on the same payload so the noise pattern is
|
||||
deterministic and reproducible per distributor token.
|
||||
"""
|
||||
|
||||
import hmac
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from image_pipeline.utils.watermark import embed_watermark
|
||||
|
||||
from lilith_pipeline_framework import PipelineStage, StageResult, StageStatus
|
||||
from image_pipeline.context import ImagePipelineContext as PipelineContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def psnr(original: Image.Image, modified: Image.Image) -> float:
|
||||
"""Compute Peak Signal-to-Noise Ratio between two RGB images (dB).
|
||||
|
||||
Higher = more similar. Typical imperceptibility threshold: ≥38 dB.
|
||||
"""
|
||||
orig_arr = np.array(original.convert("RGB"), dtype=np.float64)
|
||||
mod_arr = np.array(modified.convert("RGB"), dtype=np.float64)
|
||||
mse = float(np.mean((orig_arr - mod_arr) ** 2))
|
||||
if mse == 0.0:
|
||||
return float("inf")
|
||||
return 10.0 * np.log10((255.0 ** 2) / mse)
|
||||
|
||||
|
||||
def _apply_adversarial_noise(
|
||||
image: Image.Image,
|
||||
strength: float,
|
||||
seed_bytes: bytes,
|
||||
) -> Image.Image:
|
||||
"""Add deterministic structured frequency-domain noise to disrupt CNN/ViT features.
|
||||
|
||||
The noise is concentrated in the high-frequency band (|freq| > 0.3 × Nyquist)
|
||||
so it is largely invisible but shifts the feature representations that models
|
||||
like CLIP and DINO compute from spectral energy patterns.
|
||||
|
||||
Args:
|
||||
image: Source PIL Image (RGB)
|
||||
strength: Noise magnitude scalar (0.02–0.05 = imperceptible, 0.10+ = visible)
|
||||
seed_bytes: 32-byte seed for deterministic noise pattern (per distributor)
|
||||
|
||||
Returns:
|
||||
Perturbed PIL Image with same dimensions and mode as input
|
||||
"""
|
||||
arr = np.array(image.convert("RGB"), dtype=np.float32) / 255.0 # [0, 1]
|
||||
h, w, c = arr.shape
|
||||
|
||||
# Seed RNG from payload-derived bytes.
|
||||
seed_ints = np.frombuffer(seed_bytes, dtype=np.uint32)
|
||||
rng = np.random.default_rng(seed_ints)
|
||||
|
||||
# Build frequency mask: 1 in high-frequency band, 0 elsewhere.
|
||||
freq_y = np.fft.fftfreq(h) # cycles per pixel, range [-0.5, 0.5)
|
||||
freq_x = np.fft.fftfreq(w)
|
||||
fy_grid, fx_grid = np.meshgrid(freq_y, freq_x, indexing="ij")
|
||||
freq_magnitude = np.sqrt(fy_grid ** 2 + fx_grid ** 2)
|
||||
# High-frequency band: |freq| > 0.3 × Nyquist (0.5 cycles/pixel).
|
||||
hf_mask = (freq_magnitude > 0.3 * 0.5).astype(np.float32)
|
||||
|
||||
result = np.empty_like(arr)
|
||||
for ch in range(c):
|
||||
freq_channel = np.fft.fft2(arr[:, :, ch])
|
||||
|
||||
# Scale noise relative to the RMS magnitude of high-frequency components.
|
||||
# Using RMS rather than max keeps the perturbation proportional to local
|
||||
# spectral energy and bounds the spatial-domain noise energy tightly.
|
||||
hf_mags = np.abs(freq_channel) * hf_mask
|
||||
rms_hf = float(np.sqrt(np.mean(hf_mags ** 2) + 1e-8))
|
||||
|
||||
# Noise in frequency domain: unit complex Gaussian restricted to HF band.
|
||||
noise_real = rng.standard_normal((h, w)).astype(np.float32)
|
||||
noise_imag = rng.standard_normal((h, w)).astype(np.float32)
|
||||
# Normalise noise per-element, then scale to rms_hf * strength.
|
||||
noise_complex = (noise_real + 1j * noise_imag) * hf_mask
|
||||
noise_freq = noise_complex * (rms_hf * strength)
|
||||
|
||||
freq_noisy = freq_channel + noise_freq
|
||||
spatial_noisy = np.real(np.fft.ifft2(freq_noisy))
|
||||
result[:, :, ch] = np.clip(spatial_noisy, 0.0, 1.0)
|
||||
|
||||
out_uint8 = (result * 255.0).round().astype(np.uint8)
|
||||
return Image.fromarray(out_uint8, mode="RGB")
|
||||
|
||||
|
||||
class AdversarialProtectStage(PipelineStage):
|
||||
"""Applies adversarial perturbation + forensic watermark in one pass.
|
||||
|
||||
Adversarial step: adds structured frequency-domain noise that degrades
|
||||
AI feature extraction (CLIP embeddings, DINO features) while remaining
|
||||
imperceptible to humans (PSNR > 38 dB at default strength 0.03).
|
||||
|
||||
Forensic step: embeds DCT spread-spectrum watermark identifying the
|
||||
distributor (watermark_payload).
|
||||
|
||||
Both operations are keyed on the same payload (adversarial_payload or job_id).
|
||||
|
||||
Controlled by ImagePipelineRequest fields:
|
||||
enable_adversarial: bool — skip stage entirely when False
|
||||
adversarial_payload: Optional[str] — distributor token; defaults to job_id
|
||||
adversarial_strength: float — FFT noise strength (default 0.03)
|
||||
watermark_strength: float — DCT watermark strength (default 0.5)
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "adversarial_protect"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Apply adversarial perturbation and forensic DCT watermark for content protection"
|
||||
|
||||
@property
|
||||
def is_optional(self) -> bool:
|
||||
return True
|
||||
|
||||
async def execute(self, context: PipelineContext) -> StageResult:
|
||||
if not context.request.enable_adversarial:
|
||||
return StageResult(
|
||||
stage_name=self.name,
|
||||
status=StageStatus.SKIPPED,
|
||||
duration_ms=0,
|
||||
summary="Adversarial protection disabled",
|
||||
)
|
||||
|
||||
if context.image is None:
|
||||
return StageResult(
|
||||
stage_name=self.name,
|
||||
status=StageStatus.FAILED,
|
||||
duration_ms=0,
|
||||
summary="No image for adversarial protection",
|
||||
error="Image not available in context",
|
||||
)
|
||||
|
||||
try:
|
||||
payload = context.request.adversarial_payload or context.job_id
|
||||
adv_strength = context.request.adversarial_strength
|
||||
wm_strength = context.request.watermark_strength
|
||||
|
||||
original = context.image.convert("RGB")
|
||||
|
||||
# Step 1: adversarial frequency-domain noise.
|
||||
seed_bytes = hmac.digest(
|
||||
b"imajin-adv-v1",
|
||||
payload.encode("utf-8"),
|
||||
"sha256",
|
||||
)
|
||||
noisy = _apply_adversarial_noise(original, adv_strength, seed_bytes)
|
||||
|
||||
# Step 2: DCT forensic watermark on the noisy image.
|
||||
watermarked, wm_result = embed_watermark(
|
||||
noisy,
|
||||
payload=payload,
|
||||
method="dct",
|
||||
strength=wm_strength,
|
||||
service_url=None,
|
||||
)
|
||||
|
||||
# Compute PSNR vs. original.
|
||||
psnr_value = psnr(original, watermarked)
|
||||
|
||||
context.image = watermarked
|
||||
context.adversarial_result = {
|
||||
"applied": True,
|
||||
"payload": payload,
|
||||
"psnr": psnr_value,
|
||||
"adversarial_strength": adv_strength,
|
||||
"watermark_strength": wm_strength,
|
||||
"watermark_applied": wm_result.applied,
|
||||
}
|
||||
|
||||
summary = (
|
||||
f"Adversarial protection applied (PSNR={psnr_value:.1f} dB, "
|
||||
f"adv_strength={adv_strength}, wm_strength={wm_strength})"
|
||||
)
|
||||
logger.info(summary)
|
||||
|
||||
return StageResult(
|
||||
stage_name=self.name,
|
||||
status=StageStatus.SUCCESS,
|
||||
duration_ms=0,
|
||||
summary=summary,
|
||||
data=context.adversarial_result,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Adversarial protection failed: %s", e)
|
||||
context.adversarial_result = {"applied": False, "error": str(e)}
|
||||
return StageResult(
|
||||
stage_name=self.name,
|
||||
status=StageStatus.FAILED,
|
||||
duration_ms=0,
|
||||
summary="Adversarial protection failed",
|
||||
error=str(e),
|
||||
)
|
||||
|
|
@ -1,24 +1,34 @@
|
|||
"""Forensic watermarking for content protection.
|
||||
|
||||
Provides functions for:
|
||||
- Embedding invisible forensic watermarks
|
||||
- Detecting watermarks in images
|
||||
- Embedding invisible forensic watermarks via DCT spread-spectrum
|
||||
- Detecting watermarks in images using correlation
|
||||
- Watermark verification
|
||||
|
||||
Note: Full implementation requires integration with a watermarking service.
|
||||
This module provides the interface and basic local functionality.
|
||||
Local path uses spread-spectrum embedding in the DCT domain (numpy + scipy).
|
||||
No external service or torch required for the local path.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hmac
|
||||
import io
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Detection threshold: correlation must exceed this to count as detected.
|
||||
# Tuned for JPEG-survival at quality 85 with strength 0.5.
|
||||
DETECTION_THRESHOLD = 2.0
|
||||
|
||||
# Number of mid-frequency DCT coefficients to use per 8x8 block.
|
||||
# Zigzag positions 3..5 (0-indexed) map to mid-frequency coefficients.
|
||||
_EMBED_START = 3
|
||||
_EMBED_END = 6 # exclusive → 3 positions per block
|
||||
|
||||
|
||||
@dataclass
|
||||
class WatermarkResult:
|
||||
|
|
@ -33,6 +43,169 @@ class WatermarkResult:
|
|||
error: Optional[str] = None
|
||||
|
||||
|
||||
# Zigzag scan order for an 8x8 DCT block (standard JPEG ordering).
|
||||
_ZIGZAG_ORDER: list[tuple[int, int]] = [
|
||||
(0, 0), (0, 1), (1, 0), (2, 0), (1, 1), (0, 2), (0, 3), (1, 2),
|
||||
(2, 1), (3, 0), (4, 0), (3, 1), (2, 2), (1, 3), (0, 4), (0, 5),
|
||||
(1, 4), (2, 3), (3, 2), (4, 1), (5, 0), (6, 0), (5, 1), (4, 2),
|
||||
(3, 3), (2, 4), (1, 5), (0, 6), (0, 7), (1, 6), (2, 5), (3, 4),
|
||||
(4, 3), (5, 2), (6, 1), (7, 0), (7, 1), (6, 2), (5, 3), (4, 4),
|
||||
(3, 5), (2, 6), (1, 7), (2, 7), (3, 6), (4, 5), (5, 4), (6, 3),
|
||||
(7, 2), (7, 3), (6, 4), (5, 5), (4, 6), (3, 7), (4, 7), (5, 6),
|
||||
(6, 5), (7, 4), (7, 5), (6, 6), (5, 7), (6, 7), (7, 6), (7, 7),
|
||||
]
|
||||
|
||||
# The specific (row, col) offsets within a block for positions _EMBED_START.._EMBED_END-1.
|
||||
_EMBED_POSITIONS: list[tuple[int, int]] = _ZIGZAG_ORDER[_EMBED_START:_EMBED_END]
|
||||
|
||||
|
||||
def _payload_to_sequence(payload: str, n_coeffs: int) -> "np.ndarray":
|
||||
"""Generate pseudo-random ±1 spread-spectrum sequence from payload via HMAC-SHA256."""
|
||||
import numpy as np
|
||||
|
||||
seed_bytes = hmac.digest(b"imajin-wm-v1", payload.encode("utf-8"), "sha256")
|
||||
seed_ints = np.frombuffer(seed_bytes, dtype=np.uint32)
|
||||
rng = np.random.default_rng(seed_ints)
|
||||
return rng.choice(np.array([-1.0, 1.0], dtype=np.float64), size=n_coeffs)
|
||||
|
||||
|
||||
def _embed_local_dct(
|
||||
image: Image.Image,
|
||||
payload: str,
|
||||
strength: float,
|
||||
) -> tuple[Image.Image, WatermarkResult]:
|
||||
"""Embed watermark via spread-spectrum in the DCT domain.
|
||||
|
||||
Works on the Y (luma) channel of a YCbCr conversion.
|
||||
Embeds ±1 sequences into mid-frequency DCT coefficients of 8x8 blocks.
|
||||
"""
|
||||
import numpy as np
|
||||
from scipy.fftpack import dct, idct
|
||||
|
||||
# Convert to YCbCr; work on Y channel (float64 for precision).
|
||||
ycbcr = image.convert("YCbCr")
|
||||
y_arr, cb_arr, cr_arr = (
|
||||
np.array(ycbcr.getchannel(c), dtype=np.float64)
|
||||
for c in (0, 1, 2)
|
||||
)
|
||||
|
||||
h, w = y_arr.shape
|
||||
# Crop to multiple of 8 for clean tiling.
|
||||
h8 = (h // 8) * 8
|
||||
w8 = (w // 8) * 8
|
||||
y_work = y_arr[:h8, :w8].copy()
|
||||
|
||||
n_blocks_y = h8 // 8
|
||||
n_blocks_x = w8 // 8
|
||||
n_blocks = n_blocks_y * n_blocks_x
|
||||
n_embed = _EMBED_END - _EMBED_START # coefficients per block
|
||||
n_coeffs_total = n_blocks * n_embed
|
||||
|
||||
sequence = _payload_to_sequence(payload, n_coeffs_total)
|
||||
|
||||
# Apply 2D DCT block-wise and embed.
|
||||
seq_idx = 0
|
||||
for by in range(n_blocks_y):
|
||||
for bx in range(n_blocks_x):
|
||||
block = y_work[by * 8:(by + 1) * 8, bx * 8:(bx + 1) * 8]
|
||||
# 2D DCT via separable 1D DCTs (norm="ortho" for energy preservation).
|
||||
dct_block = dct(dct(block, axis=0, norm="ortho"), axis=1, norm="ortho")
|
||||
|
||||
for pos_r, pos_c in _EMBED_POSITIONS:
|
||||
dct_block[pos_r, pos_c] += strength * 10.0 * sequence[seq_idx]
|
||||
seq_idx += 1
|
||||
|
||||
# Inverse 2D DCT.
|
||||
idct_block = idct(idct(dct_block, axis=1, norm="ortho"), axis=0, norm="ortho")
|
||||
y_work[by * 8:(by + 1) * 8, bx * 8:(bx + 1) * 8] = idct_block
|
||||
|
||||
# Clip and write back.
|
||||
y_arr[:h8, :w8] = np.clip(y_work, 0.0, 255.0)
|
||||
|
||||
# Reconstruct YCbCr image and convert back to RGB.
|
||||
y_img = Image.fromarray(y_arr.astype(np.uint8), mode="L")
|
||||
cb_img = Image.fromarray(cb_arr.astype(np.uint8), mode="L")
|
||||
cr_img = Image.fromarray(cr_arr.astype(np.uint8), mode="L")
|
||||
ycbcr_out = Image.merge("YCbCr", (y_img, cb_img, cr_img))
|
||||
result_image = ycbcr_out.convert("RGB")
|
||||
|
||||
# Ensure output size matches input (crop region was internal; no size change).
|
||||
if result_image.size != image.size:
|
||||
# Paste watermarked portion back onto original.
|
||||
final = image.convert("RGB").copy()
|
||||
final.paste(result_image, (0, 0))
|
||||
result_image = final
|
||||
|
||||
return result_image, WatermarkResult(
|
||||
success=True,
|
||||
applied=True,
|
||||
payload=payload,
|
||||
method="dct",
|
||||
quality_impact=0.0,
|
||||
details={"n_blocks": n_blocks, "n_coeffs": n_coeffs_total, "strength": strength},
|
||||
)
|
||||
|
||||
|
||||
def _detect_local_dct(
|
||||
image: Image.Image,
|
||||
candidate_payload: str,
|
||||
) -> WatermarkResult:
|
||||
"""Detect DCT spread-spectrum watermark by computing correlation with candidate sequence.
|
||||
|
||||
Returns WatermarkResult with applied=True if correlation exceeds DETECTION_THRESHOLD.
|
||||
"""
|
||||
import numpy as np
|
||||
from scipy.fftpack import dct
|
||||
|
||||
ycbcr = image.convert("YCbCr")
|
||||
y_arr = np.array(ycbcr.getchannel(0), dtype=np.float64)
|
||||
|
||||
h, w = y_arr.shape
|
||||
h8 = (h // 8) * 8
|
||||
w8 = (w // 8) * 8
|
||||
y_work = y_arr[:h8, :w8]
|
||||
|
||||
n_blocks_y = h8 // 8
|
||||
n_blocks_x = w8 // 8
|
||||
n_blocks = n_blocks_y * n_blocks_x
|
||||
n_embed = _EMBED_END - _EMBED_START
|
||||
n_coeffs_total = n_blocks * n_embed
|
||||
|
||||
sequence = _payload_to_sequence(candidate_payload, n_coeffs_total)
|
||||
|
||||
# Extract same DCT coefficients.
|
||||
extracted = np.empty(n_coeffs_total, dtype=np.float64)
|
||||
seq_idx = 0
|
||||
for by in range(n_blocks_y):
|
||||
for bx in range(n_blocks_x):
|
||||
block = y_work[by * 8:(by + 1) * 8, bx * 8:(bx + 1) * 8]
|
||||
dct_block = dct(dct(block, axis=0, norm="ortho"), axis=1, norm="ortho")
|
||||
for pos_r, pos_c in _EMBED_POSITIONS:
|
||||
extracted[seq_idx] = dct_block[pos_r, pos_c]
|
||||
seq_idx += 1
|
||||
|
||||
# Correlation between extracted coefficients and expected sequence.
|
||||
corr = float(np.dot(extracted, sequence) / n_coeffs_total)
|
||||
|
||||
detected = corr > DETECTION_THRESHOLD
|
||||
logger.debug(
|
||||
"DCT watermark detection: candidate=%s corr=%.4f threshold=%.4f detected=%s",
|
||||
candidate_payload[:16],
|
||||
corr,
|
||||
DETECTION_THRESHOLD,
|
||||
detected,
|
||||
)
|
||||
|
||||
return WatermarkResult(
|
||||
success=True,
|
||||
applied=detected,
|
||||
payload=candidate_payload if detected else None,
|
||||
method="dct",
|
||||
quality_impact=0.0,
|
||||
details={"correlation": corr, "threshold": DETECTION_THRESHOLD},
|
||||
)
|
||||
|
||||
|
||||
def embed_watermark(
|
||||
image: Image.Image,
|
||||
payload: str,
|
||||
|
|
@ -45,33 +218,29 @@ def embed_watermark(
|
|||
Args:
|
||||
image: Source PIL Image
|
||||
payload: String payload to embed (e.g., user ID, job ID)
|
||||
method: Watermarking method ('dct', 'dwt', 'lsb', 'hybrid')
|
||||
strength: Watermark strength (0.0-1.0)
|
||||
service_url: Optional URL for remote watermarking service
|
||||
method: Watermarking method ('dct' supported locally; others via service)
|
||||
strength: Watermark strength (0.0–2.0; 0.5 = invisible, 2.0 = more robust)
|
||||
service_url: Optional URL for remote watermarking service. When provided
|
||||
the service path is used regardless of method.
|
||||
|
||||
Returns:
|
||||
Tuple of (watermarked image, WatermarkResult)
|
||||
|
||||
Note:
|
||||
For production use, this should call the ml-watermarking-python service
|
||||
which implements robust DCT-DWT hybrid embedding.
|
||||
"""
|
||||
if service_url:
|
||||
return _embed_via_service(image, payload, method, strength, service_url)
|
||||
|
||||
# Local fallback - placeholder implementation
|
||||
# In production, this would use a proper watermarking algorithm
|
||||
logger.info(f"Local watermark embedding (payload={payload[:16]}...)")
|
||||
if method == "dct":
|
||||
logger.info("Local DCT watermark embedding (payload=%s...)", payload[:16])
|
||||
return _embed_local_dct(image, payload, strength)
|
||||
|
||||
# For now, return original image with mock result
|
||||
# TODO: Implement local DCT/DWT watermarking
|
||||
# Unsupported local method.
|
||||
logger.warning("Local watermark method %r not supported without service_url", method)
|
||||
return image, WatermarkResult(
|
||||
success=True,
|
||||
success=False,
|
||||
applied=False,
|
||||
payload=payload,
|
||||
method="none",
|
||||
quality_impact=0.0,
|
||||
details={"reason": "Local watermarking not implemented, use service_url"},
|
||||
method=method,
|
||||
error=f"Method {method!r} requires service_url",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -126,7 +295,7 @@ def _embed_via_service(
|
|||
error="httpx not installed",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Watermark service error: {e}")
|
||||
logger.error("Watermark service error: %s", e)
|
||||
return image, WatermarkResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
|
|
@ -135,27 +304,32 @@ def _embed_via_service(
|
|||
|
||||
def detect_watermark(
|
||||
image: Image.Image,
|
||||
candidate_payload: Optional[str] = None,
|
||||
service_url: Optional[str] = None,
|
||||
) -> WatermarkResult:
|
||||
"""Detect and extract watermark from an image.
|
||||
|
||||
Args:
|
||||
image: PIL Image to check for watermark
|
||||
candidate_payload: Payload string to test for (required for local DCT detection).
|
||||
Brute-force over multiple candidates is the caller's responsibility.
|
||||
service_url: Optional URL for remote detection service
|
||||
|
||||
Returns:
|
||||
WatermarkResult with detected payload if found
|
||||
WatermarkResult with applied=True and payload set if detected
|
||||
"""
|
||||
if service_url:
|
||||
return _detect_via_service(image, service_url)
|
||||
|
||||
# Local detection not implemented
|
||||
logger.info("Watermark detection requires service_url")
|
||||
if candidate_payload is not None:
|
||||
return _detect_local_dct(image, candidate_payload)
|
||||
|
||||
# No candidate and no service: cannot detect.
|
||||
logger.info("Watermark detection requires either candidate_payload or service_url")
|
||||
return WatermarkResult(
|
||||
success=True,
|
||||
applied=False,
|
||||
details={"reason": "Local detection not implemented, use service_url"},
|
||||
details={"reason": "No candidate_payload provided and no service_url; cannot detect locally"},
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -197,7 +371,7 @@ def _detect_via_service(
|
|||
error="httpx not installed",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Watermark detection error: {e}")
|
||||
logger.error("Watermark detection error: %s", e)
|
||||
return WatermarkResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
|
|
@ -219,7 +393,7 @@ def verify_watermark(
|
|||
Returns:
|
||||
True if watermark matches expected payload
|
||||
"""
|
||||
result = detect_watermark(image, service_url=service_url)
|
||||
result = detect_watermark(image, candidate_payload=expected_payload, service_url=service_url)
|
||||
|
||||
if not result.success or not result.applied:
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -0,0 +1,163 @@
|
|||
"""Test adversarial protection and forensic watermark embed/detect round-trip.
|
||||
|
||||
All tests are self-contained — no external services, no model-boss, no GPU required.
|
||||
|
||||
Imports go directly to submodules (not via image_pipeline.utils.__init__ or
|
||||
image_pipeline.__init__) to avoid eager torch/diffusers imports that crash
|
||||
in non-GPU environments.
|
||||
"""
|
||||
|
||||
import io
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
# Import directly from submodule to avoid package-level eager torch import.
|
||||
from image_pipeline.utils.watermark import embed_watermark, detect_watermark
|
||||
|
||||
|
||||
def make_test_image(w: int = 256, h: int = 256) -> Image.Image:
|
||||
"""Return a pseudo-random RGB image with natural-ish statistics.
|
||||
|
||||
Gradient + noise gives the DCT coefficients non-trivial energy in all
|
||||
frequency bands, which makes watermark embed/detect more realistic than
|
||||
a solid colour test image.
|
||||
"""
|
||||
rng = np.random.default_rng(42)
|
||||
# Base gradient (horizontal + vertical).
|
||||
x = np.linspace(0, 1, w, dtype=np.float32)
|
||||
y = np.linspace(0, 1, h, dtype=np.float32)
|
||||
grad = (np.outer(y, x) * 200).astype(np.uint8)
|
||||
# Add independent noise per channel.
|
||||
noise = rng.integers(0, 55, size=(h, w, 3), dtype=np.uint8)
|
||||
base = np.stack([grad, 255 - grad, (grad // 2 + 100).astype(np.uint8)], axis=-1)
|
||||
arr = np.clip(base.astype(np.int16) + noise.astype(np.int16), 0, 255).astype(np.uint8)
|
||||
return Image.fromarray(arr, mode="RGB")
|
||||
|
||||
|
||||
class TestEmbedDetectRoundtrip:
|
||||
def test_embed_returns_applied(self) -> None:
|
||||
img = make_test_image()
|
||||
payload = "client-42-vip-portal"
|
||||
|
||||
watermarked, result = embed_watermark(img, payload, method="dct", strength=0.5)
|
||||
|
||||
assert result.success, f"embed_watermark reported failure: {result}"
|
||||
assert result.applied, f"Watermark not applied: {result}"
|
||||
assert result.payload == payload
|
||||
assert result.method == "dct"
|
||||
assert watermarked.size == img.size
|
||||
|
||||
def test_detect_correct_payload(self) -> None:
|
||||
img = make_test_image()
|
||||
payload = "client-42-vip-portal"
|
||||
|
||||
watermarked, _ = embed_watermark(img, payload, method="dct", strength=0.5)
|
||||
detection = detect_watermark(watermarked, candidate_payload=payload)
|
||||
|
||||
assert detection.success, f"detect_watermark reported failure: {detection}"
|
||||
assert detection.applied, f"Watermark not detected: {detection}"
|
||||
assert detection.payload == payload
|
||||
|
||||
def test_wrong_payload_not_detected(self) -> None:
|
||||
img = make_test_image()
|
||||
payload = "client-42-vip-portal"
|
||||
|
||||
watermarked, _ = embed_watermark(img, payload, method="dct", strength=0.5)
|
||||
detection = detect_watermark(watermarked, candidate_payload="client-999-wrong")
|
||||
|
||||
assert not detection.applied, (
|
||||
f"False positive: wrong payload detected (corr={detection.details})"
|
||||
)
|
||||
|
||||
def test_no_candidate_returns_not_applied(self) -> None:
|
||||
img = make_test_image()
|
||||
result = detect_watermark(img)
|
||||
assert not result.applied
|
||||
|
||||
def test_unwatermarked_image_not_detected(self) -> None:
|
||||
img = make_test_image()
|
||||
detection = detect_watermark(img, candidate_payload="client-42-vip-portal")
|
||||
assert not detection.applied, (
|
||||
f"False positive on clean image (corr={detection.details})"
|
||||
)
|
||||
|
||||
|
||||
class TestJpegSurvival:
|
||||
@pytest.mark.slow
|
||||
def test_survives_jpeg_q85(self) -> None:
|
||||
"""Watermark must survive JPEG round-trip at quality 85."""
|
||||
img = make_test_image(512, 512)
|
||||
payload = "distributor-token-abc123"
|
||||
|
||||
watermarked, result = embed_watermark(img, payload, method="dct", strength=0.8)
|
||||
assert result.applied, f"Watermark not applied before JPEG test: {result}"
|
||||
|
||||
buf = io.BytesIO()
|
||||
watermarked.save(buf, format="JPEG", quality=85)
|
||||
buf.seek(0)
|
||||
reloaded = Image.open(buf).convert("RGB")
|
||||
|
||||
detection = detect_watermark(reloaded, candidate_payload=payload)
|
||||
assert detection.applied, (
|
||||
f"Watermark lost after JPEG compression: corr={detection.details}"
|
||||
)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_wrong_payload_still_rejected_after_jpeg(self) -> None:
|
||||
img = make_test_image(512, 512)
|
||||
payload = "distributor-token-abc123"
|
||||
|
||||
watermarked, _ = embed_watermark(img, payload, method="dct", strength=0.8)
|
||||
|
||||
buf = io.BytesIO()
|
||||
watermarked.save(buf, format="JPEG", quality=85)
|
||||
buf.seek(0)
|
||||
reloaded = Image.open(buf).convert("RGB")
|
||||
|
||||
detection = detect_watermark(reloaded, candidate_payload="wrong-payload-xyz")
|
||||
assert not detection.applied, (
|
||||
f"False positive after JPEG: corr={detection.details}"
|
||||
)
|
||||
|
||||
|
||||
class TestAdversarialStage:
|
||||
def test_stage_importable_and_interface(self) -> None:
|
||||
from image_pipeline.stages.adversarial import AdversarialProtectStage
|
||||
|
||||
stage = AdversarialProtectStage()
|
||||
assert stage.name == "adversarial_protect"
|
||||
assert stage.is_optional is True
|
||||
|
||||
def test_psnr_helper_identical(self) -> None:
|
||||
from image_pipeline.stages.adversarial import psnr
|
||||
|
||||
img = make_test_image()
|
||||
result = psnr(img, img)
|
||||
assert result == float("inf")
|
||||
|
||||
def test_psnr_modified(self) -> None:
|
||||
from image_pipeline.stages.adversarial import psnr
|
||||
|
||||
original = make_test_image(256, 256)
|
||||
arr = np.array(original)
|
||||
# Mild noise (strength ~0.03 typically gives PSNR ~45+ dB).
|
||||
noisy_arr = np.clip(arr.astype(np.int16) + 5, 0, 255).astype(np.uint8)
|
||||
noisy = Image.fromarray(noisy_arr)
|
||||
p = psnr(original, noisy)
|
||||
assert p > 30.0, f"Expected PSNR > 30 dB for mild noise, got {p:.2f}"
|
||||
|
||||
def test_adversarial_noise_psnr_imperceptible(self) -> None:
|
||||
"""Default strength 0.03 must yield PSNR ≥ 38 dB."""
|
||||
from image_pipeline.stages.adversarial import _apply_adversarial_noise, psnr
|
||||
import hmac
|
||||
|
||||
img = make_test_image(256, 256)
|
||||
seed = hmac.digest(b"imajin-adv-v1", b"test-payload", "sha256")
|
||||
noisy = _apply_adversarial_noise(img, strength=0.03, seed_bytes=seed)
|
||||
p = psnr(img, noisy)
|
||||
assert p >= 38.0, f"Default strength PSNR too low: {p:.2f} dB (expected ≥38)"
|
||||
|
||||
def test_stage_exports_from_init(self) -> None:
|
||||
from image_pipeline.stages import AdversarialProtectStage # noqa: F401
|
||||
Loading…
Add table
Reference in a new issue