Skip to content

Commit 629bc67

Browse files
authored
Merge pull request #70 from h-munakata/main
Add CASTELLA
2 parents a3e7621 + cd43ee6 commit 629bc67

File tree

10 files changed

+3913
-11
lines changed

10 files changed

+3913
-11
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ It supports seven models, four features (video and audio features), and six data
1212
Furthermore, Lighthouse supports [audio moment retrieval](https://h-munakata.github.io/Language-based-Audio-Moment-Retrieval/), a task to identify relevant moments from an audio input based on a given text query.
1313

1414
## News
15+
- [2025/11/20] [Version 1.2](https://github.com/line/lighthouse/releases/tag/v1.2) Our work ["CASTELLA: Long Audio Dataset with Captions and Temporal Boundaries"](https://arxiv.org/abs/2511.15131) has been released. This update adds support for a new AMR dataset called CASTELLA.
1516
- [2025/06/04] [Version 1.1](https://github.com/line/lighthouse/releases/tag/v1.1) has been released. It includes API changes, AMR gradio demo, and huggingface wrappers for the audio moment retrieval and clotho dataset.
1617
- [2024/12/24] Our work ["Language-based audio moment retrieval"](https://arxiv.org/abs/2409.15672) has been accepted at ICASSP 2025.
1718
- [2024/10/22] [Version 1.0](https://github.com/line/lighthouse/releases/tag/v1.0) has been released.
@@ -142,6 +143,7 @@ Audio moment retrieval
142143
### Pre-trained weights
143144
Pre-trained weights can be downloaded from [here](https://drive.google.com/file/d/1jxs_bvwttXTF9Lk3aKLohkqfYOonLyrO/view?usp=sharing).
144145
Download and unzip on the home directory.
146+
AMR models trained on CASTELLA and Clotho-Moment is available in [here](https://zenodo.org/uploads/17422909)
145147

146148
### Datasets
147149
Due to the copyright issue, we here distribute only feature files.
@@ -158,6 +160,7 @@ To extract features from videos, we use [HERO_Video_Feature_Extractor](https://g
158160
For [AMR](https://h-munakata.github.io/Language-based-Audio-Moment-Retrieval/), download features from here.
159161

160162
- [Clotho Moment/TUT2017/UnAV100-subset](https://zenodo.org/records/13806234)
163+
- [CASTELLA](https://zenodo.org/records/17412176) [[Mirror on HF]](https://huggingface.co/datasets/lighthouse-emnlp2024/CASTELLA_CLAP_features)
161164

162165
The whole directory should be look like this:
163166
```

configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ eval_bsz: 100
1212
grad_clip: 0.1
1313
max_q_l: 32
1414
max_v_l: 75
15+
max_a_l: 75
1516
max_windows: 5
1617
clip_length: 1
1718
eval_epoch_interval: 1

configs/dataset/castella.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
dset_name: castella
2+
clip_length: 1
3+
train_path: data/castella/castella_train_release.jsonl
4+
eval_path: data/castella/castella_val_release.jsonl
5+
6+
max_a_l: 300
7+
max_v_l: 300

data/castella/castella_test_release.jsonl

Lines changed: 1347 additions & 0 deletions
Large diffs are not rendered by default.

data/castella/castella_train_release.jsonl

Lines changed: 2182 additions & 0 deletions
Large diffs are not rendered by default.

data/castella/castella_val_release.jsonl

Lines changed: 352 additions & 0 deletions
Large diffs are not rendered by default.

training/cg_detr_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def __getitem__(self, index):
196196
else:
197197
model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"], model_inputs["saliency_all_labels"] = \
198198
self.get_saliency_labels_all(meta["relevant_clip_ids"], meta["saliency_scores"], ctx_l)
199-
elif self.dset_name in ['charades', 'tacos', 'activitynet', 'clotho-moment', 'unav100-subset', 'tut2017']: ## charades, tacos, nlq
199+
elif self.dset_name in ['charades', 'tacos', 'activitynet', 'clotho-moment', 'unav100-subset', 'tut2017', 'castella']: ## charades, tacos, nlq
200200
model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"], model_inputs["saliency_all_labels"] = \
201201
self.get_saliency_labels_sub_as_query(meta["relevant_windows"][0], meta["duration"], ctx_l) # only one gt
202202
else:
@@ -458,7 +458,7 @@ def _get_audio_feat_by_vid(self, vid):
458458
raise NotImplementedError
459459
_feat = l2_normalize_np_array(_feat) # normalize?
460460
a_feat_list.append(_feat)
461-
elif self.dset_name in ['clotho-moment', 'unav100-subset', 'tut2017']:
461+
elif self.dset_name in ['clotho-moment', 'unav100-subset', 'tut2017', 'castella']:
462462
if self.a_feat_types == "clap":
463463
_feat_path = join(_feat_dir, f"{vid}.npz")
464464
_feat = np.load(_feat_path)["features"][:self.max_a_l].astype(np.float32)

training/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def __getitem__(self, index):
212212
model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"], model_inputs["saliency_all_labels"] = \
213213
self.get_saliency_labels_all(meta["relevant_clip_ids"], meta["saliency_scores"], ctx_l)
214214

215-
elif self.dset_name in ['charades', 'tacos', 'activitynet', 'clotho-moment', 'unav100-subset', 'tut2017']:
215+
elif self.dset_name in ['charades', 'tacos', 'activitynet', 'clotho-moment', 'unav100-subset', 'tut2017', 'castella']:
216216
model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"], model_inputs["saliency_all_labels"] = \
217217
self.get_saliency_labels_sub_as_query(meta["relevant_windows"][0], ctx_l)
218218
else:
@@ -480,7 +480,7 @@ def _get_audio_feat_by_vid(self, vid):
480480
raise NotImplementedError
481481
_feat = l2_normalize_np_array(_feat) # normalize?
482482
a_feat_list.append(_feat)
483-
elif self.dset_name in ['clotho-moment', 'unav100-subset', 'tut2017']:
483+
elif self.dset_name in ['clotho-moment', 'unav100-subset', 'tut2017', 'castella']:
484484
if self.a_feat_types == "clap":
485485
_feat_path = join(_feat_dir, f"{vid}.npz")
486486
_feat = np.load(_feat_path)["features"][:self.max_a_l].astype(np.float32)

training/evaluate.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,12 @@ def compute_mr_results(epoch_i, model, eval_loader, opt, criterion=None):
261261
min_w_l=2, max_w_l=60, move_window_method="left",
262262
process_func_names=("clip_ts", "round_multiple")
263263
)
264+
elif opt.dset_name in ['castella']:
265+
post_processor = PostProcessorDETR(
266+
clip_length=opt.clip_length, min_ts_val=0, max_ts_val=300,
267+
min_w_l=1, max_w_l=300, move_window_method="left",
268+
process_func_names=("clip_ts", "round_multiple")
269+
)
264270
elif opt.dset_name in ['tacos', 'activitynet', 'youtube_highlight']:
265271
post_processor = PostProcessorDETR(
266272
clip_length=opt.clip_length, min_ts_val=0, max_ts_val=50000,
@@ -367,6 +373,7 @@ def start_inference(opt, domain=None):
367373
a_feat_types=opt.a_feat_types,
368374
max_q_l=opt.max_q_l,
369375
max_v_l=opt.max_v_l,
376+
max_a_l=opt.max_a_l,
370377
clip_len=opt.clip_length,
371378
max_windows=opt.max_windows,
372379
span_loss_type=opt.span_loss_type,
@@ -375,7 +382,7 @@ def start_inference(opt, domain=None):
375382

376383
eval_dataset = CGDETR_StartEndDataset(**dataset_config) if opt.model_name == 'cg_detr' else StartEndDataset(**dataset_config)
377384
model, criterion, _, _ = setup_model(opt)
378-
checkpoint = torch.load(opt.model_path)
385+
checkpoint = torch.load(opt.model_path, weights_only=False)
379386
model.load_state_dict(checkpoint["model"])
380387
logger.info("Model checkpoint: {}".format(opt.model_path))
381388
if not load_labels:
@@ -402,6 +409,8 @@ def check_valid_combination(dataset, feature, domain):
402409
'tvsum': ['resnet_glove', 'clip', 'clip_slowfast', 'i3d_clip'],
403410
'youtube_highlight': ['clip', 'clip_slowfast'],
404411
'clotho-moment': ['clap'],
412+
'unav100-subset': ['clap'],
413+
'castella': ['clap'],
405414
}
406415

407416
domain_map = {
@@ -421,8 +430,8 @@ def check_valid_combination(dataset, feature, domain):
421430
choices=['moment_detr', 'qd_detr', 'eatr', 'cg_detr', 'uvcom', 'tr_detr', 'taskweave_hd2mr', 'taskweave_mr2hd'],
422431
help='model name. select from [moment_detr, qd_detr, eatr, cg_detr, uvcom, tr_detr, taskweave_hd2mr, taskweave_mr2hd]')
423432
parser.add_argument('--dataset', '-d', type=str, required=True,
424-
choices=['activitynet', 'charades', 'qvhighlight', 'qvhighlight_pretrain', 'tacos', 'tvsum', 'youtube_highlight', 'clotho-moment', 'unav100-subset', 'tut2017'],
425-
help='dataset name. select from [activitynet, charades, qvhighlight, qvhighlight_pretrain, tacos, tvsum, youtube_highlight, clotho-moment, unav100-subset, tut2017]')
433+
choices=['activitynet', 'charades', 'qvhighlight', 'qvhighlight_pretrain', 'tacos', 'tvsum', 'youtube_highlight', 'clotho-moment', 'unav100-subset', 'tut2017', 'castella'],
434+
help='dataset name. select from [activitynet, charades, qvhighlight, qvhighlight_pretrain, tacos, tvsum, youtube_highlight, clotho-moment, unav100-subset, tut2017, castella]')
426435
parser.add_argument('--feature', '-f', type=str, required=True,
427436
choices=['resnet_glove', 'clip', 'clip_slowfast', 'clip_slowfast_pann', 'i3d_clip', 'clap'],
428437
help='feature name. select from [resnet_glove, clip, clip_slowfast, clip_slowfast_pann, i3d_clip, clap].'

training/train.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ def train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i):
137137
losses.backward()
138138
else:
139139
outputs = model(**model_inputs, targets=targets) if opt.model_name == 'cg_detr' else model(**model_inputs)
140-
141140
loss_dict = criterion(outputs, targets)
142141
losses = sum(loss_dict[k] * criterion.weight_dict[k] for k in loss_dict.keys() if k in criterion.weight_dict)
143142

@@ -228,6 +227,7 @@ def main(opt, resume=None, domain=None):
228227
a_feat_types=opt.a_feat_types,
229228
max_q_l=opt.max_q_l,
230229
max_v_l=opt.max_v_l,
230+
max_a_l=opt.max_a_l,
231231
clip_len=opt.clip_length,
232232
max_windows=opt.max_windows,
233233
span_loss_type=opt.span_loss_type,
@@ -246,7 +246,7 @@ def main(opt, resume=None, domain=None):
246246

247247
# load checkpoint for QVHighlight pretrain -> finetune
248248
if resume is not None:
249-
checkpoint = torch.load(resume)
249+
checkpoint = torch.load(resume, weights_only=False)
250250
model.load_state_dict(checkpoint["model"])
251251
logger.info("Loaded model checkpoint: {}".format(resume))
252252

@@ -267,6 +267,7 @@ def check_valid_combination(dataset, feature, domain):
267267
'tvsum': ['resnet_glove', 'clip', 'clip_slowfast', 'i3d_clip'],
268268
'youtube_highlight': ['clip', 'clip_slowfast'],
269269
'clotho-moment': ['clap'],
270+
'castella': ['clap'],
270271
}
271272

272273
domain_map = {
@@ -286,8 +287,8 @@ def check_valid_combination(dataset, feature, domain):
286287
choices=['moment_detr', 'qd_detr', 'eatr', 'cg_detr', 'uvcom', 'tr_detr', 'taskweave_hd2mr', 'taskweave_mr2hd'],
287288
help='model name. select from [moment_detr, qd_detr, eatr, cg_detr, uvcom, tr_detr, taskweave_hd2mr, taskweave_mr2hd]')
288289
parser.add_argument('--dataset', '-d', type=str, required=True,
289-
choices=['activitynet', 'charades', 'qvhighlight', 'qvhighlight_pretrain', 'tacos', 'tvsum', 'youtube_highlight', 'clotho-moment'],
290-
help='dataset name. select from [activitynet, charades, qvhighlight, qvhighlight_pretrain, tacos, tvsum, youtube_highlight, clotho-moment]')
290+
choices=['activitynet', 'charades', 'qvhighlight', 'qvhighlight_pretrain', 'tacos', 'tvsum', 'youtube_highlight', 'clotho-moment', 'castella'],
291+
help='dataset name. select from [activitynet, charades, qvhighlight, qvhighlight_pretrain, tacos, tvsum, youtube_highlight, clotho-moment, castella]')
291292
parser.add_argument('--feature', '-f', type=str, required=True,
292293
choices=['resnet_glove', 'clip', 'clip_slowfast', 'clip_slowfast_pann', 'i3d_clip', 'clap'],
293294
help='feature name. select from [resnet_glove, clip, clip_slowfast, clip_slowfast_pann, i3d_clip, clap].'

0 commit comments

Comments
 (0)