//knnn.c (gcc knnn.c -o knnn -lm -O3 -ffast-math -march=native)
//nonparametric classifier based on Kn-nearest-neighbors technique, 2-dimensional
//Assumes equal priors and equally-sized training data.
//Based on quickselect over distance to evaluation point.
//Bryan Topp <betopp@cs.unm.edu>

#include <stdio.h>
#include <stdlib.h>
#include <math.h>


typedef struct
{
	double x;
	double y;
} datapoint_t;

//Reads a dataset from a file. Allocates and returns an array and its size.
void read_dataset(const char *filename, double **array_out, int *size_out)
{
	FILE *datafile = fopen(filename, "r");
	if(!datafile)
	{
		printf("Could not open %s\n", filename);
		exit(-1);
	}
	
	//Read through once to find out how many entries we have
	int num_points = 0;
	
	while(1)
	{
		int dummy1;
		double dummy2, dummy3;
		
		if(fscanf(datafile, " %d %lf %lf ", &dummy1, &dummy2, &dummy3) != 3)
			break;
		
		num_points++;
	}
	
	if(!feof(datafile)) //We should only finish at the end of the file
	{
		printf("error parsing file %s\n", filename);
		exit(-1);
	}
	
	printf("Read %d data points from file %s\n", num_points, filename);
	
	double *dataset = (double*)malloc(sizeof(double)*num_points*2);
	
	//Now actually read the points.
	rewind(datafile);
	num_points = 0;
	while(1)
	{
		
		int dummy;
		
		if(fscanf(datafile, " %d %lf %lf ", &dummy, &(dataset[num_points*2]), &(dataset[(num_points*2)+1])) != 3)
			break;
		
		num_points++;
	}	
	
	fclose(datafile);
	
	*size_out = num_points;
	*array_out = dataset;
}


double datadist(datapoint_t *a, datapoint_t *b)
{
	return ((a->x - b->x)*(a->x - b->x))+((a->y - b->y)*(a->y - b->y));
}

void swap_data(datapoint_t *a, datapoint_t *b)
{
	datapoint_t temp;
	temp = *a;
	*a = *b;
	*b = temp;
}


//Quickselect helper function - partitions data based on pivot
int qsel_partition(datapoint_t *eval, datapoint_t *data, int first, int last, int pivot)
{
	double pivot_dist = datadist(eval, &(data[pivot]));
	swap_data(&(data[pivot]), &(data[last])); //put pivot in last pos
	
	int ret_index = first; //we will return the new value of the pivot. start at the beginning...
	int i;
	for(i=first;i<last;i++) //for everything except the pivot
	{
		if(datadist(eval, &(data[i])) < pivot_dist)
		{
			swap_data(&(data[ret_index]), &(data[i]));
			ret_index++;
		}
	}
	
	//put the pivot back
	swap_data(&(data[last]), &(data[ret_index]));
	
	return ret_index;
}

//Returns nth-nearest point from data between indexes first and last.
double nth_nearest_sqdistance(datapoint_t *eval, datapoint_t *data, int first, int last, int n)
{	
	while(1)
	{
		if(first >= last)
			return datadist(eval, &(data[first]));
		
		int pivot = first+(rand()%((last-first)+1)); //pick a value to pivot on
		pivot = qsel_partition(eval, data, first, last, pivot);
		
		if(pivot == n) //found it
			return datadist(eval, &(data[n]));
		else if(n < pivot) //pivot is too late
			last = pivot-1;
		else //pivot too early
			first = pivot+1;
	}
}


