diff --git a/.gitignore b/.gitignore index 6606d1085..85d4ed089 100644 --- a/.gitignore +++ b/.gitignore @@ -157,5 +157,8 @@ dmypy.json # Log folders **/log/ +# Datasets folders +**/eeg_data/ + # Mac OS .DS_Store \ No newline at end of file diff --git a/benchmarks/MOABB/Commands_run_experiment.txt b/benchmarks/MOABB/Commands_run_experiment.txt new file mode 100644 index 000000000..45f46f541 --- /dev/null +++ b/benchmarks/MOABB/Commands_run_experiment.txt @@ -0,0 +1,3 @@ +python run_experiments.py --hparams=hparams/MotorImagery/BNCI2014001/EEGNet.yaml --data_folder=./eeg_data --output_folder=./results/test_run --nsbj=9 --nsess=2 --seed=12346 --nruns=1 --train_mode=leave-one-session-out + +python run_sweep.py --hparams hparams/MotorImagery/BNCI2014001/EEGNet.yaml --sweep_type optuna --n_trials 2 --data_folder eeg_data --output_folder results/htop/test_run5 --cached_data_folder eeg_data/cache --nsbj 9 --nsess 2 --seed 1234 --nruns 1 --eval_metric acc --eval_set test --data_iterator_name leave-one-session-out --device cuda diff --git a/benchmarks/MOABB/dataio/preprocessing.py b/benchmarks/MOABB/dataio/preprocessing.py index 7fd7a9dd3..83a12eca4 100644 --- a/benchmarks/MOABB/dataio/preprocessing.py +++ b/benchmarks/MOABB/dataio/preprocessing.py @@ -12,6 +12,9 @@ from speechbrain.utils.data_pipeline import provides, takes +mne.set_log_level("ERROR") + + @takes("epoch") @provides("epoch") def to_tensor(epoch): @@ -24,39 +27,44 @@ def to_tensor(epoch): cached_create_filter = cache(mne.filter.create_filter) -@takes("epoch", "info", "target_sfreq", "fmin", "fmax") -@provides("epoch", "sfreq", "target_sfreq", "fmin", "fmax") -def bandpass_resample(epoch, info, target_sfreq, fmin, fmax): - """Bandpass filter and resample an epoch.""" - - bandpass = cached_create_filter( - None, - info["sfreq"], - l_freq=fmin, - h_freq=fmax, - method="fir", - fir_design="firwin", - verbose=False, +def bandpass_resample(target_sfreq, fmin, fmax): + @takes( + "epoch", "info", ) + @provides("epoch") + def _bandpass_resample(epoch, info): + """Bandpass filter and resample an epoch.""" - # Check that filter length is reasonable - filter_length = len(bandpass) - len_x = epoch.shape[-1] - if filter_length > len_x: - # TODO: These long filters result in massive performance degradation... Do we - # want to throw an error instead? This usually happens when fmin is used - logging.warning( - "filter_length (%i) is longer than the signal (%i), " - "distortion is likely. Reduce filter length or filter a longer signal.", - filter_length, - len_x, + bandpass = cached_create_filter( + None, + info["sfreq"], + l_freq=fmin, + h_freq=fmax, + method="fir", + fir_design="firwin", + verbose=False, ) - yield mne.filter.resample( - epoch, - up=target_sfreq, - down=info["sfreq"], - method="polyphase", - window=bandpass, - ) - yield target_sfreq + # Check that filter length is reasonable + filter_length = len(bandpass) + len_x = epoch.shape[-1] + if filter_length > len_x: + # TODO: These long filters result in massive performance degradation... Do we + # want to throw an error instead? This usually happens when fmin is used + logging.warning( + "filter_length (%i) is longer than the signal (%i), " + "distortion is likely. Reduce filter length or filter a longer signal.", + filter_length, + len_x, + ) + + yield mne.filter.resample( + epoch, + up=target_sfreq, + down=info["sfreq"], + method="polyphase", + window=bandpass, + ) + yield target_sfreq + + return _bandpass_resample diff --git a/benchmarks/MOABB/hparams/MotorImagery/BNCI2014001/EEGNet.yaml b/benchmarks/MOABB/hparams/MotorImagery/BNCI2014001/EEGNet.yaml index 7dca191ac..e5ae96424 100644 --- a/benchmarks/MOABB/hparams/MotorImagery/BNCI2014001/EEGNet.yaml +++ b/benchmarks/MOABB/hparams/MotorImagery/BNCI2014001/EEGNet.yaml @@ -1,6 +1,9 @@ seed: 1234 __set_torchseed: !apply:torch.manual_seed [!ref ] +#OVERRIDES +num_workers: 4 + # DIRECTORIES data_folder: !PLACEHOLDER #'/path/to/dataset'. The dataset will be automatically downloaded in this folder cached_data_folder: !PLACEHOLDER #'path/to/pickled/dataset' @@ -8,7 +11,7 @@ output_folder: !PLACEHOLDER #'path/to/results' # DATASET HPARS # Defining the MOABB dataset. -dataset: !new:moabb.datasets.BNCI2014001 +dataset: !new:moabb.datasets.BNCI2014_001 save_prepared_dataset: True # set to True if you want to save the prepared dataset as a pkl file to load and use afterwards data_iterator_name: !PLACEHOLDER target_subject_idx: !PLACEHOLDER @@ -17,7 +20,7 @@ events_to_load: null # all events will be loaded original_sample_rate: 250 # Original sampling rate provided by dataset authors sample_rate: 125 # Target sampling rate (Hz) # band-pass filtering cut-off frequencies -fmin: 0.13 # @orion_step1: --fmin~"uniform(0.1, 5, precision=2)" +fmin: 1.0 # @orion_step1: --fmin~"uniform(0.1, 5, precision=2)" # note undefined when under .5 fmax: 46.0 # @orion_step1: --fmax~"uniform(20.0, 50.0, precision=3)" n_classes: 4 # tmin, tmax respect to stimulus onset that define the interval attribute of the dataset class @@ -39,6 +42,53 @@ C: 22 test_with: 'last' # 'last' or 'best' test_key: "acc" # Possible opts: "loss", "f1", "auc", "acc" +# DATASET +# ─── Subject extraction helpers ────────────────────────────────────────────── +# 1) Grab the whole subject_list from the BNCI2014001 object +subject_list: !apply:getattr # → dataset.subject_list + - !ref # first arg = the object + - subject_list # second arg = attribute name + +# 2) Pick the single subject we want with operator.getitem(list, idx) +target_subject: !apply:operator.getitem + - !ref # the list + - !ref # the integer index supplied on CLI + +# Get target subject +#target_subject: # TODD + +# Create the subjects list +subjects: [!ref ] + +# Create dataset using EpochedEEGDataset +#dataset_class: !new:dataio.datasets.EpochedEEGDataset + +json_path: !apply:os.path.join [!ref , "index.json"] +save_path: !ref +# dynamic items list +bandpass_resample: !apply:dataio.preprocessing.bandpass_resample + target_sfreq: !ref + fmin: !ref + fmax: !ref + +dynamic_items: + - !ref + - !name:dataio.preprocessing.to_tensor +output_keys: ["label", "subject", "session", "epoch"] +preload: True + +EEG_dataset: !apply:dataio.datasets.EpochedEEGDataset.from_moabb + dataset: !ref + json_path: !ref + subjects: !ref + save_path: !ref + dynamic_items: !ref + output_keys: !ref + preload: !ref + tmin: !ref + tmax: !ref + + # METRICS f1: !name:sklearn.metrics.f1_score average: 'macro' @@ -52,7 +102,7 @@ metrics: n_train_examples: 100 # it will be replaced in the train script # checkpoints to average avg_models: 10 # @orion_step1: --avg_models~"uniform(1, 15,discrete=True)" -number_of_epochs: 862 # @orion_step1: --number_of_epochs~"uniform(250, 1000, discrete=True)" +number_of_epochs: 10 # @orion_step1: --number_of_epochs~"uniform(250, 1000, discrete=True)" lr: 0.0001 # @orion_step1: --lr~"choices([0.01, 0.005, 0.001, 0.0005, 0.0001])" # Learning rate scheduling (cyclic learning rate is used here) max_lr: !ref # Upper bound of the cycle (max value of the lr) @@ -165,3 +215,29 @@ model: !new:models.EEGNet.EEGNet dense_max_norm: !ref dropout: !ref dense_n_neurons: !ref + +# Search Space +# the search space is defined as a dictionary of parameter names and a dictionary of possible values +# the values can be sampled from a uniform distribution, a discrete uniform distribution or a choice of values +search_space: + fmin: + type: uniform + min: 0.1 + max: 5.0 + precision: 2 + dropout: + type: uniform + min: 0.0 + max: 0.5 + precision: 3 + cnn_temporal_kernels: + type: discrete_uniform + min: 4 + max: 64 + batch_size_exponent: + type: discrete_uniform + min: 4 + max: 6 + lr: + type: choice + values: [0.01, 0.005, 0.001, 0.0005, 0.0001] diff --git a/benchmarks/MOABB/models/EEGNet.py b/benchmarks/MOABB/models/EEGNet.py index dcb62d47e..b7e840922 100644 --- a/benchmarks/MOABB/models/EEGNet.py +++ b/benchmarks/MOABB/models/EEGNet.py @@ -240,6 +240,7 @@ def forward(self, x): x : torch.Tensor (batch, time, EEG channel, channel) Input to convolve. 4d tensors are expected. """ + x = x.transpose(1, 2) x = self.conv_module(x) x = self.dense_module(x) return x diff --git a/benchmarks/MOABB/run_experiments.py b/benchmarks/MOABB/run_experiments.py new file mode 100644 index 000000000..9c498065a --- /dev/null +++ b/benchmarks/MOABB/run_experiments.py @@ -0,0 +1,224 @@ +""" +Script to run leave-one-subject-out and/or leave-one-session-out training, optionally with multiple seeds. +This script loops over the different subjects and sessions and trains different models. +At the end, the final performance is computed with the aggregate_results.py script that provides the average performance. + +Usage: +python run_experiments.py --hparams=hparams/MotorImagery/BNCI2014001/EEGNet.yaml --data_folder=eeg_data \ +--output_folder=results/MotorImagery/BNCI2014001/EEGNet --nsbj=9 --nsess=2 --seed=1986 --nruns=2 --number_of_epochs=10 + + +Authors +------- +Victor Cruz, 2025 +""" + +import sys +import subprocess +from pathlib import Path +import argparse +import random + +# import logging +# from typing import Optional +import string + + +from train import load_hparams_and_prepare_data, run_experiment + + +class ExperimentRunner: + """Manages multiple MOABB experiment runs.""" + + def __init__(self, args: list): + self.args = self.validate_args(args) + self.setup_experiment() + + def validate_args(self, args) -> argparse.Namespace: + """Validate and parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Run multiple MOABB experiments" + ) + + # Required arguments + parser.add_argument( + "--hparams", required=True, help="Path to hyperparameter file" + ) + parser.add_argument( + "--data_folder", required=True, help="Path to data directory" + ) + parser.add_argument( + "--output_folder", required=True, help="Path to output directory" + ) + parser.add_argument( + "--nsbj", type=int, required=True, help="Number of subjects" + ) + parser.add_argument( + "--nsess", type=int, required=True, help="Number of sessions" + ) + + # Optional arguments + parser.add_argument("--cached_data_folder", help="Path to cached data") + parser.add_argument("--seed", type=int, help="Random seed") + parser.add_argument( + "--nruns", type=int, default=1, help="Number of runs" + ) + parser.add_argument( + "--eval_metric", default="acc", help="Evaluation metric (acc, f1)" + ) + parser.add_argument( + "--eval_set", + default="test", + choices=["test", "dev"], + help="Evaluation set", + ) + parser.add_argument( + "--train_mode", + default="leave-one-session-out", + choices=["leave-one-session-out", "leave-one-subject-out"], + help="Training mode", + ) + parser.add_argument( + "--rnd_dir", + type=bool, + default=False, + help="Use random directory name", + ) + parser.add_argument( + "--dry_run", + type=bool, + default=False, + help="Validate setup without running", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device to run on (cuda or cpu)", + ) + + args = parser.parse_args(args) + + # Validate arguments + if args.eval_set == "dev": + args.metric_file = "valid_metrics.pkl" + else: + args.metric_file = "test_metrics.pkl" + + if args.seed is None: + args.seed = random.randint(0, 99999) + + if not args.cached_data_folder: + args.cached_data_folder = str(Path(args.data_folder) / "cache") + + return args + + def setup_experiment(self): + """Setup experiment directories and logging.""" + # Setup random directory if requested + if self.args.rnd_dir: + rnd_name = "".join(random.choices(string.ascii_letters, k=6)) + self.args.output_folder = str( + Path(self.args.output_folder) / rnd_name + ) + + # Create directories + Path(self.args.output_folder).mkdir(parents=True, exist_ok=True) + Path(self.args.data_folder).mkdir(parents=True, exist_ok=True) + Path(self.args.cached_data_folder).mkdir(parents=True, exist_ok=True) + + # Save configuration + self.save_configuration() + + def save_configuration(self): + """Save experiment configuration.""" + config_file = Path(self.args.output_folder) / "flags.txt" + with open(config_file, "w") as f: + for key, value in vars(self.args).items(): + f.write(f"{key}: {value}\n") + + def run_experiment(self, target_session_idx: int, output_folder_exp: Path): + """Run experiments for all subjects.""" + # try: + for target_subject_idx in range(self.args.nsbj): + print(f"Subject {target_subject_idx}") + + # create overrides + overrides = { + "seed": self.args.seed, + "data_folder": self.args.data_folder, + "cached_data_folder": self.args.cached_data_folder, + "output_folder": str(output_folder_exp), + "target_subject_idx": target_subject_idx, + "target_session_idx": target_session_idx, + "data_iterator_name": self.args.train_mode, + } + + # Create run_opts (empty for now, add parameters if needed) + run_opts = {"device": self.args.device} + + # Load hyperparameters and prepare data + hparams, datasets = load_hparams_and_prepare_data( + self.args.hparams, run_opts, overrides + ) + + # Run experiment + run_experiment(hparams, run_opts, datasets) + + def parse_results(self, output_folder_exp: Path, run_name: str): + """Parse results for current run.""" + cmd = [ + "python", + "utils/parse_results.py", + str(output_folder_exp), + self.args.metric_file, + self.args.eval_metric, + ] + + with open( + Path(self.args.output_folder) / f"{run_name}_results.txt", "a" + ) as f: + subprocess.run(cmd, stdout=f) + + def aggregate_final_results(self): + """Aggregate results across all runs.""" + cmd = [ + "python", + "utils/aggregate_results.py", + self.args.output_folder, + self.args.eval_metric, + ] + + with open( + Path(self.args.output_folder) / "aggregated_performance.txt", "a" + ) as f: + subprocess.run(cmd, stdout=f) + + def run(self): + """Execute all experiment runs.""" + # try: + for run_idx in range(self.args.nruns): + run_name = f"run{run_idx + 1}" + output_folder_exp = ( + Path(self.args.output_folder) / run_name / str(self.args.seed) + ) + + if self.args.train_mode == "leave-one-subject-out": + self.run_experiment(0, output_folder_exp) + elif self.args.train_mode == "leave-one-session-out": + for sess_idx in range(self.args.nsess): + self.run_experiment(sess_idx, output_folder_exp) + + # Store results + self.parse_results(output_folder_exp, run_name) + + # Update seed for next run + self.args.seed += 1 + + # Final aggregation + self.aggregate_final_results() + + +if __name__ == "__main__": + runner = ExperimentRunner(sys.argv[1:]) + sys.exit(0 if runner.run() else 1) diff --git a/benchmarks/MOABB/run_sweep.py b/benchmarks/MOABB/run_sweep.py new file mode 100644 index 000000000..5264f29a2 --- /dev/null +++ b/benchmarks/MOABB/run_sweep.py @@ -0,0 +1,246 @@ +""" +run_sweep.py +============ + +Generic sweep launcher for the MOABB benchmark. + +Supported sweep engines +----------------------- +* grid Cartesian product of all values in ``search_space``. +* random Random samples from ``search_space``. +* optuna Optuna sampler (default TPE). +* orion Print the Orion CLI space and exit (no runs). + +Design goals +------------ +1. **Zero YAML parsing here.** We *never* call `load_hyperpyyaml` + from this file, so we cannot trigger the dreaded + `'…' is a !PLACEHOLDER and must be replaced` error. +2. The heavy work is delegated to :class:`run_experiments.ExperimentRunner`, + which already loops over subjects/sessions and injects the correct + overrides for every split. +3. All CLI flags that ExperimentRunner expects are forwarded unchanged + from the sweep launcher, so the user can keep using the exact same + command-line interface they use for a single run. + +Author +------ +Victor Cruz +""" + +from __future__ import annotations + +import sys +import json +from typing import Dict, List +import hashlib +import os + +import optuna +import speechbrain as sb +from hyperpyyaml import load_hyperpyyaml + +from run_experiments import ExperimentRunner +from utils.search import ( + generate_grid, + get_optuna_space, + get_orion_space, + load_search_space_only, + sample_random, +) + +# ----------------------------------------------------------------------------- + + +def parse_top_level_cli(argv: List[str]) -> tuple[str, Dict, Dict]: + """ + Parse the CLI the same way SpeechBrain does, but keep the + ``overrides`` dictionary so we can forward it. + """ + hparams_file, run_opts, overrides = sb.parse_arguments(argv) + + # Make sure overrides is a *dict* + if isinstance(overrides, str): + overrides = load_hyperpyyaml(overrides) + overrides = dict(overrides or {}) + + return hparams_file, run_opts, overrides + + +def params_hash(params): + return hashlib.md5(str(sorted(params.items())).encode()).hexdigest()[:8] + + +# ----------------------------------------------------------------------------- + + +SB_CLI_KEYS = { + # Required by ExperimentRunner + "data_folder", + "cached_data_folder", + "output_folder", + "nsbj", + "nsess", + # Optional / misc + "seed", + "nruns", + "eval_metric", + "eval_set", + "train_mode", + "rnd_dir", + "dry_run", +} + + +def build_experimentrunner_argv( + hparams_file: str, common_cli: Dict, +) -> List[str]: + """ + Convert a dictionary of CLI options into a flat list of CLI tokens + accepted by ExperimentRunner. + """ + argv = ["--hparams", hparams_file] + for k, v in common_cli.items(): + if k in SB_CLI_KEYS: + argv.extend([f"--{k}", str(v)]) + return argv + + +# ----------------------------------------------------------------------------- + + +def run_single_experiment( + hparams_file: str, common_cli: Dict, hyperparams: Dict, +): + """ + Launch one ExperimentRunner with the supplied hyper-parameters. + + * `common_cli` -Static CLI flags copied from the user invocation. + * `hyperparams` -The hyper-parameters sampled by the sweep engine. + These are **not** CLI flags; they are passed as + overrides to SpeechBrain via the environment + variable ``SB_YAML_OVERRIDES``. + """ + # ExperimentRunner uses CLI only for infrastructure flags; the + # *actual* YAML overrides are injected inside its Python code. + argv = build_experimentrunner_argv(hparams_file, common_cli) + + # Pass the sampled hyper-params to the child process via an env var + # understood by SpeechBrain. (Simplest zero-boilerplate path.) + if hyperparams: + + env = dict(os.environ) + env["SB_YAML_OVERRIDES"] = json.dumps(hyperparams) + ExperimentRunner(argv).run() + else: + ExperimentRunner(argv).run() + + +# ----------------------------------------------------------------------------- + + +def grid_sweep(hparams_file: str, common_cli: Dict): + space = load_search_space_only(hparams_file) + for params in generate_grid(space): + run_single_experiment(hparams_file, common_cli, params) + + +def random_sweep(hparams_file: str, common_cli: Dict, n_samples: int): + space = load_search_space_only(hparams_file) + for params in sample_random(space, n_samples): + run_single_experiment(hparams_file, common_cli, params) + + +def optuna_sweep(hparams_file: str, common_cli: Dict, n_trials: int): + space = load_search_space_only(hparams_file) + optuna_space = get_optuna_space(space) + + def objective(trial: optuna.trial.Trial): + params = {} + for k, spec in optuna_space.items(): + if spec[0] == "suggest_float": + params[k] = trial.suggest_float(k, spec[1], spec[2]) + elif spec[0] == "suggest_int": + params[k] = trial.suggest_int(k, spec[1], spec[2]) + elif spec[0] == "suggest_categorical": + params[k] = trial.suggest_categorical(k, spec[1]) + trial_id = params_hash(params) + common_cli_trial = dict(common_cli) + + # Run the experiment; replace with metric parsing if desired. + common_cli_trial[ + "output_folder" + ] = f"{common_cli['output_folder']}/trial-{trial_id}" + run_single_experiment(hparams_file, common_cli_trial, params) + metrics_path = os.path.join( + common_cli_trial["output_folder"], "aggregated_performance.txt" + ) # Or your own metric file + acc = None + with open(metrics_path, "r") as f: + for line in f: + tokens = line.strip().split() + if tokens and tokens[0].lower() == "acc": + # preferred: take the value right after “avg:” + if "avg:" in tokens: + avg_idx = tokens.index("avg:") + 1 + acc = float(tokens[avg_idx]) + else: # fallback to the first number + acc = float(tokens[1].lstrip("[").rstrip("]")) + break + if acc is None: + raise RuntimeError( + f"Could not find 'acc' in {metrics_path}. " + "Check that parse_results / aggregate_results ran correctly." + ) + return acc + + study = optuna.create_study(direction="maximize") + study.optimize(objective, n_trials=n_trials) + + +def print_orion_space(hparams_file: str): + space = load_search_space_only(hparams_file) + print(" ".join(get_orion_space(space).values())) + + +# ----------------------------------------------------------------------------- + + +def main(): + hparams_file, run_opts, overrides = parse_top_level_cli(sys.argv[1:]) + + # ----------------------------------------------------------------- + # Decide which sweep engine to use + sweep_type = overrides.pop("sweep_type", "grid") + n_samples = int(overrides.pop("n_samples", 20)) + n_trials = int(overrides.pop("n_trials", 20)) + # ----------------------------------------------------------------- + + if sweep_type not in {"grid", "random", "optuna", "orion"}: + raise ValueError(f"Unknown sweep type: {sweep_type}") + + # ----------------------------------------------------------------- + # Build *common* CLI flags that every ExperimentRunner call needs + # ----------------------------------------------------------------- + common_cli = {k: v for k, v in overrides.items() if k in SB_CLI_KEYS} + + # Sensible defaults + common_cli.setdefault("train_mode", "leave-one-session-out") + common_cli.setdefault("nsbj", 1) + common_cli.setdefault("nsess", 1) + + # ----------------------------------------------------------------- + # Dispatch + # ----------------------------------------------------------------- + if sweep_type == "grid": + grid_sweep(hparams_file, common_cli) + elif sweep_type == "random": + random_sweep(hparams_file, common_cli, n_samples) + elif sweep_type == "optuna": + optuna_sweep(hparams_file, common_cli, n_trials) + else: # orion + print_orion_space(hparams_file) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/MOABB/train.py b/benchmarks/MOABB/train.py index 236319f50..62407595d 100644 --- a/benchmarks/MOABB/train.py +++ b/benchmarks/MOABB/train.py @@ -1,42 +1,181 @@ """ -This script implements raining neural networks to decode single EEG trials using various paradigms on MOABB datasets. -For a list of supported datasets and paradigms, please refer to the official documentation at http://moabb.neurotechx.com/docs/api.html. +Training neural networks for MOABB datasets using the new data loading system. -To run training (e.g., architecture: EEGNet; dataset: BNCI2014001) for a specific subject, recording session and training strategy: - > python train.py hparams/MotorImagery/BNCI2014001/EEGNet.yaml --data_folder=eeg_data --cached_data_folder=eeg_pickled_data --target_subject_idx=0 --target_session_idx=0 --data_iterator_name=leave-one-session-out - -Author ------- -Davide Borra, 2022 -Mirco Ravanelli, 2023 +Authors +------- +Victor Cruz, 2025 +(Based on original work by Davide Borra and Mirco Ravanelli) """ import pickle import os import torch from hyperpyyaml import load_hyperpyyaml -from torch.nn import init + import numpy as np import logging import sys -from utils.dataio_iterators import LeaveOneSessionOut, LeaveOneSubjectOut -from torchinfo import summary -import speechbrain as sb import yaml +import speechbrain as sb +from torch.nn import init +import json +from torch.utils.data import random_split -class MOABBBrain(sb.Brain): - """ - This class implements a brain for the MOABB benchmark. +from dataio.splitters import CrossSessionSplitter, CrossSubjectSplitter + + +def prepare_dataset(hparams): + """Create and preprocess dataset using new data loading system.""" + + dataset = hparams["EEG_dataset"] + # 1) Create and update label encoder with all raw labels from the dataset + label_encoder = sb.dataio.encoder.CategoricalEncoder() + label_encoder.update_from_didataset(dataset, "label") + + # 2) Define a small helper function that calls the encoder + def encode_label_func(raw_label): + # This returns a Tensor containing the encoded label + return label_encoder.encode_label_torch(raw_label) + + # 3) Add a dynamic item that calls our helper function + dataset.add_dynamic_item( + encode_label_func, takes=["label"], provides="encoded_label", + ) + + # 4) Change the dataset output keys to produce encoded_label instead of raw "label" + # (You can keep "label" too if you want both.) + dataset.set_output_keys(["encoded_label", "subject", "session", "epoch"]) + + return dataset + + +def prepare_splits(hparams, dataset): + """Create train/valid/test splits using new splitter system.""" + + # Create appropriate splitter + if hparams["data_iterator_name"] == "leave-one-session-out": + splitter = CrossSessionSplitter(dataset, leave_k_out=1) + elif hparams["data_iterator_name"] == "leave-one-subject-out": + splitter = CrossSubjectSplitter(dataset, leave_k_out=1) + else: + raise ValueError(f"Unknown split type: {hparams['data_iterator_name']}") + + # Get specific split based on session index + split = list(splitter)[hparams["target_session_idx"]] + train_dataset = split["train"] + total_len = len(train_dataset) + val_ratio = hparams["valid_ratio"] + train_len = int(total_len * (1 - val_ratio)) + val_len = total_len - train_len + + generator = torch.Generator().manual_seed(hparams["seed"]) + train_subset, valid_subset = random_split( + train_dataset, [train_len, val_len], generator=generator + ) + num_workers = hparams["num_workers"] + + if num_workers is None: + num_workers = torch.get_num_threads() - 1 + + # Create dataloaders + train_loader = torch.utils.data.DataLoader( + train_subset, + batch_size=hparams["batch_size"], + shuffle=True, + num_workers=num_workers, + ) + valid_loader = torch.utils.data.DataLoader( + valid_subset, batch_size=hparams["batch_size"], num_workers=num_workers + ) + test_loader = torch.utils.data.DataLoader( + split["test"], batch_size=hparams["batch_size"], num_workers=num_workers + ) + + return {"train": train_loader, "valid": valid_loader, "test": test_loader} + + +def load_hparams_and_prepare_data(hparams_file, run_opts, overrides): + """Load hyperparameters and prepare datasets.""" + if "SB_YAML_OVERRIDES" in os.environ: + overrides = dict(overrides) # make a copy + overrides.update(json.loads(os.environ["SB_YAML_OVERRIDES"])) + # Initial hparams load + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) + + # Prepare dataset + dataset = prepare_dataset(hparams) + + # Update overrides based on actual data shape + example_batch = next(iter(dataset)) + + overrides.update( + T=example_batch["epoch"].shape[1], # Time dimension + C=example_batch["epoch"].shape[0], # Channel dimension + n_train_examples=len(dataset), + ) + + # Reload hparams with shape information + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) + + # Create splits + datasets = prepare_splits(hparams, dataset) + + # Setup experiment directory + hparams["exp_dir"] = os.path.join( + hparams["output_folder"], + hparams["data_iterator_name"], + f"sub-{hparams['target_subject_idx']:03d}", + f"sess-{hparams['target_session_idx']:03d}", + ) + + # Create experiment directory and save config + sb.create_experiment_directory( + experiment_directory=hparams["exp_dir"], + hyperparams_to_save=hparams_file, + overrides=overrides, + ) + + return hparams, datasets + + +def perform_evaluation(brain, hparams, datasets, dataset_key="test"): + """This function perform the evaluation stage on a dataset and save the performance metrics in a pickle file""" + brain.log_test_as_valid = dataset_key == "valid" + + min_key, max_key = None, None + if hparams["test_key"] == "loss": + min_key = hparams["test_key"] + else: + max_key = hparams["test_key"] + # perform evaluation + brain.evaluate( + datasets[dataset_key], + progressbar=False, + min_key=min_key, + max_key=max_key, + ) + # saving metrics on the desired dataset in a pickle file + metrics_fpath = os.path.join( + hparams["exp_dir"], "{0}_metrics.pkl".format(dataset_key) + ) + with open(metrics_fpath, "wb") as handle: + pickle.dump( + brain.last_eval_stats, handle, protocol=pickle.HIGHEST_PROTOCOL + ) + + +# Keep existing MOABBBrain class and run_experiment function +# Only modify their data handling to work with new dataset format - This class inherits from the Brain class in SpeechBrain. - The Brain class is the main class that handles training, validation, - testing, and checkpointing. - """ +class MOABBBrain(sb.Brain): + """Modified Brain class for MOABB experiments with new data format.""" def init_model(self, model): - """Function to initialize neural network modules""" + """Initialize neural network modules""" for mod in model.modules(): if hasattr(mod, "weight"): if not ("Norm" in mod.__class__.__name__): @@ -48,8 +187,13 @@ def init_model(self, model): init.constant_(mod.bias, 0) def compute_forward(self, batch, stage): - "Given an input batch it computes the model output." - inputs = batch[0].to(self.device) + """Given an input batch it computes the model output.""" + # Extract EEG data from batch dictionary + inputs = batch["epoch"].to(self.device) + + # Add channel dimension if needed + if len(inputs.shape) == 3: + inputs = inputs.unsqueeze(-1) # Perform data augmentation if stage == sb.Stage.TRAIN and hasattr(self.hparams, "augment"): @@ -62,11 +206,13 @@ def compute_forward(self, batch, stage): # Normalization if hasattr(self.hparams, "normalize"): inputs = self.hparams.normalize(inputs) + return self.modules.model(inputs) def compute_objectives(self, predictions, batch, stage): - "Given the network predictions and targets computes the loss." - targets = batch[1].to(self.device) + """Compute loss given predictions and targets.""" + # Get labels from batch + targets = batch["encoded_label"].to(self.device) # Target augmentation N_augments = int(predictions.shape[0] / targets.shape[0]) @@ -74,46 +220,32 @@ def compute_objectives(self, predictions, batch, stage): loss = self.hparams.loss( predictions, - targets, + targets.squeeze(-1), weight=torch.FloatTensor(self.hparams.class_weights).to( self.device ), ) + if stage != sb.Stage.TRAIN: # From log to linear predictions tmp_preds = torch.exp(predictions) self.preds.extend(tmp_preds.detach().cpu().numpy()) - self.targets.extend(batch[1].detach().cpu().numpy()) + self.targets.extend(batch["encoded_label"].detach().cpu().numpy()) else: if hasattr(self.hparams, "lr_annealing"): self.hparams.lr_annealing.on_batch_end(self.optimizer) return loss - def on_fit_start(self,): - """Gets called at the beginning of ``fit()``""" - self.init_model(self.hparams.model) - self.init_optimizers() - in_shape = ( - (1,) - + tuple(np.floor(self.hparams.input_shape[1:-1]).astype(int)) - + (1,) - ) - model_summary = summary( - self.hparams.model, input_size=in_shape, device=self.device - ) - with open( - os.path.join(self.hparams.exp_dir, "model.txt"), "w" - ) as text_file: - text_file.write(str(model_summary)) - def on_stage_start(self, stage, epoch=None): - "Gets called when a stage (either training, validation, test) starts." + """Gets called when a stage (either training, validation, test) starts.""" if stage != sb.Stage.TRAIN: self.preds = [] self.targets = [] def on_stage_end(self, stage, stage_loss, epoch=None): """Gets called at the end of a epoch.""" + # Rest of the method remains the same as it handles metrics and checkpointing + # which don't need to change for the new data format if stage == sb.Stage.TRAIN: self.train_loss = stage_loss else: @@ -127,6 +259,8 @@ def on_stage_end(self, stage, stage_loss, epoch=None): self.last_eval_stats[metric_key] = self.hparams.metrics[ metric_key ](y_true=y_true, y_pred=y_pred) + + # ... rest of the method stays the same ... if stage == sb.Stage.VALID: # Learning rate scheduler if hasattr(self.hparams, "lr_annealing"): @@ -253,18 +387,22 @@ def check_if_best( def run_experiment(hparams, run_opts, datasets): - """This function performs a single training (e.g., single cross-validation fold)""" - idx_examples = np.arange(datasets["train"].dataset.tensors[0].shape[0]) + """Run a single experiment with the new data format.""" + # Calculate class weights + train_labels = [batch["encoded_label"] for batch in datasets["train"]] + train_labels = torch.cat(train_labels) + # train_labels = [ label for batch in datasets["train"] for label in batch["label"]] + # unique_labels, label_indices = np.unique(train_labels, return_inverse=True) + # train_labels_tensor = torch.tensor(label_indices) + n_examples_perclass = [ - idx_examples[ - np.where(datasets["train"].dataset.tensors[1] == c)[0] - ].shape[0] - for c in range(hparams["n_classes"]) + (train_labels == c).sum().item() for c in range(hparams["n_classes"]) ] n_examples_perclass = np.array(n_examples_perclass) class_weights = n_examples_perclass.max() / n_examples_perclass hparams["class_weights"] = class_weights + # Setup checkpointer checkpointer = sb.utils.checkpoints.Checkpointer( checkpoints_dir=os.path.join(hparams["exp_dir"], "save"), recoverables={ @@ -272,28 +410,31 @@ def run_experiment(hparams, run_opts, datasets): "counter": hparams["epoch_counter"], }, ) + + # Setup logger hparams["train_logger"] = sb.utils.train_logger.FileTrainLogger( save_file=os.path.join(hparams["exp_dir"], "train_log.txt") ) + + # Log dataset info logger = logging.getLogger(__name__) - logger.info("Experiment directory: {0}".format(hparams["exp_dir"])) - logger.info( - "Input shape: {0}".format( - datasets["train"].dataset.tensors[0].shape[1:] - ) - ) - logger.info( - "Training set avg value: {0}".format( - datasets["train"].dataset.tensors[0].mean() - ) - ) - datasets_summary = "Number of examples: {0} (training), {1} (validation), {2} (test)".format( - datasets["train"].dataset.tensors[0].shape[0], - datasets["valid"].dataset.tensors[0].shape[0], - datasets["test"].dataset.tensors[0].shape[0], + logger.info(f"Experiment directory: {hparams['exp_dir']}") + + # Get example batch for logging + example_batch = next(iter(datasets["train"])) + logger.info(f"Input shape: {example_batch['epoch'].shape[1:]}") + logger.info(f"Training set avg value: {example_batch['epoch'].mean():.3f}") + + # Log dataset sizes + datasets_summary = ( + f"Number of examples: " + f"{len(datasets['train'].dataset)} (training), " + f"{len(datasets['valid'].dataset)} (validation), " + f"{len(datasets['test'].dataset)} (test)" ) logger.info(datasets_summary) + # Create brain and run training brain = MOABBBrain( modules={"model": hparams["model"]}, opt_class=hparams["optimizer"], @@ -301,126 +442,54 @@ def run_experiment(hparams, run_opts, datasets): run_opts=run_opts, checkpointer=checkpointer, ) - # training + # if False: # hparams["dry_run"]: + # try: + # # Test forward pass with a batch + # batch = next(iter(datasets["train"])) + # with torch.no_grad(): + # brain.compute_forward(batch, sb.Stage.TRAIN) + # logger.info("✓ Dry run successful - model forward pass works") + # raise DryRunComplete("Model validation successful") + # except DryRunComplete: + # raise + # except Exception as e: + # logger.error(f"✗ Dry run failed: {str(e)}") + # raise + + # Training brain.fit( epoch_counter=hparams["epoch_counter"], train_set=datasets["train"], valid_set=datasets["valid"], progressbar=False, ) - # evaluation after loading model using specific key + + # Evaluation perform_evaluation(brain, hparams, datasets, dataset_key="test") - # After the first evaluation only 1 checkpoint (best overall or averaged) is stored. - # Setting avg_models to 1 to prevent deleting the checkpoint in subsequent calls of the evaluation stage. brain.hparams.avg_models = 1 perform_evaluation(brain, hparams, datasets, dataset_key="valid") -def perform_evaluation(brain, hparams, datasets, dataset_key="test"): - """This function perform the evaluation stage on a dataset and save the performance metrics in a pickle file""" - brain.log_test_as_valid = dataset_key == "valid" - - min_key, max_key = None, None - if hparams["test_key"] == "loss": - min_key = hparams["test_key"] - else: - max_key = hparams["test_key"] - # perform evaluation - brain.evaluate( - datasets[dataset_key], - progressbar=False, - min_key=min_key, - max_key=max_key, - ) - # saving metrics on the desired dataset in a pickle file - metrics_fpath = os.path.join( - hparams["exp_dir"], "{0}_metrics.pkl".format(dataset_key) - ) - with open(metrics_fpath, "wb") as handle: - pickle.dump( - brain.last_eval_stats, handle, protocol=pickle.HIGHEST_PROTOCOL - ) - - -def prepare_dataset_iterators(hparams): - """Preprocesses the dataset and partitions it into train, valid and test sets.""" - # defining data iterator to use - print("Prepare dataset iterators...") - if hparams["data_iterator_name"] == "leave-one-session-out": - data_iterator = LeaveOneSessionOut( - seed=hparams["seed"] - ) # within-subject and cross-session - elif hparams["data_iterator_name"] == "leave-one-subject-out": - data_iterator = LeaveOneSubjectOut( - seed=hparams["seed"] - ) # cross-subject and cross-session - else: - raise ValueError( - "Unknown data_iterator_name: %s" % hparams["data_iterator_name"] - ) - - tail_path, datasets = data_iterator.prepare( - data_folder=hparams["data_folder"], - dataset=hparams["dataset"], - cached_data_folder=hparams["cached_data_folder"], - batch_size=hparams["batch_size"], - valid_ratio=hparams["valid_ratio"], - target_subject_idx=hparams["target_subject_idx"], - target_session_idx=hparams["target_session_idx"], - events_to_load=hparams["events_to_load"], - original_sample_rate=hparams["original_sample_rate"], - sample_rate=hparams["sample_rate"], - fmin=hparams["fmin"], - fmax=hparams["fmax"], - tmin=hparams["tmin"], - tmax=hparams["tmax"], - save_prepared_dataset=hparams["save_prepared_dataset"], - n_steps_channel_selection=hparams["n_steps_channel_selection"], - seed_nodes=hparams.get("seed_nodes", ["Cz"]), - ) - return tail_path, datasets - - -def load_hparams_and_dataset_iterators(hparams_file, run_opts, overrides): - """Loads the hparams and datasets, injecting appropriate overrides - for the shape of the dataset. - """ - with open(hparams_file) as fin: - hparams = load_hyperpyyaml(fin, overrides) - - tail_path, datasets = prepare_dataset_iterators(hparams) - # override C and T, to be sure that network input shape matches the dataset (e.g., after time cropping or channel sampling) - overrides.update( - T=datasets["train"].dataset.tensors[0].shape[1], - C=datasets["train"].dataset.tensors[0].shape[-2], - n_train_examples=datasets["train"].dataset.tensors[0].shape[0], - ) - - # loading hparams for the each training and evaluation processes - with open(hparams_file) as fin: - hparams = load_hyperpyyaml(fin, overrides) - hparams["exp_dir"] = os.path.join(hparams["output_folder"], tail_path) - - # creating experiment directory - sb.create_experiment_directory( - experiment_directory=hparams["exp_dir"], - hyperparams_to_save=hparams_file, - overrides=overrides, - ) - - return hparams, datasets - - if __name__ == "__main__": argv = sys.argv[1:] + # try: # loading hparams to prepare the dataset and the data iterators hparams_file, run_opts, overrides = sb.core.parse_arguments(argv) overrides = yaml.load( overrides, yaml.SafeLoader ) # Convert overrides to a dict - hparams, datasets = load_hparams_and_dataset_iterators( + hparams, datasets = load_hparams_and_prepare_data( hparams_file, run_opts, overrides ) - + # print("Start Training") # Run training run_experiment(hparams, run_opts, datasets) + # except DryRunComplete: + # print("Dry run successful") + # sys.exit(0) + # except Exception as e: + # print(f"Error during execution: {str(e)}") + # if overrides.get("dry_run", False): + # print("Dry run failed") + # sys.exit(1) + # raise diff --git a/benchmarks/MOABB/utils/search.py b/benchmarks/MOABB/utils/search.py new file mode 100644 index 000000000..a4fb6f5a2 --- /dev/null +++ b/benchmarks/MOABB/utils/search.py @@ -0,0 +1,127 @@ +import yaml +import numpy as np +from itertools import product +import random +from hyperpyyaml import load_hyperpyyaml + + +def load_config_with_search_space(yaml_path): + """Load YAML config and return (base_config, search_space_dict)""" + with open(yaml_path) as f: + config = load_hyperpyyaml(f) # <- THIS IS THE CORRECT WAY + base_config = {k: v for k, v in config.items() if k != "search_space"} + search_space = config.get("search_space", {}) + return base_config, search_space + + +def load_search_space_only(yaml_path): + """Extract only the search_space dict from a YAML file (avoids parsing custom tags).""" + with open(yaml_path) as f: + lines = f.readlines() + in_search_space = False + search_space_lines = [] + for line in lines: + if line.strip().startswith("search_space:"): + in_search_space = True + search_space_lines.append(line) + continue + if in_search_space: + # End if a non-indented line or a new top-level key + if line.startswith((" ", "\t", "-")) or not line.strip(): + search_space_lines.append(line) + else: + break + # Now parse the collected search_space YAML lines + search_space_yaml = "".join(search_space_lines) + search_space = yaml.safe_load(search_space_yaml) + if search_space is None: + return {} + return search_space.get("search_space", {}) + + +def get_optuna_space(search_space): + """ + Convert search space dict to Optuna format. + Returns a dict mapping param names to (suggest_func, *args, **kwargs) + """ + space = {} + for key, spec in search_space.items(): + t = spec["type"] + if t == "uniform": + space[key] = ( + "suggest_float", + spec["min"], + spec["max"], + spec.get("precision", 4), + ) + elif t == "discrete_uniform": + space[key] = ("suggest_int", spec["min"], spec["max"]) + elif t == "choice": + space[key] = ("suggest_categorical", spec["values"]) + else: + raise ValueError(f"Unknown type: {t}") + return space + + +def get_orion_space(search_space): + """ + Convert search space dict to Orion CLI format. + Returns a dict mapping param names to CLI strings. + """ + cli_space = {} + for key, spec in search_space.items(): + t = spec["type"] + if t == "uniform": + cli_space[ + key + ] = f'--{key}~"uniform({spec["min"]}, {spec["max"]}, precision={spec.get("precision",4)})"' + elif t == "discrete_uniform": + cli_space[ + key + ] = f'--{key}~"uniform({spec["min"]}, {spec["max"]}, discrete=True)"' + elif t == "choice": + cli_space[key] = f'--{key}~"choices({spec["values"]})"' + else: + raise ValueError(f"Unknown type: {t}") + return cli_space + + +def generate_grid(search_space): + """Yield all combos as dicts (for grid search).""" + keys, values = [], [] + for k, spec in search_space.items(): + t = spec["type"] + if t == "uniform": + precision = spec.get("precision", 4) + step = 10 ** -precision + vals = np.round( + np.arange(spec["min"], spec["max"] + step, step), precision + ).tolist() + elif t == "discrete_uniform": + vals = list(range(spec["min"], spec["max"] + 1)) + elif t == "choice": + vals = spec["values"] + else: + raise ValueError(f"Unknown type: {t}") + keys.append(k) + values.append(vals) + for combo in product(*values): + yield dict(zip(keys, combo)) + + +def sample_random(search_space, n_samples): + """Yield n_samples random combos.""" + for _ in range(n_samples): + params = {} + for k, spec in search_space.items(): + t = spec["type"] + if t == "uniform": + precision = spec.get("precision", 4) + params[k] = round( + random.uniform(spec["min"], spec["max"]), precision + ) + elif t == "discrete_uniform": + params[k] = random.randint(spec["min"], spec["max"]) + elif t == "choice": + params[k] = random.choice(spec["values"]) + yield params