Skip to content

Commit 415ffe1

Browse files
committed
Minor refactoring and bug fixes
1 parent 27ad160 commit 415ffe1

File tree

6 files changed

+72
-50
lines changed

6 files changed

+72
-50
lines changed

NeuralNetwork.NET/APIs/NetworkManager.cs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
using NeuralNetworkNET.APIs.Interfaces;
99
using NeuralNetworkNET.APIs.Interfaces.Data;
1010
using NeuralNetworkNET.APIs.Results;
11-
using NeuralNetworkNET.APIs.Settings;
1211
using NeuralNetworkNET.APIs.Structs;
1312
using NeuralNetworkNET.Extensions;
1413
using NeuralNetworkNET.Networks.Graph;
@@ -60,6 +59,11 @@ public static INeuralNetwork NewGraph(TensorInfo input, [NotNull] Action<NodeBui
6059

6160
#region Training APIs
6261

62+
/// <summary>
63+
/// Gets whether or not a neural network is currently being trained
64+
/// </summary>
65+
public static bool TrainingInProgress { get; private set; }
66+
6367
/// <summary>
6468
/// Trains a neural network with the given parameters
6569
/// </summary>
@@ -148,7 +152,7 @@ private static TrainingSessionResult TrainNetworkCore(
148152
throw new ArgumentException("The input dataset doesn't match the number of input and output features for the current network", nameof(dataset));
149153

150154
// Start the training
151-
NetworkSettings.TrainingInProgress = NetworkSettings.TrainingInProgress
155+
TrainingInProgress = TrainingInProgress
152156
? throw new InvalidOperationException("Can't train two networks at the same time") // This would cause problems with cuDNN
153157
: true;
154158
TrainingSessionResult result = NetworkTrainer.TrainNetwork(
@@ -158,7 +162,7 @@ private static TrainingSessionResult TrainNetworkCore(
158162
validationDataset as ValidationDataset,
159163
testDataset as TestDataset,
160164
token);
161-
NetworkSettings.TrainingInProgress = false;
165+
TrainingInProgress = false;
162166
return result;
163167
}
164168
}

NeuralNetwork.NET/APIs/Settings/NetworkSettings.cs

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,5 @@ public static AccuracyTester AccuracyTester
3535
get => _AccuracyTester;
3636
set => _AccuracyTester = value ?? throw new ArgumentNullException(nameof(AccuracyTester), "The input delegate can't be null");
3737
}
38-
39-
/// <summary>
40-
/// Gets whether or not a neural network is currently being trained
41-
/// </summary>
42-
public static bool TrainingInProgress { get; internal set; }
43-
44-
/// <summary>
45-
/// Gets whether or not a neural network is currently processing the training samples through backpropagation (as opposed to evaluating them)
46-
/// </summary>
47-
internal static bool BackpropagationInProgress { get; set; }
4838
}
4939
}

