Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions chebai/loss/bce_weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def __init__(
self.data_extractor = data_extractor

assert (
isinstance(beta, float) and beta > 0.0
), f"Beta parameter must be a float with value greater than 0.0, for loss class {self.__class__.__name__}."
isinstance(beta, float) and beta >= 0.0 and beta <= 1.0
), f"Beta parameter must be a float with value between 0 and 1, for loss class {self.__class__.__name__}."

assert (
self.data_extractor is not None
Expand Down Expand Up @@ -63,13 +63,16 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
print(
f"Computing loss-weights based on v{self.data_extractor.chebi_version} dataset (beta={self.beta})"
)
print(f"loading: {self.data_extractor.processed_file_names[0]}")
complete_labels = torch.concat(
[
torch.stack(
[
torch.Tensor(row["labels"])
for row in self.data_extractor.load_processed_data(
filename=file_name
filename=os.path.join(
self.data_extractor.processed_dir, file_name
)
Comment on lines +73 to +75
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sfluegel05
Passing complete file path causes following issue: https://wandb.ai/chebai/chebai/runs/pjql8wjk/logs?nw=nwuseraditya0by0

Passing just the file name causes the below issue
https://wandb.ai/chebai/chebai/runs/m5rgurmi/logs?nw=nwuseraditya0by0

)
]
)
Expand Down
2 changes: 1 addition & 1 deletion chebai/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def _execute(
loss_kwargs = dict()
if self.pass_loss_kwargs:
loss_kwargs = loss_kwargs_candidates
loss_kwargs["current_epoch"] = self.trainer.current_epoch
loss_kwargs["current_epoch"] = self.trainer.current_epoch
loss = self.criterion(loss_data, loss_labels, **loss_kwargs)
if isinstance(loss, tuple):
unnamed_loss_index = 1
Expand Down
2 changes: 1 addition & 1 deletion configs/loss/bce_weighted.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
class_path: chebai.loss.bce_weighted.BCEWeighted
init_args:
beta: 0.99
beta: 0.99 # this is the default weight, change this factor to increase/decrease the weighting effect
6 changes: 0 additions & 6 deletions configs/loss/weighting_chebi100.yml

This file was deleted.

File renamed without changes.
13 changes: 0 additions & 13 deletions configs/metrics/micro-macro-f1-roc-auc-17.yml

This file was deleted.

22 changes: 0 additions & 22 deletions configs/metrics/micro-macro-f1-roc-auc-17_test.yml

This file was deleted.

13 changes: 0 additions & 13 deletions configs/metrics/micro-macro-f1-roc-auc-2.yml

This file was deleted.

13 changes: 0 additions & 13 deletions configs/metrics/micro-macro-f1-roc-auc-27.yml

This file was deleted.

2 changes: 0 additions & 2 deletions configs/metrics/micro-macro-f1-roc-auc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,3 @@ init_args:
class_path: chebai.callbacks.epoch_metrics.MacroF1
roc-auc:
class_path: torchmetrics.classification.MultilabelAUROC
init_args:
num_labels: 12
26 changes: 1 addition & 25 deletions configs/training/binary_callbacks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,6 @@
filename: 'best_f1_{epoch:02d}_{val_loss:.4f}_{val_f1:.4f}_{val_roc-auc:.4f}'
every_n_epochs: 1
save_top_k: 1
# - class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint
# init_args:
# monitor: val_loss
# mode: 'min'
# filename: 'best_loss_{epoch:02d}_{val_loss:.4f}_{val_f1:.4f}_{val_roc-auc:.4f}'
# every_n_epochs: 1
# save_top_k: 1
- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint
init_args:
monitor: val_roc-auc
Expand All @@ -23,21 +16,4 @@
init_args:
filename: 'per_{epoch:02d}_{val_loss:.4f}_{val_f1:.4f}_{val_roc-auc:.4f}'
every_n_epochs: 25
save_top_k: 1

# - class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping
# init_args:
# monitor: "val_roc-auc"
# min_delta: 0.0
# patience: 5
# verbose: False
# mode: "max"


# - class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping
# init_args:
# monitor: "val_loss_epoch"
# min_delta: 0.0
# patience: 10
# verbose: False
# mode: "min"
save_top_k: -1
29 changes: 4 additions & 25 deletions configs/training/default_callbacks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,11 @@
init_args:
monitor: val_micro-f1
mode: 'max'
filename: 'best_micro_f1_{epoch:02d}_{val_loss:.4f}_{val_macro-f1:.4f}_{val_micro-f1:.4f}_{val_roc-auc:.4f}'
filename: 'best_micro_f1_{epoch:02d}_{val_loss:.4f}_{val_macro-f1:.4f}_{val_micro-f1:.4f}'
every_n_epochs: 1
save_top_k: 1
# - class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint
# init_args:
# monitor: val_loss
# mode: 'min'
# filename: 'best_loss_{epoch:02d}_{val_loss:.4f}_{val_macro-f1:.4f}_{val_micro-f1:.4f}_{val_roc-auc:.4f}'
# every_n_epochs: 1
# save_top_k: 1
save_top_k: 3
- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint
init_args:
monitor: val_roc-auc
mode: 'max'
filename: 'best_roc-auc_{epoch:02d}_{val_loss:.4f}_{val_macro-f1:.4f}_{val_micro-f1:.4f}_{val_roc-auc:.4f}'
every_n_epochs: 1
save_top_k: 1
# - class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping
# init_args:
# monitor: "val_roc-auc"
# min_delta: 0.0
# patience: 5
# verbose: False
# mode: "max"
- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint
init_args:
filename: 'per_{epoch:02d}_{val_loss:.4f}_{val_macro-f1:.4f}_{val_micro-f1:.4f}_{val_roc-auc:.4f}'
filename: 'per_{epoch:02d}_{val_loss:.4f}_{val_macro-f1:.4f}_{val_micro-f1:.4f}'
every_n_epochs: 25
save_top_k: 1
save_top_k: -1
2 changes: 1 addition & 1 deletion configs/training/default_trainer.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
min_epochs: 20
min_epochs: 100
max_epochs: 100
default_root_dir: &default_root_dir logs
logger: csv_logger.yml
Expand Down
19 changes: 0 additions & 19 deletions configs/training/early_stop_callbacks_regression.yml

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
# - class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint
# init_args:
# monitor: val_loss
# mode: 'min'
# filename: 'best_loss_{epoch:02d}_{val_loss:.4f}_{val_mse:.4f}_{val_rmse:.4f}_{val_r2:.4f}'
# every_n_epochs: 1
# save_top_k: 1
- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint
init_args:
monitor: val_r2
Expand All @@ -19,16 +12,8 @@
filename: 'best_rmse_{epoch:02d}_{val_loss:.4f}_{val_mse:.4f}_{val_rmse:.4f}_{val_r2:.4f}'
every_n_epochs: 1
save_top_k: 1
# - class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping
# init_args:
# monitor: "val_rmse"
# min_delta: 0.0
# patience: 5
# verbose: False
# mode: "min"

- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint
init_args:
filename: 'per_{epoch:02d}_{val_loss:.4f}_{val_mse:.4f}_{val_rmse:.4f}_{val_r2:.4f}'
every_n_epochs: 25
save_top_k: 1
save_top_k: -1
13 changes: 0 additions & 13 deletions configs/training/single_class_callbacks.yml

This file was deleted.

3 changes: 0 additions & 3 deletions configs/weightings/chebi100.yml

This file was deleted.

Loading