77using NeuralNetworkNET . APIs ;
88using NeuralNetworkNET . APIs . Enums ;
99using NeuralNetworkNET . APIs . Interfaces ;
10+ using NeuralNetworkNET . APIs . Interfaces . Data ;
1011using NeuralNetworkNET . APIs . Results ;
1112using NeuralNetworkNET . APIs . Structs ;
1213using NeuralNetworkNET . Helpers ;
1617
1718namespace DigitsCudaTest
1819{
19- class Program
20+ public class Program
2021 {
21- static async Task Main ( )
22+ public static async Task Main ( )
2223 {
23- // Parse the dataset and create the network
24- ( var training , var test ) = DataParser . LoadDatasets ( ) ;
24+ // Create the network
2525 INeuralNetwork network = NetworkManager . NewSequential ( TensorInfo . Image < Alpha8 > ( 28 , 28 ) ,
26- CuDnnNetworkLayers . Convolutional ( ConvolutionInfo . Default , ( 5 , 5 ) , 20 , ActivationFunctionType . LeakyReLU ) ,
2726 CuDnnNetworkLayers . Convolutional ( ConvolutionInfo . Default , ( 5 , 5 ) , 20 , ActivationFunctionType . Identity ) ,
2827 CuDnnNetworkLayers . Pooling ( PoolingInfo . Default , ActivationFunctionType . LeakyReLU ) ,
29- CuDnnNetworkLayers . Convolutional ( ConvolutionInfo . Default , ( 3 , 3 ) , 40 , ActivationFunctionType . LeakyReLU ) ,
3028 CuDnnNetworkLayers . Convolutional ( ConvolutionInfo . Default , ( 3 , 3 ) , 40 , ActivationFunctionType . Identity ) ,
3129 CuDnnNetworkLayers . Pooling ( PoolingInfo . Default , ActivationFunctionType . LeakyReLU ) ,
3230 CuDnnNetworkLayers . FullyConnected ( 125 , ActivationFunctionType . LeCunTanh ) ,
3331 CuDnnNetworkLayers . FullyConnected ( 64 , ActivationFunctionType . LeCunTanh ) ,
3432 CuDnnNetworkLayers . Softmax ( 10 ) ) ;
3533
36- // Setup and start the training
34+ // Prepare the dataset
35+ ( var training , var test ) = DataParser . LoadDatasets ( ) ;
36+ ITrainingDataset trainingData = DatasetLoader . Training ( training , 400 ) ; // Batches of 400 samples
37+ ITestDataset testData = DatasetLoader . Test ( test , new Progress < TrainingProgressEventArgs > ( p =>
38+ {
39+ Printf ( $ "Epoch { p . Iteration } , cost: { p . Result . Cost } , accuracy: { p . Result . Accuracy } ") ;
40+ } ) ) ;
41+
42+ // Setup and network training
3743 CancellationTokenSource cts = new CancellationTokenSource ( ) ;
3844 Console . CancelKeyPress += ( s , e ) => cts . Cancel ( ) ;
3945 TrainingSessionResult result = await NetworkManager . TrainNetworkAsync ( network ,
40- DatasetLoader . Training ( training , 400 ) ,
46+ trainingData ,
4147 TrainingAlgorithms . Adadelta ( ) ,
4248 20 , 0.5f ,
43- new Progress < BatchProgress > ( p =>
44- {
45- Console . SetCursorPosition ( 0 , Console . CursorTop ) ;
46- int n = ( int ) ( p . Percentage * 32 / 100 ) ;
47- char [ ] c = new char [ 32 ] ;
48- for ( int i = 0 ; i < 32 ; i ++ ) c [ i ] = i <= n ? '=' : ' ' ;
49- Console . Write ( $ "[{ new String ( c ) } ] ") ;
50- } ) ,
51- testDataset : DatasetLoader . Test ( test , new Progress < TrainingProgressEventArgs > ( p =>
52- {
53- Printf ( $ "Epoch { p . Iteration } , cost: { p . Result . Cost } , accuracy: { p . Result . Accuracy } ") ;
54- } ) ) , token : cts . Token ) ;
49+ new Progress < BatchProgress > ( TrackBatchProgress ) ,
50+ testDataset : testData , token : cts . Token ) ;
5551
5652 // Save the training reports
5753 String
@@ -76,5 +72,15 @@ private static void Printf(String text)
7672 Console . ForegroundColor = ConsoleColor . White ;
7773 Console . Write ( $ "{ text } \n ") ;
7874 }
75+
76+ // Training monitor
77+ private static void TrackBatchProgress ( BatchProgress progress )
78+ {
79+ Console . SetCursorPosition ( 0 , Console . CursorTop ) ;
80+ int n = ( int ) ( progress . Percentage * 32 / 100 ) ; // 32 is the number of progress '=' characters to display
81+ char [ ] c = new char [ 32 ] ;
82+ for ( int i = 0 ; i < 32 ; i ++ ) c [ i ] = i <= n ? '=' : ' ' ;
83+ Console . Write ( $ "[{ new String ( c ) } ] ") ;
84+ }
7985 }
8086}
0 commit comments