@@ -443,6 +443,7 @@ def load_and_run_classification_experiment(
443443 benchmark_time = benchmark_time ,
444444 )
445445
446+
446447def transform_input (
447448 data_transforms ,
448449 x_train : np .ndarray ,
@@ -459,6 +460,7 @@ def transform_input(
459460 x_test = transform .transform (x_test , y_test )
460461 return x_train , x_test
461462
463+
462464def cross_validate_train_data (estimator , y_train , X_train ):
463465 cv_size = min (10 , len (y_train ))
464466 start = int (round (time .time () * 1000 ))
@@ -467,6 +469,7 @@ def cross_validate_train_data(estimator, y_train, X_train):
467469 train_estimate_method = f"{ cv_size } F-CV"
468470 return train_preds , train_time , train_estimate_method
469471
472+
470473class Experiment :
471474 """Run an experiment and save the results to file.
472475
@@ -522,6 +525,7 @@ class Experiment:
522525 Whether to benchmark the hardware used with a simple function and write the
523526 results. This will typically take ~2 seconds, but is hardware dependent.
524527 """
528+
525529 def __init__ (
526530 self ,
527531 estimator ,
@@ -547,11 +551,15 @@ def __init__(
547551 )
548552
549553 if not build_test_file and not build_train_file :
550- warnings .warn ("All files exist and not overwriting, skipping." , stacklevel = 1 )
554+ warnings .warn (
555+ "All files exist and not overwriting, skipping." , stacklevel = 1
556+ )
551557 return None
552558
553559 if write_attributes :
554- attribute_file_path = f"{ results_path } /{ estimator_name } /Workspace/{ dataset_name } /"
560+ attribute_file_path = (
561+ f"{ results_path } /{ estimator_name } /Workspace/{ dataset_name } /"
562+ )
555563 else :
556564 attribute_file_path = None
557565
@@ -571,28 +579,37 @@ def __init__(
571579 else :
572580 self .estimator_name = estimator_name
573581 self .estimator = self .validate_estimator (estimator = estimator )
574- self .second_comment = str (estimator .get_params ()).replace ("\n " , " " ).replace ("\r " , " " )
582+ self .second_comment = (
583+ str (estimator .get_params ()).replace ("\n " , " " ).replace ("\r " , " " )
584+ )
575585 if attribute_file_path is not None :
576586 estimator_attributes_to_file (
577587 self .estimator , attribute_file_path , max_list_shape = att_max_shape
578588 )
579589
580-
581590 def run_experiment (self ):
582591 x_train , y_train , x_test , y_test = self .load_experimental_data ()
583592
584593 self .first_comment = (
585594 "Generated by run_experiment on "
586595 f"{ datetime .now ().strftime ('%m/%d/%Y, %H:%M:%S' )} "
587596 )
588-
589- x_train , x_test = transform_input (data_transforms = self .data_transforms , x_train = x_train , x_test = x_test ,y_train = y_train , y_test = y_test )
597+
598+ x_train , x_test = transform_input (
599+ data_transforms = self .data_transforms ,
600+ x_train = x_train ,
601+ x_test = x_test ,
602+ y_train = y_train ,
603+ y_test = y_test ,
604+ )
590605 if self .benchmark_time :
591606 self .benchmark = timing_benchmark (random_state = self .resample_id )
592-
607+
593608 if self .build_train_file :
594609 train_preds , train_time = self .generate_train_preds (x_train , y_train )
595- self .write_results ("TRAIN" , y_train , train_preds , train_time , - 1 , self .benchmark , - 1 )
610+ self .write_results (
611+ "TRAIN" , y_train , train_preds , train_time , - 1 , self .benchmark , - 1
612+ )
596613
597614 if self .build_test_file :
598615 if self .needs_fit ():
@@ -605,28 +622,37 @@ def run_experiment(self):
605622 fit_time += int (round (getattr (self .estimator , "_fit_time_milli" , 0 )))
606623 test_preds , test_time = self .generate_test_preds (x_test , y_test )
607624 test_time += int (round (getattr (self .estimator , "_predict_time_milli" , 0 )))
608- self .write_results ("TEST" , y_test , test_preds , fit_time , test_time , self .benchmark , mem_usage )
625+ self .write_results (
626+ "TEST" ,
627+ y_test ,
628+ test_preds ,
629+ fit_time ,
630+ test_time ,
631+ self .benchmark ,
632+ mem_usage ,
633+ )
609634
610635 def load_experimental_data (self ):
611636 return None , None , None , None
612637
613638 def validate_estimator (self , estimator ):
614639 estimator
615640
616-
617641 def generate_train_preds (self , X_train , y_train ):
618642 return time_function (self .estimator .fit_predict , (X_train , y_train ))
619-
643+
620644 def generate_test_preds (self , x_test , y_test ):
621645 return time_function (self .estimator .predict , x_test )
622646
623-
624647 def needs_fit (self ):
625648 return False
626649
627-
628- def write_results (self , split , y , preds ,fit_time , predict_time , benchmark_time , memory_usage ):
629- third_line = self .get_third_line (y , preds ,fit_time , predict_time , benchmark_time , memory_usage )
650+ def write_results (
651+ self , split , y , preds , fit_time , predict_time , benchmark_time , memory_usage
652+ ):
653+ third_line = self .get_third_line (
654+ y , preds , fit_time , predict_time , benchmark_time , memory_usage
655+ )
630656 write_results_to_tsml_format (
631657 preds ,
632658 y ,
@@ -642,16 +668,20 @@ def write_results(self, split, y, preds,fit_time, predict_time, benchmark_time,
642668 second_line = self .second_comment ,
643669 third_line = third_line ,
644670 )
645- def get_third_line (self , y , preds , fit_time , predict_time , benchmark_time , memory_usage ):
671+
672+ def get_third_line (
673+ self , y , preds , fit_time , predict_time , benchmark_time , memory_usage
674+ ):
646675 return results_third_line (
647- y = y ,
648- preds = preds ,
649- fit_time = fit_time ,
650- predict_time = predict_time ,
651- benchmark_time = benchmark_time ,
652- memory_usage = memory_usage ,
653- )
654-
676+ y = y ,
677+ preds = preds ,
678+ fit_time = fit_time ,
679+ predict_time = predict_time ,
680+ benchmark_time = benchmark_time ,
681+ memory_usage = memory_usage ,
682+ )
683+
684+
655685class ForecastingExperiment (Experiment ):
656686 def __init__ (self ):
657687 pass
@@ -670,21 +700,25 @@ def generate_test_preds(self, x_test, y_test):
670700 def validate_estimator (self , estimator ):
671701 return validate_forecaster (estimator )
672702
703+
673704class RegressionExperiment (Experiment ):
674705 def __init__ (
675- self ,
676- ignore_custom_train_estimate = False ,
677- predefined_resample = False ,
678- problem_path = "" ,
679- ):
706+ self ,
707+ ignore_custom_train_estimate = False ,
708+ predefined_resample = False ,
709+ problem_path = "" ,
710+ ):
680711 self .is_fitted = False
681712 self .ignore_custom_train_estimate = ignore_custom_train_estimate
682713 self .problem_path = problem_path
683714 self .predefined_resample = predefined_resample
684715
685716 def load_experimental_data (self ):
686717 X_train , y_train , X_test , y_test , resample = load_experiment_data (
687- self .problem_path , self .dataset_name , self .resample_id , self .predefined_resample
718+ self .problem_path ,
719+ self .dataset_name ,
720+ self .resample_id ,
721+ self .predefined_resample ,
688722 )
689723
690724 if resample :
@@ -696,30 +730,38 @@ def load_experimental_data(self):
696730 def generate_train_preds (self , X_train , y_train ):
697731 if self .estimate_train_data and not self .ignore_custom_train_estimate :
698732 self .train_estimate_method = "Custom"
699- train_preds , train_time = time_function (self .estimator .fit_predict , (X_train , y_train ))
733+ train_preds , train_time = time_function (
734+ self .estimator .fit_predict , (X_train , y_train )
735+ )
700736 self .is_fitted = True
701737 else :
702- train_preds , train_time , self .train_estimate_method = cross_validate_train_data (self .estimator ,y_train ,X_train )
738+ train_preds , train_time , self .train_estimate_method = (
739+ cross_validate_train_data (self .estimator , y_train , X_train )
740+ )
703741 return train_preds , train_time
704742
705743 def needs_fit (self ):
706744 return not self .is_fitted
707745
708- def get_third_line (self , y , preds , fit_time , predict_time , benchmark_time , memory_usage ):
746+ def get_third_line (
747+ self , y , preds , fit_time , predict_time , benchmark_time , memory_usage
748+ ):
709749 return regression_results_third_line (
710- y = y ,
711- preds = preds ,
712- fit_time = fit_time ,
713- predict_time = predict_time ,
714- benchmark_time = benchmark_time ,
715- memory_usage = memory_usage ,
716- train_estimate_method = self .train_estimate_method ,
717- )
750+ y = y ,
751+ preds = preds ,
752+ fit_time = fit_time ,
753+ predict_time = predict_time ,
754+ benchmark_time = benchmark_time ,
755+ memory_usage = memory_usage ,
756+ train_estimate_method = self .train_estimate_method ,
757+ )
758+
718759 def validate_estimator (self , estimator ):
719760 estimator , estimate_train_data = validate_regressor (estimator )
720761 self .estimate_train_data = estimate_train_data
721762 return estimator
722-
763+
764+
723765def validate_forecaster (estimator ):
724766 if isinstance (estimator , BaseForecaster ):
725767 return estimator
@@ -728,28 +770,34 @@ def validate_forecaster(estimator):
728770 estimator , _ = validate_regressor (estimator )
729771 return RegressionForecaster (regressor = estimator )
730772 except TypeError :
731- raise TypeError ("forecaster must be an aeon forecaster or a tsml, aeon or sklearn regressor." )
773+ raise TypeError (
774+ "forecaster must be an aeon forecaster or a tsml, aeon or sklearn regressor."
775+ )
776+
732777
733778def validate_regressor (estimator ):
734779 estimate_train_data = False
735780 if isinstance (estimator , BaseRegressor ):
736- if estimator .get_tag (
737- "capability:train_estimate" , False , False
738- ):
781+ if estimator .get_tag ("capability:train_estimate" , False , False ):
739782 estimate_train_data = True
740783 return estimator , estimate_train_data
741784 elif isinstance (estimator , BaseTimeSeriesEstimator ) and is_regressor (estimator ):
742785 return estimator , estimate_train_data
743786 elif isinstance (estimator , BaseEstimator ) and is_regressor (estimator ):
744- return SklearnToTsmlRegressor (
745- regressor = estimator ,
746- pad_unequal = True ,
747- concatenate_channels = True ,
748- clone_estimator = False ,
749- random_state = (
750- estimator .random_state if hasattr (estimator , "random_state" ) else None
787+ return (
788+ SklearnToTsmlRegressor (
789+ regressor = estimator ,
790+ pad_unequal = True ,
791+ concatenate_channels = True ,
792+ clone_estimator = False ,
793+ random_state = (
794+ estimator .random_state
795+ if hasattr (estimator , "random_state" )
796+ else None
797+ ),
751798 ),
752- ), estimate_train_data
799+ estimate_train_data ,
800+ )
753801 else :
754802 raise TypeError ("regressor must be a tsml, aeon or sklearn regressor." )
755803
@@ -1200,6 +1248,7 @@ def load_and_run_clustering_experiment(
12001248 benchmark_time = benchmark_time ,
12011249 )
12021250
1251+
12031252def run_forecasting_experiment (
12041253 train ,
12051254 y_test ,
@@ -1251,7 +1300,6 @@ def run_forecasting_experiment(
12511300 pass
12521301
12531302
1254-
12551303def load_and_run_forecasting_experiment (
12561304 problem_path ,
12571305 results_path ,
@@ -1294,7 +1342,6 @@ def load_and_run_forecasting_experiment(
12941342 If set to False, this will only build results if there is not a result file
12951343 already present. If True, it will overwrite anything already there.
12961344 """
1297-
12981345 tmpdir = tempfile .mkdtemp ()
12991346 dataset = load_forecasting (dataset , tmpdir )
13001347 series = (
0 commit comments