Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
9963a0b
re-add file and stuff
isaacmg Nov 7, 2022
6208bec
Revert "re-add file and stuff"
isaacmg Nov 7, 2022
9a5807a
test and re-add
isaacmg Nov 7, 2022
7a97ba0
add stuff
isaacmg Nov 8, 2022
3e243e5
fixing stuff 2
isaacmg Nov 8, 2022
9ff1cc5
a
isaacmg Nov 14, 2022
e1980c0
f
isaacmg Nov 14, 2022
f4872c5
changing config 1uration file
isaacmg Nov 14, 2022
30ba665
remove duplicates 1
isaacmg Nov 15, 2022
5603396
fixing solar small and linter 3
isaacmg Nov 15, 2022
c704f99
fixing stupid shit 3
isaacmg Nov 15, 2022
cc52dfa
reset forcibly index
isaacmg Nov 15, 2022
328d8c3
f-u pandas index shit
isaacmg Nov 15, 2022
2146c2e
STOP THE FUCKING REDIN3DEX BULLSHIT PANDAS
isaacmg Nov 15, 2022
673b067
R lint
isaacmg Nov 15, 2022
d70e713
b as 2
isaacmg Nov 15, 2022
cb5948c
fuck little copying and pasting FUUUUUUUUU
isaacmg Nov 15, 2022
4447c87
Revert "fuck little copying and pasting FUUUUUUUUU"
isaacmg Nov 15, 2022
35c8dc1
DEEPCOPY 1
isaacmg Nov 15, 2022
cbefc32
Revert "DEEPCOPY 1"
isaacmg Nov 15, 2022
47e81f8
a
isaacmg Nov 15, 2022
1cacf07
equalize length and shit 2
isaacmg Nov 15, 2022
cf5800a
raise a value error 1
isaacmg Nov 15, 2022
52e6c73
Merge branch 'master' into series_id_finish_loader
isaacmg Dec 21, 2022
e897896
add proper cross_former params
isaacmg Apr 28, 2023
a7dbe27
Update setuptools requirement from ~=67.7.0 to ~=67.7.2
dependabot[bot] Apr 28, 2023
136f802
Merge pull request #665 from AIStream-Peelout/dependabot/pip/setuptoo…
isaacmg Apr 28, 2023
d04402a
Bump wandb from 0.15.0 to 0.15.1
dependabot[bot] May 3, 2023
758b9d6
Merge pull request #667 from AIStream-Peelout/dependabot/pip/wandb-0.…
isaacmg May 4, 2023
728c492
annoying gh 1
isaacmg May 5, 2023
c014264
Merge pull request #668 from AIStream-Peelout/fixing-_readme_badge_2
isaacmg May 5, 2023
31bc278
Revert "annoying gh 1"
isaacmg May 5, 2023
83fa987
fxing 2.09
isaacmg May 5, 2023
ccd377b
Revert "Merge pull request #668 from AIStream-Peelout/fixing-_readme_…
isaacmg May 5, 2023
11a06c7
q
isaacmg May 5, 2023
c598254
Revert "q"
isaacmg May 5, 2023
0e7ba5b
Merge branch 'master' into fixing-_readme_badge_2
isaacmg May 5, 2023
a728de0
a
isaacmg May 5, 2023
0f7e768
Revert "a"
isaacmg May 5, 2023
defc41c
fixing 2
isaacmg May 5, 2023
d2e134d
Revert "fixing 2"
isaacmg May 5, 2023
014413c
staff 2
isaacmg May 5, 2023
2fdefed
Merge pull request #670 from AIStream-Peelout/fixing-_readme_badge_2
isaacmg May 5, 2023
07377d9
Bump wandb from 0.15.1 to 0.15.2
dependabot[bot] May 8, 2023
b2f4955
Merge pull request #671 from AIStream-Peelout/dependabot/pip/wandb-0.…
isaacmg May 13, 2023
217d7da
Bump wandb from 0.15.2 to 0.15.4
dependabot[bot] Jun 7, 2023
78ff95e
Update plotly requirement from ~=5.14.1 to ~=5.15.0
dependabot[bot] Jun 9, 2023
262d879
Update setuptools requirement from ~=67.7.2 to ~=68.0.0
dependabot[bot] Jun 20, 2023
184124c
Merge pull request #679 from AIStream-Peelout/dependabot/pip/plotly-a…
isaacmg Jun 21, 2023
1d6b5bb
Merge pull request #681 from AIStream-Peelout/dependabot/pip/setuptoo…
isaacmg Jun 21, 2023
5d513b2
Merge pull request #677 from AIStream-Peelout/dependabot/pip/wandb-0.…
isaacmg Jun 21, 2023
bd1a3ed
Merge branch 'master' into series_id_finish_loader
isaacmg Jul 4, 2023
1ca7a8b
chang 23 3
isaacmg Jul 8, 2023
a99841b
ra
isaacmg Jul 8, 2023
ceff4df
see if the mixed format works
isaacmg Jul 10, 2023
dc24c8f
Revert "see if the mixed format works"
isaacmg Jul 10, 2023
ef028c4
fixing 2
isaacmg Jul 10, 2023
bbd3e53
fixing 2
isaacmg Jul 13, 2023
9b752fb
d
isaacmg Jul 14, 2023
28dd911
2
isaacmg Jul 14, 2023
9ad4986
more shit
isaacmg Jul 14, 2023
ab14eb0
r
isaacmg Jul 14, 2023
725195a
move 2 9
isaacmg Jul 14, 2023
37de122
series id more loader
isaacmg Jul 16, 2023
e5f465c
fixing to correct shape
isaacmg Jul 18, 2023
ac73d6d
remove weird shit
isaacmg Jul 19, 2023
8307c94
fixing code 1 2
isaacmg Jul 19, 2023
82b0230
r
isaacmg Jul 20, 2023
f35f482
e
isaacmg Jul 21, 2023
b644997
r r
isaacmg Jul 21, 2023
590b720
fixing the transformer loss
isaacmg Jul 23, 2023
2cb1999
Revert "fixing the transformer loss"
isaacmg Jul 23, 2023
57e3555
fixing computing of the loss
isaacmg Jul 23, 2023
a33b221
Revert "fixing computing of the loss"
isaacmg Jul 23, 2023
e7bd68f
fixing stuff
isaacmg Jul 23, 2023
ed61ca6
fixing code to run
isaacmg Jul 24, 2023
acbf448
adding fixes to path + increment
isaacmg Jul 24, 2023
3968720
Revert "adding fixes to path + increment"
isaacmg Jul 24, 2023
13f1a65
fixing stuff
isaacmg Jul 24, 2023
3e782f4
remove print debugging 3 4 5
isaacmg Jul 24, 2023
6dea9bd
fixing the tests more 2
isaacmg Jul 24, 2023
ae1955d
fixing the unit tests
isaacmg Jul 25, 2023
a2f9d5c
fixing 4
isaacmg Jul 25, 2023
90b5710
fixing the code 4
isaacmg Aug 1, 2023
5d45451
FUCK OFF LINTER FUCK LINTER
isaacmg Aug 1, 2023
ca4ef37
fixing kind of works here
isaacmg Aug 1, 2023
c3931b4
fixing r
isaacmg Aug 1, 2023
b38c43d
r
isaacmg Aug 2, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,8 @@ jobs:
name: Trainer tests
when: always
command: |
coverage run flood_forecast/trainer.py -p tests/gru_vanilla.json
echo -e 'GRU Vanilla test'
coverage run flood_forecast/trainer.py -p tests/gru_vanilla.json
coverage run flood_forecast/trainer.py -p tests/classification_test.json
coverage run flood_forecast/trainer.py -p tests/test_inf_single.json
echo -e 'test informer single target'
Expand Down Expand Up @@ -324,6 +324,7 @@ jobs:
name: Trainer1 tests
when: always
command: |
coverage run flood_forecast/trainer.py -p tests/transformer_b_series.json
coverage run flood_forecast/trainer.py -p tests/cross_former.json
coverage run flood_forecast/trainer.py -p tests/nlinear.json
coverage run flood_forecast/trainer.py -p tests/dsanet_3.json
Expand Down
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[flake8]
max_line_length=121
ignore=E305,W504,E126,E401
ignore=E305,W504,E126,E401,E721
max-complexity=19
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Deep learning for time series forecasting
![Example image](https://raw.githubusercontent.com/CoronaWhy/task-ts/master/images/Picture1.png)
Flow Forecast (FF) is an open-source deep learning for time series forecasting framework. It provides all the latest state of the art models (transformers, attention models, GRUs) and cutting edge concepts with easy to understand interpretability metrics, cloud provider integration, and model serving capabilities. Flow Forecast was the first time series framework to feature support for transformer based models and remains the only true end-to-end deep learning for time series forecasting framework. Currently, [Task-TS from CoronaWhy](https://github.com/CoronaWhy/task-ts/wiki) primarily maintains this repository. Pull requests are welcome. Historically, this repository provided open source benchmark and codes for flash flood and river flow forecasting.
Flow Forecast (FF) is an open-source deep learning for time series forecasting framework. It provides all the latest state of the art models (transformers, attention models, GRUs, ODEs) and cutting edge concepts with easy to understand interpretability metrics, cloud provider integration, and model serving capabilities. Flow Forecast was the first time series framework to feature support for transformer based models and remains the only true end-to-end deep learning for time series framework. Currently, [Task-TS from CoronaWhy](https://github.com/CoronaWhy/task-ts/wiki) primarily maintains this repository. Pull requests are welcome. Historically, this repository provided open source benchmark and codes for flash flood and river flow forecasting.

For additional tutorials (on Colab) and examples please see our [tutorials repository](https://github.com/AIStream-Peelout/flow_tutorials).
For additional tutorials and examples please see our [tutorials repository](https://github.com/AIStream-Peelout/flow_tutorials).

| branch | status |
| --- | --- |
| master | [![CircleCI](https://circleci.com/gh/AIStream-Peelout/flow-forecast.svg?style=svg&circle-token=f7be0a4863710165969ba0903fa471f08a347df1)](https://circleci.com/gh/AIStream-Peelout/flow-forecast) |
| master | [![CircleCI](https://dl.circleci.com/status-badge/img/gh/AIStream-Peelout/flow-forecast/tree/master.svg?style=svg)](https://dl.circleci.com/status-badge/redirect/gh/AIStream-Peelout/flow-forecast/tree/master) | |
| Build PY| ![Upload Python Package](https://github.com/AIStream-Peelout/flow-forecast/workflows/Upload%20Python%20Package/badge.svg)|
| Documentation | [![Documentation Status](https://readthedocs.org/projects/flow-forecast/badge/?version=latest)](https://flow-forecast.readthedocs.io/en/latest/)|
| CodeCov| [![codecov](https://codecov.io/gh/AIStream-Peelout/flow-forecast/branch/master/graph/badge.svg)](https://codecov.io/gh/AIStream-Peelout/flow-forecast)|
Expand Down
10 changes: 9 additions & 1 deletion flood_forecast/basic/base_line_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ class NaiveBase(torch.nn.Module):
"""
A very simple baseline model that returns
the fixed value based on the input sequence.
No learning used at all a
No learning used at all.
"""

def __init__(self, seq_length: int, n_time_series: int, output_seq_len=1, metric: str = "last"):
Expand All @@ -19,6 +19,14 @@ def __init__(self, seq_length: int, n_time_series: int, output_seq_len=1, metric
self.metric_function = self.metric_dict[metric]

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""_summary_

Args:
x (torch.Tensor): _description_

Returns:
torch.Tensor: _description_
"""
return self.metric_function(x, self.output_seq_len)


Expand Down
4 changes: 2 additions & 2 deletions flood_forecast/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def infer_on_torch_model(
) -> Tuple[pd.DataFrame, torch.Tensor, int, int, CSVTestLoader, List[pd.DataFrame]]:
"""
Function to handle both test evaluation and inference on a test data-frame.
:param model: The time series model present
:param model: The time series model present in the model zoo
:param test_csv_path: The path to the test data-frame
:return:
df: df including training and test data
Expand Down Expand Up @@ -339,7 +339,7 @@ def handle_ci_multi(prediction_samples: torch.Tensor, csv_test_loader: CSVTestLo
:type df_pred: [type]
:param decoder_param: [description]
:type decoder_param: bool
:param history_length: [description]
:param history_length: The number of historical time-steps
:type history_length: int
:param num_samples: The number of samples to generate (i.e. larger ci)
:type num_samples: int
Expand Down
2 changes: 1 addition & 1 deletion flood_forecast/model_dict_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


"""
Utility dictionaries to map a string to a class
Utility dictionaries to map a string to a c class
"""
pytorch_model_dict = {
"MultiAttnHeadSimple": MultiAttnHeadSimple,
Expand Down
2 changes: 1 addition & 1 deletion flood_forecast/pre_dict.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from sklearn.preprocessing import StandardScaler, RobustScaler, MinMaxScaler, MaxAbsScaler
from flood_forecast.preprocessing.interpolate_preprocess import (interpolate_missing_values,
back_forward_generic, forward_back_generic)

# SAMMY IS TOO LITTLE TO BE REAL DOG
scaler_dict = {
"StandardScaler": StandardScaler,
"RobustScaler": RobustScaler,
Expand Down
4 changes: 3 additions & 1 deletion flood_forecast/preprocessing/process_usgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@


def make_usgs_data(start_date: datetime, end_date: datetime, site_number: str) -> pd.DataFrame:
""""""
"""

"""
base_url = "https://nwis.waterdata.usgs.gov/usa/nwis/uv/?cb_00060=on&cb_00065&format=rdb&"
full_url = base_url + "site_no=" + site_number + "&period=&begin_date=" + \
start_date.strftime("%Y-%m-%d") + "&end_date=" + end_date.strftime("%Y-%m-%d")
Expand Down
51 changes: 42 additions & 9 deletions flood_forecast/preprocessing/pytorch_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from flood_forecast.preprocessing.buil_dataset import get_data
from datetime import datetime
from flood_forecast.preprocessing.temporal_feats import feature_fix
from copy import deepcopy


class CSVDataLoader(Dataset):
Expand Down Expand Up @@ -71,6 +72,8 @@ def __init__(
self.scale = None
if scaled_cols is None:
scaled_cols = relevant_cols
print("scaled cols are")
print(scaled_cols)
if start_stamp != 0 and end_stamp is not None:
self.df = self.df[start_stamp:end_stamp]
elif start_stamp != 0:
Expand Down Expand Up @@ -168,21 +171,44 @@ def __init__(self, series_id_col: str, main_params: dict, return_method: str, re
:param return_all: Whether to return all items, defaults to True
:type return_all: bool, optional
"""
main_params["relevant_cols"].append(series_id_col)
super().__init__(**main_params)
main_params1 = deepcopy(main_params)
if "scaled_cols" not in main_params1:
main_params1["scaled_cols"] = main_params1["relevant_cols"].copy()
print("The scaled cols are below")
print(main_params1["scaled_cols"])
main_params1["relevant_cols"].append(series_id_col)
super().__init__(**main_params1)
self.series_id_col = series_id_col
self.return_method = return_method
self.return_all_series = return_all
self.unique_cols = self.original_df[series_id_col].dropna().unique().tolist()
df_list = []
self.df = self.df.reset_index()
self.unique_dict = {}
print("The series id column is below:")
print(self.series_id_col)
for col in self.unique_cols:
df_list.append(self.df[self.df[self.series_id_col] == col])
new_df = self.df[self.df[self.series_id_col] == col]
df_list.append(new_df)
print(new_df.columns)
self.listed_vals = df_list
self.__make_unique_dict__()
self.__validate_data__in_df()
print(self.unique_dict)
print("unique dict")

def __validate_data__in_df(self):
"""Makes sure the data in the data-frame is the proper length for each series e
"""
if self.return_all_series:
len_first = len(self.listed_vals[0])
print("Length of first series is:" + str(len_first))
for series in self.listed_vals:
print("Length of first series is:" + str(len(series)))
series_bool = len(series) == len_first
if not series_bool:
raise IndexError("The length of sub-series data-frames are not equal.")

def __make_unique_dict__(self):
for i in range(0, len(self.unique_cols)):
self.unique_dict[self.unique_cols[i]] = i
Expand All @@ -198,12 +224,13 @@ def __getitem__(self, idx: int) -> Tuple[Dict, Dict]:
if self.return_all_series:
src_list = {}
targ_list = {}
print(self.unique_cols)
for va in self.listed_vals:
t = torch.Tensor(va.iloc[idx: self.forecast_history + idx].values)[:, :len(self.relevant_cols3) - 1]
# We need to exclude the index column on one end and the series id column on the other
t = torch.Tensor(va.iloc[idx: self.forecast_history + idx].values)[:, 1:-1]
print(t.shape)
targ_start_idx = idx + self.forecast_history
idx2 = va[self.series_id_col].iloc[0]
targ = torch.Tensor(va.iloc[targ_start_idx: targ_start_idx + self.forecast_length].to_numpy())
targ = torch.Tensor(va.iloc[targ_start_idx: targ_start_idx + self.forecast_length].to_numpy())[:, 1:-1]
src_list[self.unique_dict[idx2]] = t
targ_list[self.unique_dict[idx2]] = targ
return src_list, targ_list
Expand All @@ -214,6 +241,12 @@ def __getitem__(self, idx: int) -> Tuple[Dict, Dict]:
def __sample_series_id__(idx, series_id):
pass

def __len__(self) -> int:
if self.return_all_series:
return len(self.listed_vals[0]) - self.forecast_history - self.forecast_length - 1
else:
raise NotImplementedError("Current code only supports returning all the series at each iteration")


class CSVTestLoader(CSVDataLoader):
def __init__(
Expand Down Expand Up @@ -334,7 +367,7 @@ def __init__(

:param file_path: The path to the file
:type file_path: str
:param relevant_cols: d
:param relevant_cols: The relevant columns
:type relevant_cols: List
:param scaling: [description], defaults to None
:type scaling: [type], optional
Expand Down Expand Up @@ -488,7 +521,7 @@ def __len__(self) -> int:


class TemporalTestLoader(CSVTestLoader):
def __init__(self, time_feats, kwargs={}, decoder_step_len=None):
def __init__(self, time_feats: List[str], kwargs={}, decoder_step_len=None):
"""A test data-loader class for data in the format of the TemporalLoader.

:param time_feats: The temporal featuers to use in encoding.
Expand Down Expand Up @@ -567,7 +600,7 @@ def __init__(self, series_marker_column: str, csv_loader_params: Dict, pad_lengt
self.grouped_df = self.df.groupby(series_marker_column)
self.n_classes = n_classes

def get_item_forecast(self, idx):
def get_item_forecast(self, idx: int):
pass

def get_item_classification(self, idx: int):
Expand Down
89 changes: 48 additions & 41 deletions flood_forecast/pytorch_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from flood_forecast.basic.linear_regression import simple_decode
from flood_forecast.training_utils import EarlyStopper
from flood_forecast.custom.custom_opt import GaussianLoss, MASELoss
from flood_forecast.series_id_helper import handle_csv_id_output, handle_csv_id_validation
from torch.nn import CrossEntropyLoss


Expand Down Expand Up @@ -356,7 +357,7 @@ def torch_single_train(model: PyTorchForecast,
:type model: PyTorchForecast
:param opt: The optimizer to use in the code
:type opt: optim.Optimizer
:param criterion: [description]
:param criterion: [m
:type criterion: Type[torch.nn.modules.loss._Loss]
:param data_loader: [description]
:type data_loader: DataLoader
Expand Down Expand Up @@ -405,50 +406,50 @@ def torch_single_train(model: PyTorchForecast,
trg = trg[0]
trg[:, -pred_len:, :] = torch.zeros_like(trg[:, -pred_len:, :].long()).float().to(model.device)
# Assign to avoid other if statement
elif "SeriesIDLoader" == model.params["dataset_params"]["class"]:
pass
src = src.to(model.device)
trg = trg.to(model.device)
output = model.model(src, **forward_params)
if hasattr(model.model, "pred_len"):
multi_targets = mulit_targets_copy
pred_len = model.model.pred_len
output = output[:, :, 0:multi_targets]
labels = trg[:, -pred_len:, 0:multi_targets]
multi_targets = False
if model.params["dataset_params"]["class"] == "GeneralClassificationLoader":
labels = trg
elif multi_targets == 1:
labels = trg[:, :, 0]
elif multi_targets > 1:
labels = trg[:, :, 0:multi_targets]
if probablistic:
output1 = output
output = output.mean
output_std = output1.stddev
if type(criterion) == list:
loss = multi_crit(criterion, output, labels, None)
if "SeriesIDLoader" == model.params["dataset_params"]["class"]:
running_loss += handle_csv_id_output(src, trg, model, criterion, opt, False, multi_targets)
i += 1
else:
loss = compute_loss(labels, output, src, criterion, None, probablistic, output_std, m=multi_targets)
if loss > 100:
print("Warning: high loss detected")
loss.backward()
opt.step()
if torch.isnan(loss) or loss == float('inf'):
raise ValueError("Error infinite or NaN loss detected. Try normalizing data or performing interpolation")
running_loss += loss.item()
i += 1
src = src.to(model.device)
trg = trg.to(model.device)
output = model.model(src, **forward_params)
if hasattr(model.model, "pred_len"):
multi_targets = mulit_targets_copy
pred_len = model.model.pred_len
output = output[:, :, 0:multi_targets]
labels = trg[:, -pred_len:, 0:multi_targets]
multi_targets = False
if model.params["dataset_params"]["class"] == "GeneralClassificationLoader":
labels = trg
elif model.params["dataset_params"]["class"] == "CSVSeriesIDLoader":
labels = trg
elif multi_targets == 1:
labels = trg[:, :, 0]
elif multi_targets > 1:
labels = trg[:, :, 0:multi_targets]
if probablistic:
output1 = output
output = output.mean
output_std = output1.stddev
if type(criterion) == list:
loss = multi_crit(criterion, output, labels, None)
else:
loss = compute_loss(labels, output, src, criterion, None, probablistic, output_std, m=multi_targets)
if loss > 100:
print("Warning: high loss detected")
loss.backward()
opt.step()
if torch.isnan(loss) or loss == float('inf'):
raise ValueError("Error infinite or NaN loss detected. Try normalizing data or performing interpolation")
running_loss += loss.item()
i += 1
print("The running loss is: ")
print(running_loss)
print("The number of items in train is: " + str(i))
total_loss = running_loss / float(i)
return total_loss


def handle_crit_list():
pass


def compute_validation(validation_loader: DataLoader,
model,
epoch: int,
Expand Down Expand Up @@ -509,6 +510,10 @@ def compute_validation(validation_loader: DataLoader,
label_list = []
mod_output_list = []
for src, targ in validation_loader:
if validation_loader.dataset.__class__.__name__ == "CSVSeriesIDLoader":
scaled_crit = handle_csv_id_validation(src, targ, model, criterion, False, multi_targets)
unscaled_crit = {}
continue
src = src if isinstance(src, list) else src.to(device)
targ = targ if isinstance(targ, list) else targ.to(device)
# targ = targ if isinstance(targ, list) else targ.to(device)
Expand Down Expand Up @@ -591,12 +596,14 @@ def compute_validation(validation_loader: DataLoader,
print("Plotting test classification metrics")
label_list = torch.cat(label_list)
label_list = label_list[:, 0, :].detach().cpu()
mod_output1 = torch.cat(mod_output_list)[:, 0, :].detach().cpu().numpy()
mod_output1 = torch.cat(mod_output_list)[:, 0, :].detach().cpu()
d = torch.nn.Softmax(dim=1)
mod_output_final = d(mod_output1).numpy()
fin = label_list.max(dim=1)[1]
wandb.log({"roc_" + str(epoch): wandb.plot.roc_curve(fin, mod_output1, classes_to_plot=None, labels=None,
wandb.log({"roc_" + str(epoch): wandb.plot.roc_curve(fin, mod_output_final, classes_to_plot=None, labels=None,
title="roc_" + str(epoch))})
wandb.log({"pr": wandb.plot.pr_curve(fin, mod_output1)})
wandb.log({"conf_": wandb.plot.confusion_matrix(probs=mod_output1,
wandb.log({"pr": wandb.plot.pr_curve(fin, mod_output_final)})
wandb.log({"conf_": wandb.plot.confusion_matrix(probs=mod_output_final,
y_true=fin.detach().cpu().numpy(), class_names=None)})
model.train()
return list(scaled_crit.values())[0]
Loading