//parzen_windows.c (gcc parzen_windows.c -o parzen_windows -lm -lrt)
//nonparametric classifier based on Parzen Windows technique, 2-dimensional
//Assumes equal priors and equally-sized training data.
//Pre-bakes the training data into a 2-dimensional tree for fast searching.
//Bryan Topp <betopp@cs.unm.edu>

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <time.h>
  
//Chop training set with Kd tree until this many points remain
#ifndef TREE_LEAF_SIZE
	#define TREE_LEAF_SIZE 32
#endif

typedef struct
{
	double x;
	double y;
} datapoint_t;

typedef struct tree_node_s
{
	//if this is internal, it will have children
	struct tree_node_s *nxny;
	struct tree_node_s *nxpy;
	struct tree_node_s *pxny;
	struct tree_node_s *pxpy;
	double chop_pt[2];
	
	//All nodes point to some contiguous amount of points
	datapoint_t *pts;
	int num_pts;
} tree_node_t;


int datacomparator_x(const void * a, const void * b)
{
	datapoint_t *d_a = (datapoint_t*)a;
	datapoint_t *d_b = (datapoint_t*)b;
	
	if(d_a->x < d_b->x) return -1;
	if(d_a->x == d_b->x) return 0;
	if(d_a->x > d_b->x) return 1;
}

int datacomparator_y(const void * a, const void * b)
{
	datapoint_t *d_a = (datapoint_t*)a;
	datapoint_t *d_b = (datapoint_t*)b;
	
	if(d_a->y < d_b->y) return -1;
	if(d_a->y == d_b->y) return 0;
	if(d_a->y > d_b->y) return 1;	
}

double quadrant_comparator_split[2] = {0};
int datacomparator_quadrant(const void *a, const void *b)
{
	datapoint_t *d_a = (datapoint_t*)a;
	datapoint_t *d_b = (datapoint_t*)b;	
	
	//different sides of the X divide are the "major" split
	if(d_a->x < quadrant_comparator_split[0] && d_b->x >= quadrant_comparator_split[0])
		return -1;
	if(d_b->x < quadrant_comparator_split[0] && d_a->x >= quadrant_comparator_split[0])
		return 1;
	
	
	//different sides of the Y divide are the minor split
	if(d_a->y < quadrant_comparator_split[1] && d_b->y >= quadrant_comparator_split[1])
		return -1;
	if(d_b->y < quadrant_comparator_split[1] && d_a->y >= quadrant_comparator_split[1])
		return 1;
	
	return 0; //same quadrant.
}

//Recursive function to build Kd-tree
//Note that it operates on a contiguous range of data points.
//It sorts all of these before recursing so the children can work like this too.
tree_node_t *make_tree_node(datapoint_t *datapoints, unsigned int num_points)
{
	//allocate space for the node
	tree_node_t *retnode = (tree_node_t*)malloc(sizeof(tree_node_t));
	
	//Store the points data for the node
	retnode->pts = datapoints;
	retnode->num_pts = num_points;
	
	//if we have below the threshold points, stop building the tree
	if(num_points <= TREE_LEAF_SIZE)
	{
		retnode->nxny = NULL;
		retnode->nxpy = NULL;
		retnode->pxny = NULL;
		retnode->pxpy = NULL;
		retnode->chop_pt[0] = 0.0;
		retnode->chop_pt[1] = 0.0;
		
	//	printf("Leaf with %u points\n", num_points);
		
		return retnode;
	}	
	
	//we want the splitting point to be the median in X and Y
	//sort in X
	qsort(datapoints, num_points, sizeof(datapoint_t), datacomparator_x);
	
	if((num_points%2) == 0) //even number of points, median is between two points
		retnode->chop_pt[0] = (datapoints[num_points/2].x + datapoints[(num_points/2)+1].x) / 2.0;
	else //odd number of points, just take middle
		retnode->chop_pt[0] = datapoints[num_points/2].x;
	
	//sort in Y
	qsort(datapoints, num_points, sizeof(datapoint_t), datacomparator_y);
	
	retnode->chop_pt[1] = (datapoints[num_points/2].y + datapoints[(num_points/2)+1].y) / 2.0;
	
	
	//sort points into each quadrant
	quadrant_comparator_split[0] = retnode->chop_pt[0];
	quadrant_comparator_split[1] = retnode->chop_pt[1];
	qsort(datapoints, num_points, sizeof(datapoint_t), datacomparator_quadrant);
	
	//find the boundaries
	int xbound;
	for(xbound = 0; xbound < num_points; xbound++)
	{
		if(datapoints[xbound].x >= retnode->chop_pt[0])
			break;
	}
	
	int ybound_lowerx;
	int ybound_higherx;
	
	for(ybound_lowerx = 0; ybound_lowerx < xbound; ybound_lowerx++)
	{
		if(datapoints[ybound_lowerx].y >= retnode->chop_pt[1])
			break;
	}
	
	for(ybound_higherx = xbound; ybound_higherx < num_points; ybound_higherx++)
	{
		if(datapoints[ybound_higherx].y >= retnode->chop_pt[1])
			break;
	}
	
	//printf("split initially %d %d %d\n", ybound_lowerx, xbound, ybound_higherx);
	
	//Recurse
	retnode->nxny = make_tree_node(datapoints, ybound_lowerx);
	retnode->nxpy = make_tree_node(datapoints+ybound_lowerx, xbound-ybound_lowerx);
	retnode->pxny = make_tree_node(datapoints+xbound, ybound_higherx-xbound);
	retnode->pxpy = make_tree_node(datapoints+ybound_higherx, num_points-ybound_higherx);
	
	return retnode;
}

