Skip to content

Commit 1360ab0

Browse files
authored
Add unbounded patch adversary for object detection (#241)
* Create a folder for attack.composer. * Add composer modules for unbounded patch adversary. * Add config of Adam optimizer. * Add LoadCoords for patch adversary. * Add a config of unbounded patch adversary. * Add a datamodule config for carla patch adversary.
1 parent 5ccf5f0 commit 1360ab0

File tree

10 files changed

+198
-2
lines changed

10 files changed

+198
-2
lines changed

mart/attack/composer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .modular import *
2+
from .patch import *

mart/attack/composer.py renamed to mart/attack/composer/modular.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
from mart.nn import SequentialDict
1414

1515
if TYPE_CHECKING:
16-
from .perturber import Perturber
16+
from ..perturber import Perturber
1717

18-
__all__ = ["Composer"]
18+
__all__ = ["Composer", "Additive", "Mask", "Overlay"]
1919

2020

2121
class Composer(torch.nn.Module):

mart/attack/composer/patch.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
#
2+
# Copyright (C) 2022 Intel Corporation
3+
#
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
7+
from __future__ import annotations
8+
9+
import torch
10+
import torchvision.transforms.functional as F
11+
12+
__all__ = [
13+
"PertRectSize",
14+
"PertExtractRect",
15+
"PertRectPerspective",
16+
]
17+
18+
19+
class PertRectSize(torch.nn.Module):
20+
"""Calculate the size of the smallest rectangle that can be transformed with the highest pixel
21+
fidelity."""
22+
23+
@staticmethod
24+
def get_smallest_rect(coords):
25+
# Calculate the distance between two points.
26+
coords_shifted = torch.cat([coords[1:], coords[0:1]])
27+
w1, h2, w2, h1 = torch.sqrt(
28+
torch.sum(torch.pow(torch.subtract(coords, coords_shifted), 2), dim=1)
29+
)
30+
31+
height = int(max(h1, h2).round())
32+
width = int(max(w1, w2).round())
33+
return height, width
34+
35+
def forward(self, coords):
36+
height, width = self.get_smallest_rect(coords)
37+
return {"height": height, "width": width}
38+
39+
40+
class PertExtractRect(torch.nn.Module):
41+
"""Extract a small rectangular patch from the input size one."""
42+
43+
def forward(self, perturbation, height, width):
44+
perturbation = perturbation[:, :height, :width]
45+
return perturbation
46+
47+
48+
class PertRectPerspective(torch.nn.Module):
49+
"""Pad perturbation to input size, then perspective transform the top-left rectangle."""
50+
51+
def forward(self, perturbation, input, coords):
52+
# Pad to the input size.
53+
height_inp, width_inp = input.shape[-2:]
54+
height_pert, width_pert = perturbation.shape[-2:]
55+
height_pad = height_inp - height_pert
56+
width_pad = width_inp - width_pert
57+
perturbation = F.pad(perturbation, padding=[0, 0, width_pad, height_pad])
58+
59+
# F.perspective() requires startpoints and endpoints in CPU.
60+
startpoints = torch.tensor(
61+
[[0, 0], [width_pert, 0], [width_pert, height_pert], [0, height_pert]]
62+
)
63+
endpoints = coords.cpu()
64+
65+
perturbation = F.perspective(
66+
img=perturbation,
67+
startpoints=startpoints,
68+
endpoints=endpoints,
69+
interpolation=F.InterpolationMode.BILINEAR,
70+
fill=0,
71+
)
72+
73+
return perturbation
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
pert_extract_rect:
2+
_target_: mart.attack.composer.PertExtractRect
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
pert_rect_perspective:
2+
_target_: mart.attack.composer.PertRectPerspective
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
pert_rect_size:
2+
_target_: mart.attack.composer.PertRectSize
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
defaults:
2+
- adversary
3+
- /optimizer@optimizer: adam
4+
- enforcer: default
5+
- composer: default
6+
- composer/perturber/initializer: uniform
7+
- composer/perturber/projector: range
8+
- composer/modules:
9+
[pert_rect_size, pert_extract_rect, pert_rect_perspective, overlay]
10+
- gradient_modifier: sign
11+
- gain: rcnn_training_loss
12+
- objective: zero_ap
13+
- override /callbacks@callbacks: [progress_bar, image_visualizer]
14+
15+
max_iters: ???
16+
lr: ???
17+
18+
optimizer:
19+
maximize: True
20+
lr: ${..lr}
21+
22+
enforcer:
23+
# No constraints with complex renderer in the pipeline.
24+
# TODO: Constraint on digital perturbation?
25+
constraints: {}
26+
27+
composer:
28+
perturber:
29+
initializer:
30+
min: 0
31+
max: 255
32+
projector:
33+
min: 0
34+
max: 255
35+
sequence:
36+
seq010:
37+
pert_rect_size: ["target.coords"]
38+
seq020:
39+
pert_extract_rect:
40+
["perturbation", "pert_rect_size.height", "pert_rect_size.width"]
41+
seq040:
42+
pert_rect_perspective: ["pert_extract_rect", "input", "target.coords"]
43+
seq050:
44+
overlay: ["pert_rect_perspective", "input", "target.perturbable_mask"]
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
defaults:
2+
- default.yaml
3+
4+
train_dataset: null
5+
6+
val_dataset: null
7+
8+
test_dataset:
9+
_target_: mart.datamodules.coco.CocoDetection
10+
root: ???
11+
annFile: ${.root}/kwcoco_annotations.json
12+
modalities: ["rgb"]
13+
transforms:
14+
_target_: mart.transforms.Compose
15+
transforms:
16+
- _target_: torchvision.transforms.ToTensor
17+
- _target_: mart.transforms.ConvertCocoPolysToMask
18+
- _target_: mart.transforms.LoadPerturbableMask
19+
perturb_mask_folder: ${....root}/foreground_mask/
20+
- _target_: mart.transforms.LoadCoords
21+
folder: ${....root}/patch_metadata/
22+
- _target_: mart.transforms.Denormalize
23+
center: 0
24+
scale: 255
25+
- _target_: torch.fake_quantize_per_tensor_affine
26+
_partial_: true
27+
# (x/1+0).round().clamp(0, 255) * 1
28+
scale: 1
29+
zero_point: 0
30+
quant_min: 0
31+
quant_max: 255
32+
33+
num_workers: 0
34+
ims_per_batch: 1
35+
36+
collate_fn:
37+
_target_: hydra.utils.get_method
38+
path: mart.datamodules.coco.collate_fn

mart/configs/optimizer/adam.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
_target_: mart.optim.OptimizerFactory
2+
optimizer:
3+
_target_: hydra.utils.get_method
4+
path: torch.optim.Adam
5+
lr: ???
6+
betas:
7+
- 0.9
8+
- 0.999
9+
eps: 1e-08
10+
weight_decay: 0
11+
bias_decay: 0
12+
norm_decay: 0

mart/transforms/extended.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import os
99
from typing import Dict, Optional, Tuple
1010

11+
import numpy as np
1112
import torch
1213
from PIL import Image, ImageOps
1314
from torch import Tensor
@@ -26,6 +27,7 @@
2627
"Lambda",
2728
"SplitLambda",
2829
"LoadPerturbableMask",
30+
"LoadCoords",
2931
"ConvertInstanceSegmentationToPerturbable",
3032
"RandomHorizontalFlip",
3133
"ConvertCocoPolysToMask",
@@ -139,6 +141,25 @@ def __call__(self, image, target):
139141
return image, target
140142

141143

144+
class LoadCoords(ExTransform):
145+
"""Load perturbable masks and add to target."""
146+
147+
def __init__(self, folder) -> None:
148+
self.folder = folder
149+
self.to_tensor = T.ToTensor()
150+
151+
def __call__(self, image, target):
152+
file_name = os.path.splitext(target["file_name"])[0]
153+
coords_fname = f"{file_name}_coords.npy"
154+
coords_fpath = os.path.join(self.folder, coords_fname)
155+
coords = np.load(coords_fpath)
156+
157+
coords = self.to_tensor(coords)[0]
158+
# Convert to float to be differentiable.
159+
target["coords"] = coords
160+
return image, target
161+
162+
142163
class RandomHorizontalFlip(T.RandomHorizontalFlip, ExTransform):
143164
"""Flip the image and annotations including boxes, masks, keypoints and the
144165
perturable_masks."""

0 commit comments

Comments
 (0)