int main(int argc, const char **argv)
{
	if(argc != 6)
	{
		printf("usage: %s <training 1> <training 2> <testing 1> <testing 2> <how many neighbors>\n", argv[0]);
		exit(-1);
	}
	
	//Read datasets.
	datapoint_t *c1_training;
	int c1_training_size;
	
	datapoint_t *c2_training;
	int c2_training_size;
	
	datapoint_t *c1_testing;
	int c1_testing_size;
	
	datapoint_t *c2_testing;
	int c2_testing_size;
	
	read_dataset(argv[1], (double**)(&c1_training), &c1_training_size);
	read_dataset(argv[2], (double**)(&c2_training), &c2_training_size);
	read_dataset(argv[3], (double**)(&c1_testing), &c1_testing_size);
	read_dataset(argv[4], (double**)(&c2_testing), &c2_testing_size);
	
	int neighbors = atoi(argv[5]);
	printf("Looking for volume containing %d nearest neighbors\n", neighbors);
	if(neighbors < 1)
	{
		printf("neighbors must be >= 1!\n");
		exit(-1);
	}
	
	//Evaluate.
	int errors_classifying_c1 = 0;
	int errors_classifying_c2 = 0;
	
	//Classify each testing point in C1
	int testpt;
	for(testpt = 0; testpt < c1_testing_size; testpt++)
	{
		double c1_radius_squared = nth_nearest_sqdistance(&(c1_testing[testpt]), c1_training, 0, c1_training_size-1, neighbors-1);
		double c2_radius_squared = nth_nearest_sqdistance(&(c1_testing[testpt]), c2_training, 0, c2_training_size-1, neighbors-1);
		
		double c1_normalized = (double)c1_radius_squared * (double)c1_training_size;
		double c2_normalized = (double)c2_radius_squared * (double)c2_training_size;
		
		//assume equal priors.
		//look for the smaller radius!
		if(c2_normalized < c1_normalized)
			errors_classifying_c1++;
	}
	
	//Classify each testing point in C2
	for(testpt = 0; testpt < c2_testing_size; testpt++)
	{
		double c1_radius_squared = nth_nearest_sqdistance(&(c2_testing[testpt]), c1_training, 0, c1_training_size-1, neighbors-1);
		double c2_radius_squared = nth_nearest_sqdistance(&(c2_testing[testpt]), c2_training, 0, c2_training_size-1, neighbors-1);
		
		double c1_normalized = (double)c1_radius_squared * (double)c1_training_size;
		double c2_normalized = (double)c2_radius_squared * (double)c2_training_size;
		
		//assume equal priors.
		if(c2_normalized >= c1_normalized)
			errors_classifying_c2++;
	}
	
	printf("Errors classifying C1 testing set: %d / %d (%lf)\n", errors_classifying_c1, c1_training_size, (double)errors_classifying_c1 / (double)c1_training_size);
	printf("Errors classifying C2 testing set: %d / %d (%lf)\n", errors_classifying_c2, c2_training_size, (double)errors_classifying_c2 / (double)c2_training_size);
	printf("Overall error rate: %d / %d (%lf)\n", errors_classifying_c1+errors_classifying_c2, c1_training_size+c2_training_size,  (double)(errors_classifying_c1+errors_classifying_c2) / (double)(c1_training_size+c2_training_size));	
	
	
	#ifdef GRID_OUTPUT
	FILE *gridout = fopen("knnn.grid", "w");
	double gx, gy;
	for(gy=-8.0;gy<8.0;gy += .03125)
	{
		for(gx=-8.0;gx<8.0;gx += .03125)
		{
			datapoint_t gridpoint;
			gridpoint.x = gx;
			gridpoint.y = gy;
	
			double c1_radius_squared = nth_nearest_sqdistance(&(gridpoint), c1_training, 0, c1_training_size-1, neighbors-1);
			double c2_radius_squared = nth_nearest_sqdistance(&(gridpoint), c2_training, 0, c2_training_size-1, neighbors-1);
			
			double c1_normalized = (double)c1_radius_squared * (double)c1_training_size;
			double c2_normalized = (double)c2_radius_squared * (double)c2_training_size;
			
			//assume equal priors.
			//look for the smaller radius!
			if(c2_normalized < c1_normalized)
				fprintf(gridout, "%lf %lf %d\n", gx, gy, 2);
			else if(c2_normalized > c1_normalized)
				fprintf(gridout, "%lf %lf %d\n", gx, gy, 1);
			else 
				fprintf(gridout, "%lf %lf %d\n", gx, gy, 0);
		}
	}
	#endif
	
	return 0;
}