-
Notifications
You must be signed in to change notification settings - Fork 0
Add k-fold cross-validation #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: simple_cv
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,198 @@ | ||
| /** | ||
| * @file k_fold_cv.hpp | ||
| * @author Kirill Mishchenko | ||
| * | ||
| * k-fold cross-validation. | ||
| * | ||
| * mlpack is free software; you may redistribute it and/or modify it under the | ||
| * terms of the 3-clause BSD license. You should have received a copy of the | ||
| * 3-clause BSD license along with mlpack. If not, see | ||
| * http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
| */ | ||
| #ifndef MLPACK_CORE_CV_K_FOLD_CV_HPP | ||
| #define MLPACK_CORE_CV_K_FOLD_CV_HPP | ||
|
|
||
| #include <mlpack/core/cv/meta_info_extractor.hpp> | ||
| #include <mlpack/core/cv/cv_base.hpp> | ||
|
|
||
| namespace mlpack { | ||
| namespace cv { | ||
|
|
||
| /** | ||
| * The class KFoldCV implements k-fold cross-validation for regression and | ||
| * classification algorithms. | ||
| * | ||
| * To construct a KFoldCV object you need to pass the k parameter and arguments | ||
| * that specify data. For the latter see the CVBase constructors as a reference | ||
| * - the CVBase constructors take exactly the same arguments as ones that are | ||
| * supposed to be passed after the k parameter in the KFoldCV constructor. | ||
| * | ||
| * For example, you can run 10-fold cross-validation for SoftmaxRegression in | ||
| * the following way. | ||
| * | ||
| * @code | ||
| * // 100-point 5-dimensional random dataset. | ||
| * arma::mat data = arma::randu<arma::mat>(5, 100); | ||
| * // Random labels in the [0, 4] interval. | ||
| * arma::Row<size_t> labels = | ||
| * arma::randi<arma::Row<size_t>>(100, arma::distr_param(0, 4)); | ||
| * size_t numClasses = 5; | ||
| * | ||
| * KFoldCV<SoftmaxRegression<>, Accuracy> cv(10, data, labels, numClasses); | ||
| * | ||
| * double lambda = 0.1; | ||
| * double softmaxAccuracy = cv.Evaluate(lambda); | ||
| * @endcode | ||
| * | ||
| * @tparam MLAlgorithm A machine learning algorithm. | ||
| * @tparam Metric A metric to assess the quality of a trained model. | ||
| * @tparam MatType The type of data. | ||
| * @tparam PredictionsType The type of predictions (should be passed when the | ||
| * predictions type is a template parameter in Train methods of | ||
| * MLAlgorithm). | ||
| * @tparam WeightsType The type of weights (should be passed when weighted | ||
| * learning is supported, and the weights type is a template parameter in | ||
| * Train methods of MLAlgorithm). | ||
| */ | ||
| template<typename MLAlgorithm, | ||
| typename Metric, | ||
| typename MatType = arma::mat, | ||
| typename PredictionsType = | ||
| typename MetaInfoExtractor<MLAlgorithm, MatType>::PredictionsType, | ||
| typename WeightsType = | ||
| typename MetaInfoExtractor<MLAlgorithm, MatType, | ||
| PredictionsType>::WeightsType> | ||
| class KFoldCV : | ||
| private CVBase<MLAlgorithm, MatType, PredictionsType, WeightsType> | ||
| { | ||
| public: | ||
| /** | ||
| * Construct an object for running k-fold cross-validation. | ||
| * | ||
| * @param k Number of folds (should be at least 2). | ||
| * @param args Basic constructor arguments for MLAlgortithm (see the CVBase | ||
| * constructors for reference). | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor misspelling---this should be MLAlgorithm not MLAlgortithm. :)
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will remove this constructor eventually, so it can be skipped. |
||
| */ | ||
| template<typename... CVBaseArgs> | ||
| KFoldCV(const size_t k, const CVBaseArgs&... args); | ||
|
|
||
| /** | ||
| * Run k-fold cross-validation. | ||
| * | ||
| * @param args Arguments for MLAlgorithm (in addition to the passed | ||
| * ones in the constructor). | ||
| */ | ||
| template<typename... MLAlgorithmArgs> | ||
| double Evaluate(const MLAlgorithmArgs& ...args); | ||
|
|
||
| //! Access and modify a model from the last run of k-fold cross-validation. | ||
| MLAlgorithm& Model(); | ||
|
|
||
| private: | ||
| //! A short alias for CVBase. | ||
| using Base = CVBase<MLAlgorithm, MatType, PredictionsType, WeightsType>; | ||
|
|
||
| //! The number of bins in the dataset. | ||
| const size_t k; | ||
|
|
||
| //! The extended (by repeating the first k - 2 bins) data points. | ||
| MatType xs; | ||
| //! The extended (by repeating the first k - 2 bins) predictions. | ||
| PredictionsType ys; | ||
| //! The extended (by repeating the first k - 2 bins) weights. | ||
| WeightsType weights; | ||
|
|
||
| //! The size of each bin in terms of data points. | ||
| size_t binSize; | ||
|
|
||
| //! The size of each training subset in terms of data points. | ||
| size_t trainingSubsetSize; | ||
|
|
||
| //! A pointer to a model from the last run of k-fold cross-validation. | ||
| std::unique_ptr<MLAlgorithm> modelPtr; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I forgot to follow up with this on the last simple_cv PR, and we're about to I see a couple of possibilities---you can use a bare pointer or actually hold an This is a minor comment that doesn't really affect the actual functionality of
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Initially I thought that some machine learning algorithms are not construable without any argument. But when I reviewed the mlpack code recently, the situation looked differently to me. So, I guess we can store objects themselves and use the move semantic for light-weight assignments. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think most are default-constructible, but not all of them---I don't think we've made any strict requirement on that (and I'm not sure we need to).
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If so, why the usage of standard smart pointers is bad? As an alternative, we can keep smart pointers only as a part of implementation and keep them away of any public interface. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just consistency, there is no particular reason that smart pointers are bad. If we were making tons of them, it would be faster to use bare pointers or objects (due to the reference counting of
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess we should revert the changes made by this commit, but CVBase methods will need to return MLAlgorithm objects themselves rather than smart pointers.
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Isn't it about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Your call, if that's what you'd like to do feel free. Personally I think the code is cleaner without smart pointers but that is just personal preference and this is your code so it is your choice. :) My only request is that we don't use smart pointers in the public API for consistency with the rest of mlpack. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Ah sorry you are right, there is no reference counting for |
||
|
|
||
| /** | ||
| * Initialize without weights. | ||
| */ | ||
| template<typename DataArgsTupleT, | ||
| typename = typename std::enable_if< | ||
| std::tuple_size<DataArgsTupleT>::value == 2>::type> | ||
| void Init(const DataArgsTupleT& dataArgsTuple); | ||
|
|
||
| /** | ||
| * Initialize with weights. | ||
| */ | ||
| template<typename DataArgsTupleT, | ||
| typename = typename std::enable_if< | ||
| std::tuple_size<DataArgsTupleT>::value == 3>::type, | ||
| typename = void> | ||
| void Init(const DataArgsTupleT& dataArgsTuple); | ||
|
|
||
| /** | ||
| * Initialize the given destination matrix with the given source joined with | ||
| * its first k - 2 bins. | ||
| */ | ||
| template<typename SourceType, typename DestinationType> | ||
| void InitKFoldCVMat(const SourceType& source, DestinationType& destination); | ||
|
|
||
| /** | ||
| * Train and run evaluation in the case of non-weighted learning. | ||
| */ | ||
| template<typename...MLAlgorithmArgs, | ||
| bool Enabled = !Base::MIE::SupportsWeights, | ||
| typename = typename std::enable_if<Enabled>::type> | ||
| double TrainAndEvaluate(const MLAlgorithmArgs& ...mlAlgorithmArgs); | ||
|
|
||
| /** | ||
| * Train and run evaluation in the case of supporting weighted learning. | ||
| */ | ||
| template<typename...MLAlgorithmArgs, | ||
| bool Enabled = Base::MIE::SupportsWeights, | ||
| typename = typename std::enable_if<Enabled>::type, | ||
| typename = void> | ||
| double TrainAndEvaluate(const MLAlgorithmArgs& ...mlAlgorithmArgs); | ||
|
|
||
| /** | ||
| * Calculate the index of the first column of the ith validation subset. | ||
| * | ||
| * We take the ith validation subset after the ith training subset if | ||
| * i < k - 1 and before it otherwise. | ||
| */ | ||
| inline size_t ValidationSubsetFirstCol(const size_t i); | ||
|
|
||
| /** | ||
| * Get the ith training subset from a variable of a matrix type. | ||
| */ | ||
| template<typename ElementType> | ||
| inline arma::Mat<ElementType> GetTrainingSubset(arma::Mat<ElementType>& m, | ||
| const size_t i); | ||
|
|
||
| /** | ||
| * Get the ith training subset from a variable of a row type. | ||
| */ | ||
| template<typename ElementType> | ||
| inline arma::Row<ElementType> GetTrainingSubset(arma::Row<ElementType>& r, | ||
| const size_t i); | ||
|
|
||
| /** | ||
| * Get the ith validation subset from a variable of a matrix type. | ||
| */ | ||
| template<typename ElementType> | ||
| inline arma::Mat<ElementType> GetValidationSubset(arma::Mat<ElementType>& m, | ||
| const size_t i); | ||
|
|
||
| /** | ||
| * Get the ith validation subset from a variable of a row type. | ||
| */ | ||
| template<typename ElementType> | ||
| inline arma::Row<ElementType> GetValidationSubset(arma::Row<ElementType>& r, | ||
| const size_t i); | ||
| }; | ||
|
|
||
| } // namespace cv | ||
| } // namespace mlpack | ||
|
|
||
| // Include implementation | ||
| #include "k_fold_cv_impl.hpp" | ||
|
|
||
| #endif | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this way of referring someone to the other arguments of the constructor
could be a little bit confusing, and because we are not exposing the CVBase
constructors directly to the user, we can't try to force Doxygen to print all of
the CVBase constructors directly along with the KFoldCV constructors.
The only solution I could see would be to replicate the four constructors of
CVBase here. This would cause extra code duplication, but it would be for the
sake of clarity of documentation. Would you be willing to do that? I think
that that would also allow inlining of the Init() functions into the
constructors directly, so maybe the amount of extra code and functions would not
be that many...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean? The
CVBaseconstructors are public.Anyway, you have mentioned it already the second time (the first time was when you did it for the
SimpleCVconstructors). So, I guess it's pretty much important in general and to be consisted with other mlpack in particular - more important than I thought initially. In such a case I'm ready to do what you ask, even though I like the idea of avoiding code duplication. If we want to do it here, I guess it makes sense to do it inSimpleCVtoo.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct me if I have misunderstood---I think that because the inheritance is
private, the constructors ofCVBaseare not available inKFoldCV, and instead the single constructor ofKFoldCVis delegating toCVBase. So in the end, when a user looks at theKFoldCVdocumentation, they only see this overload, and have to dig further to find theCVBaseconstructors that are being referred to. What do you think, do you like the idea of adding more constructors for clarity, or adding more documentation to be fully clear about what should be passed in?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even if the inheritance was public, the constructors of
CVBasewould not be available inKFoldCV. This simple code snippet illustrates the point.I don't see any way how the documentation can become significantly clearer than now. So, I guess it's easier just to add explicit constructors here and to SimpleCV for uniformity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah---sorry that that is extra work but I think it is the best option here.