Skip to content

Commit 579a240

Browse files
HaoZhang534hao zhang
andauthored
Support MaskDINO COCO instance/panoptic segmentation (#154)
* add maskdino * delete useless op * add MaskDINO coco panoptic * add README for dino and bound to v0.2.1 Co-authored-by: hao zhang <[email protected]>
1 parent 57d7527 commit 579a240

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+9512
-1
lines changed

detrex/data/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,12 @@
1414
# limitations under the License.
1515

1616
from .detr_dataset_mapper import DetrDatasetMapper
17+
from .dataset_mappers import (
18+
COCOInstanceNewBaselineDatasetMapper,
19+
COCOPanopticNewBaselineDatasetMapper,
20+
MaskFormerSemanticDatasetMapper,
21+
MaskFormerInstanceDatasetMapper,
22+
MaskFormerPanopticDatasetMapper,
23+
)
24+
from . import datasets
25+
from .transforms import ColorAugSSDTransform
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# coding=utf-8
2+
# Copyright 2022 The IDEA Authors. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from .coco_instance_new_baseline_dataset_mapper import build_transform_gen as coco_instance_transform_gen
17+
from .coco_panoptic_new_baseline_dataset_mapper import build_transform_gen as coco_panoptic_transform_gen
18+
from .coco_instance_new_baseline_dataset_mapper import COCOInstanceNewBaselineDatasetMapper
19+
from .coco_panoptic_new_baseline_dataset_mapper import COCOPanopticNewBaselineDatasetMapper
20+
from .mask_former_instance_dataset_mapper import MaskFormerInstanceDatasetMapper
21+
from .mask_former_panoptic_dataset_mapper import MaskFormerPanopticDatasetMapper
22+
from .mask_former_semantic_dataset_mapper import MaskFormerSemanticDatasetMapper
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# coding=utf-8
2+
# Copyright 2022 The IDEA Authors. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ------------------------------------------------------------------------------------------------
16+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
17+
# ------------------------------------------------------------------------------------------------
18+
# COCO Instance Segmentation with LSJ Augmentation
19+
# Modified from:
20+
# https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/data/dataset_mappers/coco_instance_new_baseline_dataset_mapper.py
21+
# ------------------------------------------------------------------------------------------------
22+
23+
import copy
24+
import logging
25+
import numpy as np
26+
import torch
27+
28+
from detectron2.data import detection_utils as utils
29+
from detectron2.data import transforms as T
30+
31+
from pycocotools import mask as coco_mask
32+
33+
34+
def convert_coco_poly_to_mask(segmentations, height, width):
35+
masks = []
36+
for polygons in segmentations:
37+
rles = coco_mask.frPyObjects(polygons, height, width)
38+
mask = coco_mask.decode(rles)
39+
if len(mask.shape) < 3:
40+
mask = mask[..., None]
41+
mask = torch.as_tensor(mask, dtype=torch.uint8)
42+
mask = mask.any(dim=2)
43+
masks.append(mask)
44+
if masks:
45+
masks = torch.stack(masks, dim=0)
46+
else:
47+
masks = torch.zeros((0, height, width), dtype=torch.uint8)
48+
return masks
49+
50+
51+
def build_transform_gen(
52+
image_size,
53+
min_scale,
54+
max_scale,
55+
random_flip: str = "horizontal",
56+
is_train: bool = True,
57+
):
58+
"""
59+
Create a list of default :class:`Augmentation`.
60+
Now it includes resizing and flipping.
61+
62+
Returns:
63+
list[Augmentation]
64+
"""
65+
assert is_train, "Only support training augmentation."
66+
assert random_flip in ["none", "horizontal", "vertical"], f"Only support none/horizontal/vertical flip, but got {random_flip}"
67+
68+
augmentation = []
69+
70+
if random_flip != "none":
71+
augmentation.append(
72+
T.RandomFlip(
73+
horizontal=random_flip == "horizontal",
74+
vertical=random_flip == "vertical",
75+
)
76+
)
77+
78+
augmentation.extend([
79+
T.ResizeScale(
80+
min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size,
81+
),
82+
T.FixedSizeCrop(crop_size=(image_size, image_size))
83+
])
84+
85+
return augmentation
86+
87+
88+
class COCOInstanceNewBaselineDatasetMapper:
89+
"""
90+
A callable which takes a dataset dict in Detectron2 Dataset format,
91+
and map it into a format used by MaskFormer.
92+
93+
This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.
94+
95+
The callable currently does the following:
96+
97+
1. Read the image from "file_name"
98+
2. Applies geometric transforms to the image and annotation
99+
3. Find and applies suitable cropping to the image and annotation
100+
4. Prepare image and annotation to Tensors
101+
"""
102+
def __init__(
103+
self,
104+
is_train=True,
105+
*,
106+
augmentation,
107+
image_format,
108+
):
109+
self.augmentation = augmentation
110+
logging.getLogger(__name__).info(
111+
"[COCO_Instance_LSJ_Augment_Dataset_Mapper] Full TransformGens used in training: {}".format(str(self.augmentation))
112+
)
113+
114+
self.img_format = image_format
115+
self.is_train = is_train
116+
117+
def __call__(self, dataset_dict):
118+
"""
119+
Args:
120+
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
121+
122+
Returns:
123+
dict: a format that builtin models in detectron2 accept
124+
"""
125+
dataset_dict = copy.deepcopy(dataset_dict)
126+
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
127+
utils.check_image_size(dataset_dict, image)
128+
129+
padding_mask = np.ones(image.shape[:2])
130+
image, transforms = T.apply_transform_gens(self.augmentation, image)
131+
132+
padding_mask = transforms.apply_segmentation(padding_mask)
133+
padding_mask = ~ padding_mask.astype(bool)
134+
135+
image_shape = image.shape[:2]
136+
137+
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
138+
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
139+
# Therefore it's important to use torch.Tensor.
140+
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
141+
dataset_dict["padding_mask"] = torch.as_tensor(np.ascontiguousarray(padding_mask))
142+
143+
if not self.is_train:
144+
# USER: Modify this if you want to keep them for some reason.
145+
dataset_dict.pop("annotations", None)
146+
return dataset_dict
147+
148+
if "annotations" in dataset_dict:
149+
for anno in dataset_dict["annotations"]:
150+
anno.pop("keypoints", None)
151+
152+
annos = [
153+
utils.transform_instance_annotations(obj, transforms, image_shape)
154+
for obj in dataset_dict.pop("annotations")
155+
if obj.get("iscrowd", 0) == 0
156+
]
157+
# NOTE: does not support BitMask due to augmentation
158+
# Current BitMask cannot handle empty objects
159+
instances = utils.annotations_to_instances(annos, image_shape)
160+
# After transforms such as cropping are applied, the bounding box may no longer
161+
# tightly bound the object. As an example, imagine a triangle object
162+
# [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight
163+
# bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to
164+
# the intersection of original bounding box and the cropping box.
165+
instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
166+
# Need to filter empty instances first (due to augmentation)
167+
instances = utils.filter_empty_instances(instances)
168+
# Generate masks from polygon
169+
h, w = instances.image_size
170+
# image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float)
171+
if hasattr(instances, 'gt_masks'):
172+
gt_masks = instances.gt_masks
173+
gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w)
174+
instances.gt_masks = gt_masks
175+
# import ipdb; ipdb.set_trace()
176+
dataset_dict["instances"] = instances
177+
178+
return dataset_dict
179+
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# coding=utf-8
2+
# Copyright 2022 The IDEA Authors. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ------------------------------------------------------------------------------------------------
16+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
17+
# ------------------------------------------------------------------------------------------------
18+
# COCO Panoptic Segmentation with LSJ Augmentation
19+
# Modified from:
20+
# https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/data/dataset_mappers/coco_panoptic_new_baseline_dataset_mapper.py
21+
# ------------------------------------------------------------------------------------------------
22+
23+
import copy
24+
import logging
25+
26+
import numpy as np
27+
import torch
28+
29+
from detectron2.config import configurable
30+
from detectron2.data import detection_utils as utils
31+
from detectron2.data import transforms as T
32+
from detectron2.data.transforms import TransformGen
33+
from detectron2.structures import BitMasks, Boxes, Instances
34+
35+
__all__ = ["COCOPanopticNewBaselineDatasetMapper"]
36+
37+
38+
def build_transform_gen(
39+
image_size,
40+
min_scale,
41+
max_scale,
42+
random_flip: str = "horizontal",
43+
is_train: bool = True,
44+
):
45+
"""
46+
Create a list of default :class:`Augmentation` from config.
47+
Now it includes resizing and flipping.
48+
Returns:
49+
list[Augmentation]
50+
"""
51+
assert is_train, "Only support training augmentation"
52+
53+
augmentation = []
54+
55+
if random_flip != "none":
56+
augmentation.append(
57+
T.RandomFlip(
58+
horizontal=random_flip == "horizontal",
59+
vertical=random_flip == "vertical",
60+
)
61+
)
62+
63+
augmentation.extend([
64+
T.ResizeScale(
65+
min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size
66+
),
67+
T.FixedSizeCrop(crop_size=(image_size, image_size)),
68+
])
69+
70+
return augmentation
71+
72+
73+
# This is specifically designed for the COCO dataset.
74+
class COCOPanopticNewBaselineDatasetMapper:
75+
"""
76+
A callable which takes a dataset dict in Detectron2 Dataset format,
77+
and map it into a format used by MaskFormer.
78+
This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.
79+
The callable currently does the following:
80+
1. Read the image from "file_name"
81+
2. Applies geometric transforms to the image and annotation
82+
3. Find and applies suitable cropping to the image and annotation
83+
4. Prepare image and annotation to Tensors
84+
"""
85+
86+
def __init__(
87+
self,
88+
is_train=True,
89+
*,
90+
augmentation,
91+
image_format,
92+
):
93+
"""
94+
NOTE: this interface is experimental.
95+
Args:
96+
is_train: for training or inference
97+
augmentations: a list of augmentations or deterministic transforms to apply
98+
crop_gen: crop augmentation
99+
tfm_gens: data augmentation
100+
image_format: an image format supported by :func:`detection_utils.read_image`.
101+
"""
102+
self.augmentation = augmentation
103+
logging.getLogger(__name__).info(
104+
"[COCOPanopticNewBaselineDatasetMapper] Full TransformGens used in training: {}".format(
105+
str(self.augmentation)
106+
)
107+
)
108+
109+
self.img_format = image_format
110+
self.is_train = is_train
111+
112+
113+
def __call__(self, dataset_dict):
114+
"""
115+
Args:
116+
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
117+
Returns:
118+
dict: a format that builtin models in detectron2 accept
119+
"""
120+
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
121+
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
122+
utils.check_image_size(dataset_dict, image)
123+
124+
image, transforms = T.apply_transform_gens(self.augmentation, image)
125+
image_shape = image.shape[:2] # h, w
126+
127+
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
128+
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
129+
# Therefore it's important to use torch.Tensor.
130+
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
131+
132+
if not self.is_train:
133+
# USER: Modify this if you want to keep them for some reason.
134+
dataset_dict.pop("annotations", None)
135+
return dataset_dict
136+
137+
if "pan_seg_file_name" in dataset_dict:
138+
pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB")
139+
segments_info = dataset_dict["segments_info"]
140+
141+
# apply the same transformation to panoptic segmentation
142+
pan_seg_gt = transforms.apply_segmentation(pan_seg_gt)
143+
144+
from panopticapi.utils import rgb2id
145+
146+
pan_seg_gt = rgb2id(pan_seg_gt)
147+
148+
instances = Instances(image_shape)
149+
classes = []
150+
masks = []
151+
for segment_info in segments_info:
152+
class_id = segment_info["category_id"]
153+
if not segment_info["iscrowd"]:
154+
classes.append(class_id)
155+
masks.append(pan_seg_gt == segment_info["id"])
156+
157+
classes = np.array(classes)
158+
instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
159+
if len(masks) == 0:
160+
# Some image does not have annotation (all ignored)
161+
instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
162+
instances.gt_boxes = Boxes(torch.zeros((0, 4)))
163+
else:
164+
masks = BitMasks(
165+
torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
166+
)
167+
instances.gt_masks = masks.tensor
168+
instances.gt_boxes = masks.get_bounding_boxes()
169+
170+
dataset_dict["instances"] = instances
171+
172+
return dataset_dict
173+

0 commit comments

Comments
 (0)