Skip to content

Commit 43c0453

Browse files
authored
Merge branch 'dev' into feature_images-loading
2 parents 163ac79 + 9a52a0d commit 43c0453

File tree

2 files changed

+54
-37
lines changed

2 files changed

+54
-37
lines changed

Samples/DigitsCudaTest/Program.cs

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using NeuralNetworkNET.APIs;
88
using NeuralNetworkNET.APIs.Enums;
99
using NeuralNetworkNET.APIs.Interfaces;
10+
using NeuralNetworkNET.APIs.Interfaces.Data;
1011
using NeuralNetworkNET.APIs.Results;
1112
using NeuralNetworkNET.APIs.Structs;
1213
using NeuralNetworkNET.Helpers;
@@ -16,42 +17,37 @@
1617

1718
namespace DigitsCudaTest
1819
{
19-
class Program
20+
public class Program
2021
{
21-
static async Task Main()
22+
public static async Task Main()
2223
{
23-
// Parse the dataset and create the network
24-
(var training, var test) = DataParser.LoadDatasets();
24+
// Create the network
2525
INeuralNetwork network = NetworkManager.NewSequential(TensorInfo.Image<Alpha8>(28, 28),
26-
CuDnnNetworkLayers.Convolutional(ConvolutionInfo.Default, (5, 5), 20, ActivationFunctionType.LeakyReLU),
2726
CuDnnNetworkLayers.Convolutional(ConvolutionInfo.Default, (5, 5), 20, ActivationFunctionType.Identity),
2827
CuDnnNetworkLayers.Pooling(PoolingInfo.Default, ActivationFunctionType.LeakyReLU),
29-
CuDnnNetworkLayers.Convolutional(ConvolutionInfo.Default, (3, 3), 40, ActivationFunctionType.LeakyReLU),
3028
CuDnnNetworkLayers.Convolutional(ConvolutionInfo.Default, (3, 3), 40, ActivationFunctionType.Identity),
3129
CuDnnNetworkLayers.Pooling(PoolingInfo.Default, ActivationFunctionType.LeakyReLU),
3230
CuDnnNetworkLayers.FullyConnected(125, ActivationFunctionType.LeCunTanh),
3331
CuDnnNetworkLayers.FullyConnected(64, ActivationFunctionType.LeCunTanh),
3432
CuDnnNetworkLayers.Softmax(10));
3533

36-
// Setup and start the training
34+
// Prepare the dataset
35+
(var training, var test) = DataParser.LoadDatasets();
36+
ITrainingDataset trainingData = DatasetLoader.Training(training, 400); // Batches of 400 samples
37+
ITestDataset testData = DatasetLoader.Test(test, new Progress<TrainingProgressEventArgs>(p =>
38+
{
39+
Printf($"Epoch {p.Iteration}, cost: {p.Result.Cost}, accuracy: {p.Result.Accuracy}");
40+
}));
41+
42+
// Setup and network training
3743
CancellationTokenSource cts = new CancellationTokenSource();
3844
Console.CancelKeyPress += (s, e) => cts.Cancel();
3945
TrainingSessionResult result = await NetworkManager.TrainNetworkAsync(network,
40-
DatasetLoader.Training(training, 400),
46+
trainingData,
4147
TrainingAlgorithms.Adadelta(),
4248
20, 0.5f,
43-
new Progress<BatchProgress>(p =>
44-
{
45-
Console.SetCursorPosition(0, Console.CursorTop);
46-
int n = (int)(p.Percentage * 32 / 100);
47-
char[] c = new char[32];
48-
for (int i = 0; i < 32; i++) c[i] = i <= n ? '=' : ' ';
49-
Console.Write($"[{new String(c)}] ");
50-
}),
51-
testDataset: DatasetLoader.Test(test, new Progress<TrainingProgressEventArgs>(p =>
52-
{
53-
Printf($"Epoch {p.Iteration}, cost: {p.Result.Cost}, accuracy: {p.Result.Accuracy}");
54-
})), token: cts.Token);
49+
new Progress<BatchProgress>(TrackBatchProgress),
50+
testDataset: testData, token: cts.Token);
5551

5652
// Save the training reports
5753
String
@@ -76,5 +72,15 @@ private static void Printf(String text)
7672
Console.ForegroundColor = ConsoleColor.White;
7773
Console.Write($"{text}\n");
7874
}
75+
76+
// Training monitor
77+
private static void TrackBatchProgress(BatchProgress progress)
78+
{
79+
Console.SetCursorPosition(0, Console.CursorTop);
80+
int n = (int)(progress.Percentage * 32 / 100); // 32 is the number of progress '=' characters to display
81+
char[] c = new char[32];
82+
for (int i = 0; i < 32; i++) c[i] = i <= n ? '=' : ' ';
83+
Console.Write($"[{new String(c)}] ");
84+
}
7985
}
8086
}

