Skip to content

Commit 5b57f3d

Browse files
committed
Extract encoder class
1 parent 905daea commit 5b57f3d

File tree

4 files changed

+89
-24
lines changed

4 files changed

+89
-24
lines changed
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#include "Encoder.h"
2+
3+
#include <cmath>
4+
#include <random>
5+
#include <span>
6+
7+
#include <rnexecutorch/models/text_to_image/Constants.h>
8+
9+
namespace rnexecutorch::models::text_to_image {
10+
11+
Encoder::Encoder(const std::string &tokenizerSource,
12+
const std::string &encoderSource,
13+
std::shared_ptr<react::CallInvoker> callInvoker)
14+
: callInvoker(callInvoker),
15+
encoder(std::make_unique<embeddings::TextEmbeddings>(
16+
encoderSource, tokenizerSource, callInvoker)) {}
17+
18+
std::vector<float> Encoder::generate(std::string input) {
19+
std::shared_ptr<OwningArrayBuffer> embeddingsText = encoder->generate(input);
20+
std::shared_ptr<OwningArrayBuffer> embeddingsUncond =
21+
encoder->generate(std::string(constants::kBosToken));
22+
23+
size_t embeddingsSize = embeddingsText->size() / sizeof(float);
24+
auto *embeddingsTextPtr = reinterpret_cast<float *>(embeddingsText->data());
25+
auto *embeddingsUncondPtr =
26+
reinterpret_cast<float *>(embeddingsUncond->data());
27+
28+
std::vector<float> embeddingsConcat;
29+
embeddingsConcat.reserve(embeddingsSize * 2);
30+
embeddingsConcat.insert(embeddingsConcat.end(), embeddingsUncondPtr,
31+
embeddingsUncondPtr + embeddingsSize);
32+
embeddingsConcat.insert(embeddingsConcat.end(), embeddingsTextPtr,
33+
embeddingsTextPtr + embeddingsSize);
34+
return embeddingsConcat;
35+
}
36+
37+
size_t Encoder::getMemoryLowerBound() const noexcept {
38+
return encoder->getMemoryLowerBound();
39+
}
40+
41+
void Encoder::unload() noexcept { encoder.reset(nullptr); }
42+
43+
} // namespace rnexecutorch::models::text_to_image
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#pragma once
2+
3+
#include <memory>
4+
#include <string>
5+
#include <vector>
6+
7+
#include <ReactCommon/CallInvoker.h>
8+
#include <jsi/jsi.h>
9+
10+
#include <rnexecutorch/jsi/OwningArrayBuffer.h>
11+
#include <rnexecutorch/metaprogramming/ConstructorHelpers.h>
12+
13+
#include <rnexecutorch/models/embeddings/text/TextEmbeddings.h>
14+
15+
namespace rnexecutorch {
16+
namespace models::text_to_image {
17+
using namespace facebook;
18+
19+
class Encoder final {
20+
public:
21+
explicit Encoder(const std::string &tokenizerSource,
22+
const std::string &encoderSource,
23+
std::shared_ptr<react::CallInvoker> callInvoker);
24+
std::vector<float> generate(std::string input);
25+
size_t getMemoryLowerBound() const noexcept;
26+
void unload() noexcept;
27+
28+
private:
29+
std::shared_ptr<react::CallInvoker> callInvoker;
30+
std::unique_ptr<embeddings::TextEmbeddings> encoder;
31+
};
32+
} // namespace models::text_to_image
33+
34+
REGISTER_CONSTRUCTOR(models::text_to_image::Encoder, std::string, std::string,
35+
std::shared_ptr<react::CallInvoker>);
36+
} // namespace rnexecutorch

packages/react-native-executorch/common/rnexecutorch/models/text_to_image/TextToImage.cpp

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <span>
66

77
#include <rnexecutorch/Log.h>
8+
#include <rnexecutorch/models/text_to_image/Constants.h>
89

