huangcm
2025-07-01 676035278781360996553c427a12bf358249ebf7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
 
// Contains classes that can execute different models/parts of a model.
 
#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_
#define LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_
 
#include <memory>
 
#include "annotator/types.h"
#include "utils/base/logging.h"
#include "utils/tensor-view.h"
#include "utils/tflite-model-executor.h"
 
namespace libtextclassifier3 {
 
// Executor for the text selection prediction and classification models.
class ModelExecutor : public TfLiteModelExecutor {
 public:
  static std::unique_ptr<ModelExecutor> FromModelSpec(
      const tflite::Model* model_spec) {
    auto model = TfLiteModelFromModelSpec(model_spec);
    if (!model) {
      return nullptr;
    }
    return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model)));
  }
 
  static std::unique_ptr<ModelExecutor> FromBuffer(
      const flatbuffers::Vector<uint8_t>* model_spec_buffer) {
    auto model = TfLiteModelFromBuffer(model_spec_buffer);
    if (!model) {
      return nullptr;
    }
    return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model)));
  }
 
  TensorView<float> ComputeLogits(const TensorView<float>& features,
                                  tflite::Interpreter* interpreter) const;
 
 protected:
  explicit ModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model)
      : TfLiteModelExecutor(std::move(model)) {}
 
  static const int kInputIndexFeatures = 0;
  static const int kOutputIndexLogits = 0;
};
 
// Executor for embedding sparse features into a dense vector.
class EmbeddingExecutor {
 public:
  virtual ~EmbeddingExecutor() {}
 
  // Embeds the sparse_features into a dense embedding and adds (+) it
  // element-wise to the dest vector.
  virtual bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
                            int dest_size) const = 0;
 
  // Returns true when the model is ready to be used, false otherwise.
  virtual bool IsReady() const { return true; }
};
 
class TFLiteEmbeddingExecutor : public EmbeddingExecutor {
 public:
  static std::unique_ptr<TFLiteEmbeddingExecutor> FromBuffer(
      const flatbuffers::Vector<uint8_t>* model_spec_buffer, int embedding_size,
      int quantization_bits,
      const Model_::EmbeddingPruningMask* embedding_pruning_mask = nullptr);
 
  // Embeds the sparse_features into a dense embedding and adds (+) it
  // element-wise to the dest vector.
  bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
                    int dest_size) const;
 
  // Auxiliary function for computing prefixes used in implementation of
  // efficient mask indexing data structure.
  void ComputePrefixCounts();
 
  // Function implementing mask indexing based on efficient data structure
  int PruneBucketId(int bucket_id) const;
 
 protected:
  explicit TFLiteEmbeddingExecutor(
      std::unique_ptr<TfLiteModelExecutor> executor, int quantization_bits,
      int num_buckets, int bytes_per_embedding, int output_embedding_size,
      const TfLiteTensor* scales, const TfLiteTensor* embeddings,
      std::unique_ptr<tflite::Interpreter> interpreter,
      const Model_::EmbeddingPruningMask* embedding_pruning_mask = nullptr);
 
  std::unique_ptr<TfLiteModelExecutor> executor_;
 
  int quantization_bits_;
  int num_buckets_ = -1;
  int bytes_per_embedding_ = -1;
  int output_embedding_size_ = -1;
  const TfLiteTensor* scales_ = nullptr;
  const TfLiteTensor* embeddings_ = nullptr;
 
  // NOTE: This interpreter is used in a read-only way (as a storage for the
  // model params), thus is still thread-safe.
  std::unique_ptr<tflite::Interpreter> interpreter_;
 
  std::vector<uint64> pruning_mask_;
  std::vector<uint16> prefix_counts_;
  int full_num_buckets_ = -1;
 
  // Index of row of embedding table corresponding to all pruned buckets.
  int pruned_row_bucket_id_ = -1;
};
 
}  // namespace libtextclassifier3
 
#endif  // LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_