|
1 | 1 | let nn; |
2 | | -let training_data = [{ |
3 | | - inputs: [0, 0], |
4 | | - targets: [0] |
5 | | -}, { |
6 | | - inputs: [1, 0], |
7 | | - targets: [1] |
8 | | -}, { |
9 | | - inputs: [0, 1], |
10 | | - targets: [1] |
11 | | -}, { |
12 | | - inputs: [1, 1], |
13 | | - targets: [0] |
14 | | -}]; |
15 | | - |
16 | 2 | let lr_slider; |
17 | 3 |
|
| 4 | +let training_data = [{ |
| 5 | + inputs: [0, 0], |
| 6 | + outputs: [0] |
| 7 | + }, |
| 8 | + { |
| 9 | + inputs: [0, 1], |
| 10 | + outputs: [1] |
| 11 | + }, |
| 12 | + { |
| 13 | + inputs: [1, 0], |
| 14 | + outputs: [1] |
| 15 | + }, |
| 16 | + { |
| 17 | + inputs: [1, 1], |
| 18 | + outputs: [0] |
| 19 | + } |
| 20 | +]; |
| 21 | + |
18 | 22 | function setup() { |
19 | 23 | createCanvas(400, 400); |
20 | | - nn = new NeuralNetwork(2, 2, 1); |
21 | | - lr_slider = createSlider(0.01, 0.1, 0.05, 0.01); |
| 24 | + nn = new NeuralNetwork(2, 4, 1); |
| 25 | + lr_slider = createSlider(0.01, 0.5, 0.1, 0.01); |
| 26 | + |
22 | 27 | } |
23 | 28 |
|
24 | 29 | function draw() { |
25 | 30 | background(0); |
26 | 31 |
|
27 | | - nn.learning_rate = lr_slider.value(); |
28 | | - |
29 | | - for (let i = 0; i < 5000; i++) { |
| 32 | + for (let i = 0; i < 10; i++) { |
30 | 33 | let data = random(training_data); |
31 | | - nn.train(data.inputs, data.targets); |
| 34 | + nn.train(data.inputs, data.outputs); |
32 | 35 | } |
33 | 36 |
|
34 | | - let resolution = 20; |
35 | | - let cols = floor(width / resolution); |
36 | | - let rows = floor(height / resolution); |
| 37 | + nn.setLearningRate(lr_slider.value()); |
37 | 38 |
|
| 39 | + let resolution = 10; |
| 40 | + let cols = width / resolution; |
| 41 | + let rows = height / resolution; |
38 | 42 | for (let i = 0; i < cols; i++) { |
39 | 43 | for (let j = 0; j < rows; j++) { |
40 | | - let x = i * resolution; |
41 | | - let y = j * resolution; |
42 | | - let input_1 = i / (cols - 1); |
43 | | - let input_2 = j / (rows - 1); |
44 | | - let output = nn.predict([input_1, input_2]); |
45 | | - let col = output[0] * 255; |
46 | | - fill(col); |
| 44 | + let x1 = i / cols; |
| 45 | + let x2 = j / rows; |
| 46 | + let inputs = [x1, x2]; |
| 47 | + let y = nn.predict(inputs); |
47 | 48 | noStroke(); |
48 | | - rect(x, y, resolution, resolution); |
| 49 | + fill(y * 255); |
| 50 | + rect(i * resolution, j * resolution, resolution, resolution); |
49 | 51 | } |
50 | 52 | } |
51 | 53 |
|
|
0 commit comments