910
namespace rnexecutorch::models::text_to_image {
1011

@@ -20,8 +21,8 @@ TextToImage::TextToImage(const std::string &tokenizerSource,
2021
scheduler(std::make_unique<Scheduler>(
2122
schedulerBetaStart, schedulerBetaEnd, schedulerNumTrainTimesteps,
2223
schedulerStepsOffset, callInvoker)),
23-
encoder(std::make_unique<embeddings::TextEmbeddings>(
24-
encoderSource, tokenizerSource, callInvoker)),
24+
encoder(std::make_unique<Encoder>(tokenizerSource, encoderSource,
25+
callInvoker)),
2526
unet(std::make_unique<UNet>(unetSource, imageSize, numChannels,
2627
callInvoker)),
2728
decoder(std::make_unique<Decoder>(decoderSource, imageSize, numChannels,
@@ -30,21 +31,7 @@ TextToImage::TextToImage(const std::string &tokenizerSource,
3031
std::shared_ptr<OwningArrayBuffer>
3132
TextToImage::generate(std::string input, size_t numInferenceSteps,
3233
std::shared_ptr<jsi::Function> callback) {
33-
std::shared_ptr<OwningArrayBuffer> embeddingsText = encoder->generate(input);
34-
std::shared_ptr<OwningArrayBuffer> embeddingsUncond =
35-
encoder->generate(std::string(constants::kBosToken));
36-
37-
size_t embeddingsSize = embeddingsText->size() / sizeof(float);
38-
auto *embeddingsTextPtr = reinterpret_cast<float *>(embeddingsText->data());
39-
auto *embeddingsUncondPtr =
40-
reinterpret_cast<float *>(embeddingsUncond->data());
41-
42-
std::vector<float> embeddingsConcat;
43-
embeddingsConcat.reserve(embeddingsSize * 2);
44-
embeddingsConcat.insert(embeddingsConcat.end(), embeddingsUncondPtr,
45-
embeddingsUncondPtr + embeddingsSize);
46-
embeddingsConcat.insert(embeddingsConcat.end(), embeddingsTextPtr,
47-
embeddingsTextPtr + embeddingsSize);
34+
std::vector<float> embeddings = encoder->generate(input);
4835

4936
constexpr int32_t latentDownsample = 8;
5037
int32_t latentsSize = std::floor(modelImageSize / latentDownsample);
@@ -73,7 +60,7 @@ TextToImage::generate(std::string input, size_t numInferenceSteps,
7360
log(LOG_LEVEL::Debug, "Step:", t, "/", numInferenceSteps);
7461

7562
std::vector<float> noisePred =
76-
unet->generate(latents, timesteps[t], embeddingsConcat);
63+
unet->generate(latents, timesteps[t], embeddings);
7764

7865
size_t noiseSize = noisePred.size() / 2;
7966
std::span<const float> noisePredSpan{noisePred};
@@ -127,9 +114,9 @@ size_t TextToImage::getMemoryLowerBound() const noexcept {
127114
}
128115

129116
void TextToImage::unload() noexcept {
130-
encoder.reset(nullptr);
131-
unet.reset(nullptr);
132-
decoder.reset(nullptr);
117+
encoder->unload();
118+
unet->unload();
119+
decoder->unload();
133120
}
134121

135122
} // namespace rnexecutorch::models::text_to_image

packages/react-native-executorch/common/rnexecutorch/models/text_to_image/TextToImage.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@
1010
#include <rnexecutorch/jsi/OwningArrayBuffer.h>
1111
#include <rnexecutorch/metaprogramming/ConstructorHelpers.h>
1212

13-
#include <rnexecutorch/models/embeddings/text/TextEmbeddings.h>
14-
#include <rnexecutorch/models/text_to_image/Constants.h>
1513
#include <rnexecutorch/models/text_to_image/Decoder.h>
14+
#include <rnexecutorch/models/text_to_image/Encoder.h>
1615
#include <rnexecutorch/models/text_to_image/Scheduler.h>
1716
#include <rnexecutorch/models/text_to_image/UNet.h>
1817

@@ -50,7 +49,7 @@ class TextToImage final {
5049

5150
std::shared_ptr<react::CallInvoker> callInvoker;
5251
std::unique_ptr<Scheduler> scheduler;
53-
std::unique_ptr<embeddings::TextEmbeddings> encoder;
52+
std::unique_ptr<Encoder> encoder;
5453
std::unique_ptr<UNet> unet;
5554
std::unique_ptr<Decoder> decoder;
5655
};

0 commit comments

Comments
 (0)