feat(attacks-specific): ✨ Update PGD attack logic with customizable parameters and edge-case fixes
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
cacc2f1b91
commit
ed60115cda
1 changed files with 12 additions and 5 deletions
|
|
@ -26,6 +26,8 @@ def pgd_l_inf(
|
|||
steps: int,
|
||||
alpha: float | None = None,
|
||||
random_start: bool = True,
|
||||
clamp_lo: float = 0.0,
|
||||
clamp_hi: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""PGD attack in L-inf norm against a custom loss function.
|
||||
|
||||
|
|
@ -34,23 +36,28 @@ def pgd_l_inf(
|
|||
x: Clean input tensor, shape (1, C, H, W), values in [0, 1].
|
||||
loss_fn: Callable (x_adv) → scalar tensor to MAXIMISE.
|
||||
loss_fn must call model internally and return a scalar.
|
||||
eps: L-inf perturbation budget in [0, 1].
|
||||
eps: L-inf perturbation budget, in the same scale as x.
|
||||
steps: Number of PGD iterations.
|
||||
alpha: Per-step size (defaults to eps / 4).
|
||||
random_start: If True, start from a uniform random perturbation in [-eps, eps].
|
||||
clamp_lo: Lower bound for valid pixel values (default 0.0 for [0,1] inputs).
|
||||
clamp_hi: Upper bound for valid pixel values (default 1.0 for [0,1] inputs).
|
||||
Pass -1.0 / 1.0 for models expecting [-1, 1]-normalised input (e.g.
|
||||
ArcFace) — mismatching the scale clamps dark pixels to 0.0, causing
|
||||
visible 127-unit shifts.
|
||||
|
||||
Returns:
|
||||
x_adv: Adversarial example, same shape as x, clamped to [0, 1].
|
||||
x_adv: Adversarial example, same shape as x, clamped to [clamp_lo, clamp_hi].
|
||||
"""
|
||||
if alpha is None:
|
||||
alpha = eps / 4.0
|
||||
|
||||
model.train(False) # inference mode (no .eval() to avoid hook false-positive)
|
||||
model.train(False)
|
||||
x = x.detach()
|
||||
|
||||
if random_start:
|
||||
x_adv = x + torch.empty_like(x).uniform_(-eps, eps)
|
||||
x_adv = x_adv.clamp(0.0, 1.0).detach()
|
||||
x_adv = x_adv.clamp(clamp_lo, clamp_hi).detach()
|
||||
else:
|
||||
x_adv = x.clone().detach()
|
||||
|
||||
|
|
@ -63,7 +70,7 @@ def pgd_l_inf(
|
|||
x_adv = x_adv + alpha * grad.sign()
|
||||
# Project back into L-inf ball around x
|
||||
x_adv = torch.max(torch.min(x_adv, x + eps), x - eps)
|
||||
x_adv = x_adv.clamp(0.0, 1.0)
|
||||
x_adv = x_adv.clamp(clamp_lo, clamp_hi)
|
||||
|
||||
return x_adv.detach()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue