/********************************************************************** * 2-bit XOR problem. * * gcc -o xor xor.c -lm * xor 5 0.2 0.8 * * The first argument is the number of hidden units, the second is * the learning constant and the third is the momentum constant. So * the example runs with 5 hidden units, a learning constant * of 0.2 and momentum of 0.8. * * Set the debugging #define to control the output you get. **********************************************************************/ #include #include #include /* for setting random number seed */ #define MAXNH 100 /* maximum number of hidden units */ #define debugging 2 /* set to 1 to see debugging messages */ /* set to 2 to print just errors (for plotting) */ #define RN ((float) random() / (float)((1 << 31) - 1)) /*** function prototypes ***/ void calc_output(float x[2], int nh, float wh[MAXNH][3], float wo[MAXNH+1], float yh[MAXNH], float *yo); void update_weights(float x[2], int nh, float wh[MAXNH][3], float wo[MAXNH+1], float yh[MAXNH], float yo, float error, float rate, float mom, float dwh[MAXNH][3], float dwo[MAXNH+1]); void print_nets(float x[2], int nh, float wh[MAXNH][3], float wo[MAXNH+1], float yh[MAXNH], float yo, float error); float logistic(float x); float Dlogistic(float y); /**********************************************************************/ main(int argc, char *argv[]) { float wh[MAXNH][3]; /* hidden units' weights */ float wo[MAXNH+1]; /* output unit's weights */ float x[2]; /* input vector */ float yh[MAXNH]; /* hidden units' outputs */ float yo; /* output unit's output */ int nh = 5; /* number of hidden units */ float rate = 0.2; /* learning rate with default value*/ float mom = 0.8; /* momentum rate with default value*/ float init_weight_range = 1.0 /*0.1*/; /* range of initial weights (+- this value) */ float dwh[MAXNH][3]; /* previous weight change, for momentum */ float dwo[MAXNH+1]; /* previous weight change, for momentum */ /* Define input and target patterns */ int patterns[4][2] = {0,0, 0,1, 1,0, 1,1}; int targets[4] = {0, 1, 1, 0}; int pattern, target, ncorrect, epoch, u, i; float error, epoch_error; /* Initialize the pseudo-random number generator by current time. */ srandom(time(NULL)); /* If arguments present, process them */ if (argc > 1) nh = atoi(argv[1]); if (argc > 2) rate = atof(argv[2]); if (argc > 3) mom = atof(argv[3]); /* Initialize weights to small random values. */ for (u=0; u