Samples/DigitsTest/Program.cs

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using MnistDatasetToolkit;
44
using NeuralNetworkNET.APIs;
55
using NeuralNetworkNET.APIs.Interfaces;
6+
using NeuralNetworkNET.APIs.Interfaces.Data;
67
using NeuralNetworkNET.APIs.Results;
78
using NeuralNetworkNET.APIs.Structs;
89
using NeuralNetworkNET.Networks.Activations;
@@ -11,32 +12,32 @@
1112

1213
namespace DigitsTest
1314
{
14-
class Program
15+
public class Program
1516
{
16-
static async Task Main()
17+
public static async Task Main()
1718
{
18-
(var training, var test) = DataParser.LoadDatasets();
19+
// Create the network
1920
INeuralNetwork network = NetworkManager.NewSequential(TensorInfo.Image<Alpha8>(28, 28),
2021
NetworkLayers.Convolutional((5, 5), 20, ActivationFunctionType.Identity),
2122
NetworkLayers.Pooling(ActivationFunctionType.LeakyReLU),
2223
NetworkLayers.FullyConnected(100, ActivationFunctionType.LeCunTanh),
2324
NetworkLayers.Softmax(10));
25+
26+
// Prepare the dataset
27+
(var training, var test) = DataParser.LoadDatasets();
28+
ITrainingDataset trainingData = DatasetLoader.Training(training, 100); // Batches of 100 samples
29+
ITestDataset testData = DatasetLoader.Test(test, new Progress<TrainingProgressEventArgs>(p =>
30+
{
31+
Printf($"Epoch {p.Iteration}, cost: {p.Result.Cost}, accuracy: {p.Result.Accuracy}");
32+
}));
33+
34+
// Train the network
2435
TrainingSessionResult result = await NetworkManager.TrainNetworkAsync(network,
25-
DatasetLoader.Training(training, 100),
26-
TrainingAlgorithms.Adadelta(),
36+
trainingData,
37+
TrainingAlgorithms.Adadelta(),
2738
60, 0.5f,
28-
new Progress<BatchProgress>(p =>
29-
{
30-
Console.SetCursorPosition(0, Console.CursorTop);
31-
int n = (int)(p.Percentage * 32 / 100);
32-
char[] c = new char[32];
33-
for (int i = 0; i < 32; i++) c[i] = i <= n ? '=' : ' ';
34-
Console.Write($"[{new String(c)}] ");
35-
}),
36-
testDataset: DatasetLoader.Test(test, new Progress<TrainingProgressEventArgs>(p =>
37-
{
38-
Printf($"Epoch {p.Iteration}, cost: {p.Result.Cost}, accuracy: {p.Result.Accuracy}");
39-
})));
39+
new Progress<BatchProgress>(TrackBatchProgress),
40+
testDataset: testData);
4041
Printf($"Stop reason: {result.StopReason}, elapsed time: {result.TrainingTime}");
4142
Console.ReadKey();
4243
}
@@ -49,5 +50,15 @@ private static void Printf(String text)
4950
Console.ForegroundColor = ConsoleColor.White;
5051
Console.Write($"{text}\n");
5152
}
53+
54+
// Training monitor
55+
private static void TrackBatchProgress(BatchProgress progress)
56+
{
57+
Console.SetCursorPosition(0, Console.CursorTop);
58+
int n = (int)(progress.Percentage * 32 / 100); // 32 is the number of progress '=' characters to display
59+
char[] c = new char[32];
60+
for (int i = 0; i < 32; i++) c[i] = i <= n ? '=' : ' ';
61+
Console.Write($"[{new String(c)}] ");
62+
}
5263
}
5364
}

0 commit comments

Comments
 (0)