void write_tree_node_demo(FILE *outf, tree_node_t *node, double mins[2], double maxs[2])
{
	//make sure all data in this node is within the bounds
	int pt;
	for(pt=0;pt<node->num_pts;pt++)
	{
		if(node->pts[pt].x < mins[0])
			printf("!\n");
		if(node->pts[pt].y < mins[1])
			printf("!\n");
		if(node->pts[pt].x > maxs[0])
			printf("!\n");
		if(node->pts[pt].y > maxs[1])
			printf("!\n");
	}
		
	
	
	
	if(!(node->nxny || node->nxpy || node->pxny || node->pxpy)) //if it's a leaf, don't show any further cuts
		return;
	
	//line at Y cut
	fprintf(outf, "%lf %lf\n%lf %lf\n\n", mins[0], node->chop_pt[1], maxs[0], node->chop_pt[1]);
	//line at X cut
	fprintf(outf, "%lf %lf\n%lf %lf\n\n", node->chop_pt[0], mins[1], node->chop_pt[0], maxs[1]);
	
	double child_mins[2];
	double child_maxs[2];
	
	
	child_mins[0] = mins[0];
	child_maxs[0] = node->chop_pt[0];
	
	child_mins[1] = mins[1];
	child_maxs[1] = node->chop_pt[1];
	
	write_tree_node_demo(outf, node->nxny, child_mins, child_maxs);
	
	child_mins[0] = mins[0];
	child_maxs[0] = node->chop_pt[0];
	
	child_mins[1] = node->chop_pt[1];
	child_maxs[1] = maxs[1];
	
	write_tree_node_demo(outf, node->nxpy, child_mins, child_maxs);
	
	child_mins[0] = node->chop_pt[0];
	child_maxs[0] = maxs[0];
	
	child_mins[1] = mins[1];
	child_maxs[1] = node->chop_pt[1];
	
	write_tree_node_demo(outf, node->pxny, child_mins, child_maxs);
	
	child_mins[0] = node->chop_pt[0];
	child_maxs[0] = maxs[0];
	
	child_mins[1] = node->chop_pt[1];
	child_maxs[1] = maxs[1];
	
	write_tree_node_demo(outf, node->pxpy, child_mins, child_maxs);
}


