diff --git a/src/mlpack/core/cv/CMakeLists.txt b/src/mlpack/core/cv/CMakeLists.txt index 27dd4a9f345..67673a67a72 100644 --- a/src/mlpack/core/cv/CMakeLists.txt +++ b/src/mlpack/core/cv/CMakeLists.txt @@ -5,6 +5,8 @@ add_subdirectory(metrics) set(SOURCES cv_base.hpp cv_base_impl.hpp + k_fold_cv.hpp + k_fold_cv_impl.hpp meta_info_extractor.hpp simple_cv.hpp simple_cv_impl.hpp diff --git a/src/mlpack/core/cv/k_fold_cv.hpp b/src/mlpack/core/cv/k_fold_cv.hpp new file mode 100644 index 00000000000..544bba38cba --- /dev/null +++ b/src/mlpack/core/cv/k_fold_cv.hpp @@ -0,0 +1,198 @@ +/** + * @file k_fold_cv.hpp + * @author Kirill Mishchenko + * + * k-fold cross-validation. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ +#ifndef MLPACK_CORE_CV_K_FOLD_CV_HPP +#define MLPACK_CORE_CV_K_FOLD_CV_HPP + +#include +#include + +namespace mlpack { +namespace cv { + +/** + * The class KFoldCV implements k-fold cross-validation for regression and + * classification algorithms. + * + * To construct a KFoldCV object you need to pass the k parameter and arguments + * that specify data. For the latter see the CVBase constructors as a reference + * - the CVBase constructors take exactly the same arguments as ones that are + * supposed to be passed after the k parameter in the KFoldCV constructor. + * + * For example, you can run 10-fold cross-validation for SoftmaxRegression in + * the following way. + * + * @code + * // 100-point 5-dimensional random dataset. + * arma::mat data = arma::randu(5, 100); + * // Random labels in the [0, 4] interval. + * arma::Row labels = + * arma::randi>(100, arma::distr_param(0, 4)); + * size_t numClasses = 5; + * + * KFoldCV, Accuracy> cv(10, data, labels, numClasses); + * + * double lambda = 0.1; + * double softmaxAccuracy = cv.Evaluate(lambda); + * @endcode + * + * @tparam MLAlgorithm A machine learning algorithm. + * @tparam Metric A metric to assess the quality of a trained model. + * @tparam MatType The type of data. + * @tparam PredictionsType The type of predictions (should be passed when the + * predictions type is a template parameter in Train methods of + * MLAlgorithm). + * @tparam WeightsType The type of weights (should be passed when weighted + * learning is supported, and the weights type is a template parameter in + * Train methods of MLAlgorithm). + */ +template::PredictionsType, + typename WeightsType = + typename MetaInfoExtractor::WeightsType> +class KFoldCV : + private CVBase +{ + public: + /** + * Construct an object for running k-fold cross-validation. + * + * @param k Number of folds (should be at least 2). + * @param args Basic constructor arguments for MLAlgortithm (see the CVBase + * constructors for reference). + */ + template + KFoldCV(const size_t k, const CVBaseArgs&... args); + + /** + * Run k-fold cross-validation. + * + * @param args Arguments for MLAlgorithm (in addition to the passed + * ones in the constructor). + */ + template + double Evaluate(const MLAlgorithmArgs& ...args); + + //! Access and modify a model from the last run of k-fold cross-validation. + MLAlgorithm& Model(); + + private: + //! A short alias for CVBase. + using Base = CVBase; + + //! The number of bins in the dataset. + const size_t k; + + //! The extended (by repeating the first k - 2 bins) data points. + MatType xs; + //! The extended (by repeating the first k - 2 bins) predictions. + PredictionsType ys; + //! The extended (by repeating the first k - 2 bins) weights. + WeightsType weights; + + //! The size of each bin in terms of data points. + size_t binSize; + + //! The size of each training subset in terms of data points. + size_t trainingSubsetSize; + + //! A pointer to a model from the last run of k-fold cross-validation. + std::unique_ptr modelPtr; + + /** + * Initialize without weights. + */ + template::value == 2>::type> + void Init(const DataArgsTupleT& dataArgsTuple); + + /** + * Initialize with weights. + */ + template::value == 3>::type, + typename = void> + void Init(const DataArgsTupleT& dataArgsTuple); + + /** + * Initialize the given destination matrix with the given source joined with + * its first k - 2 bins. + */ + template + void InitKFoldCVMat(const SourceType& source, DestinationType& destination); + + /** + * Train and run evaluation in the case of non-weighted learning. + */ + template::type> + double TrainAndEvaluate(const MLAlgorithmArgs& ...mlAlgorithmArgs); + + /** + * Train and run evaluation in the case of supporting weighted learning. + */ + template::type, + typename = void> + double TrainAndEvaluate(const MLAlgorithmArgs& ...mlAlgorithmArgs); + + /** + * Calculate the index of the first column of the ith validation subset. + * + * We take the ith validation subset after the ith training subset if + * i < k - 1 and before it otherwise. + */ + inline size_t ValidationSubsetFirstCol(const size_t i); + + /** + * Get the ith training subset from a variable of a matrix type. + */ + template + inline arma::Mat GetTrainingSubset(arma::Mat& m, + const size_t i); + + /** + * Get the ith training subset from a variable of a row type. + */ + template + inline arma::Row GetTrainingSubset(arma::Row& r, + const size_t i); + + /** + * Get the ith validation subset from a variable of a matrix type. + */ + template + inline arma::Mat GetValidationSubset(arma::Mat& m, + const size_t i); + + /** + * Get the ith validation subset from a variable of a row type. + */ + template + inline arma::Row GetValidationSubset(arma::Row& r, + const size_t i); +}; + +} // namespace cv +} // namespace mlpack + +// Include implementation +#include "k_fold_cv_impl.hpp" + +#endif diff --git a/src/mlpack/core/cv/k_fold_cv_impl.hpp b/src/mlpack/core/cv/k_fold_cv_impl.hpp new file mode 100644 index 00000000000..ddecf8ca2b1 --- /dev/null +++ b/src/mlpack/core/cv/k_fold_cv_impl.hpp @@ -0,0 +1,276 @@ +/** + * @file simple_cv_impl.hpp + * @author Kirill Mishchenko + * + * The implementation of k-fold cross-validation. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ +#ifndef MLPACK_CORE_CV_K_FOLD_CV_IMPL_HPP +#define MLPACK_CORE_CV_K_FOLD_CV_IMPL_HPP + +namespace mlpack { +namespace cv { + +template +template +KFoldCV::KFoldCV(const size_t k, + const CVBaseArgs&... args) : + Base(args...), k(k) +{ + if (k < 2) + throw std::invalid_argument("KFoldCV: k should not be less than 2"); + + Init(Base::ExtractDataArgs(args...)); +} + +template +template +double KFoldCV::Evaluate(const MLAlgorithmArgs&... args) +{ + return TrainAndEvaluate(args...); +} + +template +MLAlgorithm& KFoldCV::Model() +{ + if (modelPtr == nullptr) + throw std::logic_error( + "KFoldCV::Model(): attempted to access an uninitialized model"); + + return *modelPtr; +} + +template +template +void KFoldCV::Init(const DataArgsTupleT& dataArgsTuple) +{ + InitKFoldCVMat(std::get<0>(dataArgsTuple), xs); + InitKFoldCVMat(std::get<1>(dataArgsTuple), ys); + + Base::AssertDataConsistency(xs, ys); +} + +template +template +void KFoldCV::Init(const DataArgsTupleT& dataArgsTuple) +{ + InitKFoldCVMat(std::get<0>(dataArgsTuple), xs); + InitKFoldCVMat(std::get<1>(dataArgsTuple), ys); + InitKFoldCVMat(std::get<2>(dataArgsTuple), weights); + + Base::AssertDataConsistency(xs, ys, weights); +} + +template +template +void KFoldCV::InitKFoldCVMat(const SourceType& source, + DestinationType& destination) +{ + const DestinationType& sourceAsDT = source; + + binSize = sourceAsDT.n_cols / k; + trainingSubsetSize = binSize * (k - 1); + + destination = arma::join_rows(sourceAsDT, + sourceAsDT.cols(0, trainingSubsetSize - binSize - 1)); +} + +template +template +double KFoldCV::TrainAndEvaluate(const MLAlgorithmArgs&... args) +{ + arma::vec evaluations(k); + + for (size_t i = 0; i < k; ++i) + { + modelPtr = this->Train(GetTrainingSubset(xs, i), GetTrainingSubset(ys, i), + args...); + evaluations(i) = Metric::Evaluate(*modelPtr, GetValidationSubset(xs, i), + GetValidationSubset(ys, i)); + } + + return arma::mean(evaluations); +} + +template +template +double KFoldCV::TrainAndEvaluate(const MLAlgorithmArgs&... args) +{ + arma::vec evaluations(k); + + for (size_t i = 0; i < k; ++i) + { + if (weights.n_elem > 0) + modelPtr = this->Train(GetTrainingSubset(xs, i), GetTrainingSubset(ys, i), + GetTrainingSubset(weights, i), args...); + else + modelPtr = this->Train(GetTrainingSubset(xs, i), GetTrainingSubset(ys, i), + args...); + evaluations(i) = Metric::Evaluate(*modelPtr, GetValidationSubset(xs, i), + GetValidationSubset(ys, i)); + } + + return arma::mean(evaluations); +} + +template +size_t KFoldCV::ValidationSubsetFirstCol(const size_t i) +{ + return (i < k - 1) ? (binSize * i + trainingSubsetSize) : (binSize * (i - 1)); +} + +template +template +arma::Mat KFoldCV::GetTrainingSubset( + arma::Mat& m, + const size_t i) +{ + return arma::Mat(m.colptr(binSize * i), m.n_rows, + trainingSubsetSize, false, true); +} + +template +template +arma::Row KFoldCV::GetTrainingSubset( + arma::Row& r, + const size_t i) +{ + return arma::Row(r.colptr(binSize * i), trainingSubsetSize, + false, true); +} + +template +template +arma::Mat KFoldCV::GetValidationSubset( + arma::Mat& m, + const size_t i) +{ + return arma::Mat(m.colptr(ValidationSubsetFirstCol(i)), m.n_rows, + binSize, false, true); +} + +template +template +arma::Row KFoldCV::GetValidationSubset( + arma::Row& r, + const size_t i) +{ + return arma::Row(r.colptr(ValidationSubsetFirstCol(i)), binSize, + false, true); +} + +} // namespace cv +} // namespace mlpack + +#endif diff --git a/src/mlpack/tests/cv_test.cpp b/src/mlpack/tests/cv_test.cpp index 5e921e9d793..cd4fafabff2 100644 --- a/src/mlpack/tests/cv_test.cpp +++ b/src/mlpack/tests/cv_test.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -25,6 +26,7 @@ #include #include #include +#include #include #include @@ -34,6 +36,7 @@ using namespace mlpack; using namespace mlpack::ann; using namespace mlpack::cv; using namespace mlpack::optimization; +using namespace mlpack::naive_bayes; using namespace mlpack::regression; using namespace mlpack::tree; @@ -304,4 +307,47 @@ BOOST_AUTO_TEST_CASE(SimpleCVWithDTTest) } } +/** + * Test k-fold cross-validation with the MSE metric. + */ +BOOST_AUTO_TEST_CASE(KFoldCVMSETest) +{ + // Defining dataset with two sets of responses for the same two data points. + arma::mat data("0 1 0 1"); + arma::rowvec responses("0 1 1 3"); + + // 2-fold cross-validation. + KFoldCV cv(2, data, responses); + + // In each of two validation tests the MSE value should be the same. + double expectedMSE = + double((1 - 0) * (1 - 0) + (3 - 1) * (3 - 1)) / 2 * 2 / 2; + + + BOOST_REQUIRE_CLOSE(cv.Evaluate(), expectedMSE, 1e-5); +} + +/** + * Test k-fold cross-validation with the Accuracy metric. + */ +BOOST_AUTO_TEST_CASE(KFoldCVAccuracyTest) +{ + // Making a 10-points dataset. The last point should be classified wrong when + // it is tested separately. + arma::mat data("0 1 2 3 100 101 102 103 104 5"); + arma::Row labels("0 0 0 0 1 1 1 1 1 1"); + // This parameter should be passed into the cross-validation constructor + // after merging #1038. + size_t numClasses = 2; + + // 10-fold cross-validation. + KFoldCV, Accuracy> cv(10, data, labels); + + // We should succeed in classifying separately the first nine samples, and + // fail with the remaining one. + double expectedAccuracy = (9 * 1.0 + 0.0) / 10; + + BOOST_REQUIRE_CLOSE(cv.Evaluate(numClasses), expectedAccuracy, 1e-5); +} + BOOST_AUTO_TEST_SUITE_END();