Skip to content

Commit 1dc5b34

Browse files
authored
fix: abstract image preprocessing in C++ computer vision (#376)
## Description Abstract image loading into an Executorch tensor in C++ computer vision native code. ### Type of change - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [x] iOS - [x] Android ### Related issues #374 ### Checklist - [x] I have performed a self-review of my code - [ ] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [x] My changes generate no new warnings
1 parent 211757d commit 1dc5b34

File tree

11 files changed

+51
-59
lines changed

11 files changed

+51
-59
lines changed

apps/computer-vision/ios/Podfile.lock

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2356,7 +2356,7 @@ SPEC CHECKSUMS:
23562356
React-logger: 8edfcedc100544791cd82692ca5a574240a16219
23572357
React-Mapbuffer: c3f4b608e4a59dd2f6a416ef4d47a14400194468
23582358
React-microtasksnativemodule: 054f34e9b82f02bd40f09cebd4083828b5b2beb6
2359-
react-native-executorch: 30047a5076fa3c91119618147627d895d87af51b
2359+
react-native-executorch: 53f918e0e3905243cc39d2d1a9df018bcd49c77b
23602360
react-native-image-picker: 8a3f16000e794f5381a7fe47bb48fd8d06741e47
23612361
react-native-safe-area-context: 562163222d999b79a51577eda2ea8ad2c32b4d06
23622362
react-native-skia: b6cb66e99a953dae6880348c92cfb20a76d90b4f

packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.cpp

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,20 +111,43 @@ cv::Mat readImage(const std::string &imageURI) {
111111
throw std::runtime_error("Read image error: invalid argument");
112112
}
113113

114-
cv::cvtColor(image, image, cv::COLOR_BGR2RGB);
115114
return image;
116115
}
117116

118-
TensorPtr getTensorFromMatrix(const std::vector<int32_t> &sizes,
117+
TensorPtr getTensorFromMatrix(const std::vector<int32_t> &tensorDims,
119118
const cv::Mat &matrix) {
120119
std::vector<float> inputVector = colorMatToVector(matrix);
121-
return executorch::extension::make_tensor_ptr(sizes, inputVector);
120+
return executorch::extension::make_tensor_ptr(tensorDims, inputVector);
122121
}
123122

124123
cv::Mat getMatrixFromTensor(cv::Size size, const Tensor &tensor) {
125124
auto resultData = static_cast<const float *>(tensor.const_data_ptr());
126125
return bufferToColorMat(std::span<const float>(resultData, tensor.numel()),
127126
size);
128127
}
128+
129+
std::pair<TensorPtr, cv::Size>
130+
readImageToTensor(const std::string &path,
131+
const std::vector<int32_t> &tensorDims) {
132+
cv::Mat input = imageprocessing::readImage(path);
133+
cv::Size imageSize = input.size();
134+
135+
if (tensorDims.size() < 2) {
136+
char errorMessage[100];
137+
std::snprintf(errorMessage, sizeof(errorMessage),
138+
"Unexpected tensor size, expected at least 2 dimentions "
139+
"but got: %zu.",
140+
tensorDims.size());
141+
throw std::runtime_error(errorMessage);
142+
}
143+
cv::Size tensorSize = cv::Size(tensorDims[tensorDims.size() - 1],
144+
tensorDims[tensorDims.size() - 2]);
145+
146+
cv::resize(input, input, tensorSize);
147+
148+
cv::cvtColor(input, input, cv::COLOR_BGR2RGB);
149+
150+
return {imageprocessing::getTensorFromMatrix(tensorDims, input), imageSize};
151+
}
129152
} // namespace imageprocessing
130153
} // namespace rnexecutorch

packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.h

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
#pragma once
22

3-
#include <executorch/extension/tensor/tensor.h>
4-
#include <executorch/extension/tensor/tensor_ptr.h>
5-
#include <opencv2/opencv.hpp>
3+
#include <optional>
64
#include <span>
75
#include <string>
86
#include <vector>
97

