Skip to content

Commit 9a52a0d

Browse files
committed
Sample programs improved
1 parent ea6d85d commit 9a52a0d

File tree

2 files changed

+54
-37
lines changed

2 files changed

+54
-37
lines changed

Samples/DigitsCudaTest/Program.cs

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using NeuralNetworkNET.APIs;
88
using NeuralNetworkNET.APIs.Enums;
99
using NeuralNetworkNET.APIs.Interfaces;
10+
using NeuralNetworkNET.APIs.Interfaces.Data;
1011
using NeuralNetworkNET.APIs.Results;
1112
using NeuralNetworkNET.APIs.Structs;
1213
using NeuralNetworkNET.Helpers;
@@ -15,42 +16,37 @@
1516

1617
namespace DigitsCudaTest
1718
{
18-
class Program
19+
public class Program
1920
{
20-
static async Task Main()
21+
public static async Task Main()
2122
{
22-
// Parse the dataset and create the network
23-
(var training, var test) = DataParser.LoadDatasets();
23+
// Create the network
2424
INeuralNetwork network = NetworkManager.NewSequential(TensorInfo.CreateForGrayscaleImage(28, 28),
25-
CuDnnNetworkLayers.Convolutional(ConvolutionInfo.Default, (5, 5), 20, ActivationFunctionType.LeakyReLU),
2625
CuDnnNetworkLayers.Convolutional(ConvolutionInfo.Default, (5, 5), 20, ActivationFunctionType.Identity),
2726
CuDnnNetworkLayers.Pooling(PoolingInfo.Default, ActivationFunctionType.LeakyReLU),
28-
CuDnnNetworkLayers.Convolutional(ConvolutionInfo.Default, (3, 3), 40, ActivationFunctionType.LeakyReLU),
2927
CuDnnNetworkLayers.Convolutional(ConvolutionInfo.Default, (3, 3), 40, ActivationFunctionType.Identity),
3028
CuDnnNetworkLayers.Pooling(PoolingInfo.Default, ActivationFunctionType.LeakyReLU),
3129
CuDnnNetworkLayers.FullyConnected(125, ActivationFunctionType.LeCunTanh),
3230
CuDnnNetworkLayers.FullyConnected(64, ActivationFunctionType.LeCunTanh),
3331
CuDnnNetworkLayers.Softmax(10));
3432

35-
// Setup and start the training
33+
// Prepare the dataset
34+
(var training, var test) = DataParser.LoadDatasets();
35+
ITrainingDataset trainingData = DatasetLoader.Training(training, 400); // Batches of 400 samples
36+
ITestDataset testData = DatasetLoader.Test(test, new Progress<TrainingProgressEventArgs>(p =>
37+
{
38+
Printf($"Epoch {p.Iteration}, cost: {p.Result.Cost}, accuracy: {p.Result.Accuracy}");
39+
}));
40+
41+
// Setup and network training
3642
CancellationTokenSource cts = new CancellationTokenSource();
3743
Console.CancelKeyPress += (s, e) => cts.Cancel();
3844
TrainingSessionResult result = await NetworkManager.TrainNetworkAsync(network,
39-
DatasetLoader.Training(training, 400),
45+
trainingData,
4046
TrainingAlgorithms.Adadelta(),
4147
20, 0.5f,
42-
new Progress<BatchProgress>(p =>
43-
{
44-
Console.SetCursorPosition(0, Console.CursorTop);
45-
int n = (int)(p.Percentage * 32 / 100);
46-
char[] c = new char[32];
47-
for (int i = 0; i < 32; i++) c[i] = i <= n ? '=' : ' ';
48-
Console.Write($"[{new String(c)}] ");
49-
}),
50-
testDataset: DatasetLoader.Test(test, new Progress<TrainingProgressEventArgs>(p =>
51-
{
52-
Printf($"Epoch {p.Iteration}, cost: {p.Result.Cost}, accuracy: {p.Result.Accuracy}");
53-
})), token: cts.Token);
48+
new Progress<BatchProgress>(TrackBatchProgress),
49+
testDataset: testData, token: cts.Token);
5450

5551
// Save the training reports
5652
String
@@ -75,5 +71,15 @@ private static void Printf(String text)
7571
Console.ForegroundColor = ConsoleColor.White;
7672
Console.Write($"{text}\n");
7773
}
74+
75+
// Training monitor
76+
private static void TrackBatchProgress(BatchProgress progress)
77+
{
78+
Console.SetCursorPosition(0, Console.CursorTop);
79+
int n = (int)(progress.Percentage * 32 / 100); // 32 is the number of progress '=' characters to display
80+
char[] c = new char[32];
81+
for (int i = 0; i < 32; i++) c[i] = i <= n ? '=' : ' ';
82+
Console.Write($"[{new String(c)}] ");
83+
}
7884
}
7985
}