NeuralNetwork.NET/APIs/Structs/Tensor.cs

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ private Tensor(IntPtr ptr, int entities, int length)
9595
/// <summary>
9696
/// Creates a new instance with the specified shape
9797
/// </summary>
98-
/// <param name="n">The height of the matrix</param>
99-
/// <param name="chw">The width of the matrix</param>
98+
/// <param name="n">The height of the <see cref="Tensor"/></param>
99+
/// <param name="chw">The width of the <see cref="Tensor"/></param>
100100
/// <param name="tensor">The resulting instance</param>
101101
[MethodImpl(MethodImplOptions.AggressiveInlining)]
102102
public static void New(int n, int chw, out Tensor tensor)
@@ -108,8 +108,8 @@ public static void New(int n, int chw, out Tensor tensor)
108108
/// <summary>
109109
/// Creates a new instance with the specified shape and initializes the allocated memory to 0s
110110
/// </summary>
111-
/// <param name="n">The height of the matrix</param>
112-
/// <param name="chw">The width of the matrix</param>
111+
/// <param name="n">The height of the <see cref="Tensor"/></param>
112+
/// <param name="chw">The width of the <see cref="Tensor"/></param>
113113
/// <param name="tensor">The resulting instance</param>
114114
[MethodImpl(MethodImplOptions.AggressiveInlining)]
115115
public static unsafe void NewZeroed(int n, int chw, out Tensor tensor)
@@ -124,8 +124,8 @@ public static unsafe void NewZeroed(int n, int chw, out Tensor tensor)
124124
/// Creates a new instance by wrapping the input pointer
125125
/// </summary>
126126
/// <param name="p">The target memory area</param>
127-
/// <param name="n">The height of the final matrix</param>
128-
/// <param name="chw">The width of the final matrix</param>
127+
/// <param name="n">The height of the final <see cref="Tensor"/></param>
128+
/// <param name="chw">The width of the final <see cref="Tensor"/></param>
129129
/// <param name="tensor">The resulting instance</param>
130130
[MethodImpl(MethodImplOptions.AggressiveInlining)]
131131
public static unsafe void Reshape(float* p, int n, int chw, out Tensor tensor)
@@ -152,8 +152,8 @@ public static unsafe void Reshape(float* p, int n, int chw, out Tensor tensor)
152152
/// Creates a new instance by copying the contents at the given memory location and reshaping it to the desired size
153153
/// </summary>
154154
/// <param name="p">The target memory area to copy</param>
155-
/// <param name="n">The height of the final matrix</param>
156-
/// <param name="chw">The width of the final matrix</param>
155+
/// <param name="n">The height of the final <see cref="Tensor"/></param>
156+
/// <param name="chw">The width of the final <see cref="Tensor"/></param>
157157
/// <param name="tensor">The resulting instance</param>
158158
[MethodImpl(MethodImplOptions.AggressiveInlining)]
159159
public static unsafe void From(float* p, int n, int chw, out Tensor tensor)
@@ -179,8 +179,8 @@ public static unsafe void From([NotNull] float[,] m, out Tensor tensor)
179179
/// Creates a new instance by copying the contents of the input vector and reshaping it to the desired size
180180
/// </summary>
181181
/// <param name="v">The input vector to copy</param>
182-
/// <param name="n">The height of the final matrix</param>
183-
/// <param name="chw">The width of the final matrix</param>
182+
/// <param name="n">The height of the final <see cref="Tensor"/></param>
183+
/// <param name="chw">The width of the final <see cref="Tensor"/></param>
184184
/// <param name="tensor">The resulting instance</param>
185185
[MethodImpl(MethodImplOptions.AggressiveInlining)]
186186
public static unsafe void From([NotNull] float[] v, int n, int chw, out Tensor tensor)
@@ -197,8 +197,8 @@ public static unsafe void From([NotNull] float[] v, int n, int chw, out Tensor t
197197
/// <summary>
198198
/// Creates a new instance by wrapping the current memory area
199199
/// </summary>
200-
/// <param name="n">The height of the final matrix</param>
201-
/// <param name="chw">The width of the final matrix</param>
200+
/// <param name="n">The height of the final <see cref="Tensor"/></param>
201+
/// <param name="chw">The width of the final <see cref="Tensor"/></param>
202202
/// <param name="tensor">The resulting instance</param>
203203
[MethodImpl(MethodImplOptions.AggressiveInlining)]
204204
public void Reshape(int n, int chw, out Tensor tensor)
@@ -223,19 +223,31 @@ public void Reshape(int n, int chw, out Tensor tensor)
223223
public bool MatchShape(int entities, int length) => Entities == entities && Length == length;
224224

225225
/// <summary>
226-
/// Overwrites the contents of the current matrix with the input matrix
226+
/// Overwrites the contents of the current instance with the input <see cref="Tensor"/>
227227
/// </summary>
228228
/// <param name="tensor">The input <see cref="Tensor"/> to copy</param>
229229
[MethodImpl(MethodImplOptions.AggressiveInlining)]
230230
public unsafe void Overwrite(in Tensor tensor)
231231
{
232-
if (tensor.Entities != Entities || tensor.Length != Length) throw new ArgumentException("The input matrix doesn't have the same size as the target");
232+
if (tensor.Entities != Entities || tensor.Length != Length) throw new ArgumentException("The input tensor doesn't have the same size as the target");
233233
int bytes = sizeof(float) * Size;
234234
Buffer.MemoryCopy(tensor, this, bytes, bytes);
235235
}
236236

237237
/// <summary>
238-
/// Duplicates the current instance to an output <see cref="Tensor"/> matrix
238+
/// Overwrites the contents of the current <see cref="Tensor"/> with the input array
239+
/// </summary>
240+
/// <param name="array">The input array to copy</param>
241+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
242+
public unsafe void Overwrite([NotNull] float[] array)
243+
{
244+
if (array.Length != Size) throw new ArgumentException("The input array doesn't have the same size as the target");
245+
int bytes = sizeof(float) * Size;
246+
fixed (float* p = array) Buffer.MemoryCopy(p, this, bytes, bytes);
247+
}
248+
249+
/// <summary>
250+
/// Duplicates the current instance to an output <see cref="Tensor"/>
239251
/// </summary>
240252
/// <param name="tensor">The output tensor</param>
241253
[MethodImpl(MethodImplOptions.AggressiveInlining)]

NeuralNetwork.NET/Networks/Layers/Abstract/BatchNormalizationLayerBase.cs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
using System.IO;
33
using JetBrains.Annotations;
44
using NeuralNetworkNET.APIs.Enums;
5-
using NeuralNetworkNET.APIs.Settings;
65
using NeuralNetworkNET.APIs.Structs;
76
using NeuralNetworkNET.Extensions;
87
using NeuralNetworkNET.Networks.Layers.Initialization;
8+
using NeuralNetworkNET.SupervisedLearning.Optimization;
99
using Newtonsoft.Json;
1010

1111
namespace NeuralNetworkNET.Networks.Layers.Abstract
@@ -21,13 +21,13 @@ internal abstract class BatchNormalizationLayerBase : WeightedLayerBase
2121
/// The cached mu tensor
2222
/// </summary>
2323
[NotNull]
24-
protected float[] Mu;
24+
public float[] Mu { get; }
2525

2626
/// <summary>
2727
/// The cached sigma^2 tensor
2828
/// </summary>
2929
[NotNull]
30-
protected readonly float[] Sigma2;
30+
public float[] Sigma2 { get; }
3131

3232
// The current iteration number (for the Cumulative Moving Average)
3333
private int _Iteration;
@@ -60,6 +60,7 @@ protected BatchNormalizationLayerBase(in TensorInfo shape, NormalizationMode mod
6060
break;
6161
default: throw new ArgumentOutOfRangeException("Invalid batch normalization mode");
6262
}
63+
Sigma2.AsSpan().Fill(1);
6364
NormalizationMode = mode;
6465
}
6566

