Posted by: sureshamrita | August 24, 2011

A C++ implementation of k-fold cross validation

In K-fold cross validation, the given data is randomly divided into k disjoint sets of equal size n/k, where n is the total number of patters present in the given data. The classifier is trained k times, each time with a different set held out as a validation set. The estimated performance is the mean of these k errors. (Taken from Duda, Hart, and Stork – Pattern Classification, 2006, page 483.

The given code implements a k-fold cross validation with the following features.

  • The code is independent of the container in which the given data is stored.
  • The code just expects an iterator pointing to the beginning of the container and one pointing to the end (to be precise, one past the end) of the container.
  • The partitioned data can be stored in any container.
  • The code expects only an output iterator pointing to the destination containers (one for training data and the other for testing data)

The header file “mystrcat.h” included in the following code is described here

#ifndef _kfoldh_
#define _kfoldh_
#include <vector>
#include "/home/suresh/C++/myCodes/mystrcat.h"
#include <stdexcept>
#include <algorithm>
#include <iterator>
#include <boost/foreach.hpp>

using namespace std;

#define foreach BOOST_FOREACH

template<class In>
class Kfold {
public:
	Kfold(int k, In _beg, In _end);
	template<class Out>
	void getFold(int foldNo, Out training, Out testing);

private:
	In beg;
	In end;
	int K; //how many folds in this
	vector<int> whichFoldToGo;

};

template<class In>
Kfold<In>::Kfold(int _k, In _beg, In _end) :
		beg(_beg), end(_end), K(_k) {
	if (K <= 0)
		throw runtime_error(Mystrcat("The supplied value of K is =", K,". One cannot create ",k, "no of folds"));

	//create the vector of integers
	int foldNo = 0;
	for (In i = beg; i != end; i++) {
		whichFoldToGo.push_back(++foldNo);
		if (foldNo == K)
			foldNo = 0;
	}
	if (!K)
		throw runtime_error(Mystrcat("With this value of k (=", k ,")Equal division of the data is not possible"));
	random_shuffle(whichFoldToGo.begin(), whichFoldToGo.end());
}

template<class In>
template<class Out>
void Kfold<In>::getFold(int foldNo, Out training, Out testing) {

	int k = 0;
	In i = beg;
	while (i != end) {
		if (whichFoldToGo[k++] == foldNo) {
			*testing++ = *i++;
		} else
			*training++ = *i++;
	}
}

#endif
//******************************Application Code
#include "kfold.h"
#include <vector>
using namespace std;

int main() {
	vector<int> v = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };

	const int folds = 5;
	Kfold<vector<int>::const_iterator> kf(folds, v.begin(), v.end());

	vector<int> test, train;

	for (int i = 0; i != folds; i++) {
		kf.getFold(i + 1, back_inserter(train), back_inserter(test));
		cout << "Fold " << i + 1 << " Training Data" << endl;
		foreach(auto x,train)
					cout << x << " ";
		cout << endl;
		cout << "Fold " << i + 1 << " Testing Data" << endl;
		foreach(auto x,test)
					cout << x << " ";
		cout << endl;

		train.clear();
		test.clear();
	}

	return 0;
}
 

The code uses the auto keyword from C++ 0x. So please  select the appropriate compiler flag to include C++0x features. In the case of linux, the compiler flag is: -std=c++0x

Advertisements

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s

Categories

%d bloggers like this: