Skip to content

Commit b6797d6

Browse files
committed
Merge branch 'arb/streamline_experiments' of https://github.com/alexbanwell1/tsml-eval into arb/streamline_experiments
2 parents 21fa740 + 363a8d6 commit b6797d6

File tree

6 files changed

+123
-64
lines changed

6 files changed

+123
-64
lines changed

tsml_eval/evaluation/storage/forecaster_results.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
from sklearn.metrics import mean_absolute_percentage_error, mean_squared_error
55

66
from tsml_eval.evaluation.storage.estimator_results import EstimatorResults
7-
from tsml_eval.utils.results_writing import results_third_line, write_results_to_tsml_format
7+
from tsml_eval.utils.results_writing import (
8+
results_third_line,
9+
write_results_to_tsml_format,
10+
)
811

912

1013
class ForecasterResults(EstimatorResults):
@@ -157,7 +160,7 @@ def save_to_file(self, file_path, full_path=True):
157160
benchmark_time=self.benchmark_time,
158161
memory_usage=self.memory_usage,
159162
)
160-
write_results_to_tsml_format (
163+
write_results_to_tsml_format(
161164
self.predictions,
162165
self.target_labels,
163166
self.estimator_name,
@@ -169,7 +172,7 @@ def save_to_file(self, file_path, full_path=True):
169172
time_unit=self.time_unit,
170173
first_line_comment=self.description,
171174
second_line=self.parameter_info,
172-
third_line=third_line
175+
third_line=third_line,
173176
)
174177

175178
def load_from_file(self, file_path, verify_values=True):

tsml_eval/evaluation/storage/regressor_results.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
)
1111

1212
from tsml_eval.evaluation.storage.estimator_results import EstimatorResults
13-
from tsml_eval.utils.results_writing import regression_results_third_line, write_results_to_tsml_format
13+
from tsml_eval.utils.results_writing import (
14+
regression_results_third_line,
15+
write_results_to_tsml_format,
16+
)
1417

1518

1619
class RegressorResults(EstimatorResults):

tsml_eval/experiments/experiments.py

Lines changed: 103 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ def load_and_run_classification_experiment(
443443
benchmark_time=benchmark_time,
444444
)
445445

