Skip to content

Commit 73a306d

Browse files
committed
Layers JSON serialization improved, code refactoring
1 parent fa49eba commit 73a306d

File tree

5 files changed

+87
-38
lines changed

5 files changed

+87
-38
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
using System;
2+
using System.IO;
3+
using System.Linq;
4+
using System.Runtime.CompilerServices;
5+
using System.Security.Cryptography;
6+
using System.Threading.Tasks;
7+
using JetBrains.Annotations;
8+
using NeuralNetworkNET.Extensions;
9+
10+
namespace NeuralNetworkNET.Helpers
11+
{
12+
/// <summary>
13+
/// A static class that can be used to quickly calculate hashes from array of an arbitrary <see langword="struct"/> type
14+
/// </summary>
15+
public static class Sha256
16+
{
17+
// The SHA256 hash bytes length
18+
private const int HashLength = 32;
19+
20+
/// <summary>
21+
/// Calculates an hash for the input <typeparamref name="T"/> array
22+
/// </summary>
23+
/// <typeparam name="T">The type of items in the input array</typeparam>
24+
/// <param name="array">The input array to process</param>
25+
[PublicAPI]
26+
[Pure, NotNull]
27+
public static unsafe byte[] Hash<T>([NotNull] T[] array) where T : struct
28+
{
29+
int size = Unsafe.SizeOf<T>() * array.Length;
30+
fixed (byte* p = &Unsafe.As<T, byte>(ref array[0]))
31+
using (UnmanagedMemoryStream stream = new UnmanagedMemoryStream(p, size, size, FileAccess.Read))
32+
using (SHA256 provider = SHA256.Create())
33+
{
34+
return provider.ComputeHash(stream);
35+
}
36+
}
37+
38+
/// <summary>
39+
/// Calculates an aggregate hash for the input <typeparamref name="T"/> arrays
40+
/// </summary>
41+
/// <typeparam name="T">The type of items in the input arrays</typeparam>
42+
/// <param name="arrays">The arrays to process</param>
43+
[PublicAPI]
44+
[Pure, NotNull]
45+
public static unsafe byte[] Hash<T>([NotNull, ItemNotNull] params T[][] arrays) where T : struct
46+
{
47+
// Compute the hashes in parallel
48+
if (arrays.Length == 0) return new byte[0];
49+
if (arrays.Any(v => v.Length == 0)) throw new ArgumentException("The input array can't contain empty vectors");
50+
byte[][] hashes = new byte[arrays.Length][];
51+
Parallel.For(0, arrays.Length, i => hashes[i] = Hash(arrays[i])).AssertCompleted();
52+
53+
// Merge the computed hashes into a single bytes array
54+
unchecked
55+
{
56+
byte[] result = new byte[HashLength];
57+
fixed (byte* p = result)
58+
for (int i = 0; i < HashLength; i++)
59+
{
60+
uint hash = 17;
61+
for (int j = 0; j < hashes.Length; j++)
62+
hash = hash * 31 + hashes[j][i];
63+
p[i] = (byte)(hash % byte.MaxValue);
64+
}
65+
return result;
66+
}
67+
}
68+
}
69+
}

NeuralNetwork.NET/Networks/Layers/Abstract/BatchNormalizationLayerBase.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using NeuralNetworkNET.APIs.Enums;
55
using NeuralNetworkNET.APIs.Structs;
66
using NeuralNetworkNET.Extensions;
7+
using NeuralNetworkNET.Helpers;
78
using NeuralNetworkNET.Networks.Layers.Initialization;
89
using NeuralNetworkNET.SupervisedLearning.Optimization;
910
using Newtonsoft.Json;
@@ -32,6 +33,15 @@ internal abstract class BatchNormalizationLayerBase : WeightedLayerBase
3233
// The current iteration number (for the Cumulative Moving Average)
3334
private int _Iteration;
3435

36+
/// <summary>
37+
/// Gets the current CMA factor used to update the <see cref="Mu"/> and <see cref="Sigma2"/> tensors
38+
/// </summary>
39+
[JsonProperty(nameof(CumulativeMovingAverageFactor), Order = 6)]
40+
public float CumulativeMovingAverageFactor => 1f / (1 + _Iteration);
41+
42+
/// <inheritdoc/>
43+
public override String Hash => Convert.ToBase64String(Sha256.Hash(Weights, Biases, Mu, Sigma2));
44+
3545
/// <inheritdoc/>
3646
public override LayerType LayerType { get; } = LayerType.BatchNormalization;
3747

NeuralNetwork.NET/Networks/Layers/Abstract/WeightedLayerBase.cs

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
using System;
22
using System.IO;
3-
using System.Security.Cryptography;
43
using JetBrains.Annotations;
54
using NeuralNetworkNET.APIs.Enums;
65
using NeuralNetworkNET.APIs.Interfaces;
76
using NeuralNetworkNET.APIs.Structs;
87
using NeuralNetworkNET.Extensions;
8+
using NeuralNetworkNET.Helpers;
99
using Newtonsoft.Json;
1010

