#include <math.h>
#include <omp.h>
#include <stdio.h>
#include "neunet.h"
#include <stdlib.h>

extern int n_hidden,n_input,n_output,n_learn,n_test,dy,dx;

extern double hidden_weight[MAX_hidden][MAX_input],
	      hidden_delta[MAX_hidden],
	      hidden_output[MAX_hidden],
	      hidden_data[MAX_hidden][MAX_input],
	      output_weight[MAX_output][MAX_hidden],
	      output_delta[MAX_output],
	      output_output[MAX_output],
              output_data[MAX_output][MAX_hidden],
	      input[MAX_input],result[MAX_output];

extern int *test_input_data, *test_output_data, *charrow, *charpos;
extern double momentum, learning_rate;

main(argc,argv)
int argc;
char **argv;
{
  int i,j,k,l,m,n,nbad,lastbad,seed,n_learn;
  long time1,time2,dtime;
  double drand48();
  double ssq,diff;
  int *pinput;
  FILE *IN,*OUT;

  if(argc<8)
  {
    printf("Usage: learn <inputdata> <seed> <n_hidden> <learning_rate> <momentum> <n_loop> <outputfile>\n");
    exit(-1);
  }

  n_input=200;
  n_output=64;

  sscanf(argv[2],"%d",&seed);
  sscanf(argv[3],"%d",&n_hidden);
  sscanf(argv[4],"%lg",&learning_rate);
  sscanf(argv[5],"%lg",&momentum);
  sscanf(argv[6],"%d",&n_learn);

  if(n_input>MAX_input)
  {
    printf("<n_input> too big (limit is %d).\n",MAX_input);
    exit(-1);
  }

  if(n_hidden>MAX_hidden)
  {
    printf("<n_hidden> too big (limit is %d).\n",MAX_hidden);
    exit(-1);
  }

  if(n_output>MAX_output)
  {
    printf("<n_output> too big (limit is %d).\n",MAX_output);
    exit(-1);
  }

  /*

  Read input.

  */

  IN=fopen(argv[1],"r");
  read_training(IN);
  fclose(IN);

  srand48(seed);

  /*

  Initialize weights:

  */

  for(i=0;i<n_hidden;i++)
  {
    for(j=0;j<n_input;j++)
    {
      hidden_weight[i][j]=drand48()-0.5;
      hidden_data[i][j]=0.0;
    }
    hidden_delta[i]=0.0;
    hidden_output[i]=0.0;
  }

  for(i=0;i<n_output;i++)
  {
    for(j=0;j<n_hidden;j++)
    {
      output_weight[i][j]=drand48()-0.5;
      output_data[i][j]=0.5;
    }
    output_delta[i]=0.5;
    output_output[i]=0.5;
  }

  pinput=(int *) malloc(sizeof(*pinput)*n_test);

  time(&time1);

  for(i=0;i<n_learn;i++)
  {
    /* permute input list */
    for(j=0;j<n_test;j++) pinput[j]=j;
    for(j=0;j<n_test;j++)
    {
      k=lrand48()%(n_test-j)+j;
      l=pinput[j];
      pinput[j]=pinput[k];
      pinput[k]=l;
    }

    ssq=0;
    nbad=0;
    lastbad=-1;

    for(j=0;j<n_test;j++)
    {
      k=pinput[j];
      for(l=0;l<n_input;l++) input[l]=test_input_data[k*n_input+l]/318.75+0.1;
      for(l=0;l<n_output;l++) result[l]=(test_output_data[k]==l)?0.9:0.1;

      propagate();
      m=0;
      for(l=1;l<n_output;l++)
      {
	if(output_output[l]>output_output[m]) m=l;
      }
      if(test_output_data[k]!=m)
      {
	nbad++;
	lastbad=k;
      }

      back_propagate();

      for(l=0;l<n_output;l++) 
      {
	diff=result[l]-output_output[l];
	ssq+=diff*diff;
      }
    }
    printf("%d\t%d\t%lg",i,nbad,ssq);
    if(lastbad>=0)
    {
      printf("\t%d\t%d\n",charrow[lastbad],charpos[lastbad]);
    }
    else
    {
      printf("\n");
    }

    fflush(stdout);
  }
  time(&time2);
  dtime=time2-time1;

  printf("%d %d %d %lg %lg %d %lg %ld\n",
	 n_input,n_hidden,n_output,learning_rate,momentum,
	 n_learn,ssq,dtime);
  OUT=fopen(argv[7],"w");
  write_network(OUT);
  fclose(OUT);
}

