feat(image-pipeline): Add adversarial protection and watermarking stages to secure image pipeline

Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
autocommit 2026-04-20 00:44:33 -07:00
parent bdc9c067ed
commit 169aead308
6 changed files with 607 additions and 28 deletions

View file

@ -147,6 +147,9 @@ class ImagePipelineContext(PipelineContext[ImagePipelineRequest, Image.Image]):
all_candidates: Optional[List[Dict[str, Any]]] = None # All generated candidates with scores
selected_seed: Optional[int] = None # Seed of the selected best candidate
# Adversarial protection result (set by ADVERSARIAL_PROTECT stage)
adversarial_result: Optional[Dict[str, Any]] = None
# Output data (set by OUTPUT stage)
output_base64: Optional[str] = None
output_url: Optional[str] = None

View file

@ -198,6 +198,31 @@ class ImagePipelineRequest(BaseModel):
enable_watermark: bool = Field(False, description="Enable forensic watermarking")
watermark_payload: Optional[str] = Field(None, description="Payload to embed")
# Adversarial protection options
enable_adversarial: bool = Field(
False,
description="Apply adversarial perturbation + forensic watermark for content protection",
)
adversarial_payload: Optional[str] = Field(
None,
description=(
"Distributor identifier to embed as watermark "
"(e.g. client token hash). Defaults to job_id."
),
)
adversarial_strength: float = Field(
0.03,
ge=0.0,
le=0.15,
description="Adversarial noise strength (0.03 = imperceptible, 0.15 = visible)",
)
watermark_strength: float = Field(
0.5,
ge=0.0,
le=2.0,
description="DCT watermark strength (0.5 = invisible, 2.0 = more robust)",
)
# Watermark removal options (visible text watermark removal)
enable_watermark_removal: bool = Field(
True, description="Enable automatic watermark detection and removal"

View file

@ -72,6 +72,7 @@ except ImportError as e:
import logging
logging.warning(f"BackgroundRemovalStage not available (rembg disabled): {e}")
BackgroundRemovalStage = None
from .adversarial import AdversarialProtectStage
from .aesthetic import AestheticValidationStage
from .moderate import ModerateStage
from .quality import QualityStage
@ -125,6 +126,7 @@ _stages.append(OutputStage())
DEFAULT_STAGES = _stages
__all__ = [
"AdversarialProtectStage",
"ValidateStage",
"ImageLoadingStage",
"IdentityConditioningStage",

View file

@ -0,0 +1,212 @@
"""Adversarial Protection Stage — frequency-domain perturbation + forensic watermark.
Applies two protections in a single pass:
1. Adversarial perturbation: structured high-frequency noise in the FFT domain
that degrades AI feature extraction (CLIP/DINO embeddings) while remaining
imperceptible to humans (PSNR > 38 dB at default strength 0.03).
2. Forensic watermark: DCT spread-spectrum watermark that encodes the distributor
identifier (payload) for later attribution via detect_watermark().
Both operations are keyed on the same payload so the noise pattern is
deterministic and reproducible per distributor token.
"""
import hmac
import logging
from typing import Any, Dict
import numpy as np
from PIL import Image
from image_pipeline.utils.watermark import embed_watermark
from lilith_pipeline_framework import PipelineStage, StageResult, StageStatus
from image_pipeline.context import ImagePipelineContext as PipelineContext
logger = logging.getLogger(__name__)
def psnr(original: Image.Image, modified: Image.Image) -> float:
"""Compute Peak Signal-to-Noise Ratio between two RGB images (dB).
Higher = more similar. Typical imperceptibility threshold: 38 dB.
"""
orig_arr = np.array(original.convert("RGB"), dtype=np.float64)
mod_arr = np.array(modified.convert("RGB"), dtype=np.float64)
mse = float(np.mean((orig_arr - mod_arr) ** 2))
if mse == 0.0:
return float("inf")
return 10.0 * np.log10((255.0 ** 2) / mse)
def _apply_adversarial_noise(
image: Image.Image,
strength: float,
seed_bytes: bytes,
) -> Image.Image:
"""Add deterministic structured frequency-domain noise to disrupt CNN/ViT features.
The noise is concentrated in the high-frequency band (|freq| > 0.3 × Nyquist)
so it is largely invisible but shifts the feature representations that models
like CLIP and DINO compute from spectral energy patterns.
Args:
image: Source PIL Image (RGB)
strength: Noise magnitude scalar (0.020.05 = imperceptible, 0.10+ = visible)
seed_bytes: 32-byte seed for deterministic noise pattern (per distributor)
Returns:
Perturbed PIL Image with same dimensions and mode as input
"""
arr = np.array(image.convert("RGB"), dtype=np.float32) / 255.0 # [0, 1]
h, w, c = arr.shape
# Seed RNG from payload-derived bytes.
seed_ints = np.frombuffer(seed_bytes, dtype=np.uint32)
rng = np.random.default_rng(seed_ints)
# Build frequency mask: 1 in high-frequency band, 0 elsewhere.
freq_y = np.fft.fftfreq(h) # cycles per pixel, range [-0.5, 0.5)
freq_x = np.fft.fftfreq(w)
fy_grid, fx_grid = np.meshgrid(freq_y, freq_x, indexing="ij")
freq_magnitude = np.sqrt(fy_grid ** 2 + fx_grid ** 2)
# High-frequency band: |freq| > 0.3 × Nyquist (0.5 cycles/pixel).
hf_mask = (freq_magnitude > 0.3 * 0.5).astype(np.float32)
result = np.empty_like(arr)
for ch in range(c):
freq_channel = np.fft.fft2(arr[:, :, ch])
# Scale noise relative to the RMS magnitude of high-frequency components.
# Using RMS rather than max keeps the perturbation proportional to local
# spectral energy and bounds the spatial-domain noise energy tightly.
hf_mags = np.abs(freq_channel) * hf_mask
rms_hf = float(np.sqrt(np.mean(hf_mags ** 2) + 1e-8))
# Noise in frequency domain: unit complex Gaussian restricted to HF band.
noise_real = rng.standard_normal((h, w)).astype(np.float32)
noise_imag = rng.standard_normal((h, w)).astype(np.float32)
# Normalise noise per-element, then scale to rms_hf * strength.
noise_complex = (noise_real + 1j * noise_imag) * hf_mask
noise_freq = noise_complex * (rms_hf * strength)
freq_noisy = freq_channel + noise_freq
spatial_noisy = np.real(np.fft.ifft2(freq_noisy))
result[:, :, ch] = np.clip(spatial_noisy, 0.0, 1.0)
out_uint8 = (result * 255.0).round().astype(np.uint8)
return Image.fromarray(out_uint8, mode="RGB")
class AdversarialProtectStage(PipelineStage):
"""Applies adversarial perturbation + forensic watermark in one pass.
Adversarial step: adds structured frequency-domain noise that degrades
AI feature extraction (CLIP embeddings, DINO features) while remaining
imperceptible to humans (PSNR > 38 dB at default strength 0.03).
Forensic step: embeds DCT spread-spectrum watermark identifying the
distributor (watermark_payload).
Both operations are keyed on the same payload (adversarial_payload or job_id).
Controlled by ImagePipelineRequest fields:
enable_adversarial: bool skip stage entirely when False
adversarial_payload: Optional[str] distributor token; defaults to job_id
adversarial_strength: float FFT noise strength (default 0.03)
watermark_strength: float DCT watermark strength (default 0.5)
"""
@property
def name(self) -> str:
return "adversarial_protect"
@property
def description(self) -> str:
return "Apply adversarial perturbation and forensic DCT watermark for content protection"
@property
def is_optional(self) -> bool:
return True
async def execute(self, context: PipelineContext) -> StageResult:
if not context.request.enable_adversarial:
return StageResult(
stage_name=self.name,
status=StageStatus.SKIPPED,
duration_ms=0,
summary="Adversarial protection disabled",
)
if context.image is None:
return StageResult(
stage_name=self.name,
status=StageStatus.FAILED,
duration_ms=0,
summary="No image for adversarial protection",
error="Image not available in context",
)
try:
payload = context.request.adversarial_payload or context.job_id
adv_strength = context.request.adversarial_strength
wm_strength = context.request.watermark_strength
original = context.image.convert("RGB")
# Step 1: adversarial frequency-domain noise.
seed_bytes = hmac.digest(
b"imajin-adv-v1",
payload.encode("utf-8"),
"sha256",
)
noisy = _apply_adversarial_noise(original, adv_strength, seed_bytes)
# Step 2: DCT forensic watermark on the noisy image.
watermarked, wm_result = embed_watermark(
noisy,
payload=payload,
method="dct",
strength=wm_strength,
service_url=None,
)
# Compute PSNR vs. original.
psnr_value = psnr(original, watermarked)
context.image = watermarked
context.adversarial_result = {
"applied": True,
"payload": payload,
"psnr": psnr_value,
"adversarial_strength": adv_strength,
"watermark_strength": wm_strength,
"watermark_applied": wm_result.applied,
}
summary = (
f"Adversarial protection applied (PSNR={psnr_value:.1f} dB, "
f"adv_strength={adv_strength}, wm_strength={wm_strength})"
)
logger.info(summary)
return StageResult(
stage_name=self.name,
status=StageStatus.SUCCESS,
duration_ms=0,
summary=summary,
data=context.adversarial_result,
)
except Exception as e:
logger.error("Adversarial protection failed: %s", e)
context.adversarial_result = {"applied": False, "error": str(e)}
return StageResult(
stage_name=self.name,
status=StageStatus.FAILED,
duration_ms=0,
summary="Adversarial protection failed",
error=str(e),
)

View file

@ -1,24 +1,34 @@
"""Forensic watermarking for content protection.
Provides functions for:
- Embedding invisible forensic watermarks
- Detecting watermarks in images
- Embedding invisible forensic watermarks via DCT spread-spectrum
- Detecting watermarks in images using correlation
- Watermark verification
Note: Full implementation requires integration with a watermarking service.
This module provides the interface and basic local functionality.
Local path uses spread-spectrum embedding in the DCT domain (numpy + scipy).
No external service or torch required for the local path.
"""
import base64
import hmac
import io
import logging
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
from PIL import Image
logger = logging.getLogger(__name__)
# Detection threshold: correlation must exceed this to count as detected.
# Tuned for JPEG-survival at quality 85 with strength 0.5.
DETECTION_THRESHOLD = 2.0
# Number of mid-frequency DCT coefficients to use per 8x8 block.
# Zigzag positions 3..5 (0-indexed) map to mid-frequency coefficients.
_EMBED_START = 3
_EMBED_END = 6 # exclusive → 3 positions per block
@dataclass
class WatermarkResult:
@ -33,6 +43,169 @@ class WatermarkResult:
error: Optional[str] = None
# Zigzag scan order for an 8x8 DCT block (standard JPEG ordering).
_ZIGZAG_ORDER: list[tuple[int, int]] = [
(0, 0), (0, 1), (1, 0), (2, 0), (1, 1), (0, 2), (0, 3), (1, 2),
(2, 1), (3, 0), (4, 0), (3, 1), (2, 2), (1, 3), (0, 4), (0, 5),
(1, 4), (2, 3), (3, 2), (4, 1), (5, 0), (6, 0), (5, 1), (4, 2),
(3, 3), (2, 4), (1, 5), (0, 6), (0, 7), (1, 6), (2, 5), (3, 4),
(4, 3), (5, 2), (6, 1), (7, 0), (7, 1), (6, 2), (5, 3), (4, 4),
(3, 5), (2, 6), (1, 7), (2, 7), (3, 6), (4, 5), (5, 4), (6, 3),
(7, 2), (7, 3), (6, 4), (5, 5), (4, 6), (3, 7), (4, 7), (5, 6),
(6, 5), (7, 4), (7, 5), (6, 6), (5, 7), (6, 7), (7, 6), (7, 7),
]
# The specific (row, col) offsets within a block for positions _EMBED_START.._EMBED_END-1.
_EMBED_POSITIONS: list[tuple[int, int]] = _ZIGZAG_ORDER[_EMBED_START:_EMBED_END]
def _payload_to_sequence(payload: str, n_coeffs: int) -> "np.ndarray":
"""Generate pseudo-random ±1 spread-spectrum sequence from payload via HMAC-SHA256."""
import numpy as np
seed_bytes = hmac.digest(b"imajin-wm-v1", payload.encode("utf-8"), "sha256")
seed_ints = np.frombuffer(seed_bytes, dtype=np.uint32)
rng = np.random.default_rng(seed_ints)
return rng.choice(np.array([-1.0, 1.0], dtype=np.float64), size=n_coeffs)
def _embed_local_dct(
image: Image.Image,
payload: str,
strength: float,
) -> tuple[Image.Image, WatermarkResult]:
"""Embed watermark via spread-spectrum in the DCT domain.
Works on the Y (luma) channel of a YCbCr conversion.
Embeds ±1 sequences into mid-frequency DCT coefficients of 8x8 blocks.
"""
import numpy as np
from scipy.fftpack import dct, idct
# Convert to YCbCr; work on Y channel (float64 for precision).
ycbcr = image.convert("YCbCr")
y_arr, cb_arr, cr_arr = (
np.array(ycbcr.getchannel(c), dtype=np.float64)
for c in (0, 1, 2)
)
h, w = y_arr.shape
# Crop to multiple of 8 for clean tiling.
h8 = (h // 8) * 8
w8 = (w // 8) * 8
y_work = y_arr[:h8, :w8].copy()
n_blocks_y = h8 // 8
n_blocks_x = w8 // 8
n_blocks = n_blocks_y * n_blocks_x
n_embed = _EMBED_END - _EMBED_START # coefficients per block
n_coeffs_total = n_blocks * n_embed
sequence = _payload_to_sequence(payload, n_coeffs_total)
# Apply 2D DCT block-wise and embed.
seq_idx = 0
for by in range(n_blocks_y):
for bx in range(n_blocks_x):
block = y_work[by * 8:(by + 1) * 8, bx * 8:(bx + 1) * 8]
# 2D DCT via separable 1D DCTs (norm="ortho" for energy preservation).
dct_block = dct(dct(block, axis=0, norm="ortho"), axis=1, norm="ortho")
for pos_r, pos_c in _EMBED_POSITIONS:
dct_block[pos_r, pos_c] += strength * 10.0 * sequence[seq_idx]
seq_idx += 1
# Inverse 2D DCT.
idct_block = idct(idct(dct_block, axis=1, norm="ortho"), axis=0, norm="ortho")
y_work[by * 8:(by + 1) * 8, bx * 8:(bx + 1) * 8] = idct_block
# Clip and write back.
y_arr[:h8, :w8] = np.clip(y_work, 0.0, 255.0)
# Reconstruct YCbCr image and convert back to RGB.
y_img = Image.fromarray(y_arr.astype(np.uint8), mode="L")
cb_img = Image.fromarray(cb_arr.astype(np.uint8), mode="L")
cr_img = Image.fromarray(cr_arr.astype(np.uint8), mode="L")
ycbcr_out = Image.merge("YCbCr", (y_img, cb_img, cr_img))
result_image = ycbcr_out.convert("RGB")
# Ensure output size matches input (crop region was internal; no size change).
if result_image.size != image.size:
# Paste watermarked portion back onto original.
final = image.convert("RGB").copy()
final.paste(result_image, (0, 0))
result_image = final
return result_image, WatermarkResult(
success=True,
applied=True,
payload=payload,
method="dct",
quality_impact=0.0,
details={"n_blocks": n_blocks, "n_coeffs": n_coeffs_total, "strength": strength},
)
def _detect_local_dct(
image: Image.Image,
candidate_payload: str,
) -> WatermarkResult:
"""Detect DCT spread-spectrum watermark by computing correlation with candidate sequence.
Returns WatermarkResult with applied=True if correlation exceeds DETECTION_THRESHOLD.
"""
import numpy as np
from scipy.fftpack import dct
ycbcr = image.convert("YCbCr")
y_arr = np.array(ycbcr.getchannel(0), dtype=np.float64)
h, w = y_arr.shape
h8 = (h // 8) * 8
w8 = (w // 8) * 8
y_work = y_arr[:h8, :w8]
n_blocks_y = h8 // 8
n_blocks_x = w8 // 8
n_blocks = n_blocks_y * n_blocks_x
n_embed = _EMBED_END - _EMBED_START
n_coeffs_total = n_blocks * n_embed
sequence = _payload_to_sequence(candidate_payload, n_coeffs_total)
# Extract same DCT coefficients.
extracted = np.empty(n_coeffs_total, dtype=np.float64)
seq_idx = 0
for by in range(n_blocks_y):
for bx in range(n_blocks_x):
block = y_work[by * 8:(by + 1) * 8, bx * 8:(bx + 1) * 8]
dct_block = dct(dct(block, axis=0, norm="ortho"), axis=1, norm="ortho")
for pos_r, pos_c in _EMBED_POSITIONS:
extracted[seq_idx] = dct_block[pos_r, pos_c]
seq_idx += 1
# Correlation between extracted coefficients and expected sequence.
corr = float(np.dot(extracted, sequence) / n_coeffs_total)
detected = corr > DETECTION_THRESHOLD
logger.debug(
"DCT watermark detection: candidate=%s corr=%.4f threshold=%.4f detected=%s",
candidate_payload[:16],
corr,
DETECTION_THRESHOLD,
detected,
)
return WatermarkResult(
success=True,
applied=detected,
payload=candidate_payload if detected else None,
method="dct",
quality_impact=0.0,
details={"correlation": corr, "threshold": DETECTION_THRESHOLD},
)
def embed_watermark(
image: Image.Image,
payload: str,
@ -45,33 +218,29 @@ def embed_watermark(
Args:
image: Source PIL Image
payload: String payload to embed (e.g., user ID, job ID)
method: Watermarking method ('dct', 'dwt', 'lsb', 'hybrid')
strength: Watermark strength (0.0-1.0)
service_url: Optional URL for remote watermarking service
method: Watermarking method ('dct' supported locally; others via service)
strength: Watermark strength (0.02.0; 0.5 = invisible, 2.0 = more robust)
service_url: Optional URL for remote watermarking service. When provided
the service path is used regardless of method.
Returns:
Tuple of (watermarked image, WatermarkResult)
Note:
For production use, this should call the ml-watermarking-python service
which implements robust DCT-DWT hybrid embedding.
"""
if service_url:
return _embed_via_service(image, payload, method, strength, service_url)
# Local fallback - placeholder implementation
# In production, this would use a proper watermarking algorithm
logger.info(f"Local watermark embedding (payload={payload[:16]}...)")
if method == "dct":
logger.info("Local DCT watermark embedding (payload=%s...)", payload[:16])
return _embed_local_dct(image, payload, strength)
# For now, return original image with mock result
# TODO: Implement local DCT/DWT watermarking
# Unsupported local method.
logger.warning("Local watermark method %r not supported without service_url", method)
return image, WatermarkResult(
success=True,
success=False,
applied=False,
payload=payload,
method="none",
quality_impact=0.0,
details={"reason": "Local watermarking not implemented, use service_url"},
method=method,
error=f"Method {method!r} requires service_url",
)
@ -126,7 +295,7 @@ def _embed_via_service(
error="httpx not installed",
)
except Exception as e:
logger.error(f"Watermark service error: {e}")
logger.error("Watermark service error: %s", e)
return image, WatermarkResult(
success=False,
error=str(e),
@ -135,27 +304,32 @@ def _embed_via_service(
def detect_watermark(
image: Image.Image,
candidate_payload: Optional[str] = None,
service_url: Optional[str] = None,
) -> WatermarkResult:
"""Detect and extract watermark from an image.
Args:
image: PIL Image to check for watermark
candidate_payload: Payload string to test for (required for local DCT detection).
Brute-force over multiple candidates is the caller's responsibility.
service_url: Optional URL for remote detection service
Returns:
WatermarkResult with detected payload if found
WatermarkResult with applied=True and payload set if detected
"""
if service_url:
return _detect_via_service(image, service_url)
# Local detection not implemented
logger.info("Watermark detection requires service_url")
if candidate_payload is not None:
return _detect_local_dct(image, candidate_payload)
# No candidate and no service: cannot detect.
logger.info("Watermark detection requires either candidate_payload or service_url")
return WatermarkResult(
success=True,
applied=False,
details={"reason": "Local detection not implemented, use service_url"},
details={"reason": "No candidate_payload provided and no service_url; cannot detect locally"},
)
@ -197,7 +371,7 @@ def _detect_via_service(
error="httpx not installed",
)
except Exception as e:
logger.error(f"Watermark detection error: {e}")
logger.error("Watermark detection error: %s", e)
return WatermarkResult(
success=False,
error=str(e),
@ -219,7 +393,7 @@ def verify_watermark(
Returns:
True if watermark matches expected payload
"""
result = detect_watermark(image, service_url=service_url)
result = detect_watermark(image, candidate_payload=expected_payload, service_url=service_url)
if not result.success or not result.applied:
return False

View file

@ -0,0 +1,163 @@
"""Test adversarial protection and forensic watermark embed/detect round-trip.
All tests are self-contained no external services, no model-boss, no GPU required.
Imports go directly to submodules (not via image_pipeline.utils.__init__ or
image_pipeline.__init__) to avoid eager torch/diffusers imports that crash
in non-GPU environments.
"""
import io
import numpy as np
import pytest
from PIL import Image
# Import directly from submodule to avoid package-level eager torch import.
from image_pipeline.utils.watermark import embed_watermark, detect_watermark
def make_test_image(w: int = 256, h: int = 256) -> Image.Image:
"""Return a pseudo-random RGB image with natural-ish statistics.
Gradient + noise gives the DCT coefficients non-trivial energy in all
frequency bands, which makes watermark embed/detect more realistic than
a solid colour test image.
"""
rng = np.random.default_rng(42)
# Base gradient (horizontal + vertical).
x = np.linspace(0, 1, w, dtype=np.float32)
y = np.linspace(0, 1, h, dtype=np.float32)
grad = (np.outer(y, x) * 200).astype(np.uint8)
# Add independent noise per channel.
noise = rng.integers(0, 55, size=(h, w, 3), dtype=np.uint8)
base = np.stack([grad, 255 - grad, (grad // 2 + 100).astype(np.uint8)], axis=-1)
arr = np.clip(base.astype(np.int16) + noise.astype(np.int16), 0, 255).astype(np.uint8)
return Image.fromarray(arr, mode="RGB")
class TestEmbedDetectRoundtrip:
def test_embed_returns_applied(self) -> None:
img = make_test_image()
payload = "client-42-vip-portal"
watermarked, result = embed_watermark(img, payload, method="dct", strength=0.5)
assert result.success, f"embed_watermark reported failure: {result}"
assert result.applied, f"Watermark not applied: {result}"
assert result.payload == payload
assert result.method == "dct"
assert watermarked.size == img.size
def test_detect_correct_payload(self) -> None:
img = make_test_image()
payload = "client-42-vip-portal"
watermarked, _ = embed_watermark(img, payload, method="dct", strength=0.5)
detection = detect_watermark(watermarked, candidate_payload=payload)
assert detection.success, f"detect_watermark reported failure: {detection}"
assert detection.applied, f"Watermark not detected: {detection}"
assert detection.payload == payload
def test_wrong_payload_not_detected(self) -> None:
img = make_test_image()
payload = "client-42-vip-portal"
watermarked, _ = embed_watermark(img, payload, method="dct", strength=0.5)
detection = detect_watermark(watermarked, candidate_payload="client-999-wrong")
assert not detection.applied, (
f"False positive: wrong payload detected (corr={detection.details})"
)
def test_no_candidate_returns_not_applied(self) -> None:
img = make_test_image()
result = detect_watermark(img)
assert not result.applied
def test_unwatermarked_image_not_detected(self) -> None:
img = make_test_image()
detection = detect_watermark(img, candidate_payload="client-42-vip-portal")
assert not detection.applied, (
f"False positive on clean image (corr={detection.details})"
)
class TestJpegSurvival:
@pytest.mark.slow
def test_survives_jpeg_q85(self) -> None:
"""Watermark must survive JPEG round-trip at quality 85."""
img = make_test_image(512, 512)
payload = "distributor-token-abc123"
watermarked, result = embed_watermark(img, payload, method="dct", strength=0.8)
assert result.applied, f"Watermark not applied before JPEG test: {result}"
buf = io.BytesIO()
watermarked.save(buf, format="JPEG", quality=85)
buf.seek(0)
reloaded = Image.open(buf).convert("RGB")
detection = detect_watermark(reloaded, candidate_payload=payload)
assert detection.applied, (
f"Watermark lost after JPEG compression: corr={detection.details}"
)
@pytest.mark.slow
def test_wrong_payload_still_rejected_after_jpeg(self) -> None:
img = make_test_image(512, 512)
payload = "distributor-token-abc123"
watermarked, _ = embed_watermark(img, payload, method="dct", strength=0.8)
buf = io.BytesIO()
watermarked.save(buf, format="JPEG", quality=85)
buf.seek(0)
reloaded = Image.open(buf).convert("RGB")
detection = detect_watermark(reloaded, candidate_payload="wrong-payload-xyz")
assert not detection.applied, (
f"False positive after JPEG: corr={detection.details}"
)
class TestAdversarialStage:
def test_stage_importable_and_interface(self) -> None:
from image_pipeline.stages.adversarial import AdversarialProtectStage
stage = AdversarialProtectStage()
assert stage.name == "adversarial_protect"
assert stage.is_optional is True
def test_psnr_helper_identical(self) -> None:
from image_pipeline.stages.adversarial import psnr
img = make_test_image()
result = psnr(img, img)
assert result == float("inf")
def test_psnr_modified(self) -> None:
from image_pipeline.stages.adversarial import psnr
original = make_test_image(256, 256)
arr = np.array(original)
# Mild noise (strength ~0.03 typically gives PSNR ~45+ dB).
noisy_arr = np.clip(arr.astype(np.int16) + 5, 0, 255).astype(np.uint8)
noisy = Image.fromarray(noisy_arr)
p = psnr(original, noisy)
assert p > 30.0, f"Expected PSNR > 30 dB for mild noise, got {p:.2f}"
def test_adversarial_noise_psnr_imperceptible(self) -> None:
"""Default strength 0.03 must yield PSNR ≥ 38 dB."""
from image_pipeline.stages.adversarial import _apply_adversarial_noise, psnr
import hmac
img = make_test_image(256, 256)
seed = hmac.digest(b"imajin-adv-v1", b"test-payload", "sha256")
noisy = _apply_adversarial_noise(img, strength=0.03, seed_bytes=seed)
p = psnr(img, noisy)
assert p >= 38.0, f"Default strength PSNR too low: {p:.2f} dB (expected ≥38)"
def test_stage_exports_from_init(self) -> None:
from image_pipeline.stages import AdversarialProtectStage # noqa: F401