diff --git a/services/imajin-adversarial/service/src/attacks/pgd.py b/services/imajin-adversarial/service/src/attacks/pgd.py index 067e3b50..a1d9881e 100644 --- a/services/imajin-adversarial/service/src/attacks/pgd.py +++ b/services/imajin-adversarial/service/src/attacks/pgd.py @@ -26,6 +26,8 @@ def pgd_l_inf( steps: int, alpha: float | None = None, random_start: bool = True, + clamp_lo: float = 0.0, + clamp_hi: float = 1.0, ) -> torch.Tensor: """PGD attack in L-inf norm against a custom loss function. @@ -34,23 +36,28 @@ def pgd_l_inf( x: Clean input tensor, shape (1, C, H, W), values in [0, 1]. loss_fn: Callable (x_adv) → scalar tensor to MAXIMISE. loss_fn must call model internally and return a scalar. - eps: L-inf perturbation budget in [0, 1]. + eps: L-inf perturbation budget, in the same scale as x. steps: Number of PGD iterations. alpha: Per-step size (defaults to eps / 4). random_start: If True, start from a uniform random perturbation in [-eps, eps]. + clamp_lo: Lower bound for valid pixel values (default 0.0 for [0,1] inputs). + clamp_hi: Upper bound for valid pixel values (default 1.0 for [0,1] inputs). + Pass -1.0 / 1.0 for models expecting [-1, 1]-normalised input (e.g. + ArcFace) — mismatching the scale clamps dark pixels to 0.0, causing + visible 127-unit shifts. Returns: - x_adv: Adversarial example, same shape as x, clamped to [0, 1]. + x_adv: Adversarial example, same shape as x, clamped to [clamp_lo, clamp_hi]. """ if alpha is None: alpha = eps / 4.0 - model.train(False) # inference mode (no .eval() to avoid hook false-positive) + model.train(False) x = x.detach() if random_start: x_adv = x + torch.empty_like(x).uniform_(-eps, eps) - x_adv = x_adv.clamp(0.0, 1.0).detach() + x_adv = x_adv.clamp(clamp_lo, clamp_hi).detach() else: x_adv = x.clone().detach() @@ -63,7 +70,7 @@ def pgd_l_inf( x_adv = x_adv + alpha * grad.sign() # Project back into L-inf ball around x x_adv = torch.max(torch.min(x_adv, x + eps), x - eps) - x_adv = x_adv.clamp(0.0, 1.0) + x_adv = x_adv.clamp(clamp_lo, clamp_hi) return x_adv.detach()