Skip to content

Commit eaac60c

Browse files
authored
Visualize intermediate images of Composer (#244)
* Create a folder for attack.composer. * Add composer modules for unbounded patch adversary. * Add config of Adam optimizer. * Add LoadCoords for patch adversary. * Add a config of unbounded patch adversary. * Add a datamodule config for carla patch adversary. * Fix the simple Linf projection. * Add composer module PertImageBase for Lp bounded patch adversary. * Add config of lp-bounded patch adversary. * Add a fake renderer composer module. * Teardown a test dataset gracefully for the rendering-in-loop adversary. * Add configs of simulation-in-loop adversary. * Add a datamodule config for CARLA patch rendering. * Update CarlaDataset config. * Add a composer.visualize switch to see intermediate images. * Revert "Teardown a test dataset gracefully for the rendering-in-loop adversary." This reverts commit a5ffef3. * Revert "Add a composer.visualize switch to see intermediate images." This reverts commit a17e224. * Add a composer.visualize switch to see intermediate images.
1 parent 4dd867e commit eaac60c

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

mart/attack/composer/modular.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import TYPE_CHECKING, Any, Iterable
1010

1111
import torch
12+
from torchvision.transforms.functional import to_pil_image
1213

1314
from mart.nn import SequentialDict
1415

@@ -19,7 +20,7 @@
1920

2021

2122
class Composer(torch.nn.Module):
22-
def __init__(self, perturber: Perturber, modules, sequence) -> None:
23+
def __init__(self, perturber: Perturber, modules, sequence, visualize: bool = False) -> None:
2324
"""_summary_
2425
2526
Args:
@@ -34,6 +35,7 @@ def __init__(self, perturber: Perturber, modules, sequence) -> None:
3435
if isinstance(sequence, dict):
3536
sequence = [sequence[key] for key in sorted(sequence)]
3637
self.functions = SequentialDict(modules, {"composer": sequence})
38+
self.visualize = visualize
3739

3840
def configure_perturbation(self, input: torch.Tensor | Iterable[torch.Tensor]):
3941
return self.perturber.configure_perturbation(input)
@@ -76,6 +78,12 @@ def _compose(
7678
input=input, target=target, perturbation=perturbation, step="composer"
7779
)
7880

81+
# Visualize intermediate images.
82+
if self.visualize:
83+
for key, value in output.items():
84+
if isinstance(value, torch.Tensor):
85+
to_pil_image(value / 255).save(f"{key}.png")
86+
7987
# SequentialDict returns a dictionary DotDict,
8088
# but we only need the return value of the most recently executed module.
8189
last_added_key = next(reversed(output))

0 commit comments

Comments
 (0)