Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/mlpack/core/cv/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ add_subdirectory(metrics)
set(SOURCES
cv_base.hpp
cv_base_impl.hpp
k_fold_cv.hpp
k_fold_cv_impl.hpp
meta_info_extractor.hpp
simple_cv.hpp
simple_cv_impl.hpp
Expand Down
198 changes: 198 additions & 0 deletions src/mlpack/core/cv/k_fold_cv.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
/**
* @file k_fold_cv.hpp
* @author Kirill Mishchenko
*
* k-fold cross-validation.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_CORE_CV_K_FOLD_CV_HPP
#define MLPACK_CORE_CV_K_FOLD_CV_HPP

#include <mlpack/core/cv/meta_info_extractor.hpp>
#include <mlpack/core/cv/cv_base.hpp>

namespace mlpack {
namespace cv {

/**
* The class KFoldCV implements k-fold cross-validation for regression and
* classification algorithms.
*
* To construct a KFoldCV object you need to pass the k parameter and arguments
* that specify data. For the latter see the CVBase constructors as a reference
* - the CVBase constructors take exactly the same arguments as ones that are
* supposed to be passed after the k parameter in the KFoldCV constructor.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this way of referring someone to the other arguments of the constructor
could be a little bit confusing, and because we are not exposing the CVBase
constructors directly to the user, we can't try to force Doxygen to print all of
the CVBase constructors directly along with the KFoldCV constructors.

The only solution I could see would be to replicate the four constructors of
CVBase here. This would cause extra code duplication, but it would be for the
sake of clarity of documentation. Would you be willing to do that? I think
that that would also allow inlining of the Init() functions into the
constructors directly, so maybe the amount of extra code and functions would not
be that many...

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and because we are not exposing the CVBase
constructors directly to the user

What do you mean? The CVBase constructors are public.

Anyway, you have mentioned it already the second time (the first time was when you did it for the SimpleCV constructors). So, I guess it's pretty much important in general and to be consisted with other mlpack in particular - more important than I thought initially. In such a case I'm ready to do what you ask, even though I like the idea of avoiding code duplication. If we want to do it here, I guess it makes sense to do it in SimpleCV too.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct me if I have misunderstood---I think that because the inheritance is private, the constructors of CVBase are not available in KFoldCV, and instead the single constructor of KFoldCV is delegating to CVBase. So in the end, when a user looks at the KFoldCV documentation, they only see this overload, and have to dig further to find the CVBase constructors that are being referred to. What do you think, do you like the idea of adding more constructors for clarity, or adding more documentation to be fully clear about what should be passed in?

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct me if I have misunderstood---I think that because the inheritance is private, the constructors of CVBase are not available in KFoldCV, and instead the single constructor of KFoldCV is delegating to CVBase.

Even if the inheritance was public, the constructors of CVBase would not be available in KFoldCV. This simple code snippet illustrates the point.

class A
{
public:
    A(int a) {}
};

class B : public A
{
};

int main()
{
  B b(1); // This line will fail to be compiled.
}

What do you think, do you like the idea of adding more constructors for clarity, or adding more documentation to be fully clear about what should be passed in?

I don't see any way how the documentation can become significantly clearer than now. So, I guess it's easier just to add explicit constructors here and to SimpleCV for uniformity.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah---sorry that that is extra work but I think it is the best option here.

*
* For example, you can run 10-fold cross-validation for SoftmaxRegression in
* the following way.
*
* @code
* // 100-point 5-dimensional random dataset.
* arma::mat data = arma::randu<arma::mat>(5, 100);
* // Random labels in the [0, 4] interval.
* arma::Row<size_t> labels =
* arma::randi<arma::Row<size_t>>(100, arma::distr_param(0, 4));
* size_t numClasses = 5;
*
* KFoldCV<SoftmaxRegression<>, Accuracy> cv(10, data, labels, numClasses);
*
* double lambda = 0.1;
* double softmaxAccuracy = cv.Evaluate(lambda);
* @endcode
*
* @tparam MLAlgorithm A machine learning algorithm.
* @tparam Metric A metric to assess the quality of a trained model.
* @tparam MatType The type of data.
* @tparam PredictionsType The type of predictions (should be passed when the
* predictions type is a template parameter in Train methods of
* MLAlgorithm).
* @tparam WeightsType The type of weights (should be passed when weighted
* learning is supported, and the weights type is a template parameter in
* Train methods of MLAlgorithm).
*/
template<typename MLAlgorithm,
typename Metric,
typename MatType = arma::mat,
typename PredictionsType =
typename MetaInfoExtractor<MLAlgorithm, MatType>::PredictionsType,
typename WeightsType =
typename MetaInfoExtractor<MLAlgorithm, MatType,
PredictionsType>::WeightsType>
class KFoldCV :
private CVBase<MLAlgorithm, MatType, PredictionsType, WeightsType>
{
public:
/**
* Construct an object for running k-fold cross-validation.
*
* @param k Number of folds (should be at least 2).
* @param args Basic constructor arguments for MLAlgortithm (see the CVBase
* constructors for reference).

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor misspelling---this should be MLAlgorithm not MLAlgortithm. :)

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove this constructor eventually, so it can be skipped.

*/
template<typename... CVBaseArgs>
KFoldCV(const size_t k, const CVBaseArgs&... args);

/**
* Run k-fold cross-validation.
*
* @param args Arguments for MLAlgorithm (in addition to the passed
* ones in the constructor).
*/
template<typename... MLAlgorithmArgs>
double Evaluate(const MLAlgorithmArgs& ...args);

//! Access and modify a model from the last run of k-fold cross-validation.
MLAlgorithm& Model();

private:
//! A short alias for CVBase.
using Base = CVBase<MLAlgorithm, MatType, PredictionsType, WeightsType>;

//! The number of bins in the dataset.
const size_t k;

//! The extended (by repeating the first k - 2 bins) data points.
MatType xs;
//! The extended (by repeating the first k - 2 bins) predictions.
PredictionsType ys;
//! The extended (by repeating the first k - 2 bins) weights.
WeightsType weights;

//! The size of each bin in terms of data points.
size_t binSize;

//! The size of each training subset in terms of data points.
size_t trainingSubsetSize;

//! A pointer to a model from the last run of k-fold cross-validation.
std::unique_ptr<MLAlgorithm> modelPtr;

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot to follow up with this on the last simple_cv PR, and we're about to
merge that so I don't want to bring this up again there. Nowhere else in the
mlpack codebase is std::unique_ptr<> used, so for consistency I'd prefer to
avoid the use of the class.

I see a couple of possibilities---you can use a bare pointer or actually hold an
MLAlgorithm internally and I think this will not affect the external API you
have already provided. At a quick glance, I think that both of those should be
possible; but perhaps there is a catch I have overlooked.

This is a minor comment that doesn't really affect the actual functionality of
the CV code at all (I think), so we can consider it low-priority. If you
prefer, I can submit the PR to make these changes, but it will be a few weeks
until that happens. (Maybe I will prepare it on the flight home, but that will
be August 12...)

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially I thought that some machine learning algorithms are not construable without any argument. But when I reviewed the mlpack code recently, the situation looked differently to me. So, I guess we can store objects themselves and use the move semantic for light-weight assignments.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think most are default-constructible, but not all of them---I don't think we've made any strict requirement on that (and I'm not sure we need to).

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so, why the usage of standard smart pointers is bad? As an alternative, we can keep smart pointers only as a part of implementation and keep them away of any public interface.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just consistency, there is no particular reason that smart pointers are bad. If we were making tons of them, it would be faster to use bare pointers or objects (due to the reference counting of unique_ptr), but that is not a concern here. If you really want to use unique_ptr, I guess it is ok, but like you said we should keep them away from the public interface since they do not appear anywhere else in the mlpack API. Let me know what you'd like to do.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we should revert the changes made by this commit, but CVBase methods will need to return MLAlgorithm objects themselves rather than smart pointers.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(due to the reference counting of unique_ptr)

Isn't it about shared_ptr?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your call, if that's what you'd like to do feel free. Personally I think the code is cleaner without smart pointers but that is just personal preference and this is your code so it is your choice. :) My only request is that we don't use smart pointers in the public API for consistency with the rest of mlpack.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it about shared_ptr?

Ah sorry you are right, there is no reference counting for unique_ptr.


/**
* Initialize without weights.
*/
template<typename DataArgsTupleT,
typename = typename std::enable_if<
std::tuple_size<DataArgsTupleT>::value == 2>::type>
void Init(const DataArgsTupleT& dataArgsTuple);

/**
* Initialize with weights.
*/
template<typename DataArgsTupleT,
typename = typename std::enable_if<
std::tuple_size<DataArgsTupleT>::value == 3>::type,
typename = void>
void Init(const DataArgsTupleT& dataArgsTuple);

/**
* Initialize the given destination matrix with the given source joined with
* its first k - 2 bins.
*/
template<typename SourceType, typename DestinationType>
void InitKFoldCVMat(const SourceType& source, DestinationType& destination);

/**
* Train and run evaluation in the case of non-weighted learning.
*/
template<typename...MLAlgorithmArgs,
bool Enabled = !Base::MIE::SupportsWeights,
typename = typename std::enable_if<Enabled>::type>
double TrainAndEvaluate(const MLAlgorithmArgs& ...mlAlgorithmArgs);

/**
* Train and run evaluation in the case of supporting weighted learning.
*/
template<typename...MLAlgorithmArgs,
bool Enabled = Base::MIE::SupportsWeights,
typename = typename std::enable_if<Enabled>::type,
typename = void>
double TrainAndEvaluate(const MLAlgorithmArgs& ...mlAlgorithmArgs);

/**
* Calculate the index of the first column of the ith validation subset.
*
* We take the ith validation subset after the ith training subset if
* i < k - 1 and before it otherwise.
*/
inline size_t ValidationSubsetFirstCol(const size_t i);

/**
* Get the ith training subset from a variable of a matrix type.
*/
template<typename ElementType>
inline arma::Mat<ElementType> GetTrainingSubset(arma::Mat<ElementType>& m,
const size_t i);

/**
* Get the ith training subset from a variable of a row type.
*/
template<typename ElementType>
inline arma::Row<ElementType> GetTrainingSubset(arma::Row<ElementType>& r,
const size_t i);

/**
* Get the ith validation subset from a variable of a matrix type.
*/
template<typename ElementType>
inline arma::Mat<ElementType> GetValidationSubset(arma::Mat<ElementType>& m,
const size_t i);

/**
* Get the ith validation subset from a variable of a row type.
*/
template<typename ElementType>
inline arma::Row<ElementType> GetValidationSubset(arma::Row<ElementType>& r,
const size_t i);
};

} // namespace cv
} // namespace mlpack

// Include implementation
#include "k_fold_cv_impl.hpp"

#endif
Loading