//Reads a dataset from a file. Allocates and returns an array and its size.
void read_dataset(const char *filename, double **array_out, int *size_out)
{
	FILE *datafile = fopen(filename, "r");
	if(!datafile)
	{
		printf("Could not open %s\n", filename);
		exit(-1);
	}
	
	//Read through once to find out how many entries we have
	int num_points = 0;
	
	while(1)
	{
		int dummy1;
		double dummy2, dummy3;
		
		if(fscanf(datafile, " %d %lf %lf ", &dummy1, &dummy2, &dummy3) != 3)
			break;
		
		num_points++;
	}
	
	if(!feof(datafile)) //We should only finish at the end of the file
	{
		printf("error parsing file %s\n", filename);
		exit(-1);
	}
	
	printf("Read %d data points from file %s\n", num_points, filename);
	
	double *dataset = (double*)malloc(sizeof(double)*num_points*2);
	
	//Now actually read the points.
	rewind(datafile);
	num_points = 0;
	while(1)
	{
		
		int dummy;
		
		if(fscanf(datafile, " %d %lf %lf ", &dummy, &(dataset[num_points*2]), &(dataset[(num_points*2)+1])) != 3)
			break;
		
		num_points++;
	}	
	
	fclose(datafile);
	
	*size_out = num_points;
	*array_out = dataset;
}

//very simple windowing function
int count_in_window(tree_node_t *space, double windowlen, double x_center, double y_center)
{	
	if(!(space->nxny || space->nxpy || space->pxny || space->pxpy)) //leaf node
	{
		int count = 0;
		
		int leafpt;
		for(leafpt=0;leafpt<space->num_pts;leafpt++)
		{
			//All the different ways the training point could be outside the window
			
			if(space->pts[leafpt].x >= x_center + (windowlen/2.0))
				continue;
			if(space->pts[leafpt].x <= x_center - (windowlen/2.0))
				continue;
			
			if(space->pts[leafpt].y >= y_center + (windowlen/2.0))
				continue;
			if(space->pts[leafpt].y <= y_center - (windowlen/2.0))
				continue;
			
			
			count++;
		}
		
		return count;
	}
	
	//depending on the window's relation to the "chop point", we may or may not need to search each quadrant
	int search_nx = space->chop_pt[0] >= (x_center - (windowlen/2.0));
	int search_px = space->chop_pt[0] <= (x_center + (windowlen/2.0));
	
	int search_ny = space->chop_pt[1] >= (y_center - (windowlen/2.0));
	int search_py = space->chop_pt[1] <= (y_center + (windowlen/2.0));

	
	int count = 0;
	
	if(search_nx && search_ny)
		count += count_in_window(space->nxny, windowlen, x_center, y_center);
	
	if(search_nx && search_py)
		count += count_in_window(space->nxpy, windowlen, x_center, y_center);
	
	if(search_px && search_ny)
		count += count_in_window(space->pxny, windowlen, x_center, y_center);
	
	if(search_px && search_py)
		count += count_in_window(space->pxpy, windowlen, x_center, y_center);
	
	return count;
}

