2626
2727
2828from ultralytics import FastSAM
29- from ultralytics .models .fastsam import FastSAMPrompt
29+ from ultralytics .models .fastsam import FastSAMPredictor
3030from ultralytics .models .sam import Predictor as SAMPredictor
3131import fire
32- import numpy as np
3332import ultralytics
3433
3534from openadapt import cache
4140SAM_MODEL_NAMES = (
4241 "sam_b.pt" , # base
4342 "sam_l.pt" , # large
44- # "mobile_sam.pt",
4543)
4644MODEL_NAMES = FASTSAM_MODEL_NAMES + SAM_MODEL_NAMES
4745DEFAULT_MODEL_NAME = MODEL_NAMES [0 ]
4846
4947
50- # TODO: rename
5148def fetch_segmented_image (
5249 image : Image .Image ,
5350 model_name : str = DEFAULT_MODEL_NAME ,
@@ -74,14 +71,12 @@ def fetch_segmented_image(
7471def do_fastsam (
7572 image : Image ,
7673 model_name : str ,
77- # TODO: inject from config
7874 device : str = "cpu" ,
7975 retina_masks : bool = True ,
8076 imgsz : int | tuple [int , int ] | None = 1024 ,
81- # threshold below which boxes will be filtered out
8277 min_confidence_threshold : float = 0.4 ,
83- # discards all overlapping boxes with IoU > iou_threshold
8478 max_iou_threshold : float = 0.9 ,
79+ max_det : int = 1000 ,
8580 max_retries : int = 5 ,
8681 retry_delay_seconds : float = 0.1 ,
8782) -> Image :
@@ -90,100 +85,35 @@ def do_fastsam(
9085 For usage of thresholds see:
9186 github.com/ultralytics/ultralytics/blob/dacbd48fcf8407098166c6812eeb751deaac0faf
9287 /ultralytics/utils/ops.py#L164
93-
94- Args:
95- TODO
96- min_confidence_threshold (float, optional): The minimum confidence score
97- that a detection must meet or exceed to be considered valid. Detections
98- below this threshold will not be marked. Defaults to 0.00.
99- max_iou_threshold (float, optional): The maximum allowed Intersection over
100- Union (IoU) value for overlapping detections. Detections that exceed this
101- IoU threshold are considered for suppression, keeping only the
102- detection with the highest confidence. Defaults to 0.05.
10388 """
10489 model = FastSAM (model_name )
105-
10690 imgsz = imgsz or image .size
10791
108- # Run inference on image
10992 everything_results = model (
11093 image ,
11194 device = device ,
11295 retina_masks = retina_masks ,
11396 imgsz = imgsz ,
11497 conf = min_confidence_threshold ,
11598 iou = max_iou_threshold ,
99+ max_det = max_det ,
116100 )
117-
118- # Prepare a Prompt Process object
119- prompt_process = FastSAMPrompt (image , everything_results , device = "cpu" )
120-
121- # Everything prompt
122- annotations = prompt_process .everything_prompt ()
123-
124- # TODO: support other modes once issues are fixed
125- # https://github.com/ultralytics/ultralytics/issues/13218#issuecomment-2142960103
126-
127- # Bbox default shape [0,0,0,0] -> [x1,y1,x2,y2]
128- # annotations = prompt_process.box_prompt(bbox=[200, 200, 300, 300])
129-
130- # Text prompt
131- # annotations = prompt_process.text_prompt(text='a photo of a dog')
132-
133- # Point prompt
134- # points default [[0,0]] [[x1,y1],[x2,y2]]
135- # point_label default [0] [1,0] 0:background, 1:foreground
136- # annotations = prompt_process.point_prompt(points=[[200, 200]], pointlabel=[1])
137-
138- assert len (annotations ) == 1 , len (annotations )
139- annotation = annotations [0 ]
140-
141- # hide original image
142- annotation .orig_img = np .ones (annotation .orig_img .shape )
143-
144- # TODO: in memory, e.g. with prompt_process.fast_show_mask()
145- with TemporaryDirectory () as tmp_dir :
146- # Force the output format to PNG to prevent JPEG compression artefacts
147- annotation .path = annotation .path .replace (".jpg" , ".png" )
148- prompt_process .plot (
149- [annotation ],
150- tmp_dir ,
151- with_contours = False ,
152- retina = False ,
101+ assert len (everything_results ) == 1 , len (everything_results )
102+ annotation = everything_results [0 ]
103+
104+ segmented_image = Image .fromarray (
105+ annotation .plot (
106+ img = np .ones (annotation .orig_img .shape , dtype = annotation .orig_img .dtype ),
107+ kpt_line = False ,
108+ labels = False ,
109+ boxes = False ,
110+ probs = False ,
111+ color_mode = "instance" ,
153112 )
154- result_name = os .path .basename (annotation .path )
155- logger .info (f"{ annotation .path = } " )
156- segmented_image_path = Path (tmp_dir ) / result_name
157- segmented_image = Image .open (segmented_image_path )
158-
159- # Ensure the image is fully loaded before deletion to avoid errors or incomplete operations,
160- # as some operating systems and file systems lock files during read or processing.
161- segmented_image .load ()
162-
163- # Attempt to delete the file with retries and delay
164- retries = 0
165-
166- while retries < max_retries :
167- try :
168- os .remove (segmented_image_path )
169- break # If deletion succeeds, exit loop
170- except OSError as e :
171- if e .errno == errno .ENOENT : # File not found
172- break
173- else :
174- retries += 1
175- time .sleep (retry_delay_seconds )
176-
177- if retries == max_retries :
178- logger .warning (f"Failed to delete { segmented_image_path } " )
179- # Check if the dimensions of the original and segmented images differ
180- # XXX TODO this is a hack, this plotting code should be refactored, but the
181- # bug may exist in ultralytics, since they seem to resize as well; see:
182- # https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/plotting.py#L238
183- # https://github.com/ultralytics/ultralytics/issues/561#issuecomment-1403079910
113+ )
114+
184115 if image .size != segmented_image .size :
185116 logger .warning (f"{ image .size = } != { segmented_image .size = } , resizing..." )
186- # Resize segmented_image to match original using nearest neighbor interpolation
187117 segmented_image = segmented_image .resize (image .size , Image .NEAREST )
188118
189119 assert image .size == segmented_image .size , (image .size , segmented_image .size )
@@ -194,7 +124,6 @@ def do_fastsam(
194124def do_sam (
195125 image : Image .Image ,
196126 model_name : str ,
197- # TODO: add params
198127) -> Image .Image :
199128 # Create SAMPredictor
200129 overrides = dict (
@@ -207,20 +136,7 @@ def do_sam(
207136 predictor = SAMPredictor (overrides = overrides )
208137
209138 # Segment with additional args
210- # results = predictor(source=image, crop_n_layers=1, points_stride=64)
211- results = predictor (
212- source = image ,
213- # crop_n_layers=3,
214- # crop_overlap_ratio=0.5,
215- # crop_downscale_factor=1,
216- # point_grids=None,
217- # points_stride=12,
218- # points_batch_size=128,
219- # conf_thres=0.8,
220- # stability_score_thresh=0.95,
221- # stability_score_offset=0.95,
222- # crop_nms_thresh=0.8,
223- )
139+ results = predictor (source = image )
224140 mask_ims = results_to_mask_images (results )
225141 segmented_image = colorize_masks (mask_ims )
226142 return segmented_image
@@ -238,8 +154,7 @@ def results_to_mask_images(
238154
239155
240156def colorize_masks (masks : list [Image .Image ]) -> Image .Image :
241- """
242- Takes a list of PIL images containing binary masks and returns a new PIL.Image
157+ """Takes a list of PIL images containing binary masks and returns a new PIL.Image
243158 where each mask is colored differently using a unique color for each mask.
244159
245160 Args:
@@ -249,15 +164,11 @@ def colorize_masks(masks: list[Image.Image]) -> Image.Image:
249164 PIL.Image: A new image with each mask in a different color.
250165 """
251166 if not masks :
252- return None # Return None if the list is empty
167+ return None
253168
254- # Assuming all masks are the same size, get dimensions
255169 width , height = masks [0 ].size
256-
257- # Create an empty array with 3 color channels (RGB)
258170 result_image = np .zeros ((height , width , 3 ), dtype = np .uint8 )
259171
260- # Generate unique colors using HSV color space
261172 num_masks = len (masks )
262173 colors = [
263174 tuple (
@@ -271,17 +182,12 @@ def colorize_masks(masks: list[Image.Image]) -> Image.Image:
271182 ]
272183
273184 for idx , mask in enumerate (masks ):
274- # Convert PIL Image to numpy array
275185 mask_array = np .array (mask )
276-
277- # Apply the color to the mask
278186 for c in range (3 ):
279- # Only colorize where the mask is True (assuming mask is binary: 0 or 255)
280187 result_image [:, :, c ] += (mask_array / 255 * colors [idx ][c ]).astype (
281188 np .uint8
282189 )
283190
284- # Convert the result back to a PIL image
285191 return Image .fromarray (result_image )
286192
287193
0 commit comments