77using NeuralNetworkNET . APIs ;
88using NeuralNetworkNET . APIs . Enums ;
99using NeuralNetworkNET . APIs . Interfaces ;
10+ using NeuralNetworkNET . APIs . Interfaces . Data ;
1011using NeuralNetworkNET . APIs . Results ;
1112using NeuralNetworkNET . APIs . Structs ;
1213using NeuralNetworkNET . Helpers ;
1516
1617namespace DigitsCudaTest
1718{
18- class Program
19+ public class Program
1920 {
20- static async Task Main ( )
21+ public static async Task Main ( )
2122 {
22- // Parse the dataset and create the network
23- ( var training , var test ) = DataParser . LoadDatasets ( ) ;
23+ // Create the network
2424 INeuralNetwork network = NetworkManager . NewSequential ( TensorInfo . CreateForGrayscaleImage ( 28 , 28 ) ,
25- CuDnnNetworkLayers . Convolutional ( ConvolutionInfo . Default , ( 5 , 5 ) , 20 , ActivationFunctionType . LeakyReLU ) ,
2625 CuDnnNetworkLayers . Convolutional ( ConvolutionInfo . Default , ( 5 , 5 ) , 20 , ActivationFunctionType . Identity ) ,
2726 CuDnnNetworkLayers . Pooling ( PoolingInfo . Default , ActivationFunctionType . LeakyReLU ) ,
28- CuDnnNetworkLayers . Convolutional ( ConvolutionInfo . Default , ( 3 , 3 ) , 40 , ActivationFunctionType . LeakyReLU ) ,
2927 CuDnnNetworkLayers . Convolutional ( ConvolutionInfo . Default , ( 3 , 3 ) , 40 , ActivationFunctionType . Identity ) ,
3028 CuDnnNetworkLayers . Pooling ( PoolingInfo . Default , ActivationFunctionType . LeakyReLU ) ,
3129 CuDnnNetworkLayers . FullyConnected ( 125 , ActivationFunctionType . LeCunTanh ) ,
3230 CuDnnNetworkLayers . FullyConnected ( 64 , ActivationFunctionType . LeCunTanh ) ,
3331 CuDnnNetworkLayers . Softmax ( 10 ) ) ;
3432
35- // Setup and start the training
33+ // Prepare the dataset
34+ ( var training , var test ) = DataParser . LoadDatasets ( ) ;
35+ ITrainingDataset trainingData = DatasetLoader . Training ( training , 400 ) ; // Batches of 400 samples
36+ ITestDataset testData = DatasetLoader . Test ( test , new Progress < TrainingProgressEventArgs > ( p =>
37+ {
38+ Printf ( $ "Epoch { p . Iteration } , cost: { p . Result . Cost } , accuracy: { p . Result . Accuracy } ") ;
39+ } ) ) ;
40+
41+ // Setup and network training
3642 CancellationTokenSource cts = new CancellationTokenSource ( ) ;
3743 Console . CancelKeyPress += ( s , e ) => cts . Cancel ( ) ;
3844 TrainingSessionResult result = await NetworkManager . TrainNetworkAsync ( network ,
39- DatasetLoader . Training ( training , 400 ) ,
45+ trainingData ,
4046 TrainingAlgorithms . Adadelta ( ) ,
4147 20 , 0.5f ,
42- new Progress < BatchProgress > ( p =>
43- {
44- Console . SetCursorPosition ( 0 , Console . CursorTop ) ;
45- int n = ( int ) ( p . Percentage * 32 / 100 ) ;
46- char [ ] c = new char [ 32 ] ;
47- for ( int i = 0 ; i < 32 ; i ++ ) c [ i ] = i <= n ? '=' : ' ' ;
48- Console . Write ( $ "[{ new String ( c ) } ] ") ;
49- } ) ,
50- testDataset : DatasetLoader . Test ( test , new Progress < TrainingProgressEventArgs > ( p =>
51- {
52- Printf ( $ "Epoch { p . Iteration } , cost: { p . Result . Cost } , accuracy: { p . Result . Accuracy } ") ;
53- } ) ) , token : cts . Token ) ;
48+ new Progress < BatchProgress > ( TrackBatchProgress ) ,
49+ testDataset : testData , token : cts . Token ) ;
5450
5551 // Save the training reports
5652 String
@@ -75,5 +71,15 @@ private static void Printf(String text)
7571 Console . ForegroundColor = ConsoleColor . White ;
7672 Console . Write ( $ "{ text } \n ") ;
7773 }
74+
75+ // Training monitor
76+ private static void TrackBatchProgress ( BatchProgress progress )
77+ {
78+ Console . SetCursorPosition ( 0 , Console . CursorTop ) ;
79+ int n = ( int ) ( progress . Percentage * 32 / 100 ) ; // 32 is the number of progress '=' characters to display
80+ char [ ] c = new char [ 32 ] ;
81+ for ( int i = 0 ; i < 32 ; i ++ ) c [ i ] = i <= n ? '=' : ' ' ;
82+ Console . Write ( $ "[{ new String ( c ) } ] ") ;
83+ }
7884 }
7985}
0 commit comments