446+
446447
def transform_input(
447448
data_transforms,
448449
x_train: np.ndarray,
@@ -459,6 +460,7 @@ def transform_input(
459460
x_test = transform.transform(x_test, y_test)
460461
return x_train, x_test
461462

463+
462464
def cross_validate_train_data(estimator, y_train, X_train):
463465
cv_size = min(10, len(y_train))
464466
start = int(round(time.time() * 1000))
@@ -467,6 +469,7 @@ def cross_validate_train_data(estimator, y_train, X_train):
467469
train_estimate_method = f"{cv_size}F-CV"
468470
return train_preds, train_time, train_estimate_method
469471

472+
470473
class Experiment:
471474
"""Run an experiment and save the results to file.
472475
@@ -522,6 +525,7 @@ class Experiment:
522525
Whether to benchmark the hardware used with a simple function and write the
523526
results. This will typically take ~2 seconds, but is hardware dependent.
524527
"""
528+
525529
def __init__(
526530
self,
527531
estimator,
@@ -547,11 +551,15 @@ def __init__(
547551
)
548552

549553
if not build_test_file and not build_train_file:
550-
warnings.warn("All files exist and not overwriting, skipping.", stacklevel=1)
554+
warnings.warn(
555+
"All files exist and not overwriting, skipping.", stacklevel=1
556+
)
551557
return None
552558

553559
if write_attributes:
554-
attribute_file_path = f"{results_path}/{estimator_name}/Workspace/{dataset_name}/"
560+
attribute_file_path = (
561+
f"{results_path}/{estimator_name}/Workspace/{dataset_name}/"
562+
)
555563
else:
556564
attribute_file_path = None
557565

@@ -571,28 +579,37 @@ def __init__(
571579
else:
572580
self.estimator_name = estimator_name
573581
self.estimator = self.validate_estimator(estimator=estimator)
574-
self.second_comment = str(estimator.get_params()).replace("\n", " ").replace("\r", " ")
582+
self.second_comment = (
583+
str(estimator.get_params()).replace("\n", " ").replace("\r", " ")
584+
)
575585
if attribute_file_path is not None:
576586
estimator_attributes_to_file(
577587
self.estimator, attribute_file_path, max_list_shape=att_max_shape
578588
)
579589

580-
581590
def run_experiment(self):
582591
x_train, y_train, x_test, y_test = self.load_experimental_data()
583592

584593
self.first_comment = (
585594
"Generated by run_experiment on "
586595
f"{datetime.now().strftime('%m/%d/%Y, %H:%M:%S')}"
587596
)
588-
589-
x_train, x_test = transform_input(data_transforms=self.data_transforms, x_train=x_train, x_test=x_test,y_train=y_train, y_test=y_test)
597+
598+
x_train, x_test = transform_input(
599+
data_transforms=self.data_transforms,
600+
x_train=x_train,
601+
x_test=x_test,
602+
y_train=y_train,
603+
y_test=y_test,
604+
)
590605
if self.benchmark_time:
591606
self.benchmark = timing_benchmark(random_state=self.resample_id)
592-
607+
593608
if self.build_train_file:
594609
train_preds, train_time = self.generate_train_preds(x_train, y_train)
595-
self.write_results("TRAIN", y_train, train_preds, train_time, -1, self.benchmark, -1)
610+
self.write_results(
611+
"TRAIN", y_train, train_preds, train_time, -1, self.benchmark, -1
612+
)
596613

597614
if self.build_test_file:
598615
if self.needs_fit():
@@ -605,28 +622,37 @@ def run_experiment(self):
605622
fit_time += int(round(getattr(self.estimator, "_fit_time_milli", 0)))
606623
test_preds, test_time = self.generate_test_preds(x_test, y_test)
607624
test_time += int(round(getattr(self.estimator, "_predict_time_milli", 0)))
608-
self.write_results("TEST", y_test, test_preds, fit_time, test_time, self.benchmark, mem_usage)
625+
self.write_results(
626+
"TEST",
627+
y_test,
628+
test_preds,
629+
fit_time,
630+
test_time,
631+
self.benchmark,
632+
mem_usage,
633+
)
609634

610635
def load_experimental_data(self):
611636
return None, None, None, None
612637

613638
def validate_estimator(self, estimator):
614639
estimator
615640

616-
617641
def generate_train_preds(self, X_train, y_train):
618642
return time_function(self.estimator.fit_predict, (X_train, y_train))
619-
643+
620644
def generate_test_preds(self, x_test, y_test):
621645
return time_function(self.estimator.predict, x_test)
622646

623-
624647
def needs_fit(self):
625648
return False
626649

627-
628-
def write_results(self, split, y, preds,fit_time, predict_time, benchmark_time, memory_usage):
629-
third_line = self.get_third_line(y, preds,fit_time, predict_time, benchmark_time, memory_usage)
650+
def write_results(
651+
self, split, y, preds, fit_time, predict_time, benchmark_time, memory_usage
652+
):
653+
third_line = self.get_third_line(
654+
y, preds, fit_time, predict_time, benchmark_time, memory_usage
655+
)
630656
write_results_to_tsml_format(
631657
preds,
632658
y,
@@ -642,16 +668,20 @@ def write_results(self, split, y, preds,fit_time, predict_time, benchmark_time,
642668
second_line=self.second_comment,
643669
third_line=third_line,
644670
)
645-
def get_third_line(self, y, preds, fit_time, predict_time, benchmark_time, memory_usage):
671+
672+
def get_third_line(
673+
self, y, preds, fit_time, predict_time, benchmark_time, memory_usage
674+
):
646675
return results_third_line(
647-
y=y,
648-
preds=preds,
649-
fit_time=fit_time,
650-
predict_time=predict_time,
651-
benchmark_time=benchmark_time,
652-
memory_usage=memory_usage,
653-
)
654-
676+
y=y,
677+
preds=preds,
678+
fit_time=fit_time,
679+
predict_time=predict_time,
680+
benchmark_time=benchmark_time,
681+
memory_usage=memory_usage,
682+
)
683+
684+
655685
class ForecastingExperiment(Experiment):
656686
def __init__(self):
657687
pass
@@ -677,21 +707,25 @@ def generate_test_preds(self, x_test, y_test):
677707
def validate_estimator(self, estimator):
678708
return validate_forecaster(estimator)
679709

710+
680711
class RegressionExperiment(Experiment):
681712
def __init__(
682-
self,
683-
ignore_custom_train_estimate=False,
684-
predefined_resample = False,
685-
problem_path="",
686-
):
713+
self,
714+
ignore_custom_train_estimate=False,
715+
predefined_resample=False,
716+
problem_path="",
717+
):
687718
self.is_fitted = False
688719
self.ignore_custom_train_estimate = ignore_custom_train_estimate
689720
self.problem_path = problem_path
690721
self.predefined_resample = predefined_resample
691722

692723
def load_experimental_data(self):
693724
X_train, y_train, X_test, y_test, resample = load_experiment_data(
694-
self.problem_path, self.dataset_name, self.resample_id, self.predefined_resample
725+
self.problem_path,
726+
self.dataset_name,
727+
self.resample_id,
728+
self.predefined_resample,
695729
)
696730

697731
if resample:
@@ -703,30 +737,38 @@ def load_experimental_data(self):
703737
def generate_train_preds(self, X_train, y_train):
704738
if self.estimate_train_data and not self.ignore_custom_train_estimate:
705739
self.train_estimate_method = "Custom"
706-
train_preds, train_time = time_function(self.estimator.fit_predict, (X_train, y_train))
740+
train_preds, train_time = time_function(
741+
self.estimator.fit_predict, (X_train, y_train)
742+
)
707743
self.is_fitted = True
708744
else:
709-
train_preds, train_time, self.train_estimate_method = cross_validate_train_data(self.estimator,y_train,X_train)
745+
train_preds, train_time, self.train_estimate_method = (
746+
cross_validate_train_data(self.estimator, y_train, X_train)
747+
)
710748
return train_preds, train_time
711749

712750
def needs_fit(self):
713751
return not self.is_fitted
714752

715-
def get_third_line(self, y, preds, fit_time, predict_time, benchmark_time, memory_usage):
753+
def get_third_line(
754+
self, y, preds, fit_time, predict_time, benchmark_time, memory_usage
755+
):
716756
return regression_results_third_line(
717-
y=y,
718-
preds=preds,
719-
fit_time=fit_time,
720-
predict_time=predict_time,
721-
benchmark_time=benchmark_time,
722-
memory_usage=memory_usage,
723-
train_estimate_method=self.train_estimate_method,
724-
)
757+
y=y,
758+
preds=preds,
759+
fit_time=fit_time,
760+
predict_time=predict_time,
761+
benchmark_time=benchmark_time,
762+
memory_usage=memory_usage,
763+
train_estimate_method=self.train_estimate_method,
764+
)
765+
725766
def validate_estimator(self, estimator):
726767
estimator, estimate_train_data = validate_regressor(estimator)
727768
self.estimate_train_data = estimate_train_data
728769
return estimator
729-
770+
771+
730772
def validate_forecaster(estimator):
731773
if isinstance(estimator, BaseForecaster):
732774
return estimator
@@ -735,28 +777,34 @@ def validate_forecaster(estimator):
735777
estimator, _ = validate_regressor(estimator)
736778
return RegressionForecaster(regressor=estimator)
737779
except TypeError:
738-
raise TypeError("forecaster must be an aeon forecaster or a tsml, aeon or sklearn regressor.")
780+
raise TypeError(
781+
"forecaster must be an aeon forecaster or a tsml, aeon or sklearn regressor."
782+
)
783+
739784

740785
def validate_regressor(estimator):
741786
estimate_train_data = False
742787
if isinstance(estimator, BaseRegressor):
743-
if estimator.get_tag(
744-
"capability:train_estimate", False, False
745-
):
788+
if estimator.get_tag("capability:train_estimate", False, False):
746789
estimate_train_data = True
747790
return estimator, estimate_train_data
748791
elif isinstance(estimator, BaseTimeSeriesEstimator) and is_regressor(estimator):
749792
return estimator, estimate_train_data
750793
elif isinstance(estimator, BaseEstimator) and is_regressor(estimator):
751-
return SklearnToTsmlRegressor(
752-
regressor=estimator,
753-
pad_unequal=True,
754-
concatenate_channels=True,
755-
clone_estimator=False,
756-
random_state=(
757-
estimator.random_state if hasattr(estimator, "random_state") else None
794+
return (
795+
SklearnToTsmlRegressor(
796+
regressor=estimator,
797+
pad_unequal=True,
798+
concatenate_channels=True,
799+
clone_estimator=False,
800+
random_state=(
801+
estimator.random_state
802+
if hasattr(estimator, "random_state")
803+
else None
804+
),
758805
),
759-
), estimate_train_data
806+
estimate_train_data,
807+
)
760808
else:
761809
raise TypeError("regressor must be a tsml, aeon or sklearn regressor.")
762810

@@ -1207,6 +1255,7 @@ def load_and_run_clustering_experiment(
12071255
benchmark_time=benchmark_time,
12081256
)
12091257

1258+
12101259
def run_forecasting_experiment(
12111260
train,
12121261
y_test,
@@ -1258,7 +1307,6 @@ def run_forecasting_experiment(
12581307
pass
12591308

12601309

1261-
12621310
def load_and_run_forecasting_experiment(
12631311
problem_path,
12641312
results_path,
@@ -1301,7 +1349,6 @@ def load_and_run_forecasting_experiment(
13011349
If set to False, this will only build results if there is not a result file
13021350
already present. If True, it will overwrite anything already there.
13031351
"""
1304-
13051352
tmpdir = tempfile.mkdtemp()
13061353
dataset = load_forecasting(dataset, tmpdir)
13071354
series = (

tsml_eval/utils/functions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
]
99

1010
import time
11+
1112
import numpy as np
1213

1314

@@ -102,7 +103,8 @@ def rank_array(arr, higher_better=True):
102103

103104
return ranks
104105

106+
105107
def time_function(function, args=None, kwargs=None):
106108
start = int(round(time.time() * 1000))
107109
output = function(*args, **kwargs)
108-
return int(round(time.time() * 1000)) - start, output
110+
return int(round(time.time() * 1000)) - start, output

0 commit comments

Comments
 (0)