11using System ;
2+ using System . Reflection ;
23using JetBrains . Annotations ;
34using Microsoft . VisualStudio . TestTools . UnitTesting ;
45using NeuralNetworkNET . APIs . Enums ;
@@ -46,9 +47,17 @@ private static void TestForward(NetworkLayerBase cpu, NetworkLayerBase gpu, int
4647 Tensor . Free ( x , z_cpu , a_cpu , z_gpu , a_gpu ) ;
4748 }
4849
50+ // Sets the static property that signals whenever the backpropagation pass is being executed (needed for some layer types)
51+ private static void SetBackpropagationProperty ( bool value )
52+ {
53+ PropertyInfo property = typeof ( NetworkTrainer ) . GetProperty ( nameof ( NetworkTrainer . BackpropagationInProgress ) , BindingFlags . Static | BindingFlags . Public ) ;
54+ if ( property == null ) throw new InvalidOperationException ( "Couldn't find the target property" ) ;
55+ property . SetValue ( null , value ) ;
56+ }
57+
4958 private static void TestBackward ( WeightedLayerBase cpu , WeightedLayerBase gpu , int samples )
5059 {
51- NetworkTrainer . BackpropagationInProgress = true ;
60+ SetBackpropagationProperty ( true ) ;
5261 Tensor
5362 x = CreateRandomTensor ( samples , cpu . InputInfo . Size ) ,
5463 dy = CreateRandomTensor ( samples , cpu . OutputInfo . Size ) ;
@@ -62,12 +71,12 @@ private static void TestBackward(WeightedLayerBase cpu, WeightedLayerBase gpu, i
6271 Assert . IsTrue ( dJdw_cpu . ContentEquals ( dJdw_gpu , 1e-4f , 1e-5f ) ) ;
6372 Assert . IsTrue ( dJdb_cpu . ContentEquals ( dJdb_gpu , 1e-4f , 1e-5f ) ) ; // The cuDNN ConvolutionBackwardBias is not always as precise as the CPU version
6473 Tensor . Free ( x , dy , dx1 , dx2 , z_cpu , a_cpu , z_gpu , a_gpu , dJdw_cpu , dJdb_cpu , dJdw_gpu , dJdb_gpu ) ;
65- NetworkTrainer . BackpropagationInProgress = false ;
74+ SetBackpropagationProperty ( false ) ;
6675 }
6776
6877 private static unsafe void TestBackward ( OutputLayerBase cpu , OutputLayerBase gpu , float [ , ] y )
6978 {
70- NetworkTrainer . BackpropagationInProgress = true ;
79+ SetBackpropagationProperty ( true ) ;
7180 int n = y . GetLength ( 0 ) ;
7281 fixed ( float * p = y )
7382 {
@@ -86,7 +95,7 @@ private static unsafe void TestBackward(OutputLayerBase cpu, OutputLayerBase gpu
8695 Assert . IsTrue ( dJdb_cpu . ContentEquals ( dJdb_gpu , 1e-4f , 1e-5f ) ) ;
8796 Tensor . Free ( x , dy , dx1 , dx2 , z_cpu , a_cpu , z_gpu , a_gpu , dJdw_cpu , dJdw_gpu , dJdb_cpu , dJdb_gpu ) ;
8897 }
89- NetworkTrainer . BackpropagationInProgress = false ;
98+ SetBackpropagationProperty ( false ) ;
9099 }
91100
92101 #endregion
0 commit comments