@@ -80,7 +81,7 @@ protected BatchNormalizationLayerBase(in TensorInfo shape, NormalizationMode mod
8081
/// <inheritdoc/>
8182
public override void Forward(in Tensor x, out Tensor z, out Tensor a)
8283
{
83-
if (NetworkSettings.BackpropagationInProgress) ForwardTraining(1f / (1 + _Iteration++), x, out z, out a);
84+
if (NetworkTrainer.BackpropagationInProgress) ForwardTraining(1f / (1 + _Iteration++), x, out z, out a);
8485
else ForwardInference(x, out z, out a);
8586
}
8687

NeuralNetwork.NET/SupervisedLearning/Optimization/NetworkTrainer.cs

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
using NeuralNetworkNET.APIs.Enums;
66
using NeuralNetworkNET.APIs.Interfaces;
77
using NeuralNetworkNET.APIs.Results;
8-
using NeuralNetworkNET.APIs.Settings;
98
using NeuralNetworkNET.Extensions;
109
using NeuralNetworkNET.Networks.Implementations;
1110
using NeuralNetworkNET.Services;
@@ -83,6 +82,21 @@ public static TrainingSessionResult TrainNetwork(
8382
return Optimize(network, batches, epochs, dropout, optimizer, batchProgress, trainingProgress, validationDataset, testDataset, token);
8483
}
8584

85+
/// <summary>
86+
/// Gets whether or not a neural network is currently processing the training samples through backpropagation (as opposed to evaluating them)
87+
/// </summary>
88+
public static bool BackpropagationInProgress
89+
{
90+
get;
91+
92+
// Switch from private to internal in DEBUG mode to allow for external handling in the Unit tests
93+
#if DEBUG
94+
set;
95+
#else
96+
private set;
97+
#endif
98+
}
99+
86100
/// <summary>
87101
/// Trains the target <see cref="SequentialNetwork"/> using the input algorithm
88102
/// </summary>
@@ -123,18 +137,18 @@ TrainingSessionResult PrepareResult(TrainingStopReason reason, int loops)
123137
miniBatches.CrossShuffle();
124138

125139
// Gradient descent over the current batches
126-
NetworkSettings.BackpropagationInProgress = true;
140+
BackpropagationInProgress = true;
127141
for (int j = 0; j < miniBatches.BatchesCount; j++)
128142
{
129143
if (token.IsCancellationRequested)
130144
{
131-
NetworkSettings.BackpropagationInProgress = true;
145+
BackpropagationInProgress = false;
132146
return PrepareResult(TrainingStopReason.TrainingCanceled, i);
133147
}
134148
network.Backpropagate(miniBatches.Batches[j], dropout, updater);
135149
batchMonitor?.NotifyCompletedBatch(miniBatches.Batches[j].X.GetLength(0));
136150
}
137-
NetworkSettings.BackpropagationInProgress = true;
151+
BackpropagationInProgress = false;
138152
batchMonitor?.Reset();
139153
if (network.IsInNumericOverflow) return PrepareResult(TrainingStopReason.NumericOverflow, i);
140154

Unit/NeuralNetwork.NET.Cuda.Unit/CuDnnLayersTest.cs

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
using JetBrains.Annotations;
1+
using System;
2+
using JetBrains.Annotations;
23
using Microsoft.VisualStudio.TestTools.UnitTesting;
34
using NeuralNetworkNET.APIs.Enums;
4-
using NeuralNetworkNET.APIs.Settings;
55
using NeuralNetworkNET.APIs.Structs;
66
using NeuralNetworkNET.Extensions;
77
using NeuralNetworkNET.Helpers;
88
using NeuralNetworkNET.Networks.Layers.Abstract;
99
using NeuralNetworkNET.Networks.Layers.Cpu;
1010
using NeuralNetworkNET.Networks.Layers.Cuda;
1111
using NeuralNetworkNET.Networks.Layers.Initialization;
12+
using NeuralNetworkNET.SupervisedLearning.Optimization;
1213

1314
namespace NeuralNetworkNET.Cuda.Unit
1415
{
@@ -47,7 +48,7 @@ private static void TestForward(NetworkLayerBase cpu, NetworkLayerBase gpu, int
4748

4849
private static void TestBackward(WeightedLayerBase cpu, WeightedLayerBase gpu, int samples)
4950
{
50-
NetworkSettings.TrainingInProgress = true;
51+
NetworkTrainer.BackpropagationInProgress = true;
5152
Tensor
5253
x = CreateRandomTensor(samples, cpu.InputInfo.Size),
5354
dy = CreateRandomTensor(samples, cpu.OutputInfo.Size);
@@ -61,12 +62,12 @@ private static void TestBackward(WeightedLayerBase cpu, WeightedLayerBase gpu, i
6162
Assert.IsTrue(dJdw_cpu.ContentEquals(dJdw_gpu, 1e-4f, 1e-5f));
6263
Assert.IsTrue(dJdb_cpu.ContentEquals(dJdb_gpu, 1e-4f, 1e-5f)); // The cuDNN ConvolutionBackwardBias is not always as precise as the CPU version
6364
Tensor.Free(x, dy, dx1, dx2, z_cpu, a_cpu, z_gpu, a_gpu, dJdw_cpu, dJdb_cpu, dJdw_gpu, dJdb_gpu);
64-
NetworkSettings.TrainingInProgress = false;
65+
NetworkTrainer.BackpropagationInProgress = false;
6566
}
6667

6768
private static unsafe void TestBackward(OutputLayerBase cpu, OutputLayerBase gpu, float[,] y)
6869
{
69-
NetworkSettings.TrainingInProgress = true;
70+
NetworkTrainer.BackpropagationInProgress = true;
7071
int n = y.GetLength(0);
7172
fixed (float* p = y)
7273
{
@@ -85,7 +86,7 @@ private static unsafe void TestBackward(OutputLayerBase cpu, OutputLayerBase gpu
8586
Assert.IsTrue(dJdb_cpu.ContentEquals(dJdb_gpu, 1e-4f, 1e-5f));
8687
Tensor.Free(x, dy, dx1, dx2, z_cpu, a_cpu, z_gpu, a_gpu, dJdw_cpu, dJdw_gpu, dJdb_cpu, dJdb_gpu);
8788
}
88-
NetworkSettings.TrainingInProgress = false;
89+
NetworkTrainer.BackpropagationInProgress = false;
8990
}
9091

9192
#endregion
@@ -164,36 +165,36 @@ public void ConvolutionBackward()
164165
[TestMethod]
165166
public void PerActivationBatchNormalizationForward()
166167
{
167-
WeightedLayerBase
168+
BatchNormalizationLayerBase
168169
cpu = new BatchNormalizationLayer(TensorInfo.Linear(250), NormalizationMode.PerActivation, ActivationType.ReLU),
169-
gpu = new CuDnnBatchNormalizationLayer(cpu.InputInfo, NormalizationMode.PerActivation, cpu.Weights, cpu.Biases, new float[250], new float[250], cpu.ActivationType);
170+
gpu = new CuDnnBatchNormalizationLayer(cpu.InputInfo, NormalizationMode.PerActivation, cpu.Weights, cpu.Biases, cpu.Mu.AsSpan().Copy(), cpu.Sigma2.AsSpan().Copy(), cpu.ActivationType);
170171
TestForward(cpu, gpu, 400);
171172
}
172173

173174
[TestMethod]
174175
public void PerActivationBatchNormalizationBackward()
175176
{
176-
WeightedLayerBase
177+
BatchNormalizationLayerBase
177178
cpu = new BatchNormalizationLayer(TensorInfo.Linear(250), NormalizationMode.PerActivation, ActivationType.ReLU),
178-
gpu = new CuDnnBatchNormalizationLayer(cpu.InputInfo, NormalizationMode.PerActivation, cpu.Weights, cpu.Biases, new float[250], new float[250], cpu.ActivationType);
179+
gpu = new CuDnnBatchNormalizationLayer(cpu.InputInfo, NormalizationMode.PerActivation, cpu.Weights, cpu.Biases, cpu.Mu.AsSpan().Copy(), cpu.Sigma2.AsSpan().Copy(), cpu.ActivationType);
179180
TestBackward(cpu, gpu, 400);
180181
}
181182

182183
[TestMethod]
183184
public void SpatialBatchNormalizationForward()
184185
{
185-
WeightedLayerBase
186+
BatchNormalizationLayerBase
186187
cpu = new BatchNormalizationLayer(TensorInfo.Volume(12, 12, 13), NormalizationMode.Spatial, ActivationType.ReLU),
187-
gpu = new CuDnnBatchNormalizationLayer(cpu.InputInfo, NormalizationMode.Spatial, cpu.Weights, cpu.Biases, new float[13], new float[13], cpu.ActivationType);
188+
gpu = new CuDnnBatchNormalizationLayer(cpu.InputInfo, NormalizationMode.Spatial, cpu.Weights, cpu.Biases, cpu.Mu.AsSpan().Copy(), cpu.Sigma2.AsSpan().Copy(), cpu.ActivationType);
188189
TestForward(cpu, gpu, 400);
189190
}
190191

191192
[TestMethod]
192193
public void SpatialBatchNormalizationBackward()
193194
{
194-
WeightedLayerBase
195+
BatchNormalizationLayerBase
195196
cpu = new BatchNormalizationLayer(TensorInfo.Volume(12, 12, 13), NormalizationMode.Spatial, ActivationType.ReLU),
196-
gpu = new CuDnnBatchNormalizationLayer(cpu.InputInfo, NormalizationMode.Spatial, cpu.Weights, cpu.Biases, new float[13], new float[13], cpu.ActivationType);
197+
gpu = new CuDnnBatchNormalizationLayer(cpu.InputInfo, NormalizationMode.Spatial, cpu.Weights, cpu.Biases, cpu.Mu.AsSpan().Copy(), cpu.Sigma2.AsSpan().Copy(), cpu.ActivationType);
197198
TestBackward(cpu, gpu, 400);
198199
}
199200

0 commit comments

Comments
 (0)