Samples/DigitsTest/Program.cs

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,39 +3,40 @@
33
using MnistDatasetToolkit;
44
using NeuralNetworkNET.APIs;
55
using NeuralNetworkNET.APIs.Interfaces;
6+
using NeuralNetworkNET.APIs.Interfaces.Data;
67
using NeuralNetworkNET.APIs.Results;
78
using NeuralNetworkNET.APIs.Structs;
89
using NeuralNetworkNET.Networks.Activations;
910
using NeuralNetworkNET.SupervisedLearning.Optimization.Progress;
1011

1112
namespace DigitsTest
1213
{
13-
class Program
14+
public class Program
1415
{
15-
static async Task Main()
16+
public static async Task Main()
1617
{
17-
(var training, var test) = DataParser.LoadDatasets();
18+
// Create the network
1819
INeuralNetwork network = NetworkManager.NewSequential(TensorInfo.CreateForGrayscaleImage(28, 28),
1920
NetworkLayers.Convolutional((5, 5), 20, ActivationFunctionType.Identity),
2021
NetworkLayers.Pooling(ActivationFunctionType.LeakyReLU),
2122
NetworkLayers.FullyConnected(100, ActivationFunctionType.LeCunTanh),
2223
NetworkLayers.Softmax(10));
24+
25+
// Prepare the dataset
26+
(var training, var test) = DataParser.LoadDatasets();
27+
ITrainingDataset trainingData = DatasetLoader.Training(training, 100); // Batches of 100 samples
28+
ITestDataset testData = DatasetLoader.Test(test, new Progress<TrainingProgressEventArgs>(p =>
29+
{
30+
Printf($"Epoch {p.Iteration}, cost: {p.Result.Cost}, accuracy: {p.Result.Accuracy}");
31+
}));
32+
33+
// Train the network
2334
TrainingSessionResult result = await NetworkManager.TrainNetworkAsync(network,
24-
DatasetLoader.Training(training, 100),
25-
TrainingAlgorithms.Adadelta(),
35+
trainingData,
36+
TrainingAlgorithms.Adadelta(),
2637
60, 0.5f,
27-
new Progress<BatchProgress>(p =>
28-
{
29-
Console.SetCursorPosition(0, Console.CursorTop);
30-
int n = (int)(p.Percentage * 32 / 100);
31-
char[] c = new char[32];
32-
for (int i = 0; i < 32; i++) c[i] = i <= n ? '=' : ' ';
33-
Console.Write($"[{new String(c)}] ");
34-
}),
35-
testDataset: DatasetLoader.Test(test, new Progress<TrainingProgressEventArgs>(p =>
36-
{
37-
Printf($"Epoch {p.Iteration}, cost: {p.Result.Cost}, accuracy: {p.Result.Accuracy}");
38-
})));
38+
new Progress<BatchProgress>(TrackBatchProgress),
39+
testDataset: testData);
3940
Printf($"Stop reason: {result.StopReason}, elapsed time: {result.TrainingTime}");
4041
Console.ReadKey();
4142
}
@@ -48,5 +49,15 @@ private static void Printf(String text)
4849
Console.ForegroundColor = ConsoleColor.White;
4950
Console.Write($"{text}\n");
5051
}
52+
53+
// Training monitor
54+
private static void TrackBatchProgress(BatchProgress progress)
55+
{
56+
Console.SetCursorPosition(0, Console.CursorTop);
57+
int n = (int)(progress.Percentage * 32 / 100); // 32 is the number of progress '=' characters to display
58+
char[] c = new char[32];
59+
for (int i = 0; i < 32; i++) c[i] = i <= n ? '=' : ' ';
60+
Console.Write($"[{new String(c)}] ");
61+
}
5162
}
5263
}

0 commit comments

Comments
 (0)