Skip to content

Commit 71a6dc1

Browse files
authored
Merge pull request #58 from Sergio0694/feature_images-loading
Feature images loading
2 parents 9a52a0d + 43c0453 commit 71a6dc1

File tree

15 files changed

+261
-274
lines changed

15 files changed

+261
-274
lines changed

NeuralNetwork.NET/APIs/DatasetLoader.cs

Lines changed: 95 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44
using JetBrains.Annotations;
55
using NeuralNetworkNET.APIs.Interfaces.Data;
66
using NeuralNetworkNET.Extensions;
7+
using NeuralNetworkNET.Helpers;
78
using NeuralNetworkNET.SupervisedLearning.Data;
89
using NeuralNetworkNET.SupervisedLearning.Optimization.Parameters;
910
using NeuralNetworkNET.SupervisedLearning.Optimization.Progress;
11+
using SixLabors.ImageSharp;
12+
using SixLabors.ImageSharp.PixelFormats;
1013

1114
namespace NeuralNetworkNET.APIs
1215
{
@@ -30,12 +33,12 @@ public static class DatasetLoader
3033
/// <summary>
3134
/// Creates a new <see cref="ITrainingDataset"/> instance to train a network from the input collection, with the specified batch size
3235
/// </summary>
33-
/// <param name="data">The source collection to use to build the training dataset</param>
36+
/// <param name="data">The source collection to use to build the training dataset, where the samples will be extracted from the input <see cref="Func{TResult}"/> instances in parallel</param>
3437
/// <param name="size">The desired dataset batch size</param>
3538
[PublicAPI]
3639
[Pure, NotNull]
3740
[CollectionAccess(CollectionAccessType.Read)]
38-
public static ITrainingDataset Training([NotNull] IEnumerable<Func<(float[] X, float[] Y)>> data, int size) => BatchesCollection.From(data, size);
41+
public static ITrainingDataset Training([NotNull, ItemNotNull] IEnumerable<Func<(float[] X, float[] Y)>> data, int size) => BatchesCollection.From(data, size);
3942

4043
/// <summary>
4144
/// Creates a new <see cref="ITrainingDataset"/> instance to train a network from the input matrices, with the specified batch size
@@ -47,6 +50,34 @@ public static class DatasetLoader
4750
[CollectionAccess(CollectionAccessType.Read)]
4851
public static ITrainingDataset Training((float[,] X, float[,] Y) data, int size) => BatchesCollection.From(data, size);
4952

53+
/// <summary>
54+
/// Creates a new <see cref="ITrainingDataset"/> instance to train a network from the input data, where each input sample is an image in a specified format
55+
/// </summary>
56+
/// <typeparam name="TPixel">The type of image pixels. It must be either <see cref="Alpha8"/>, <see cref="Rgb24"/> or <see cref="Argb32"/></typeparam>
57+
/// <param name="data">A list of <see cref="ValueTuple{T1, T2}"/> items, where the first element is the image path and the second is a vector with the expected outputs</param>
58+
/// <param name="size">The desired dataset batch size</param>
59+
/// <param name="modify">An optional <see cref="Action{T}"/> to modify each sample image when loading the dataset</param>
60+
[PublicAPI]
61+
[Pure, NotNull]
62+
[CollectionAccess(CollectionAccessType.Read)]
63+
public static ITrainingDataset Training<TPixel>([NotNull] IEnumerable<(String X, float[] Y)> data, int size, [CanBeNull] Action<IImageProcessingContext<TPixel>> modify = null)
64+
where TPixel : struct, IPixel<TPixel>
65+
=> BatchesCollection.From(data.Select<(String X, float[] Y), Func<(float[], float[])>>(xy => () => (ImageLoader.Load(xy.X, modify), xy.Y)), size);
66+
67+
/// <summary>
68+
/// Creates a new <see cref="ITrainingDataset"/> instance to train a network from the input data, where each input sample is an image in a specified format
69+
/// </summary>
70+
/// <typeparam name="TPixel">The type of image pixels. It must be either <see cref="Alpha8"/>, <see cref="Rgb24"/> or <see cref="Argb32"/></typeparam>
71+
/// <param name="data">A list of <see cref="ValueTuple{T1, T2}"/> items, where the first element is the image path and the second is a <see cref="Func{TResult}"/> returning a vector with the expected outputs</param>
72+
/// <param name="size">The desired dataset batch size</param>
73+
/// <param name="modify">An optional <see cref="Action{T}"/> to modify each sample image when loading the dataset</param>
74+
[PublicAPI]
75+
[Pure, NotNull]
76+
[CollectionAccess(CollectionAccessType.Read)]
77+
public static ITrainingDataset Training<TPixel>([NotNull] IEnumerable<(String X, Func<float[]> Y)> data, int size, [CanBeNull] Action<IImageProcessingContext<TPixel>> modify = null)
78+
where TPixel : struct, IPixel<TPixel>
79+
=> BatchesCollection.From(data.Select<(String X, Func<float[]> Y), Func<(float[], float[])>>(xy => () => (ImageLoader.Load(xy.X, modify), xy.Y())), size);
80+
5081
#endregion
5182

5283
#region Validation
@@ -66,13 +97,13 @@ public static IValidationDataset Validation([NotNull] IEnumerable<(float[] X, fl
6697
/// <summary>
6798
/// Creates a new <see cref="IValidationDataset"/> instance to validate a network accuracy from the input collection
6899
/// </summary>
69-
/// <param name="data">The source collection to use to build the validation dataset</param>
100+
/// <param name="data">The source collection to use to build the validation dataset, where the samples will be extracted from the input <see cref="Func{TResult}"/> instances in parallel</param>
70101
/// <param name="tolerance">The desired tolerance to test the network for convergence</param>
71102
/// <param name="epochs">The epochs interval to consider when testing the network for convergence</param>
72103
[PublicAPI]
73104
[Pure, NotNull]
74105
[CollectionAccess(CollectionAccessType.Read)]
75-
public static IValidationDataset Validation([NotNull] IEnumerable<Func<(float[] X, float[] Y)>> data, float tolerance = 1e-2f, int epochs = 5)
106+
public static IValidationDataset Validation([NotNull, ItemNotNull] IEnumerable<Func<(float[] X, float[] Y)>> data, float tolerance = 1e-2f, int epochs = 5)
76107
=> Validation(data.AsParallel().Select(f => f()), tolerance, epochs);
77108

78109
/// <summary>
@@ -86,6 +117,36 @@ public static IValidationDataset Validation([NotNull] IEnumerable<Func<(float[]
86117
[CollectionAccess(CollectionAccessType.Read)]
87118
public static IValidationDataset Validation((float[,] X, float[,] Y) data, float tolerance = 1e-2f, int epochs = 5) => new ValidationDataset(data, tolerance, epochs);
88119

120+
/// <summary>
121+
/// Creates a new <see cref="IValidationDataset"/> instance to validate a network accuracy from the input collection
122+
/// </summary>
123+
/// <typeparam name="TPixel">The type of image pixels. It must be either <see cref="Alpha8"/>, <see cref="Rgb24"/> or <see cref="Argb32"/></typeparam>
124+
/// <param name="data">A list of <see cref="ValueTuple{T1, T2}"/> items, where the first element is the image path and the second is a vector with the expected outputs</param>
125+
/// <param name="tolerance">The desired tolerance to test the network for convergence</param>
126+
/// <param name="epochs">The epochs interval to consider when testing the network for convergence</param>
127+
/// <param name="modify">An optional <see cref="Action{T}"/> to modify each sample image when loading the dataset</param>
128+
[PublicAPI]
129+
[Pure, NotNull]
130+
[CollectionAccess(CollectionAccessType.Read)]
131+
public static IValidationDataset Validation<TPixel>([NotNull] IEnumerable<(String X, float[] Y)> data, float tolerance = 1e-2f, int epochs = 5, [CanBeNull] Action<IImageProcessingContext<TPixel>> modify = null)
132+
where TPixel : struct, IPixel<TPixel>
133+
=> Validation(data.Select<(String X, float[] Y), Func<(float[], float[])>>(xy => () => (ImageLoader.Load(xy.X, modify), xy.Y)).AsParallel(), tolerance, epochs);
134+
135+
/// <summary>
136+
/// Creates a new <see cref="IValidationDataset"/> instance to validate a network accuracy from the input collection
137+
/// </summary>
138+
/// <typeparam name="TPixel">The type of image pixels. It must be either <see cref="Alpha8"/>, <see cref="Rgb24"/> or <see cref="Argb32"/></typeparam>
139+
/// <param name="data">A list of <see cref="ValueTuple{T1, T2}"/> items, where the first element is the image path and the second is a <see cref="Func{TResult}"/> returning a vector with the expected outputs</param>
140+
/// <param name="tolerance">The desired tolerance to test the network for convergence</param>
141+
/// <param name="epochs">The epochs interval to consider when testing the network for convergence</param>
142+
/// <param name="modify">An optional <see cref="Action{T}"/> to modify each sample image when loading the dataset</param>
143+
[PublicAPI]
144+
[Pure, NotNull]
145+
[CollectionAccess(CollectionAccessType.Read)]
146+
public static IValidationDataset Validation<TPixel>([NotNull] IEnumerable<(String X, Func<float[]> Y)> data, float tolerance = 1e-2f, int epochs = 5, [CanBeNull] Action<IImageProcessingContext<TPixel>> modify = null)
147+
where TPixel : struct, IPixel<TPixel>
148+
=> Validation(data.Select<(String X, Func<float[]> Y), Func<(float[], float[])>>(xy => () => (ImageLoader.Load(xy.X, modify), xy.Y())).AsParallel(), tolerance, epochs);
149+
89150
#endregion
90151

91152
#region Test
@@ -104,12 +165,12 @@ public static ITestDataset Test([NotNull] IEnumerable<(float[] X, float[] Y)> da
104165
/// <summary>
105166
/// Creates a new <see cref="ITestDataset"/> instance to test a network from the input collection
106167
/// </summary>
107-
/// <param name="data">The source collection to use to build the test dataset</param>
168+
/// <param name="data">The source collection to use to build the test dataset, where the samples will be extracted from the input <see cref="Func{TResult}"/> instances in parallel</param>
108169
/// <param name="progress">The optional progress callback to use</param>
109170
[PublicAPI]
110171
[Pure, NotNull]
111172
[CollectionAccess(CollectionAccessType.Read)]
112-
public static ITestDataset Test([NotNull] IEnumerable<Func<(float[] X, float[] Y)>> data, [CanBeNull] IProgress<TrainingProgressEventArgs> progress = null)
173+
public static ITestDataset Test([NotNull, ItemNotNull] IEnumerable<Func<(float[] X, float[] Y)>> data, [CanBeNull] IProgress<TrainingProgressEventArgs> progress = null)
113174
=> Test(data.AsParallel().Select(f => f()), progress);
114175

115176
/// <summary>
@@ -122,6 +183,34 @@ public static ITestDataset Test([NotNull] IEnumerable<Func<(float[] X, float[] Y
122183
[CollectionAccess(CollectionAccessType.Read)]
123184
public static ITestDataset Test((float[,] X, float[,] Y) data, [CanBeNull] IProgress<TrainingProgressEventArgs> progress = null) => new TestDataset(data, progress);
124185

186+
/// <summary>
187+
/// Creates a new <see cref="ITestDataset"/> instance to test a network from the input collection
188+
/// </summary>
189+
/// <typeparam name="TPixel">The type of image pixels. It must be either <see cref="Alpha8"/>, <see cref="Rgb24"/> or <see cref="Argb32"/></typeparam>
190+
/// <param name="data">A list of <see cref="ValueTuple{T1, T2}"/> items, where the first element is the image path and the second is a vector with the expected outputs</param>
191+
/// <param name="progress">The optional progress callback to use</param>
192+
/// <param name="modify">An optional <see cref="Action{T}"/> to modify each sample image when loading the dataset</param>
193+
[PublicAPI]
194+
[Pure, NotNull]
195+
[CollectionAccess(CollectionAccessType.Read)]
196+
public static ITestDataset Test<TPixel>([NotNull] IEnumerable<(String X, float[] Y)> data, [CanBeNull] IProgress<TrainingProgressEventArgs> progress = null, [CanBeNull] Action<IImageProcessingContext<TPixel>> modify = null)
197+
where TPixel : struct, IPixel<TPixel>
198+
=> Test(data.Select<(String X, float[] Y), Func<(float[], float[])>>(xy => () => (ImageLoader.Load(xy.X, modify), xy.Y)).AsParallel(), progress);
199+
200+
/// <summary>
201+
/// Creates a new <see cref="ITestDataset"/> instance to test a network from the input collection
202+
/// </summary>
203+
/// <typeparam name="TPixel">The type of image pixels. It must be either <see cref="Alpha8"/>, <see cref="Rgb24"/> or <see cref="Argb32"/></typeparam>
204+
/// <param name="data">A list of <see cref="ValueTuple{T1, T2}"/> items, where the first element is the image path and the second is a <see cref="Func{TResult}"/> returning a vector with the expected outputs</param>
205+
/// <param name="progress">The optional progress callback to use</param>
206+
/// <param name="modify">An optional <see cref="Action{T}"/> to modify each sample image when loading the dataset</param>
207+
[PublicAPI]
208+
[Pure, NotNull]
209+
[CollectionAccess(CollectionAccessType.Read)]
210+
public static ITestDataset Test<TPixel>([NotNull] IEnumerable<(String X, Func<float[]> Y)> data, [CanBeNull] IProgress<TrainingProgressEventArgs> progress = null, [CanBeNull] Action<IImageProcessingContext<TPixel>> modify = null)
211+
where TPixel : struct, IPixel<TPixel>
212+
=> Test(data.Select<(String X, Func<float[]> Y), Func<(float[], float[])>>(xy => () => (ImageLoader.Load(xy.X, modify), xy.Y())).AsParallel(), progress);
213+
125214
#endregion
126215
}
127216
}

NeuralNetwork.NET/APIs/Structs/TensorInfo.cs

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System;
44
using System.Diagnostics;
55
using System.Runtime.CompilerServices;
6+
using SixLabors.ImageSharp.PixelFormats;
67

78
namespace NeuralNetworkNET.APIs.Structs
89
{
@@ -67,30 +68,38 @@ internal TensorInfo(int height, int width, int channels)
6768
}
6869

6970
/// <summary>
70-
/// Creates a new <see cref="TensorInfo"/> instance for an RGB image
71+
/// Creates a new <see cref="TensorInfo"/> instance for a linear network layer, without keeping track of spatial info
7172
/// </summary>
72-
/// <param name="height">The height of the input image</param>
73-
/// <param name="width">The width of the input image</param>
73+
/// <param name="size">The input size</param>
7474
[PublicAPI]
7575
[Pure]
76-
public static TensorInfo CreateForRgbImage(int height, int width) => new TensorInfo(height, width, 3);
76+
public static TensorInfo Linear(int size) => new TensorInfo(1, 1, size);
7777

7878
/// <summary>
79-
/// Creates a new <see cref="TensorInfo"/> instance for a grayscale image
79+
/// Creates a new <see cref="TensorInfo"/> instance for an image with a user-defined pixel type
8080
/// </summary>
81+
/// <typeparam name="TPixel">The type of image pixels. It must be either <see cref="Alpha8"/>, <see cref="Rgb24"/> or <see cref="Argb32"/></typeparam>
8182
/// <param name="height">The height of the input image</param>
8283
/// <param name="width">The width of the input image</param>
8384
[PublicAPI]
8485
[Pure]
85-
public static TensorInfo CreateForGrayscaleImage(int height, int width) => new TensorInfo(height, width, 1);
86+
public static TensorInfo Image<TPixel>(int height, int width) where TPixel : struct, IPixel<TPixel>
87+
{
88+
if (typeof(TPixel) == typeof(Alpha8)) return new TensorInfo(height, width, 1);
89+
if (typeof(TPixel) == typeof(Rgb24)) return new TensorInfo(height, width, 3);
90+
if (typeof(TPixel) == typeof(Argb32)) return new TensorInfo(height, width, 4);
91+
throw new InvalidOperationException($"The {typeof(TPixel).Name} pixel format isn't currently supported");
92+
}
8693

8794
/// <summary>
88-
/// Creates a new <see cref="TensorInfo"/> instance for a linear network layer, without keeping track of spatial info
95+
/// Creates a new <see cref="TensorInfo"/> instance for with a custom 3D shape
8996
/// </summary>
90-
/// <param name="size">The input size</param>
97+
/// <param name="height">The input volume height</param>
98+
/// <param name="width">The input volume width</param>
99+
/// <param name="channels">The number of channels in the input volume</param>
91100
[PublicAPI]
92101
[Pure]
93-
public static TensorInfo CreateLinear(int size) => new TensorInfo(1, 1, size);
102+
public static TensorInfo Volume(int height, int width, int channels) => new TensorInfo(height, width, channels);
94103

95104
#endregion
96105

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
using System;
2+
using JetBrains.Annotations;
3+
using SixLabors.ImageSharp;
4+
using SixLabors.ImageSharp.Advanced;
5+
using SixLabors.ImageSharp.PixelFormats;
6+
7+
namespace NeuralNetworkNET.Helpers
8+
{
9+
/// <summary>
10+
/// A static class with some helper methods to quickly load a sample from a target image file
11+
/// </summary>
12+
internal static class ImageLoader
13+
{
14+
/// <summary>
15+
/// Loads the target image and applies the requested changes, then converts it to a dataset sample
16+
/// </summary>
17+
/// <param name="path">The path of the image to load</param>
18+
/// <param name="modify">The optional changes to apply to the image</param>
19+
[Pure, NotNull]
20+
public static float[] Load<TPixel>([NotNull] String path, [CanBeNull] Action<IImageProcessingContext<TPixel>> modify) where TPixel : struct, IPixel<TPixel>
21+
{
22+
using (Image<TPixel> image = Image.Load<TPixel>(path))
23+
{
24+
if (modify != null) image.Mutate(modify);
25+
if (typeof(TPixel) == typeof(Alpha8)) return Load(image as Image<Alpha8>);
26+
if (typeof(TPixel) == typeof(Rgb24)) return Load(image as Image<Rgb24>);
27+
if (typeof(TPixel) == typeof(Argb32)) return Load(image as Image<Argb32>);
28+
throw new InvalidOperationException($"The {typeof(TPixel).Name} pixel format isn't currently supported");
29+
}
30+
}
31+
32+
#region Loaders
33+
34+
// Loads an RGBA32 image
35+
[Pure, NotNull]
36+
private static unsafe float[] Load(Image<Argb32> image)
37+
{
38+
int resolution = image.Height * image.Width;
39+
float[] sample = new float[resolution * 4];
40+
fixed (Argb32* p0 = &image.DangerousGetPinnableReferenceToPixelBuffer())
41+
fixed (float* psample = sample)
42+
{
43+
for (int i = 0; i < resolution; i++)
44+
{
45+
Argb32* pxy = p0 + i;
46+
psample[i] = pxy->A / 255f;
47+
psample[i + resolution] = pxy->R / 255f;
48+
psample[i + 2 * resolution] = pxy->G / 255f;
49+
psample[i + 3 * resolution] = pxy->B / 255f;
50+
}
51+
}
52+
return sample;
53+
}
54+
55+
// Loads an RGBA24 image
56+
[Pure, NotNull]
57+
private static unsafe float[] Load(Image<Rgb24> image)
58+
{
59+
int resolution = image.Height * image.Width;
60+
float[] sample = new float[resolution * 3];
61+
fixed (Rgb24* p0 = &image.DangerousGetPinnableReferenceToPixelBuffer())
62+
fixed (float* psample = sample)
63+
{
64+
for (int i = 0; i < resolution; i++)
65+
{
66+
Rgb24* pxy = p0 + i;
67+
psample[i] = pxy->R / 255f;
68+
psample[i + resolution] = pxy->G / 255f;
69+
psample[i + 2 * resolution] = pxy->B / 255f;
70+
}
71+
}
72+
return sample;
73+
}
74+
75+
// Loads an ALPHA8 image
76+
[Pure, NotNull]
77+
private static unsafe float[] Load(Image<Alpha8> image)
78+
{
79+
int resolution = image.Height * image.Width;
80+
float[] sample = new float[resolution];
81+
fixed (Alpha8* p0 = &image.DangerousGetPinnableReferenceToPixelBuffer())
82+
fixed (float* psample = sample)
83+
for (int i = 0; i < resolution; i++)
84+
psample[i] = p0[i].PackedValue / 255f;
85+
return sample;
86+
}
87+
88+
#endregion
89+
}
90+
}

0 commit comments

Comments
 (0)