lin
2025-08-01 633231e833e21d5b8b1c00cb15aedb62b3b78e8f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
 
#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_PARAMS_H_
#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_PARAMS_H_
 
#include <string>
 
#include "lang_id/common/fel/task-context.h"
#include "lang_id/common/lite_base/float16.h"
#include "lang_id/common/lite_base/logging.h"
 
namespace libtextclassifier3 {
 
enum class QuantizationType {
  NONE = 0,
 
  // Quantization to 8 bit unsigned ints.
  UINT8,
 
  // Quantization to 4 bit unsigned ints.
  UINT4,
 
  // Quantization to 16 bit floats, the type defined in
  // lang_id/common/float16.h
  FLOAT16,
 
  // NOTE: for backward compatibility, if you add a new value to this enum, add
  // it *at the end*, such that you do not change the integer values of the
  // existing enum values.
};
 
// Converts "UINT8" -> QuantizationType::UINT8, and so on.
QuantizationType ParseQuantizationType(const string &s);
 
// API for accessing parameters for a feed-forward neural network with
// embeddings.
//
//
// In fact, we provide two APIs: a high-level (and highly-recommented) API, with
// methods named using the BigCamel notation (e.g., GetEmbeddingMatrix()) and a
// low-level API, using C-style names (e.g., softmax_num_cols()).
//
// Note: the API below is meant to allow the inference code (the class
// libtextclassifier3::mobile::EmbeddingNetwork) to use the data directly, with no need
// for transposing any matrix (which would require extra overhead on mobile
// devices).  Hence, as indicated by the comments for the API methods, some of
// the matrices below are the transposes of the corresponding matrices from the
// original proto.
class EmbeddingNetworkParams {
 public:
  virtual ~EmbeddingNetworkParams() {}
 
  // Returns true if these params are valid.  False otherwise (e.g., if the
  // underlying data is corrupted).  If is_valid() returns false, clients should
  // not call any other method on that instance of EmbeddingNetworkParams.  If
  // is_valid() returns true, then calls to the API methods below should not
  // crash *if they are called with index parameters in bounds*.  E.g., if
  // is_valid() and 0 <= i < embeddings_size(), then GetEmbeddingMatrix(i)
  // should not crash.
  virtual bool is_valid() const = 0;
 
  // **** High-level API.
 
  // Simple representation of a matrix.  This small struct that doesn't own any
  // resource intentionally supports copy / assign, to simplify our APIs.
  struct Matrix {
    // Number of rows.
    int rows = 0;
 
    // Number of columns.
    int cols = 0;
 
    QuantizationType quant_type = QuantizationType::NONE;
 
    // Pointer to matrix elements, in row-major order
    // (https://en.wikipedia.org/wiki/Row-major_order) Not owned.
    const void *elements = nullptr;
 
    // Quantization scales: one scale for each row.
    const ::libtextclassifier3::mobile::float16 *quant_scales = nullptr;
  };
 
  // Returns i-th embedding matrix.  Crashes on out of bounds indices.
  //
  // This is the transpose of the corresponding matrix from the original proto.
  Matrix GetEmbeddingMatrix(int i) const {
    CheckIndex(i, embeddings_size(), "embedding matrix");
    Matrix matrix;
    matrix.rows = embeddings_num_rows(i);
    matrix.cols = embeddings_num_cols(i);
    matrix.elements = embeddings_weights(i);
    matrix.quant_type = embeddings_quant_type(i);
    matrix.quant_scales = embeddings_quant_scales(i);
    return matrix;
  }
 
  // Returns weight matrix for i-th hidden layer.  Crashes on out of bounds
  // indices.
  //
  // This is the transpose of the corresponding matrix from the original proto.
  Matrix GetHiddenLayerMatrix(int i) const {
    CheckIndex(i, hidden_size(), "hidden layer");
    Matrix matrix;
    matrix.rows = hidden_num_rows(i);
    matrix.cols = hidden_num_cols(i);
 
    // Quantization not supported here.
    matrix.quant_type = hidden_weights_quant_type(i);
    matrix.elements = hidden_weights(i);
    return matrix;
  }
 
  // Returns bias for i-th hidden layer.  Technically a Matrix, but we expect it
  // to be a row/column vector (i.e., num rows or num cols is 1).  However, we
  // don't CHECK for that: we just provide access to underlying data.  Crashes
  // on out of bounds indices.
  Matrix GetHiddenLayerBias(int i) const {
    CheckIndex(i, hidden_bias_size(), "hidden layer bias");
    Matrix matrix;
    matrix.rows = hidden_bias_num_rows(i);
    matrix.cols = hidden_bias_num_cols(i);
 
    // Quantization not supported here.
    matrix.quant_type = QuantizationType::NONE;
    matrix.elements = hidden_bias_weights(i);
    return matrix;
  }
 
  // Returns true if a softmax layer exists.
  bool HasSoftmax() const {
    return softmax_size() == 1;
  }
 
  // Returns weight matrix for the softmax layer.  Note: should be called only
  // if HasSoftmax() is true.
  //
  // This is the transpose of the corresponding matrix from the original proto.
  Matrix GetSoftmaxMatrix() const {
    SAFTM_CHECK(HasSoftmax()) << "No softmax layer.";
    Matrix matrix;
    matrix.rows = softmax_num_rows(0);
    matrix.cols = softmax_num_cols(0);
 
    // Quantization not supported here.
    matrix.quant_type = softmax_weights_quant_type(0);
    matrix.elements = softmax_weights(0);
    return matrix;
  }
 