int main(int argc, const char **argv)
{
	if(argc != 6)
	{
		printf("usage: %s <training 1> <training 2> <testing 1> <testing 2> <window length>\n", argv[0]);
		exit(-1);
	}
	
	//Read datasets.
	datapoint_t *c1_training;
	int c1_training_size;
	
	datapoint_t *c2_training;
	int c2_training_size;
	
	datapoint_t *c1_testing;
	int c1_testing_size;
	
	datapoint_t *c2_testing;
	int c2_testing_size;
	
	read_dataset(argv[1], (double**)(&c1_training), &c1_training_size);
	read_dataset(argv[2], (double**)(&c2_training), &c2_training_size);
	read_dataset(argv[3], (double**)(&c1_testing), &c1_testing_size);
	read_dataset(argv[4], (double**)(&c2_testing), &c2_testing_size);
	
	//Build the training sets 
	tree_node_t *c1_training_root = make_tree_node(c1_training, c1_training_size);
	tree_node_t *c2_training_root = make_tree_node(c2_training, c2_training_size);
	
	printf("Trees built\n");
	
	//Spit out the trees 
	FILE *demo_file_1 = fopen("kd_tree_1.lines", "w");
	FILE *demo_file_2 = fopen("kd_tree_2.lines", "w");
	if(!demo_file_1 || !demo_file_2)
	{
		printf("failed to open tree dump files for writing\n");
		exit(-1);
	}
	
	
	
	double defmins1[2] = {-8,-8};
	double defmaxs1[2] = {8,8};
	
	write_tree_node_demo(demo_file_1, c1_training_root, defmins1, defmaxs1);
	
	double defmins2[2] = {-8,-8}; //these get clobbered so we need them again
	double defmaxs2[2] = {8,8};
	
	write_tree_node_demo(demo_file_2, c2_training_root, defmins2, defmaxs2);
	
	
	//Evaluate.
	int errors_classifying_c1 = 0;
	int errors_classifying_c2 = 0;
	
	double windowlen = atof(argv[5]);
	printf("Using window length %lf\n", windowlen);
	
	struct timespec start_time, end_time;
	clock_gettime(CLOCK_REALTIME, &start_time);
	
	//Classify each testing point in C1
	int testpt;
	for(testpt = 0; testpt < c1_testing_size; testpt++)
	{
		int c1_training_count = count_in_window(c1_training_root, windowlen, c1_testing[testpt].x, c1_testing[testpt].y);
		int c2_training_count = count_in_window(c2_training_root, windowlen, c1_testing[testpt].x, c1_testing[testpt].y);
		
		double c1_normalized = (double)c1_training_count / (double)c1_training_size;
		double c2_normalized = (double)c2_training_count / (double)c2_training_size;
		
		//assume equal priors.
		if(c2_normalized > c1_normalized)
			errors_classifying_c1++;
	}
	
	//Classify each testing point in C2
	for(testpt = 0; testpt < c2_testing_size; testpt++)
	{
		int c1_training_count = count_in_window(c1_training_root, windowlen, c2_testing[testpt].x, c2_testing[testpt].y);
		int c2_training_count = count_in_window(c2_training_root, windowlen, c2_testing[testpt].x, c2_testing[testpt].y);
		
		double c1_normalized = (double)c1_training_count / (double)c1_training_size;
		double c2_normalized = (double)c2_training_count / (double)c2_training_size;
		
		//assume equal priors.
		if(c2_normalized <= c1_normalized)
			errors_classifying_c2++;
	}	
	
	clock_gettime(CLOCK_REALTIME, &end_time);
	printf("Time to classify 10000 points: %lf microseconds\n", (((double)(end_time.tv_sec - start_time.tv_sec)*1000000.0)) + ((double)(end_time.tv_nsec - start_time.tv_nsec) / 1000.0));
	
	printf("Errors classifying C1 testing set: %d / %d (%lf)\n", errors_classifying_c1, c1_training_size, (double)errors_classifying_c1 / (double)c1_training_size);
	printf("Errors classifying C2 testing set: %d / %d (%lf)\n", errors_classifying_c2, c2_training_size, (double)errors_classifying_c2 / (double)c2_training_size);
	printf("Overall error rate: %d / %d (%lf)\n", errors_classifying_c1+errors_classifying_c2, c1_training_size+c2_training_size,  (double)(errors_classifying_c1+errors_classifying_c2) / (double)(c1_training_size+c2_training_size));
	
	
	#ifdef GRID_OUTPUT
	FILE *gridout = fopen("pw.grid", "w");
	double gx, gy;
	for(gy=-8.0;gy<8.0;gy += .03125)
	{
		for(gx=-8.0;gx<8.0;gx += .03125)
		{
			int c1_training_count = count_in_window(c1_training_root, windowlen, gx, gy);
			int c2_training_count = count_in_window(c2_training_root, windowlen, gx, gy);
			
			double c1_normalized = (double)c1_training_count / (double)c1_training_size;
			double c2_normalized = (double)c2_training_count / (double)c2_training_size;
			
			//assume equal priors.
			if(c2_normalized > c1_normalized)
				fprintf(gridout, "%lf %lf %d\n", gx, gy, 2);
			else if(c2_normalized < c1_normalized)
				fprintf(gridout, "%lf %lf %d\n", gx, gy, 1);
			else 
				fprintf(gridout, "%lf %lf %d\n", gx, gy, 0);
		}
	}
	#endif
	return 0;
}