diff --git a/README.md b/README.md
index a7a35d7..cb29bb6 100644
--- a/README.md
+++ b/README.md
@@ -6,6 +6,7 @@ Neural Network JavaScript library for Coding Train tutorials
Here are some demos running directly in the browser:
* [XOR problem](https://codingtrain.github.io/Toy-Neural-Network-JS/examples/xor/)
* [Handwritten digit recognition](https://codingtrain.github.io/Toy-Neural-Network-JS/examples/mnist/)
+* [Multiple hidden layers](https://maksuel.github.io/Toy-Neural-Network-JS/examples/doodle_classification/)
## To-Do List
@@ -20,7 +21,7 @@ Here are some demos running directly in the browser:
* only use testing data
* [ ] Support for saving / restoring network (see [#50](https://github.com/CodingTrain/Toy-Neural-Network-JS/pull/50))
* [ ] Support for different activation functions (see [#45](https://github.com/CodingTrain/Toy-Neural-Network-JS/pull/45), [#62](https://github.com/CodingTrain/Toy-Neural-Network-JS/pull/62))
-* [ ] Support for multiple hidden layers (see [#61](https://github.com/CodingTrain/Toy-Neural-Network-JS/pull/61))
+* [x] Support for multiple hidden layers (see [#107](https://github.com/CodingTrain/Toy-Neural-Network-JS/pull/107))
* [ ] Support for neuro-evolution
* [ ] play flappy bird (many players at once).
* [ ] play pong (many game simulations at once)
diff --git a/examples/doodle_classification/index.html b/examples/doodle_classification/index.html
index 1419153..863920d 100644
--- a/examples/doodle_classification/index.html
+++ b/examples/doodle_classification/index.html
@@ -12,6 +12,7 @@
+
diff --git a/examples/doodle_classification/sketch.js b/examples/doodle_classification/sketch.js
index 57ec727..f8e87da 100644
--- a/examples/doodle_classification/sketch.js
+++ b/examples/doodle_classification/sketch.js
@@ -15,6 +15,8 @@ let rainbows = {};
let nn;
+let visualization;
+
function preload() {
catsData = loadBytes('data/cats1000.bin');
trainsData = loadBytes('data/trains1000.bin');
@@ -31,8 +33,8 @@ function setup() {
prepareData(rainbows, rainbowsData, RAINBOW);
prepareData(trains, trainsData, TRAIN);
- // Making the neural network
- nn = new NeuralNetwork(784, 64, 3);
+ // Making the neural network (multi-hidden layers)
+ nn = new NeuralNetwork(784, 256, 64, 3);
// Randomizing the data
let training = [];
@@ -95,6 +97,8 @@ function setup() {
// let percent = testAll(testing);
// console.log("% Correct: " + percent);
// }
+
+ visualization = new Visualization(nn);
}
diff --git a/examples/doodle_classification/visualization.js b/examples/doodle_classification/visualization.js
new file mode 100644
index 0000000..29c9832
--- /dev/null
+++ b/examples/doodle_classification/visualization.js
@@ -0,0 +1,36 @@
+class Visualization {
+
+ static _check(nn) {
+ if(typeof p5 !== 'function') {
+ throw new Error('Need to include p5js');
+ } else if(!nn instanceof NeuralNetwork) {
+ throw new Error('Need a instance of NeuralNetwork');
+ }
+ }
+
+ static _getLayers(nn) {
+ let layers = [];
+
+ for(let layer of nn.layers) {
+ layers.push(layer.nodes);
+ }
+
+ return layers;
+ }
+
+ static graphics(nn) {
+ this._check(nn);
+
+ let layers = this._getLayers(nn);
+
+ let w = floor(layers.length * 20);
+ let h = floor(layers[0] * 20);
+
+ let graphics = createGraphics(w,h);
+
+ console.log(graphics.width, graphics.height);
+
+ return graphics;
+
+ }
+}
\ No newline at end of file
diff --git a/lib/nn.js b/lib/nn.js
index 0260280..a1b8c73 100644
--- a/lib/nn.js
+++ b/lib/nn.js
@@ -19,162 +19,285 @@ let tanh = new ActivationFunction(
class NeuralNetwork {
- // TODO: document what a, b, c are
- constructor(a, b, c) {
- if (a instanceof NeuralNetwork) {
- this.input_nodes = a.input_nodes;
- this.hidden_nodes = a.hidden_nodes;
- this.output_nodes = a.output_nodes;
-
- this.weights_ih = a.weights_ih.copy();
- this.weights_ho = a.weights_ho.copy();
-
- this.bias_h = a.bias_h.copy();
- this.bias_o = a.bias_o.copy();
- } else {
- this.input_nodes = a;
- this.hidden_nodes = b;
- this.output_nodes = c;
+ /**
+ * Constructor method.
+ *
+ * The user can enter with parameters (integer) to create a new NeuralNetwork,
+ * where: the first parameter represents the number of inputs, the second
+ * (or more) parameter represents the number of hidden nodes and the last
+ * parameter represents the number of outputs of the network.
+ * The user can copy an instance of NeuralNetwork by passing the same as an
+ * argument to the constructor.
+ *
+ * @param {NeuralNetwork|Array|...Integer} args (Rest parameters)
+ */
+ constructor(...args) {
+
+ if(args.length === 1 && args[0] instanceof NeuralNetwork) {
+
+ let layers = [];
+
+ for(let layer of args[0].layers) {
+ layers.push(layer.nodes);
+ }
- this.weights_ih = new Matrix(this.hidden_nodes, this.input_nodes);
- this.weights_ho = new Matrix(this.output_nodes, this.hidden_nodes);
- this.weights_ih.randomize();
- this.weights_ho.randomize();
+ this._build(layers, args[0].learningRate);
- this.bias_h = new Matrix(this.hidden_nodes, 1);
- this.bias_o = new Matrix(this.output_nodes, 1);
- this.bias_h.randomize();
- this.bias_o.randomize();
- }
+ for(let i = 0; i < this.connections.length; i++) {
+ this.connections[i].weights = args[0].connections[i].weights.copy();
+ this.connections[i].bias = args[0].connections[i].bias.copy();
+ }
- // TODO: copy these as well
- this.setLearningRate();
- this.setActivationFunction();
+ } else if(args.length === 1 && Array.isArray(args[0])) {
+ this._build(args[0]);
- }
+ } else if(args.length === 3 && Number.isInteger(args[0]) &&
+ Array.isArray(args[1]) && Number.isInteger(args[2])) {
- predict(input_array) {
+ let layers = [];
- // Generating the Hidden Outputs
- let inputs = Matrix.fromArray(input_array);
- let hidden = Matrix.multiply(this.weights_ih, inputs);
- hidden.add(this.bias_h);
- // activation function!
- hidden.map(this.activation_function.func);
+ layers.push(args[0]);
- // Generating the output's output!
- let output = Matrix.multiply(this.weights_ho, hidden);
- output.add(this.bias_o);
- output.map(this.activation_function.func);
+ for(let arg of args[1]) {
+ layers.push(arg);
+ }
- // Sending back to the caller!
- return output.toArray();
- }
+ layers.push(args[2]);
+
+ this._build(layers);
+
+ } else if(args.length >= 3) {
+
+ this._build(args);
+
+ } else {
- setLearningRate(learning_rate = 0.1) {
- this.learning_rate = learning_rate;
+ throw new Error('Invalid arguments. Read the documentation!');
+ }
}
- setActivationFunction(func = sigmoid) {
- this.activation_function = func;
+ // PRIVATE
+ _build(layers, learningRate = 0.1) {
+
+ if(!Array.isArray(layers)) {
+ throw new Error('Must be array of nodes.');
+ } else if(!layers.every( value => Number.isInteger(value) )) {
+ throw new Error('All arguments must be integer.');
+ }
+
+ this.learningRate = learningRate;
+ this.activationFunction = sigmoid;
+
+ this.layers = [];
+ this.connections = [];
+
+ for(let nodes of layers) {
+ this.layers.push({
+ nodes: nodes
+ });
+ }
+
+ for(let i = 0; i < this.layers.length - 1; i++) {
+ let primaryNodes = this.layers[i].nodes,
+ secondaryNodes = this.layers[i+1].nodes;
+ this.connections.push({
+ weights: new Matrix(secondaryNodes, primaryNodes).randomize(),
+ bias: new Matrix(secondaryNodes, 1).randomize()
+ });
+ }
}
- train(input_array, target_array) {
- // Generating the Hidden Outputs
- let inputs = Matrix.fromArray(input_array);
- let hidden = Matrix.multiply(this.weights_ih, inputs);
- hidden.add(this.bias_h);
- // activation function!
- hidden.map(this.activation_function.func);
+ _walk(inputs) {
- // Generating the output's output!
- let outputs = Matrix.multiply(this.weights_ho, hidden);
- outputs.add(this.bias_o);
- outputs.map(this.activation_function.func);
+ this.inputs = inputs;
- // Convert array to matrix object
- let targets = Matrix.fromArray(target_array);
+ for(let i = 0; i < this.connections.length; i++) {
+ this.layers[i+1].results = Matrix.multiply(
+ this.connections[i].weights,
+ i === 0 ? this.inputs : this.layers[i].results
+ )
+ .add(this.connections[i].bias)
+ .map(this.activationFunction.func);
+ }
- // Calculate the error
- // ERROR = TARGETS - OUTPUTS
- let output_errors = Matrix.subtract(targets, outputs);
+ return this.outputs;
+ }
- // let gradient = outputs * (1 - outputs);
- // Calculate gradient
- let gradients = Matrix.map(outputs, this.activation_function.dfunc);
- gradients.multiply(output_errors);
- gradients.multiply(this.learning_rate);
+ // PUBLIC
+ train(inputs, targets) {
+ this._walk(inputs);
- // Calculate deltas
- let hidden_T = Matrix.transpose(hidden);
- let weight_ho_deltas = Matrix.multiply(gradients, hidden_T);
+ // TODO: Handle errors
+ if(targets.length !== this.outputsNodes) {
+ throw new Error('ERROR: Target array size.');
+ }
- // Adjust the weights by deltas
- this.weights_ho.add(weight_ho_deltas);
- // Adjust the bias by its deltas (which is just the gradients)
- this.bias_o.add(gradients);
+ // Backpropagation
+ for(let i = this._outputsIndex; i > 0; i--) {
- // Calculate the hidden layer errors
- let who_t = Matrix.transpose(this.weights_ho);
- let hidden_errors = Matrix.multiply(who_t, output_errors);
+ if(i === this._outputsIndex) {
- // Calculate hidden gradient
- let hidden_gradient = Matrix.map(hidden, this.activation_function.dfunc);
- hidden_gradient.multiply(hidden_errors);
- hidden_gradient.multiply(this.learning_rate);
+ this.layers[i].errors = Matrix.subtract(
+ Matrix.fromArray(targets),
+ this.outputs
+ );
- // Calcuate input->hidden deltas
- let inputs_T = Matrix.transpose(inputs);
- let weight_ih_deltas = Matrix.multiply(hidden_gradient, inputs_T);
+ } else {
- this.weights_ih.add(weight_ih_deltas);
- // Adjust the bias by its deltas (which is just the gradients)
- this.bias_h.add(hidden_gradient);
+ this.layers[i].errors = Matrix.multiply(
+ Matrix.transpose(
+ this.connections[i].weights
+ ),
+ this.layers[i+1].errors
+ );
+ }
- // outputs.print();
- // targets.print();
- // error.print();
- }
+ let gradients = Matrix.map(
+ this.layers[i].results,
+ this.activationFunction.dfunc
+ )
+ .multiply(this.layers[i].errors)
+ .multiply(this.learningRate);
- serialize() {
- return JSON.stringify(this);
- }
+ let deltas = Matrix.multiply(
+ gradients,
+ Matrix.transpose(i === 1 ? this.inputs : this.layers[i-1].results)
+ );
- static deserialize(data) {
- if (typeof data == 'string') {
- data = JSON.parse(data);
+ this.connections[i-1].weights.add(deltas);
+ this.connections[i-1].bias.add(gradients);
}
- let nn = new NeuralNetwork(data.input_nodes, data.hidden_nodes, data.output_nodes);
- nn.weights_ih = Matrix.deserialize(data.weights_ih);
- nn.weights_ho = Matrix.deserialize(data.weights_ho);
- nn.bias_h = Matrix.deserialize(data.bias_h);
- nn.bias_o = Matrix.deserialize(data.bias_o);
- nn.learning_rate = data.learning_rate;
- return nn;
}
+ predict(inputs) {
+ return this._walk(inputs).toArray();
+ }
- // Adding function for neuro-evolution
copy() {
return new NeuralNetwork(this);
}
+ serialize() {
+ let nn = this.copy();
+ delete nn._activationFunction;
+
+ for(let layer of nn.layers) {
+ delete layer.matrix;
+ delete layer.results;
+ delete layer.errors;
+ }
+
+ return JSON.stringify(nn);
+ }
+
mutate(rate) {
- function mutate(val) {
- if (Math.random() < rate) {
+
+ if(Number(rate) !== rate || !(0 < rate && rate <= 1)) {
+ throw new Error('Mutate rate must be a number > 0 and <= 1.');
+ }
+
+ let mutate = value => {
+ if(Math.random() < rate) {
return Math.random() * 1000 - 1;
} else {
- return val;
+ return value;
}
+ };
+
+ for(let connection of this.connections) {
+ connection.weights.map(mutate);
+ connection.bias.map(mutate);
}
- this.weights_ih.map(mutate);
- this.weights_ho.map(mutate);
- this.bias_h.map(mutate);
- this.bias_o.map(mutate);
}
+ // SETTERS
+ set inputs(inputs) {
+ if(!Array.isArray(inputs)) {
+ throw new Error('Inputs must be array.');
+ } else if(inputs.length !== this.inputsNodes) {
+ throw new Error('Inputs size.');
+ } else if(!inputs.every( val => true )) { // TODO: Check if is float number
+ throw new Error('Inputs value must be a number.');
+ }
+
+ this.layers[this._inputsIndex].matrix = Matrix.fromArray(inputs);
+ }
+
+ set learningRate(rate) {
+
+ if(Number(rate) !== rate) {
+ throw new Error('Learning rate must be a number');
+ }
+
+ this._learningRate = rate;
+ }
+
+ set activationFunction(func) {
+
+ if(!func instanceof ActivationFunction) {
+ throw new Error('Activation function must be a instance of ActivationFunction.');
+ }
+ this._activationFunction = func;
+ }
+
+ // GETTERS
+ get _inputsIndex() {
+ return 0;
+ }
+
+ get _outputsIndex() {
+ return this.layers.length - 1;
+ }
+
+ get inputsNodes() {
+ return this.layers[this._inputsIndex].nodes;
+ }
+
+ get inputs() {
+ return this.layers[this._inputsIndex].matrix;
+ }
+
+ get outputsNodes() {
+ return this.layers[this._outputsIndex].nodes;
+ }
+
+ get outputs() {
+ return this.layers[this._outputsIndex].results;
+ }
+
+ get learningRate() {
+ return this._learningRate;
+ }
+
+ get activationFunction() {
+ return this._activationFunction;
+ }
+
+ // STATIC
+ static deserialize(data) {
+ if (typeof data == 'string') {
+ data = JSON.parse(data);
+ }
+
+ let args = [null];
+
+ for(let layer of data.layers) {
+ args.push(layer.nodes);
+ }
+
+ let nn = new (Function.prototype.bind.apply(this, args));
+
+ nn.learningRate = data._learningRate;
+
+ for(let i = 0; i < nn.connections.length; i++) {
+ nn.connections[i].weights = Matrix.deserialize(data.connections[i].weights);
+ nn.connections[i].bias = Matrix.deserialize(data.connections[i].bias);
+ }
+
+ return nn;
+ }
}