//nnn.c (gcc nnn.c -o nnn -lm -O3 -ffast-math -march=native)
//nonparametric classifier based on n-nearest-neighbors technique, 2-dimensional
//Assumes equal priors and equally-sized training data.
//Uses quicksort.
//Bryan Topp <betopp@cs.unm.edu>

#include <stdio.h>
#include <stdlib.h>
#include <math.h>


typedef struct
{
	double x;
	double y;
} datapoint_t;



//Reads a dataset from a file. Allocates and returns an array and its size.
void read_dataset(const char *filename, double **array_out, int *size_out)
{
	FILE *datafile = fopen(filename, "r");
	if(!datafile)
	{
		printf("Could not open %s\n", filename);
		exit(-1);
	}
	
	//Read through once to find out how many entries we have
	int num_points = 0;
	
	while(1)
	{
		int dummy1;
		double dummy2, dummy3;
		
		if(fscanf(datafile, " %d %lf %lf ", &dummy1, &dummy2, &dummy3) != 3)
			break;
		
		num_points++;
	}
	
	if(!feof(datafile)) //We should only finish at the end of the file
	{
		printf("error parsing file %s\n", filename);
		exit(-1);
	}
	
	printf("Read %d data points from file %s\n", num_points, filename);
	
	double *dataset = (double*)malloc(sizeof(double)*num_points*2);
	
	//Now actually read the points.
	rewind(datafile);
	num_points = 0;
	while(1)
	{
		
		int dummy;
		
		if(fscanf(datafile, " %d %lf %lf ", &dummy, &(dataset[num_points*2]), &(dataset[(num_points*2)+1])) != 3)
			break;
		
		num_points++;
	}	
	
	fclose(datafile);
	
	*size_out = num_points;
	*array_out = dataset;
}


double datadist(datapoint_t *a, datapoint_t *b)
{
	return ((a->x - b->x)*(a->x - b->x))+((a->y - b->y)*(a->y - b->y));
}

datapoint_t *data_comparator_evaluation_point;
int data_comparator(const void *a, const void *b)
{

	datapoint_t *d_a = (datapoint_t*)a;
	datapoint_t *d_b = (datapoint_t*)b;
	
	double dist_a = datadist(d_a, data_comparator_evaluation_point);
	double dist_b = datadist(d_b, data_comparator_evaluation_point);
	
	if(dist_a < dist_b) return -1;
	if(dist_a == dist_b) return 0;
	if(dist_a > dist_b) return 1;		
}

int nnn_vote(datapoint_t *eval, datapoint_t *c1, int c1_num, datapoint_t *c2, int c2_num, int num_neighbors)
{
	data_comparator_evaluation_point = eval;
	qsort(c1, c1_num, sizeof(datapoint_t), data_comparator);
	qsort(c2, c2_num, sizeof(datapoint_t), data_comparator);
	
	int c1_count = 0;
	int c2_count = 0;
	
	int i;
	for(i=0;i<num_neighbors;i++)
	{
		float c1_radius_squared = datadist(eval, &(c1[c1_count]));
		float c2_radius_squared = datadist(eval, &(c2[c2_count]));
		
		if(c1_radius_squared < c2_radius_squared)
			c1_count++; //move to next point in C1
		else if(c1_radius_squared > c2_radius_squared)
			c2_count++; //move to next point in C2
		else //ignore and keep looking
		{
			c1_count++;
			c2_count++;
		}
			
	}
	
	//tie? drop 1 neighbor and use that result.
	if(c1_count == c2_count)
	{
		if(num_neighbors > 0)
			return nnn_vote(eval, c1, c1_num, c2, c2_num, num_neighbors-1);
		else
			return 1; //give up.
	}	
	
	return c2_count - c1_count;
}

int main(int argc, const char **argv)
{
	if(argc != 6)
	{
		printf("usage: %s <training 1> <training 2> <testing 1> <testing 2> <how many neighbors>\n", argv[0]);
		exit(-1);
	}
	
	//Read datasets.
	datapoint_t *c1_training;
	int c1_training_size;
	
	datapoint_t *c2_training;
	int c2_training_size;
	
	datapoint_t *c1_testing;
	int c1_testing_size;
	
	datapoint_t *c2_testing;
	int c2_testing_size;
	
	read_dataset(argv[1], (double**)(&c1_training), &c1_training_size);
	read_dataset(argv[2], (double**)(&c2_training), &c2_training_size);
	read_dataset(argv[3], (double**)(&c1_testing), &c1_testing_size);
	read_dataset(argv[4], (double**)(&c2_testing), &c2_testing_size);
	
	int neighbors = atoi(argv[5]);
	printf("Voting with %d nearest neighbors across both classes\n", neighbors);
	if(neighbors < 1)
	{
		printf("neighbors must be >= 1!\n");
		exit(-1);
	}
	
	//Evaluate.
	int errors_classifying_c1 = 0;
	int errors_classifying_c2 = 0;
	
	//Classify each testing point in C1
	int testpt;
	for(testpt = 0; testpt < c1_testing_size; testpt++)
	{
		int result = nnn_vote(&(c1_testing[testpt]), c1_training, c1_training_size, c2_training, c2_training_size, neighbors);
		if(result > 0)
			errors_classifying_c1++;
	}
	
	//Classify each testing point in C2
	for(testpt = 0; testpt < c2_testing_size; testpt++)
	{
		int result = nnn_vote(&(c2_testing[testpt]), c1_training, c1_training_size, c2_training, c2_training_size, neighbors);
		if(result < 0)
			errors_classifying_c2++;
	}
	
	printf("Errors classifying C1 testing set: %d / %d (%lf)\n", errors_classifying_c1, c1_training_size, (double)errors_classifying_c1 / (double)c1_training_size);
	printf("Errors classifying C2 testing set: %d / %d (%lf)\n", errors_classifying_c2, c2_training_size, (double)errors_classifying_c2 / (double)c2_training_size);
	printf("Overall error rate: %d / %d (%lf)\n", errors_classifying_c1+errors_classifying_c2, c1_training_size+c2_training_size,  (double)(errors_classifying_c1+errors_classifying_c2) / (double)(c1_training_size+c2_training_size));	
	
	
	#ifdef GRID_OUTPUT
	FILE *gridout = fopen("nnn.grid", "w");
	double gx, gy;
	for(gy=-8.0;gy<8.0;gy += .03125)
	{
		for(gx=-8.0;gx<8.0;gx += .03125)
		{
			datapoint_t gridpoint;
			gridpoint.x = gx;
			gridpoint.y = gy;
	
			int result = nnn_vote(&gridpoint, c1_training, c1_training_size, c2_training, c2_training_size, neighbors);
			
			//assume equal priors.
			//look for the smaller radius!
			if(result > 0)
				fprintf(gridout, "%lf %lf %d\n", gx, gy, 2);
			else if(result < 0)
				fprintf(gridout, "%lf %lf %d\n", gx, gy, 1);
			else 
				fprintf(gridout, "%lf %lf %d\n", gx, gy, 0);
		}
	}
	#endif
	
	return 0;
}