Skip to content

Commit 905daea

Browse files
committed
Manage input inside UNet
1 parent 4d09ed4 commit 905daea

File tree

5 files changed

+21
-22
lines changed

5 files changed

+21
-22
lines changed

packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Decoder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Decoder::Decoder(const std::string &modelSource, int32_t modelImageSize,
1414
: BaseModel(modelSource, callInvoker), modelImageSize(modelImageSize),
1515
numChannels(numChannels) {}
1616

17-
std::vector<float> Decoder::generate(std::vector<float> &input) {
17+
std::vector<float> Decoder::generate(std::vector<float> &input) const {
1818
constexpr int32_t latentDownsample = 8;
1919
const int32_t latentsImageSize =
2020
std::floor(modelImageSize / latentDownsample);

packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Decoder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class Decoder final : public BaseModel {
1717
explicit Decoder(const std::string &modelSource, int32_t modelImageSize,
1818
int32_t numChannels,
1919
std::shared_ptr<react::CallInvoker> callInvoker);
20-
std::vector<float> generate(std::vector<float> &input);
20+
std::vector<float> generate(std::vector<float> &input) const;
2121

2222
private:
2323
int32_t modelImageSize;

packages/react-native-executorch/common/rnexecutorch/models/text_to_image/TextToImage.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ TextToImage::generate(std::string input, size_t numInferenceSteps,
4747
embeddingsTextPtr + embeddingsSize);
4848

4949
constexpr int32_t latentDownsample = 8;
50-
int32_t latentsWidth = std::floor(modelImageSize / latentDownsample);
51-
int32_t latentsSize = numChannels * latentsWidth * latentsWidth;
52-
std::vector<float> latents(latentsSize);
50+
int32_t latentsSize = std::floor(modelImageSize / latentDownsample);
51+
int32_t latentsImageSize = numChannels * latentsSize * latentsSize;
52+
std::vector<float> latents(latentsImageSize);
5353
std::random_device rd;
5454
std::mt19937 gen(rd());
5555
std::normal_distribution<float> dist(0.0, 1.0);
@@ -71,13 +71,9 @@ TextToImage::generate(std::string input, size_t numInferenceSteps,
7171
return postprocess({});
7272
}
7373
log(LOG_LEVEL::Debug, "Step:", t, "/", numInferenceSteps);
74-
std::vector<float> latentsConcat;
75-
latentsConcat.reserve(2 * latentsSize);
76-
latentsConcat.insert(latentsConcat.end(), latents.begin(), latents.end());
77-
latentsConcat.insert(latentsConcat.end(), latents.begin(), latents.end());
7874

7975
std::vector<float> noisePred =
80-
unet->generate(latentsConcat, timesteps[t], embeddingsConcat);
76+
unet->generate(latents, timesteps[t], embeddingsConcat);
8177

8278
size_t noiseSize = noisePred.size() / 2;
8379
std::span<const float> noisePredSpan{noisePred};

packages/react-native-executorch/common/rnexecutorch/models/text_to_image/UNet.cpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,29 @@ using namespace executorch::extension;
88

99
UNet::UNet(const std::string &modelSource, int32_t modelImageSize,
1010
int32_t numChannels, std::shared_ptr<react::CallInvoker> callInvoker)
11-
: BaseModel(modelSource, callInvoker), modelImageSize(modelImageSize),
12-
numChannels(numChannels) {}
11+
: BaseModel(modelSource, callInvoker), numChannels(numChannels) {
12+
constexpr int32_t latentDownsample = 8;
13+
latentsSize = std::floor(modelImageSize / latentDownsample);
14+
}
1315

1416
std::vector<float> UNet::generate(std::vector<float> &latents, int32_t timestep,
15-
std::vector<float> &embeddings) {
16-
constexpr int32_t latentDownsample = 8;
17-
const int32_t latentsImageSize =
18-
std::floor(modelImageSize / latentDownsample);
19-
std::vector<int32_t> latentsShape = {2, numChannels, latentsImageSize,
20-
latentsImageSize};
17+
std::vector<float> &embeddings) const {
18+
std::vector<float> latentsConcat;
19+
latentsConcat.reserve(2 * latentsSize);
20+
latentsConcat.insert(latentsConcat.end(), latents.begin(), latents.end());
21+
latentsConcat.insert(latentsConcat.end(), latents.begin(), latents.end());
22+
23+
std::vector<int32_t> latentsShape = {2, numChannels, latentsSize,
24+
latentsSize};
2125
std::vector<int32_t> timestepShape = {1};
2226
std::vector<int32_t> embeddingsShape = {2, 77, 768};
2327

24-
// TODO change after reexporting the model
2528
std::vector<int64_t> timestepData = {static_cast<int64_t>(timestep)};
2629
auto timestepTensor =
2730
make_tensor_ptr(timestepShape, timestepData.data(), ScalarType::Long);
2831

2932
auto latentsTensor =
30-
make_tensor_ptr(latentsShape, latents.data(), ScalarType::Float);
33+
make_tensor_ptr(latentsShape, latentsConcat.data(), ScalarType::Float);
3134
auto embeddingsTensor =
3235
make_tensor_ptr(embeddingsShape, embeddings.data(), ScalarType::Float);
3336

packages/react-native-executorch/common/rnexecutorch/models/text_to_image/UNet.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ class UNet final : public BaseModel {
1818
int32_t numChannels,
1919
std::shared_ptr<react::CallInvoker> callInvoker);
2020
std::vector<float> generate(std::vector<float> &latents, int32_t timestep,
21-
std::vector<float> &embeddings);
21+
std::vector<float> &embeddings) const;
2222

2323
private:
24-
int32_t modelImageSize;
2524
int32_t numChannels;
25+
int32_t latentsSize;
2626
};
2727
} // namespace models::text_to_image
2828

0 commit comments

Comments
 (0)