Skip to content

Commit 363a8d6

Browse files
alexbanwell1github-actions[bot]
authored andcommitted
Automatic pre-commit fixes
1 parent 5c582c8 commit 363a8d6

File tree

6 files changed

+123
-64
lines changed

6 files changed

+123
-64
lines changed

tsml_eval/evaluation/storage/forecaster_results.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
from sklearn.metrics import mean_absolute_percentage_error, mean_squared_error
55

66
from tsml_eval.evaluation.storage.estimator_results import EstimatorResults
7-
from tsml_eval.utils.results_writing import results_third_line, write_results_to_tsml_format
7+
from tsml_eval.utils.results_writing import (
8+
results_third_line,
9+
write_results_to_tsml_format,
10+
)
811

912

1013
class ForecasterResults(EstimatorResults):
@@ -157,7 +160,7 @@ def save_to_file(self, file_path, full_path=True):
157160
benchmark_time=self.benchmark_time,
158161
memory_usage=self.memory_usage,
159162
)
160-
write_results_to_tsml_format (
163+
write_results_to_tsml_format(
161164
self.predictions,
162165
self.target_labels,
163166
self.estimator_name,
@@ -169,7 +172,7 @@ def save_to_file(self, file_path, full_path=True):
169172
time_unit=self.time_unit,
170173
first_line_comment=self.description,
171174
second_line=self.parameter_info,
172-
third_line=third_line
175+
third_line=third_line,
173176
)
174177

175178
def load_from_file(self, file_path, verify_values=True):

tsml_eval/evaluation/storage/regressor_results.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
)
1111

1212
from tsml_eval.evaluation.storage.estimator_results import EstimatorResults
13-
from tsml_eval.utils.results_writing import regression_results_third_line, write_results_to_tsml_format
13+
from tsml_eval.utils.results_writing import (
14+
regression_results_third_line,
15+
write_results_to_tsml_format,
16+
)
1417

1518

1619
class RegressorResults(EstimatorResults):

tsml_eval/experiments/experiments.py

Lines changed: 103 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ def load_and_run_classification_experiment(
443443
benchmark_time=benchmark_time,
444444
)
445445

