/***************************************************************************/
/*                  Last Update   12/27/2002
/*
/*						Testing Program for
/*
/*       Learning Vector Quantization Network for Classification 
/*
/*                     Jiang Li and Dr. Manry   
/*
/*					   Signal Processing Lab
/*					    		UTA
/*
/*This a test program for the Learining Vector Quantization Network
/*
/***************************************************************************/


/* Function Link Netword with OR Algorithm Testing Program 
				
				  By Jiang Li
				  7/29/2003

***********************************************************/

#include <stdio.h>
#include <stdlib.h>
#include<malloc.h>
#include <math.h>
#include<conio.h>
#include<string.h>

#define Outfile "FLNTestResult.txt"   /* File to store the testing result */



void main(void)                                       /* begain of main */
{

	char *Infile, *WeightsFile, str[100];

											 /* Infile    --> the file name of the testing data
		                                     Outfile      --> the file name to store the testing result
											 WeightsFile  --> well trained weights file
	                                         str          --> a string variable to read the data */
	
	FILE *ifs,*fpOut,*fpWeights;             /* ifs    --> the file pointer for the traing data file
											 fpOut     --> the file pointer for the output file in which
	                                                       the well trained cluster centers will be stored
	                                         fpWeights --> the file pointer for the weights file */

	int i, j, d, L, ii;           /* i,j,k,ii,temp1,temp2     --> some variables used for iteration */
	                                         
	int N, Nv, Nc;
			
	int Check;

	int Out_Flag;                            /* Out_Flag  --> Indicate whether the testing file has desired output? */
	                            
		
	int TotalError, *ErrorForEachClass, *NvForClass;
	double ErrorPercent;
											 /* ErrorPercent   --> The Percentage of the misclassifed patterns */

	double Max_y;
	int Class_obtained;	
	float *x, *xa, *y, *mean_Input, *std_Input;  
	int ClassId;                             /* ClassId   --> The class membership for the current pattern */
	float  **W;	                         	
	

	char *get_string(char *);                /* Input string */
	int get_int(char *,int, int);            /* Input int varible */
	void Getaugmentedvector(float *, float *,int,int,int);
	int Factorial(int, int);
	

	
	/* Getting relevant information from the user */
	Infile = get_string("Enter Testing File Name: ");

	/* Opening the input files in "read" mode */
	ifs= fopen(Infile,"r");
	/* Checking for the existance of the files */
	if (ifs == NULL )
	{
		perror(Infile);
		exit(1);
	}
	
	/* Does this testing file has desired output? */
	printf("\n Does the testing file has desired output? \n");
	printf("\n Choose (0) for NO \n");
	printf("\n        (1) for YES : ");
	Out_Flag = get_int("",0,1);
	
	/* Getting cluster center file */
	WeightsFile = get_string("Enter class center file name: ");
	
	/* Opening the file in "read" mode and checking the existance of the file */
	fpWeights = fopen(WeightsFile,"r");
	if(fpWeights == NULL)
	{
		perror(WeightsFile);
		exit(1);
	}
	
	N = get_int("Enter the No. of Input : ",1,100);

	Nc = get_int("Enter the No. of classes : ",1,100);

		
	/* Read information from the file user inputed*/
	fscanf(fpWeights,"%s",str);
	fscanf(fpWeights,"%d",&Check);
	if(Check != N)
	{
		printf("The input file's input number does not match the weights file!\n");
		exit(1);
	}
	fscanf(fpWeights,"%d",&Check);
	if(Check != Nc)
	{
		printf("The input file's class number does not match the weights file!\n");
		exit(1);
	}

	fscanf(fpWeights,"%d",&d);


	L = Factorial(N,d);
	/* Allocating Memory dynamically*/

	x = (float*)malloc(sizeof(float)*N);
	xa= (float*)malloc(sizeof(float)*(L));
	y= (float*)malloc(sizeof(float)*(Nc));
	mean_Input = (float *)malloc(sizeof(float)*N);
	std_Input = (float *)malloc(sizeof(float)*N);
	ErrorForEachClass = (int*)malloc(sizeof(int)*Nc);
	NvForClass = (int*)malloc(sizeof(int)*Nc);
	W= (float**)malloc(sizeof(float*)*Nc);
	
	for(i = 0;i < Nc; i ++)
	{
		W[i] = (float*)malloc(sizeof(float)*L);
	}
	
	/* Read the weights from the file */
	//read the mean and standard dedivation of the inputs
	for(i = 0; i < N ; i ++){
		fscanf(fpWeights,"%s",str);
		mean_Input[i] = atof(str);
	}
	for(i = 0; i < N; i ++){
		fscanf(fpWeights,"%s",str);
		std_Input[i] = atof(str);
	}
	//first save the output weights to the file according the  right order
	for(i=0;i<Nc;i++){
		for (j=0;j<L;j++){
			fscanf(fpWeights,"%s",str);
			W[i][j] = atof(str);
		}
	}

	/* Initialization */
	Nv= 0; TotalError = 0;
	for(i = 0; i < Nc; i ++) {
		ErrorForEachClass[i] = 0;
		NvForClass[i] = 0;
	}


	/* Count the No. of the patterns and No. of patterns for each class */
	while(!feof(ifs))
	{
		Nv++;    

		for (i=0;i<N && !feof(ifs);i++) {
			fscanf(ifs,"%f",&x[i]);
		}

		if(Out_Flag == 1)
		{
			if (!feof(ifs) ) {
				fscanf(ifs,"%d", &ClassId);
			}
		
			NvForClass[ClassId-1] ++;
		}
	}
	/* pay attention we count one more which is not necessery */
	Nv--;  
	if(Out_Flag) NvForClass[ClassId-1] --;

	
	/* Open a file to store the testing result */
	fpOut = fopen(Outfile,"w");

	/* Checking for the existance of the file */
	if (fpOut == NULL )
	{
		perror(Outfile);
		exit(1);
	}
		
	/* Reset the file pointer to the begining of the file */
	rewind(ifs);
	fprintf(fpOut,"Index\t");
	for(i = 0; i < N; i++)
		fprintf(fpOut,"Input[%d]\t",i);
	fprintf(fpOut,"Class_Id\tDesired_Id\n\n");
	i=0;
	while(!feof(ifs))
	{
		i++;
		/* Read inputs */
		for (ii=0;ii<N && !feof(ifs);ii++) {
			fscanf(ifs,"%f",&x[ii]);
		}
		/* Normorlize the inputs */
		for(i = 0; i < N; i ++)	{
			if(std_Input[i]>1e-20)
				x[i] = (x[i] - mean_Input[i])/std_Input[i];
			else
				x[i] = (x[i] - mean_Input[i]);
		}
		Getaugmentedvector(x,xa,d,N,L);
		/* Read ClassId if has desired output */
		if(Out_Flag){
			if (!feof(ifs) ) {
				fscanf(ifs,"%d", &ClassId);
			}
		}
		
		/* Find the closest center for each input pattern */
		for (ii=0;ii<Nc;ii++){
			y[ii] = 0.0;
			for(j=0;j<L;j++) {y[ii] += xa[j]*W[ii][j];}
		}
		//here we find the maxima of the output
		Max_y= y[0];
		Class_obtained=1;
		for (ii=1;ii<Nc;ii++){
			if(y[ii]>Max_y) {
					Max_y= y[ii];
					Class_obtained= ii+1;
			}
		}
		
/*	   Checking whether the obtained class is equal to the desired class.
		if they are not equal it means that an error has occurred. so we increment
		the classification error count by 1 */
		if ( ClassId != Class_obtained ) {	
			ErrorForEachClass[ClassId-1] += 1;
			TotalError ++;
		}

		fprintf(fpOut,"%d\t",i);
		for(j = 0;j < N; j ++)
		{
			fprintf(fpOut,"%f\t",x[j]);
		}

		if(Out_Flag){
			fprintf(fpOut,"%d\t%d\n",Class_obtained,ClassId);
		}
		else{
			fprintf(fpOut,"%d\tN/A\n",Class_obtained);
		}
			
	}

	fclose(ifs);

	if(Out_Flag){
		ErrorPercent = (double)TotalError*100.0/(double)Nv;
	}

	/* Output the testing result */
	
	fprintf(fpOut,"\n There are %d classes.\n",Nc);
	
	if(Out_Flag){
		for(i = 0; i < Nc; i ++)
		{
			printf("\n Class %d has %d error pattern(s) out of %d patterns.\n", i,ErrorForEachClass[i], NvForClass[i]);
			fprintf(fpOut,"\n Class %d has %d error pattern(s) out of %d patterns.\n", i,ErrorForEachClass[i], NvForClass[i]);
		}
	
		printf("\n Total error patterns are : %d.\n",TotalError);
		printf("\n Error percentage is: %f.\n",ErrorPercent);
 	
		fprintf(fpOut,"\n Total error patterns are : %d.\n",TotalError);
		fprintf(fpOut,"\n Error percentage is: %f.\n",ErrorPercent);
	}
 
	printf("\n Testing is completed!\n");
	fprintf(fpOut,"\n Testing is completed!\n");
	
	fcloseall();
 	
}																	/* end of main */

