-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathdata_utils.py
More file actions
executable file
·424 lines (349 loc) · 16.3 KB
/
data_utils.py
File metadata and controls
executable file
·424 lines (349 loc) · 16.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
import math
import os
from io import BytesIO
import cartopy.crs as ccrs
import cartopy.io.img_tiles as cimgt
import fsspec
import numpy as np
import pyarrow.parquet as pq
from matplotlib.backends.backend_agg import FigureCanvasAgg
from matplotlib.figure import Figure
from matplotlib.patches import Rectangle
from PIL import Image, ImageDraw, ImageFont
from rasterio.io import MemoryFile
def preprocess_s2_true_color(rgb_array):
"""
Normalize raw Sentinel-2 RGB bands to true-color values for display.
Applies the standard true-color normalization: divide by 10,000 and
scale by 2.5, clipping to the range [0, 1].
Args:
rgb_array (np.ndarray): Raw Sentinel-2 RGB array (H, W, 3) in uint16.
Returns:
np.ndarray: Normalized true-color array in range [0, 1] (float32).
"""
return (2.5 * (rgb_array.astype(np.float32) / 10000.0)).clip(0, 1)
def crop_center(img_array, cropx, cropy):
y, x, _c = img_array.shape
startx = x // 2 - (cropx // 2)
starty = y // 2 - (cropy // 2)
return img_array[starty:starty+cropy, startx:startx+cropx]
def read_tif_bytes(tif_bytes):
with MemoryFile(tif_bytes) as mem_f:
with mem_f.open(driver='GTiff') as f:
return f.read().squeeze()
def read_row_memory(row_dict, columns=None):
if columns is None:
columns = ["thumbnail"]
url = row_dict['parquet_url']
row_idx = row_dict['parquet_row']
fs_options = {
"cache_type": "readahead",
"block_size": 5 * 1024 * 1024
}
with fsspec.open(url, mode='rb', **fs_options) as f:
with pq.ParquetFile(f) as pf:
table = pf.read_row_group(row_idx, columns=columns)
row_output = {}
for col in columns:
col_data = table[col][0].as_py()
if col != 'thumbnail':
row_output[col] = read_tif_bytes(col_data)
else:
stream = BytesIO(col_data)
row_output[col] = Image.open(stream)
return row_output
def _prepare_row_dict(product_id, df_source, verbose=True):
"""Locate the product row and fix the parquet URL. Returns (row_dict, error_tuple)."""
if df_source is None:
if verbose:
print("❌ Error: No DataFrame provided.")
return None, (None, None)
row_subset = df_source[df_source['product_id'] == product_id]
if len(row_subset) == 0:
if verbose:
print(f"❌ Error: Product ID {product_id} not found in DataFrame.")
return None, (None, None)
row_dict = row_subset.iloc[0].to_dict()
if 'parquet_url' in row_dict:
url = row_dict['parquet_url']
# Resolve relative paths to absolute paths
if not url.startswith(('http://', 'https://', '/')):
# candidate_bases = [
# '/data384/datasets/Core-S2L2A-249k/',
# '/data384/datasets/Core-S2L2A/',
# './',
# ]
# resolved = False
# for base in candidate_bases:
# abs_path = os.path.join(base, url)
# if os.path.exists(abs_path):
# row_dict['parquet_url'] = abs_path
# resolved = True
# break
# if not resolved:
# # Try as-is (fsspec may handle it)
# pass
pass # disable load from local file
elif 'huggingface.co' in url:
row_dict['parquet_url'] = url.replace('https://huggingface.co', 'https://modelscope.cn').replace('resolve/main', 'resolve/master')
elif 'hf-mirror.com' in url:
row_dict['parquet_url'] = url.replace('https://hf-mirror.com', 'https://modelscope.cn').replace('resolve/main', 'resolve/master')
else:
if verbose:
print("❌ Error: 'parquet_url' missing in metadata.")
return None, (None, None)
return row_dict, None
def _bands_to_rgb_pil(bands_data, verbose=True, normalize=True):
"""
Stack B04/B03/B02 bands into a RGB PIL Image pair (384-crop, full).
Args:
bands_data (dict): Dictionary with 'B04', 'B03', 'B02' band arrays.
verbose (bool): Whether to print debug info.
normalize (bool): If True, apply true-color normalization (2.5 * value / 1e4).
If False, return raw values directly converted to uint8
(values > 255 will be clamped).
Returns:
tuple: (img_384, img_full) as PIL Images.
"""
rgb_img = np.stack([bands_data['B04'], bands_data['B03'], bands_data['B02']], axis=-1)
if verbose:
print(f"Raw RGB stats: Min={rgb_img.min()}, Max={rgb_img.max()}, Mean={rgb_img.mean()}, Dtype={rgb_img.dtype}")
if normalize:
rgb_norm = preprocess_s2_true_color(rgb_img)
rgb_uint8 = (rgb_norm * 255).astype(np.uint8)
else:
rgb_uint8 = rgb_img.clip(0, 255).astype(np.uint8)
if verbose:
print(f"Processed RGB stats: Min={rgb_uint8.min()}, Max={rgb_uint8.max()}, Mean={rgb_uint8.mean()}")
img_full = Image.fromarray(rgb_uint8)
if rgb_uint8.shape[0] >= 384 and rgb_uint8.shape[1] >= 384:
cropped_array = crop_center(rgb_uint8, 384, 384)
img_384 = Image.fromarray(cropped_array)
else:
if verbose:
print(f"⚠️ Image too small {rgb_uint8.shape}, resizing to 384x384.")
img_384 = img_full.resize((384, 384))
return img_384, img_full
def _thumbnail_to_pil(thumb_img, verbose=True):
"""Convert a thumbnail PIL Image to a (384-crop/resize, full) pair."""
img_full = thumb_img.convert("RGB")
w, h = img_full.size
if w >= 384 and h >= 384:
arr = np.array(img_full)
cropped = crop_center(arr, 384, 384)
img_384 = Image.fromarray(cropped)
else:
if verbose:
print(f"⚠️ Thumbnail too small ({w}x{h}), resizing to 384x384.")
img_384 = img_full.resize((384, 384))
return img_384, img_full
# All 12 Sentinel-2 bands available in MajorTOM parquet files
MULTIBAND_COLUMNS = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12']
def reorder_multiband(multiband_array, target_bands, source_bands=None):
"""
Reorder a multiband array from source band order to target band order.
This is the single source of truth for mapping between the 12-band
MajorTOM format and any model-specific band subset/order.
Args:
multiband_array (np.ndarray): Array of shape [..., C] where C matches
len(source_bands). Typically [H, W, 12] from download_and_process_image.
target_bands (list[str]): Band names the model expects, e.g.
['B02','B03',...] for Clay or ['B01',...,'B12'] for SatCLIP.
source_bands (list[str] | None): Band names present in multiband_array.
Defaults to MULTIBAND_COLUMNS.
Returns:
np.ndarray: Array reordered to target_bands, shape [..., len(target_bands)].
Raises:
ValueError: If a band in target_bands is not found in source_bands.
"""
if source_bands is None:
source_bands = MULTIBAND_COLUMNS
# Fast path: no reordering needed
if len(target_bands) == len(source_bands) and list(target_bands) == list(source_bands):
return multiband_array
band_map = {name: i for i, name in enumerate(source_bands)}
missing = [b for b in target_bands if b not in band_map]
if missing:
raise ValueError(
f"Target bands not found in source: {missing}. "
f"Source has {source_bands}, target asked for {target_bands}."
)
indices = [band_map[b] for b in target_bands]
return multiband_array[..., indices]
def download_and_process_image(product_id, df_source=None, verbose=True, mode="thumbnail", normalize=True):
"""
Download and process a MajorTOM image.
Args:
product_id: The product identifier in df_source.
df_source: DataFrame with metadata (product_id, parquet_url, parquet_row, …).
verbose: Print progress / debug info.
mode: Download mode — one of:
"thumbnail" (default) — read the pre-rendered thumbnail column (fastest).
"rgb" — read B04/B03/B02 bands and compose true-color RGB.
"multiband" — read all 12 S2 bands + thumbnail for preview.
normalize: For mode="rgb", whether to apply true-color normalization.
Set to False if you need raw band values for model preprocessing.
Returns:
mode="thumbnail" → (img_384, img_full) — PIL Images from thumbnail.
mode="rgb" → (img_384, img_full) — PIL Images from RGB bands.
mode="multiband" → (img_384, img_full, bands) — thumbnail preview + np.ndarray (H, W, 12) uint16.
"""
if os.path.exists("./configs/modelscope_ai.yaml"):
os.environ["MODEL_DOMAIN"] = "modelscope.cn"
else:
os.environ["MODEL_DOMAIN"] = "modelscope.cn"
row_dict, _err = _prepare_row_dict(product_id, df_source, verbose)
if row_dict is None:
return (None, None) if mode != "multiband" else (None, None, None)
if verbose:
print(f"⬇️ Fetching data for {product_id} [mode={mode}] from {row_dict['parquet_url']}...")
try:
# ---- thumbnail mode ----
if mode == "thumbnail":
data = read_row_memory(row_dict, columns=['thumbnail'])
if 'thumbnail' not in data or data['thumbnail'] is None:
if verbose:
print("⚠️ Thumbnail unavailable, falling back to rgb mode.")
return download_and_process_image(product_id, df_source, verbose, mode="rgb")
img_384, img_full = _thumbnail_to_pil(data['thumbnail'], verbose)
if verbose:
print(f"✅ Successfully processed {product_id} (thumbnail)")
return img_384, img_full
# ---- rgb mode ----
elif mode == "rgb":
bands_data = read_row_memory(row_dict, columns=['B04', 'B03', 'B02'])
if not all(b in bands_data for b in ['B04', 'B03', 'B02']):
if verbose:
print(f"❌ Error: Missing bands in fetched data for {product_id}")
return None, None
img_384, img_full = _bands_to_rgb_pil(bands_data, verbose, normalize=normalize)
if verbose:
print(f"✅ Successfully processed {product_id} (rgb)")
return img_384, img_full
# ---- multiband mode ----
elif mode == "multiband":
columns_to_read = ['thumbnail', *MULTIBAND_COLUMNS]
data = read_row_memory(row_dict, columns=columns_to_read)
# Preview from thumbnail (fallback to RGB composite)
if 'thumbnail' in data and data['thumbnail'] is not None:
img_384, img_full = _thumbnail_to_pil(data['thumbnail'], verbose)
elif all(b in data for b in ['B04', 'B03', 'B02']):
img_384, img_full = _bands_to_rgb_pil(data, verbose)
else:
img_384, img_full = None, None
# Stack all 12 bands → (H, W, 12)
# Determine reference shape from 10m bands (B04/B03/B02) for consistent dimensions
ref_bands_10m = ['B04', 'B03', 'B02']
ref_shape = None
for rb in ref_bands_10m:
if rb in data and data[rb] is not None:
ref_shape = data[rb].shape[:2] # (H, W)
break
if ref_shape is None:
ref_shape = next((data[b].shape[:2] for b in MULTIBAND_COLUMNS if b in data and data[b] is not None), (224, 224))
band_arrays = []
for band_name in MULTIBAND_COLUMNS:
if band_name not in data or data[band_name] is None:
if verbose:
print(f"⚠️ Band {band_name} missing, filling with zeros.")
band_arrays.append(np.zeros(ref_shape, dtype=np.uint16))
else:
arr = data[band_name]
# Resize bands with different spatial resolution to the reference shape
if arr.shape[:2] != ref_shape:
if verbose:
print(f"⚠️ Band {band_name} shape {arr.shape} != ref {ref_shape}, resizing.")
arr_pil = Image.fromarray(arr)
arr_pil = arr_pil.resize((ref_shape[1], ref_shape[0]), resample=Image.BICUBIC)
arr = np.array(arr_pil)
band_arrays.append(arr)
multiband_array = np.stack(band_arrays, axis=-1) # (H, W, 12)
if verbose:
print(f"✅ Successfully processed {product_id} (multiband {multiband_array.shape})")
return img_384, img_full, multiband_array
else:
if verbose:
print(f"❌ Unknown mode: {mode}")
return None, None
except Exception as e:
if verbose:
print(f"❌ Error processing {product_id}: {e}")
import traceback
traceback.print_exc()
return (None, None) if mode != "multiband" else (None, None, None)
# Define Esri Imagery Class
class EsriImagery(cimgt.GoogleTiles):
def _image_url(self, tile):
x, y, z = tile
return f'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}'
def get_placeholder_image(text="Image Unavailable", size=(384, 384)):
img = Image.new('RGB', size, color=(200, 200, 200))
d = ImageDraw.Draw(img)
try:
# Try to load a default font
font = ImageFont.load_default()
except Exception:
font = None
# Draw text in center (rough approximation)
# For better centering we would need font metrics, but simple is fine here
d.text((20, size[1]//2), text, fill=(0, 0, 0), font=font)
return img
def get_esri_satellite_image(lat, lon, score=None, rank=None, query=None):
"""
Generates a satellite image visualization using Esri World Imagery via Cartopy.
Matches the style of the provided notebook.
Uses OO Matplotlib API for thread safety.
"""
try:
imagery = EsriImagery()
# Create figure using OO API
fig = Figure(figsize=(5, 5), dpi=100)
_canvas = FigureCanvasAgg(fig)
ax = fig.add_subplot(1, 1, 1, projection=imagery.crs)
# Set extent to approx 10km x 10km around the point
extent_deg = 0.05
ax.set_extent([lon - extent_deg, lon + extent_deg, lat - extent_deg, lat + extent_deg], crs=ccrs.PlateCarree())
# Add the imagery
ax.add_image(imagery, 14)
# Add a marker for the center
ax.plot(lon, lat, marker='+', color='yellow', markersize=12, markeredgewidth=2, transform=ccrs.PlateCarree())
# Add Bounding Box (3840m x 3840m)
box_size_m = 384 * 10 # 3840m
# Convert meters to degrees (approx)
# 1 deg lat = 111320m
# 1 deg lon = 111320m * cos(lat)
dlat = (box_size_m / 111320)
dlon = (box_size_m / (111320 * math.cos(math.radians(lat))))
# Bottom-Left corner
rect_lon = lon - dlon / 2
rect_lat = lat - dlat / 2
# Add Rectangle
rect = Rectangle((rect_lon, rect_lat), dlon, dlat,
linewidth=2, edgecolor='red', facecolor='none', transform=ccrs.PlateCarree())
ax.add_patch(rect)
# Title
title_parts = []
if query:
title_parts.append(f"{query}")
if rank is not None:
title_parts.append(f"Rank {rank}")
if score is not None:
title_parts.append(f"Score: {score:.4f}")
ax.set_title("\n".join(title_parts), fontsize=10)
# Save to buffer
buf = BytesIO()
fig.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
return Image.open(buf)
except Exception as e:
# Suppress full traceback for network errors to avoid log spam
error_msg = str(e)
if "Connection reset by peer" in error_msg or "Network is unreachable" in error_msg or "urlopen error" in error_msg:
print(f"⚠️ Network warning: Could not fetch Esri satellite map for ({lat:.4f}, {lon:.4f}). Server might be offline.")
else:
print(f"Error generating Esri image for {lat}, {lon}: {e}")
# Only print traceback for non-network errors
# import traceback
# traceback.print_exc()
# Return a placeholder image with text
return get_placeholder_image(f"Map Unavailable\n({lat:.2f}, {lon:.2f})")