446+
446447
def transform_input(
447448
data_transforms,
448449
x_train: np.ndarray,
@@ -459,6 +460,7 @@ def transform_input(
459460
x_test = transform.transform(x_test, y_test)
460461
return x_train, x_test
461462

463+
462464
def cross_validate_train_data(estimator, y_train, X_train):
463465
cv_size = min(10, len(y_train))
464466
start = int(round(time.time() * 1000))
@@ -467,6 +469,7 @@ def cross_validate_train_data(estimator, y_train, X_train):
467469
train_estimate_method = f"{cv_size}F-CV"
468470
return train_preds, train_time, train_estimate_method
469471

472+
470473
class Experiment:
471474
"""Run an experiment and save the results to file.
472475
@@ -522,6 +525,7 @@ class Experiment:
522525
Whether to benchmark the hardware used with a simple function and write the
523526
results. This will typically take ~2 seconds, but is hardware dependent.
524527
"""
528+
525529
def __init__(
526530
self,
527531
estimator,
@@ -547,11 +551,15 @@ def __init__(
547551
)
548552

549553
if not build_test_file and not build_train_file:
550-
warnings.warn("All files exist and not overwriting, skipping.", stacklevel=1)
554+
warnings.warn(
555+
"All files exist and not overwriting, skipping.", stacklevel=1
556+
)
551557
return None
552558

553559
if write_attributes:
554-
attribute_file_path = f"{results_path}/{estimator_name}/Workspace/{dataset_name}/"
560+
attribute_file_path = (
561+
f"{results_path}/{estimator_name}/Workspace/{dataset_name}/"
562+
)
555563
else:
556564
attribute_file_path = None
557565

@@ -571,28 +579,37 @@ def __init__(
571579
else:
572580
self.estimator_name = estimator_name
573581
self.estimator = self.validate_estimator(estimator=estimator)
574-
self.second_comment = str(estimator.get_params()).replace("\n", " ").replace("\r", " ")
582+
self.second_comment = (
583+
str(estimator.get_params()).replace("\n", " ").replace("\r", " ")
584+
)
575585
if attribute_file_path is not None:
576586
estimator_attributes_to_file(
577587
self.estimator, attribute_file_path, max_list_shape=att_max_shape
578588
)
579589

580-
581590
def run_experiment(self):
582591
x_train, y_train, x_test, y_test = self.load_experimental_data()
583592

584593
self.first_comment = (
585594
"Generated by run_experiment on "
586595
f"{datetime.now().strftime('%m/%d/%Y, %H:%M:%S')}"
587596
)
588-
589-
x_train, x_test = transform_input(data_transforms=self.data_transforms, x_train=x_train, x_test=x_test,y_train=y_train, y_test=y_test)
597+
598+
x_train, x_test = transform_input(
599+
data_transforms=self.data_transforms,
600+
x_train=x_train,
601+
x_test=x_test,
602+
y_train=y_train,
603+
y_test=y_test,
604+
)
590605
if self.benchmark_time:
591606
self.benchmark = timing_benchmark(random_state=self.resample_id)
592-
607+
593608
if self.build_train_file:
594609
train_preds, train_time = self.generate_train_preds(x_train, y_train)
595-
self.write_results("TRAIN", y_train, train_preds, train_time, -1, self.benchmark, -1)
610+
self.write_results(
611+
"TRAIN", y_train, train_preds, train_time, -1, self.benchmark, -1
612+
)
596613

597614
if self.build_test_file:
598615
if self.needs_fit():
@@ -605,28 +622,37 @@ def run_experiment(self):
605622
fit_time += int(round(getattr(self.estimator, "_fit_time_milli", 0)))
606623
test_preds, test_time = self.generate_test_preds(x_test, y_test)
607624
test_time += int(round(getattr(self.estimator, "_predict_time_milli", 0)))
608-
self.write_results("TEST", y_test, test_preds, fit_time, test_time, self.benchmark, mem_usage)
625+
self.write_results(
626+
"TEST",
627+
y_test,
628+
test_preds,
629+
fit_time,
630+
test_time,
631+
self.benchmark,
632+
mem_usage,
633+
)
609634

610635
def load_experimental_data(self):
611636
return None, None, None, None
612637

613638
def validate_estimator(self, estimator):
614639
estimator
615640

616-
617641
def generate_train_preds(self, X_train, y_train):
618642
return time_function(self.estimator.fit_predict, (X_train, y_train))
619-
643+
620644
def generate_test_preds(self, x_test, y_test):
621645
return time_function(self.estimator.predict, x_test)
622646

623-
624647
def needs_fit(self):
625648
return False
626649

627-
628-
def write_results(self, split, y, preds,fit_time, predict_time, benchmark_time, memory_usage):
629-
third_line = self.get_third_line(y, preds,fit_time, predict_time, benchmark_time, memory_usage)
650+
def write_results(
651+
self, split, y, preds, fit_time, predict_time, benchmark_time, memory_usage
652+
):
653+
third_line = self.get_third_line(
654+
y, preds, fit_time, predict_time, benchmark_time, memory_usage
655+
)
630656
write_results_to_tsml_format(
631657
preds,
632658
y,
@@ -642,16 +668,20 @@ def write_results(self, split, y, preds,fit_time, predict_time, benchmark_time,
642668
second_line=self.second_comment,
643669
third_line=third_line,
644670
)
645-
def get_third_line(self, y, preds, fit_time, predict_time, benchmark_time, memory_usage):
671+
672+
def get_third_line(
673+
self, y, preds, fit_time, predict_time, benchmark_time, memory_usage
674+
):
646675
return results_third_line(
647-
y=y,
648-
preds=preds,
649-
fit_time=fit_time,
650-
predict_time=predict_time,
651-
benchmark_time=benchmark_time,
652-
memory_usage=memory_usage,
653-
)
654-
676+
y=y,
677+
preds=preds,
678+
fit_time=fit_time,
679+
predict_time=predict_time,
680+
benchmark_time=benchmark_time,
681+
memory_usage=memory_usage,
682+
)
683+
684+
655685
class ForecastingExperiment(Experiment):
656686
def __init__(self):
657687
pass
@@ -670,21 +700,25 @@ def generate_test_preds(self, x_test, y_test):
670700
def validate_estimator(self, estimator):
671701
return validate_forecaster(estimator)
672702

703+
673704
class RegressionExperiment(Experiment):
674705
def __init__(
675-
self,
676-
ignore_custom_train_estimate=False,
677-
predefined_resample = False,
678-
problem_path="",
679-
):
706+
self,
707+
ignore_custom_train_estimate=False,
708+
predefined_resample=False,
709+
problem_path="",
710+
):
680711
self.is_fitted = False
681712
self.ignore_custom_train_estimate = ignore_custom_train_estimate
682713
self.problem_path = problem_path
683714
self.predefined_resample = predefined_resample
684715

685716
def load_experimental_data(self):
686717
X_train, y_train, X_test, y_test, resample = load_experiment_data(
687-
self.problem_path, self.dataset_name, self.resample_id, self.predefined_resample
718+
self.problem_path,
719+
self.dataset_name,
720+
self.resample_id,
721+
self.predefined_resample,
688722
)
689723

690724
if resample:
@@ -696,30 +730,38 @@ def load_experimental_data(self):
696730
def generate_train_preds(self, X_train, y_train):
697731
if self.estimate_train_data and not self.ignore_custom_train_estimate:
698732
self.train_estimate_method = "Custom"
699-
train_preds, train_time = time_function(self.estimator.fit_predict, (X_train, y_train))
733+
train_preds, train_time = time_function(
734+
self.estimator.fit_predict, (X_train, y_train)
735+
)
700736
self.is_fitted = True
701737
else:
702-
train_preds, train_time, self.train_estimate_method = cross_validate_train_data(self.estimator,y_train,X_train)
738+
train_preds, train_time, self.train_estimate_method = (
739+
cross_validate_train_data(self.estimator, y_train, X_train)
740+
)
703741
return train_preds, train_time
704742

705743
def needs_fit(self):
706744
return not self.is_fitted
707745

708-
def get_third_line(self, y, preds, fit_time, predict_time, benchmark_time, memory_usage):
746+
def get_third_line(
747+
self, y, preds, fit_time, predict_time, benchmark_time, memory_usage
748+
):
709749
return regression_results_third_line(
710-
y=y,
711-
preds=preds,
712-
fit_time=fit_time,
713-
predict_time=predict_time,
714-
benchmark_time=benchmark_time,
715-
memory_usage=memory_usage,
716-
train_estimate_method=self.train_estimate_method,
717-
)
750+
y=y,
751+
preds=preds,
752+
fit_time=fit_time,
753+
predict_time=predict_time,
754+
benchmark_time=benchmark_time,
755+
memory_usage=memory_usage,
756+
train_estimate_method=self.train_estimate_method,
757+
)
758+
718759
def validate_estimator(self, estimator):
719760
estimator, estimate_train_data = validate_regressor(estimator)
720761
self.estimate_train_data = estimate_train_data
721762
return estimator
722-
763+
764+
723765
def validate_forecaster(estimator):
724766
if isinstance(estimator, BaseForecaster):
725767
return estimator
@@ -728,28 +770,34 @@ def validate_forecaster(estimator):
728770
estimator, _ = validate_regressor(estimator)
729771
return RegressionForecaster(regressor=estimator)
730772
except TypeError:
731-
raise TypeError("forecaster must be an aeon forecaster or a tsml, aeon or sklearn regressor.")
773+
raise TypeError(
774+
"forecaster must be an aeon forecaster or a tsml, aeon or sklearn regressor."
775+
)
776+
732777

733778
def validate_regressor(estimator):
734779
estimate_train_data = False
735780
if isinstance(estimator, BaseRegressor):
736-
if estimator.get_tag(
737-
"capability:train_estimate", False, False
738-
):
781+
if estimator.get_tag("capability:train_estimate", False, False):
739782
estimate_train_data = True
740783
return estimator, estimate_train_data
741784
elif isinstance(estimator, BaseTimeSeriesEstimator) and is_regressor(estimator):
742785
return estimator, estimate_train_data
743786
elif isinstance(estimator, BaseEstimator) and is_regressor(estimator):
744-
return SklearnToTsmlRegressor(
745-
regressor=estimator,
746-
pad_unequal=True,
747-
concatenate_channels=True,
748-
clone_estimator=False,
749-
random_state=(
750-
estimator.random_state if hasattr(estimator, "random_state") else None
787+
return (
788+
SklearnToTsmlRegressor(
789+
regressor=estimator,
790+
pad_unequal=True,
791+
concatenate_channels=True,
792+
clone_estimator=False,
793+
random_state=(
794+
estimator.random_state
795+
if hasattr(estimator, "random_state")
796+
else None
797+
),
751798
),
752-
), estimate_train_data
799+
estimate_train_data,
800+
)
753801
else:
754802
raise TypeError("regressor must be a tsml, aeon or sklearn regressor.")
755803

@@ -1200,6 +1248,7 @@ def load_and_run_clustering_experiment(
12001248
benchmark_time=benchmark_time,
12011249
)
12021250

1251+
12031252
def run_forecasting_experiment(
12041253
train,
12051254
y_test,
@@ -1251,7 +1300,6 @@ def run_forecasting_experiment(
12511300
pass
12521301

12531302

1254-
12551303
def load_and_run_forecasting_experiment(
12561304
problem_path,
12571305
results_path,
@@ -1294,7 +1342,6 @@ def load_and_run_forecasting_experiment(
12941342
If set to False, this will only build results if there is not a result file
12951343
already present. If True, it will overwrite anything already there.
12961344
"""
1297-
12981345
tmpdir = tempfile.mkdtemp()
12991346
dataset = load_forecasting(dataset, tmpdir)
13001347
series = (

tsml_eval/utils/functions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
]
99

1010
import time
11+
1112
import numpy as np
1213

1314

@@ -102,7 +103,8 @@ def rank_array(arr, higher_better=True):
102103

103104
return ranks
104105

106+
105107
def time_function(function, args=None, kwargs=None):
106108
start = int(round(time.time() * 1000))
107109
output = function(*args, **kwargs)
108-
return int(round(time.time() * 1000)) - start, output
110+
return int(round(time.time() * 1000)) - start, output

0 commit comments

Comments
 (0)