/*************************************************************************/

int get_int(char *title_string,int low_limit, int up_limit)
{
	 int i,error_flag;
	 char *get_string();             /* get string routine */
	 char *cp,*endcp;                /* char pointer */
	 char *stemp;                    /* temp string */

/* check for limit error, low may equal high but not greater */
	 if(low_limit > up_limit) {
		  printf("\nLimit error, lower > upper\n");
		  exit(1);
	 }

/* make prompt string */
	 stemp = (char *) malloc(strlen(title_string) + 60);
	 if(!stemp) {
		  printf("\nString allocation error in get_int\n");
		  exit(1);
	 }
	 sprintf(stemp,"%s [%d...%d]",title_string,low_limit,up_limit);

/* get the string and make sure i is in range and valid */
	 do {
		  cp = get_string(stemp);
		  i = (int) strtol(cp,&endcp,10);
		  error_flag = (cp == endcp) || (*endcp != '\0'); /* detect errors */
		  free(cp);                                   /* free string space */
	 } while(i < low_limit || i > up_limit || error_flag);

/* free temp string and return result */
	 free(stemp);
	 return(i);
}

/*****************************************************************************/
char *get_string(char *title_string)
{
	 char *alpha;                            /* result string pointer */

	 alpha = (char *) malloc(80);
	 if(!alpha) {
		  printf("\nString allocation error in get_string\n");
		  exit(1);
	 }
	 printf(" %s ",title_string);
	 gets(alpha);

	 return(alpha);
}

/****************************************************************************/


/************************************************************************************/
int Factorial(int N, int d)
{
	int result = 1;
	int i =0;

	if (d == 0)
		return result;
	else{
			for (i = N+d; i > N; i --)
			{
				result *= i;
			}
			for (i = 1; i <= d; i ++)
			{
				result /= i;
			}
			return result;
		}
}


/*******************************************************************************/
void Getaugmentedvector(float *x, float *xa, int d, int N, int L)
{
	//Right now I only calculate the vector when the d =2
	int i, j, k, m;
	int r;
	r=0;
	
	for(i = 0; i < N; i++)
		{
			xa[r++] = x[i];
			if(d >= 2)
			{
				for(j = 0; j <= i; j++)
				{
					xa[r++] = x[i]*x[j];
					if(d >= 3)
					{
						for(k = 0; k <= j; k++)
						{
							xa[r++] = x[i]*x[j]*x[k];
							if(d >= 4)
							{
								for(m = 0; m <= k; m++)
									xa[r++] = x[i]*x[j]*x[k]*x[m];
							}
						}
					}
				}
			}
		}
		xa[r] = 1.0;

}
/*********************************************************************************/