1111
namespace NeuralNetworkNET.Networks.Layers.Abstract
@@ -23,39 +23,7 @@ internal abstract class WeightedLayerBase : NetworkLayerBase
2323
/// </summary>
2424
[NotNull]
2525
[JsonProperty(nameof(Hash), Order = 5)]
26-
public unsafe String Hash
27-
{
28-
[Pure]
29-
get
30-
{
31-
fixed (float* pw = Weights, pb = Biases)
32-
{
33-
// Use unmanaged streams to avoid copying the weights and biases
34-
int
35-
weightsSize = sizeof(float) * Weights.Length,
36-
biasesSize = sizeof(float) * Biases.Length;
37-
using (UnmanagedMemoryStream
38-
weightsStream = new UnmanagedMemoryStream((byte*)pw, weightsSize, weightsSize, FileAccess.Read),
39-
biasesStream = new UnmanagedMemoryStream((byte*)pb, biasesSize, biasesSize, FileAccess.Read))
40-
using (SHA256 provider = SHA256.Create())
41-
{
42-
// Compute the two SHA256 hashes and combine them (there isn't a way to concatenate two streams with the hash class)
43-
byte[]
44-
weightsHash = provider.ComputeHash(weightsStream),
45-
biasesHash = provider.ComputeHash(biasesStream),
46-
hash = new byte[32];
47-
unchecked
48-
{
49-
for (int i = 0; i < 32; i++)
50-
hash[i] = (byte)(17 * 31 * weightsHash[i] * 31 * biasesHash[i] % byte.MaxValue); // Trust me
51-
}
52-
53-
// Convert the final hash to a base64 string
54-
return Convert.ToBase64String(hash);
55-
}
56-
}
57-
}
58-
}
26+
public virtual String Hash => Convert.ToBase64String(Sha256.Hash(Weights, Biases));
5927

6028
/// <summary>
6129
/// Gets the weights for the current network layer

Unit/NeuralNetwork.NET.Unit/GraphNetworkTest.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -328,16 +328,17 @@ public void JsonMetadataSerialization1()
328328
var conv1 = root.Layer(NetworkLayers.Convolutional((5, 5), 10, ActivationType.ReLU));
329329
var pool1 = conv1.Layer(NetworkLayers.Pooling(ActivationType.Sigmoid));
330330

331-
var _1x1 = pool1.Layer(NetworkLayers.Convolutional((1, 1), 20, ActivationType.ReLU));
331+
var _1x1 = pool1.Layer(NetworkLayers.Convolutional((1, 1), 20, ActivationType.Identity));
332332
var _3x3reduce1x1 = pool1.Layer(NetworkLayers.Convolutional((1, 1), 20, ActivationType.ReLU));
333-
var _3x3 = _3x3reduce1x1.Layer(NetworkLayers.Convolutional((1, 1), 20, ActivationType.ReLU));
333+
var _3x3 = _3x3reduce1x1.Layer(NetworkLayers.Convolutional((1, 1), 20, ActivationType.Identity));
334334

335335
var split = _3x3.TrainingBranch();
336336
var fct = split.Layer(NetworkLayers.FullyConnected(100, ActivationType.LeCunTanh));
337337
_ = fct.Layer(NetworkLayers.Softmax(10));
338338

339339
var stack = _1x1.DepthConcatenation(_3x3);
340-
var fc1 = stack.Layer(NetworkLayers.FullyConnected(100, ActivationType.Sigmoid));
340+
var bn = stack.Layer(NetworkLayers.BatchNormalization(NormalizationMode.Spatial, ActivationType.ReLU));
341+
var fc1 = bn.Layer(NetworkLayers.FullyConnected(100, ActivationType.Sigmoid));
341342
_ = fc1.Layer(NetworkLayers.Softmax(10));
342343
});
343344
String json = network.SerializeMetadataAsJson();

Unit/NeuralNetwork.NET.Unit/SerializationTest.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ public void JsonMetadataSerialization()
9090
NetworkLayers.Convolutional((10, 10), 20, ActivationType.Identity),
9191
NetworkLayers.Pooling(ActivationType.ReLU),
9292
NetworkLayers.Convolutional((10, 10), 20, ActivationType.Identity),
93-
NetworkLayers.Pooling(ActivationType.ReLU),
93+
NetworkLayers.Pooling(ActivationType.Identity),
94+
NetworkLayers.BatchNormalization(NormalizationMode.Spatial, ActivationType.ReLU),
9495
NetworkLayers.FullyConnected(125, ActivationType.Tanh),
9596
NetworkLayers.Softmax(133));
9697
String metadata1 = network.SerializeMetadataAsJson();

0 commit comments

Comments
 (0)