  // Returns bias for the softmax layer.  Technically a Matrix, but we expect it
  // to be a row/column vector (i.e., num rows or num cols is 1).  However, we
  // don't CHECK for that: we just provide access to underlying data.
  Matrix GetSoftmaxBias() const {
    SAFTM_CHECK(HasSoftmax()) << "No softmax layer.";
    Matrix matrix;
    matrix.rows = softmax_bias_num_rows(0);
    matrix.cols = softmax_bias_num_cols(0);
 
    // Quantization not supported here.
    matrix.quant_type = QuantizationType::NONE;
    matrix.elements = softmax_bias_weights(0);
    return matrix;
  }
 
  // Updates the EmbeddingNetwork-related parameters from task_context.  Returns
  // true on success, false on error.
  virtual bool UpdateTaskContextParameters(
      mobile::TaskContext *task_context) = 0;
 
  // **** Low-level API.
  //
  // * Most low-level API methods are documented by giving an equivalent
  //   function call on proto, the original proto (of type
  //   EmbeddingNetworkProto) which was used to generate the C++ code.
  //
  // * To simplify our generation code, optional proto fields of message type
  //   are treated as repeated fields with 0 or 1 instances.  As such, we have
  //   *_size() methods for such optional fields: they return 0 or 1.
  //
  // * "transpose(M)" denotes the transpose of a matrix M.
 
  // ** Access methods for repeated MatrixParams embeddings.
  //
  // Returns proto.embeddings_size().
  virtual int embeddings_size() const = 0;
 
  // Returns number of rows of transpose(proto.embeddings(i)).
  virtual int embeddings_num_rows(int i) const = 0;
 
  // Returns number of columns of transpose(proto.embeddings(i)).
  virtual int embeddings_num_cols(int i) const = 0;
 
  // Returns pointer to elements of transpose(proto.embeddings(i)), in row-major
  // order.  NOTE: for unquantized embeddings, this returns a pointer to float;
  // for quantized embeddings, this returns a pointer to uint8.
  virtual const void *embeddings_weights(int i) const = 0;
 
  virtual QuantizationType embeddings_quant_type(int i) const {
    return QuantizationType::NONE;
  }
 
  virtual const ::libtextclassifier3::mobile::float16 *embeddings_quant_scales(
      int i) const {
    return nullptr;
  }
 
  // ** Access methods for repeated MatrixParams hidden.
  //
  // Returns embedding_network_proto.hidden_size().
  virtual int hidden_size() const = 0;
 
  // Returns embedding_network_proto.hidden(i).rows().
  virtual int hidden_num_rows(int i) const = 0;
 
  // Returns embedding_network_proto.hidden(i).rows().
  virtual int hidden_num_cols(int i) const = 0;
 
  // Returns quantization mode for the weights of the i-th hidden layer.
  virtual QuantizationType hidden_weights_quant_type(int i) const {
    return QuantizationType::NONE;
  }
 
  // Returns pointer to beginning of array of floats with all values from
  // embedding_network_proto.hidden(i).
  virtual const void *hidden_weights(int i) const = 0;
 
  // ** Access methods for repeated MatrixParams hidden_bias.
  //
  // Returns proto.hidden_bias_size().
  virtual int hidden_bias_size() const = 0;
 
  // Returns number of rows of proto.hidden_bias(i).
  virtual int hidden_bias_num_rows(int i) const = 0;
 
  // Returns number of columns of proto.hidden_bias(i).
  virtual int hidden_bias_num_cols(int i) const = 0;
 
  // Returns pointer to elements of proto.hidden_bias(i), in row-major order.
  virtual const void *hidden_bias_weights(int i) const = 0;
 
  // ** Access methods for optional MatrixParams softmax.
  //
  // Returns 1 if proto has optional field softmax, 0 otherwise.
  virtual int softmax_size() const = 0;
 
  // Returns number of rows of transpose(proto.softmax()).
  virtual int softmax_num_rows(int i) const = 0;
 
  // Returns number of columns of transpose(proto.softmax()).
  virtual int softmax_num_cols(int i) const = 0;
 
  // Returns quantization mode for the softmax weights.
  virtual QuantizationType softmax_weights_quant_type(int i) const {
    return QuantizationType::NONE;
  }
 
  // Returns pointer to elements of transpose(proto.softmax()), in row-major
  // order.
  virtual const void *softmax_weights(int i) const = 0;
 
  // ** Access methods for optional MatrixParams softmax_bias.
  //
  // Returns 1 if proto has optional field softmax_bias, 0 otherwise.
  virtual int softmax_bias_size() const = 0;
 
  // Returns number of rows of proto.softmax_bias().
  virtual int softmax_bias_num_rows(int i) const = 0;
 
  // Returns number of columns of proto.softmax_bias().
  virtual int softmax_bias_num_cols(int i) const = 0;
 
  // Returns pointer to elements of proto.softmax_bias(), in row-major order.
  virtual const void *softmax_bias_weights(int i) const = 0;
 
  // ** Access methods for repeated int32 embedding_num_features.
  //
  // Returns proto.embedding_num_features_size().
  virtual int embedding_num_features_size() const = 0;
 
  // Returns proto.embedding_num_features(i).
  virtual int embedding_num_features(int i) const = 0;
 
  // ** Access methods for is_precomputed
  //
  // Returns proto.has_is_precomputed().
  virtual bool has_is_precomputed() const = 0;
 
  // Returns proto.is_precomputed().
  virtual bool is_precomputed() const = 0;
 
 protected:
  void CheckIndex(int index, int size, const string &description) const {
    SAFTM_CHECK_GE(index, 0)
        << "Out-of-range index for " << description << ": " << index;
    SAFTM_CHECK_LT(index, size)
        << "Out-of-range index for " << description << ": " << index;
  }
};  // class EmbeddingNetworkParams
 
}  // namespace nlp_saft
 
#endif  // NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_PARAMS_H_