@@ -443,6 +443,7 @@ def load_and_run_classification_experiment(
443443 benchmark_time = benchmark_time ,
444444 )
445445
446+
446447def transform_input (
447448 data_transforms ,
448449 x_train : np .ndarray ,
@@ -459,6 +460,7 @@ def transform_input(
459460 x_test = transform .transform (x_test , y_test )
460461 return x_train , x_test
461462
463+
462464def cross_validate_train_data (estimator , y_train , X_train ):
463465 cv_size = min (10 , len (y_train ))
464466 start = int (round (time .time () * 1000 ))
@@ -467,6 +469,7 @@ def cross_validate_train_data(estimator, y_train, X_train):
467469 train_estimate_method = f"{ cv_size } F-CV"
468470 return train_preds , train_time , train_estimate_method
469471
472+
470473class Experiment :
471474 """Run an experiment and save the results to file.
472475
@@ -522,6 +525,7 @@ class Experiment:
522525 Whether to benchmark the hardware used with a simple function and write the
523526 results. This will typically take ~2 seconds, but is hardware dependent.
524527 """
528+
525529 def __init__ (
526530 self ,
527531 estimator ,
@@ -547,11 +551,15 @@ def __init__(
547551 )
548552
549553 if not build_test_file and not build_train_file :
550- warnings .warn ("All files exist and not overwriting, skipping." , stacklevel = 1 )
554+ warnings .warn (
555+ "All files exist and not overwriting, skipping." , stacklevel = 1
556+ )
551557 return None
552558
553559 if write_attributes :
554- attribute_file_path = f"{ results_path } /{ estimator_name } /Workspace/{ dataset_name } /"
560+ attribute_file_path = (
561+ f"{ results_path } /{ estimator_name } /Workspace/{ dataset_name } /"
562+ )
555563 else :
556564 attribute_file_path = None
557565
@@ -571,28 +579,37 @@ def __init__(
571579 else :
572580 self .estimator_name = estimator_name
573581 self .estimator = self .validate_estimator (estimator = estimator )
574- self .second_comment = str (estimator .get_params ()).replace ("\n " , " " ).replace ("\r " , " " )
582+ self .second_comment = (
583+ str (estimator .get_params ()).replace ("\n " , " " ).replace ("\r " , " " )
584+ )
575585 if attribute_file_path is not None :
576586 estimator_attributes_to_file (
577587 self .estimator , attribute_file_path , max_list_shape = att_max_shape
578588 )
579589
580-
581590 def run_experiment (self ):
582591 x_train , y_train , x_test , y_test = self .load_experimental_data ()
583592
584593 self .first_comment = (
585594 "Generated by run_experiment on "
586595 f"{ datetime .now ().strftime ('%m/%d/%Y, %H:%M:%S' )} "
587596 )
588-
589- x_train , x_test = transform_input (data_transforms = self .data_transforms , x_train = x_train , x_test = x_test ,y_train = y_train , y_test = y_test )
597+
598+ x_train , x_test = transform_input (
599+ data_transforms = self .data_transforms ,
600+ x_train = x_train ,
601+ x_test = x_test ,
602+ y_train = y_train ,
603+ y_test = y_test ,
604+ )
590605 if self .benchmark_time :
591606 self .benchmark = timing_benchmark (random_state = self .resample_id )
592-
607+
593608 if self .build_train_file :
594609 train_preds , train_time = self .generate_train_preds (x_train , y_train )
595- self .write_results ("TRAIN" , y_train , train_preds , train_time , - 1 , self .benchmark , - 1 )
610+ self .write_results (
611+ "TRAIN" , y_train , train_preds , train_time , - 1 , self .benchmark , - 1
612+ )
596613
597614 if self .build_test_file :
598615 if self .needs_fit ():
@@ -605,28 +622,37 @@ def run_experiment(self):
605622 fit_time += int (round (getattr (self .estimator , "_fit_time_milli" , 0 )))
606623 test_preds , test_time = self .generate_test_preds (x_test , y_test )
607624 test_time += int (round (getattr (self .estimator , "_predict_time_milli" , 0 )))
608- self .write_results ("TEST" , y_test , test_preds , fit_time , test_time , self .benchmark , mem_usage )
625+ self .write_results (
626+ "TEST" ,
627+ y_test ,
628+ test_preds ,
629+ fit_time ,
630+ test_time ,
631+ self .benchmark ,
632+ mem_usage ,
633+ )
609634
610635 def load_experimental_data (self ):
611636 return None , None , None , None
612637
613638 def validate_estimator (self , estimator ):
614639 estimator
615640
616-
617641 def generate_train_preds (self , X_train , y_train ):
618642 return time_function (self .estimator .fit_predict , (X_train , y_train ))
619-
643+
620644 def generate_test_preds (self , x_test , y_test ):
621645 return time_function (self .estimator .predict , x_test )
622646
623-
624647 def needs_fit (self ):
625648 return False
626649
627-
628- def write_results (self , split , y , preds ,fit_time , predict_time , benchmark_time , memory_usage ):
629- third_line = self .get_third_line (y , preds ,fit_time , predict_time , benchmark_time , memory_usage )
650+ def write_results (
651+ self , split , y , preds , fit_time , predict_time , benchmark_time , memory_usage
652+ ):
653+ third_line = self .get_third_line (
654+ y , preds , fit_time , predict_time , benchmark_time , memory_usage
655+ )
630656 write_results_to_tsml_format (
631657 preds ,
632658 y ,
@@ -642,16 +668,20 @@ def write_results(self, split, y, preds,fit_time, predict_time, benchmark_time,
642668 second_line = self .second_comment ,
643669 third_line = third_line ,
644670 )
645- def get_third_line (self , y , preds , fit_time , predict_time , benchmark_time , memory_usage ):
671+
672+ def get_third_line (
673+ self , y , preds , fit_time , predict_time , benchmark_time , memory_usage
674+ ):
646675 return results_third_line (
647- y = y ,
648- preds = preds ,
649- fit_time = fit_time ,
650- predict_time = predict_time ,
651- benchmark_time = benchmark_time ,
652- memory_usage = memory_usage ,
653- )
654-
676+ y = y ,
677+ preds = preds ,
678+ fit_time = fit_time ,
679+ predict_time = predict_time ,
680+ benchmark_time = benchmark_time ,
681+ memory_usage = memory_usage ,
682+ )
683+
684+
655685class ForecastingExperiment (Experiment ):
656686 def __init__ (self ):
657687 pass
@@ -677,21 +707,25 @@ def generate_test_preds(self, x_test, y_test):
677707 def validate_estimator (self , estimator ):
678708 return validate_forecaster (estimator )
679709
710+
680711class RegressionExperiment (Experiment ):
681712 def __init__ (
682- self ,
683- ignore_custom_train_estimate = False ,
684- predefined_resample = False ,
685- problem_path = "" ,
686- ):
713+ self ,
714+ ignore_custom_train_estimate = False ,
715+ predefined_resample = False ,
716+ problem_path = "" ,
717+ ):
687718 self .is_fitted = False
688719 self .ignore_custom_train_estimate = ignore_custom_train_estimate
689720 self .problem_path = problem_path
690721 self .predefined_resample = predefined_resample
691722
692723 def load_experimental_data (self ):
693724 X_train , y_train , X_test , y_test , resample = load_experiment_data (
694- self .problem_path , self .dataset_name , self .resample_id , self .predefined_resample
725+ self .problem_path ,
726+ self .dataset_name ,
727+ self .resample_id ,
728+ self .predefined_resample ,
695729 )
696730
697731 if resample :
@@ -703,30 +737,38 @@ def load_experimental_data(self):
703737 def generate_train_preds (self , X_train , y_train ):
704738 if self .estimate_train_data and not self .ignore_custom_train_estimate :
705739 self .train_estimate_method = "Custom"
706- train_preds , train_time = time_function (self .estimator .fit_predict , (X_train , y_train ))
740+ train_preds , train_time = time_function (
741+ self .estimator .fit_predict , (X_train , y_train )
742+ )
707743 self .is_fitted = True
708744 else :
709- train_preds , train_time , self .train_estimate_method = cross_validate_train_data (self .estimator ,y_train ,X_train )
745+ train_preds , train_time , self .train_estimate_method = (
746+ cross_validate_train_data (self .estimator , y_train , X_train )
747+ )
710748 return train_preds , train_time
711749
712750 def needs_fit (self ):
713751 return not self .is_fitted
714752
715- def get_third_line (self , y , preds , fit_time , predict_time , benchmark_time , memory_usage ):
753+ def get_third_line (
754+ self , y , preds , fit_time , predict_time , benchmark_time , memory_usage
755+ ):
716756 return regression_results_third_line (
717- y = y ,
718- preds = preds ,
719- fit_time = fit_time ,
720- predict_time = predict_time ,
721- benchmark_time = benchmark_time ,
722- memory_usage = memory_usage ,
723- train_estimate_method = self .train_estimate_method ,
724- )
757+ y = y ,
758+ preds = preds ,
759+ fit_time = fit_time ,
760+ predict_time = predict_time ,
761+ benchmark_time = benchmark_time ,
762+ memory_usage = memory_usage ,
763+ train_estimate_method = self .train_estimate_method ,
764+ )
765+
725766 def validate_estimator (self , estimator ):
726767 estimator , estimate_train_data = validate_regressor (estimator )
727768 self .estimate_train_data = estimate_train_data
728769 return estimator
729-
770+
771+
730772def validate_forecaster (estimator ):
731773 if isinstance (estimator , BaseForecaster ):
732774 return estimator
@@ -735,28 +777,34 @@ def validate_forecaster(estimator):
735777 estimator , _ = validate_regressor (estimator )
736778 return RegressionForecaster (regressor = estimator )
737779 except TypeError :
738- raise TypeError ("forecaster must be an aeon forecaster or a tsml, aeon or sklearn regressor." )
780+ raise TypeError (
781+ "forecaster must be an aeon forecaster or a tsml, aeon or sklearn regressor."
782+ )
783+
739784
740785def validate_regressor (estimator ):
741786 estimate_train_data = False
742787 if isinstance (estimator , BaseRegressor ):
743- if estimator .get_tag (
744- "capability:train_estimate" , False , False
745- ):
788+ if estimator .get_tag ("capability:train_estimate" , False , False ):
746789 estimate_train_data = True
747790 return estimator , estimate_train_data
748791 elif isinstance (estimator , BaseTimeSeriesEstimator ) and is_regressor (estimator ):
749792 return estimator , estimate_train_data
750793 elif isinstance (estimator , BaseEstimator ) and is_regressor (estimator ):
751- return SklearnToTsmlRegressor (
752- regressor = estimator ,
753- pad_unequal = True ,
754- concatenate_channels = True ,
755- clone_estimator = False ,
756- random_state = (
757- estimator .random_state if hasattr (estimator , "random_state" ) else None
794+ return (
795+ SklearnToTsmlRegressor (
796+ regressor = estimator ,
797+ pad_unequal = True ,
798+ concatenate_channels = True ,
799+ clone_estimator = False ,
800+ random_state = (
801+ estimator .random_state
802+ if hasattr (estimator , "random_state" )
803+ else None
804+ ),
758805 ),
759- ), estimate_train_data
806+ estimate_train_data ,
807+ )
760808 else :
761809 raise TypeError ("regressor must be a tsml, aeon or sklearn regressor." )
762810
@@ -1207,6 +1255,7 @@ def load_and_run_clustering_experiment(
12071255 benchmark_time = benchmark_time ,
12081256 )
12091257
1258+
12101259def run_forecasting_experiment (
12111260 train ,
12121261 y_test ,
@@ -1258,7 +1307,6 @@ def run_forecasting_experiment(
12581307 pass
12591308
12601309
1261-
12621310def load_and_run_forecasting_experiment (
12631311 problem_path ,
12641312 results_path ,
@@ -1301,7 +1349,6 @@ def load_and_run_forecasting_experiment(
13011349 If set to False, this will only build results if there is not a result file
13021350 already present. If True, it will overwrite anything already there.
13031351 """
1304-
13051352 tmpdir = tempfile .mkdtemp ()
13061353 dataset = load_forecasting (dataset , tmpdir )
13071354 series = (
0 commit comments