8+
#include <executorch/extension/tensor/tensor.h>
9+
#include <executorch/extension/tensor/tensor_ptr.h>
10+
11+
#include <opencv2/opencv.hpp>
12+
1013
namespace rnexecutorch::imageprocessing {
1114
using executorch::aten::Tensor;
1215
using executorch::extension::TensorPtr;
@@ -21,9 +24,15 @@ std::vector<float> colorMatToVector(const cv::Mat &mat);
2124
cv::Mat bufferToColorMat(const std::span<const float> &buffer,
2225
cv::Size matSize);
2326
std::string saveToTempFile(const cv::Mat &image);
27+
/// @brief Read image in a BGR format to a cv::Mat
2428
cv::Mat readImage(const std::string &imageURI);
25-
TensorPtr getTensorFromMatrix(const std::vector<int32_t> &sizes,
29+
TensorPtr getTensorFromMatrix(const std::vector<int32_t> &tensorDims,
2630
const cv::Mat &mat);
2731
cv::Mat getMatrixFromTensor(cv::Size size, const Tensor &tensor);
32+
/// @brief Read image, resize it and copy it to an ET tensor to store it.
33+
/// @return Returns a tensor pointer and the original size of the image.
34+
std::pair<TensorPtr, cv::Size>
35+
readImageToTensor(const std::string &path,
36+
const std::vector<int32_t> &tensorDims);
2837

2938
} // namespace rnexecutorch::imageprocessing

packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,10 @@ Classification::Classification(const std::string &modelSource,
3030

3131
std::unordered_map<std::string_view, float>
3232
Classification::forward(std::string imageSource) {
33-
auto tensor = preprocess(imageSource);
33+
auto inputTensor =
34+
imageprocessing::readImageToTensor(imageSource, getInputShape()[0]).first;
3435

35-
auto forwardResult = forwardET(tensor);
36+
auto forwardResult = forwardET(inputTensor);
3637
if (!forwardResult.ok()) {
3738
throw std::runtime_error(
3839
"Failed to forward, error: " +
@@ -42,13 +43,6 @@ Classification::forward(std::string imageSource) {
4243
return postprocess(forwardResult->at(0).toTensor());
4344
}
4445

45-
TensorPtr Classification::preprocess(const std::string &imageSource) {
46-
cv::Mat image = imageprocessing::readImage(imageSource);
47-
cv::resize(image, image, modelImageSize);
48-
49-
return imageprocessing::getTensorFromMatrix(getInputShape()[0], image);
50-
}
51-
5246
std::unordered_map<std::string_view, float>
5347
Classification::postprocess(const Tensor &tensor) {
5448
std::span<const float> resultData(

packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ class Classification : public BaseModel {
1818
std::unordered_map<std::string_view, float> forward(std::string imageSource);
1919

2020
private:
21-
TensorPtr preprocess(const std::string &imageSource);
2221
std::unordered_map<std::string_view, float> postprocess(const Tensor &tensor);
2322

2423
cv::Size modelImageSize{0, 0};

packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ std::shared_ptr<jsi::Object>
3737
ImageSegmentation::forward(std::string imageSource,
3838
std::set<std::string, std::less<>> classesOfInterest,
3939
bool resize) {
40-
auto [inputTensor, originalSize] = preprocess(imageSource);
40+
auto [inputTensor, originalSize] =
41+
imageprocessing::readImageToTensor(imageSource, getInputShape()[0]);
4142

4243
auto forwardResult = forwardET(inputTensor);
4344
if (!forwardResult.ok()) {
@@ -50,19 +51,6 @@ ImageSegmentation::forward(std::string imageSource,
5051
classesOfInterest, resize);
5152
}
5253

53-
std::pair<TensorPtr, cv::Size>
54-
ImageSegmentation::preprocess(const std::string &imageSource) {
55-
cv::Mat input = imageprocessing::readImage(imageSource);
56-
cv::Size inputSize = input.size();
57-
58-
cv::resize(input, input, modelImageSize);
59-
60-
std::vector<float> inputVector = imageprocessing::colorMatToVector(input);
61-
return {
62-
executorch::extension::make_tensor_ptr(getInputShape()[0], inputVector),
63-
inputSize};
64-
}
65-
6654
std::shared_ptr<jsi::Object> ImageSegmentation::postprocess(
6755
const Tensor &tensor, cv::Size originalSize,
6856
std::set<std::string, std::less<>> classesOfInterest, bool resize) {

packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ class ImageSegmentation : public BaseModel {
2626
std::set<std::string, std::less<>> classesOfInterest, bool resize);
2727

2828
private:
29-
std::pair<TensorPtr, cv::Size> preprocess(const std::string &imageSource);
3029
std::shared_ptr<jsi::Object>
3130
postprocess(const Tensor &tensor, cv::Size originalSize,
3231
std::set<std::string, std::less<>> classesOfInterest,

packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,6 @@ ObjectDetection::ObjectDetection(
2525
modelInputShape[modelInputShape.size() - 2]);
2626
}
2727

28-
std::pair<TensorPtr, cv::Size>
29-
ObjectDetection::preprocess(const std::string &imageSource) {
30-
cv::Mat image = imageprocessing::readImage(imageSource);
31-
auto originalSize = image.size();
32-
cv::resize(image, image, modelImageSize);
33-
34-
return {imageprocessing::getTensorFromMatrix(getInputShape()[0], image),
35-
originalSize};
36-
}
37-
3828
std::vector<Detection>
3929
ObjectDetection::postprocess(const std::vector<EValue> &tensors,
4030
cv::Size originalSize, double detectionThreshold) {
@@ -77,9 +67,10 @@ ObjectDetection::postprocess(const std::vector<EValue> &tensors,
7767

7868
std::vector<Detection> ObjectDetection::forward(std::string imageSource,
7969
double detectionThreshold) {
80-
auto [tensor, originalSize] = preprocess(imageSource);
70+
auto [inputTensor, originalSize] =
71+
imageprocessing::readImageToTensor(imageSource, getInputShape()[0]);
8172

82-
auto forwardResult = forwardET(tensor);
73+
auto forwardResult = forwardET(inputTensor);
8374
if (!forwardResult.ok()) {
8475
throw std::runtime_error(
8576
"Failed to forward, error: " +

packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ class ObjectDetection : public BaseModel {
2121
double detectionThreshold);
2222

2323
private:
24-
std::pair<TensorPtr, cv::Size> preprocess(const std::string &imageSource);
2524
std::vector<Detection> postprocess(const std::vector<EValue> &tensors,
2625
cv::Size originalSize,
2726
double detectionThreshold);

packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,6 @@ StyleTransfer::StyleTransfer(const std::string &modelSource,
3333
modelInputShape[modelInputShape.size() - 2]);
3434
}
3535

36-
std::pair<TensorPtr, cv::Size>
37-
StyleTransfer::preprocess(const std::string &imageSource) {
38-
cv::Mat image = imageprocessing::readImage(imageSource);
39-
auto originalSize = image.size();
40-
cv::resize(image, image, modelImageSize);
41-
42-
return {imageprocessing::getTensorFromMatrix(getInputShape()[0], image),
43-
originalSize};
44-
}
45-
4636
std::string StyleTransfer::postprocess(const Tensor &tensor,
4737
cv::Size originalSize) {
4838
cv::Mat mat = imageprocessing::getMatrixFromTensor(modelImageSize, tensor);
@@ -52,9 +42,10 @@ std::string StyleTransfer::postprocess(const Tensor &tensor,
5242
}
5343

5444
std::string StyleTransfer::forward(std::string imageSource) {
55-
auto [tensor, originalSize] = preprocess(imageSource);
45+
auto [inputTensor, originalSize] =
46+
imageprocessing::readImageToTensor(imageSource, getInputShape()[0]);
5647

57-
auto forwardResult = forwardET(tensor);
48+
auto forwardResult = forwardET(inputTensor);
5849
if (!forwardResult.ok()) {
5950
throw std::runtime_error(
6051
"Failed to forward, error: " +

0 commit comments

Comments
 (0)