lin
2025-08-21 57113df3a0e2be01232281fad9a5f2c060567981
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
 
// Feature processing for FFModel (feed-forward SmartSelection model).
 
#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_
#define LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_
 
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
 
#include "annotator/cached-features.h"
#include "annotator/model_generated.h"
#include "annotator/types.h"
#include "utils/base/integral_types.h"
#include "utils/base/logging.h"
#include "utils/token-feature-extractor.h"
#include "utils/tokenizer.h"
#include "utils/utf8/unicodetext.h"
#include "utils/utf8/unilib.h"
 
namespace libtextclassifier3 {
 
constexpr int kInvalidLabel = -1;
 
namespace internal {
 
Tokenizer BuildTokenizer(const FeatureProcessorOptions* options,
                         const UniLib* unilib);
 
TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
    const FeatureProcessorOptions* options);
 
// Splits tokens that contain the selection boundary inside them.
// E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
                                      std::vector<Token>* tokens);
 
// Returns the index of token that corresponds to the codepoint span.
int CenterTokenFromClick(CodepointSpan span, const std::vector<Token>& tokens);
 
// Returns the index of token that corresponds to the middle of the  codepoint
// span.
int CenterTokenFromMiddleOfSelection(
    CodepointSpan span, const std::vector<Token>& selectable_tokens);
 
// Strips the tokens from the tokens vector that are not used for feature
// extraction because they are out of scope, or pads them so that there is
// enough tokens in the required context_size for all inferences with a click
// in relative_click_span.
void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
                      std::vector<Token>* tokens, int* click_pos);
 
}  // namespace internal
 
// Converts a codepoint span to a token span in the given list of tokens.
// If snap_boundaries_to_containing_tokens is set to true, it is enough for a
// token to overlap with the codepoint range to be considered part of it.
// Otherwise it must be fully included in the range.
TokenSpan CodepointSpanToTokenSpan(
    const std::vector<Token>& selectable_tokens, CodepointSpan codepoint_span,
    bool snap_boundaries_to_containing_tokens = false);
 
// Converts a token span to a codepoint span in the given list of tokens.
CodepointSpan TokenSpanToCodepointSpan(
    const std::vector<Token>& selectable_tokens, TokenSpan token_span);
 
// Takes care of preparing features for the span prediction model.
class FeatureProcessor {
 public:
  // A cache mapping codepoint spans to embedded tokens features. An instance
  // can be provided to multiple calls to ExtractFeatures() operating on the
  // same context (the same codepoint spans corresponding to the same tokens),
  // as an optimization. Note that the tokenizations do not have to be
  // identical.
  typedef std::map<CodepointSpan, std::vector<float>> EmbeddingCache;
 
  FeatureProcessor(const FeatureProcessorOptions* options, const UniLib* unilib)
      : feature_extractor_(internal::BuildTokenFeatureExtractorOptions(options),
                           *unilib),
        options_(options),
        tokenizer_(internal::BuildTokenizer(options, unilib)) {
    MakeLabelMaps();
    if (options->supported_codepoint_ranges() != nullptr) {
      SortCodepointRanges({options->supported_codepoint_ranges()->begin(),
                           options->supported_codepoint_ranges()->end()},
                          &supported_codepoint_ranges_);
    }
    PrepareIgnoredSpanBoundaryCodepoints();
  }
 
  // Tokenizes the input string using the selected tokenization method.
  std::vector<Token> Tokenize(const std::string& text) const;
 
  // Same as above but takes UnicodeText.
  std::vector<Token> Tokenize(const UnicodeText& text_unicode) const;
 
  // Converts a label into a token span.
  bool LabelToTokenSpan(int label, TokenSpan* token_span) const;
 
  // Gets the total number of selection labels.
  int GetSelectionLabelCount() const { return label_to_selection_.size(); }
 
  // Gets the string value for given collection label.
  std::string LabelToCollection(int label) const;
 
  // Gets the total number of collections of the model.
  int NumCollections() const { return collection_to_label_.size(); }
 
  // Gets the name of the default collection.
  std::string GetDefaultCollection() const;
 
  const FeatureProcessorOptions* GetOptions() const { return options_; }
 
  // Retokenizes the context and input span, and finds the click position.
  // Depending on the options, might modify tokens (split them or remove them).
  void RetokenizeAndFindClick(const std::string& context,
                              CodepointSpan input_span,
                              bool only_use_line_with_click,
                              std::vector<Token>* tokens, int* click_pos) const;
 
  // Same as above but takes UnicodeText.
  void RetokenizeAndFindClick(const UnicodeText& context_unicode,
                              CodepointSpan input_span,
                              bool only_use_line_with_click,
                              std::vector<Token>* tokens, int* click_pos) const;
 
  // Returns true if the token span has enough supported codepoints (as defined
  // in the model config) or not and model should not run.
  bool HasEnoughSupportedCodepoints(const std::vector<Token>& tokens,
                                    TokenSpan token_span) const;
 
  // Extracts features as a CachedFeatures object that can be used for repeated
  // inference over token spans in the given context.
  bool ExtractFeatures(const std::vector<Token>& tokens, TokenSpan token_span,
                       CodepointSpan selection_span_for_feature,
                       const EmbeddingExecutor* embedding_executor,
                       EmbeddingCache* embedding_cache, int feature_vector_size,
                       std::unique_ptr<CachedFeatures>* cached_features) const;
 
  // Fills selection_label_spans with CodepointSpans that correspond to the
  // selection labels. The CodepointSpans are based on the codepoint ranges of
  // given tokens.
  bool SelectionLabelSpans(
      VectorSpan<Token> tokens,
      std::vector<CodepointSpan>* selection_label_spans) const;
 
  int DenseFeaturesCount() const {
    return feature_extractor_.DenseFeaturesCount();
  }
 
  int EmbeddingSize() const { return options_->embedding_size(); }
 
  // Splits context to several segments.
  std::vector<UnicodeTextRange> SplitContext(
      const UnicodeText& context_unicode) const;
 
  // Strips boundary codepoints from the span in context and returns the new
  // start and end indices. If the span comprises entirely of boundary
  // codepoints, the first index of span is returned for both indices.
  CodepointSpan StripBoundaryCodepoints(const std::string& context,
                                        CodepointSpan span) const;
 
  // Same as above but takes UnicodeText.
  CodepointSpan StripBoundaryCodepoints(const UnicodeText& context_unicode,
                                        CodepointSpan span) const;
 
  // Same as above but takes a pair of iterators for the span, for efficiency.
  CodepointSpan StripBoundaryCodepoints(
      const UnicodeText::const_iterator& span_begin,
      const UnicodeText::const_iterator& span_end, CodepointSpan span) const;
 
  // Same as above, but takes an optional buffer for saving the modified value.
  // As an optimization, returns pointer to 'value' if nothing was stripped, or
  // pointer to 'buffer' if something was stripped.
  const std::string& StripBoundaryCodepoints(const std::string& value,
                                             std::string* buffer) const;
 
 protected:
  // Returns the class id corresponding to the given string collection
  // identifier. There is a catch-all class id that the function returns for
  // unknown collections.
  int CollectionToLabel(const std::string& collection) const;
 
  // Prepares mapping from collection names to labels.
  void MakeLabelMaps();
 
  // Gets the number of spannable tokens for the model.
  //
  // Spannable tokens are those tokens of context, which the model predicts
  // selection spans over (i.e., there is 1:1 correspondence between the output
  // classes of the model and each of the spannable tokens).
  int GetNumContextTokens() const { return options_->context_size() * 2 + 1; }
 
  // Converts a label into a span of codepoint indices corresponding to it
  // given output_tokens.
  bool LabelToSpan(int label, const VectorSpan<Token>& output_tokens,
                   CodepointSpan* span) const;
 
  // Converts a span to the corresponding label given output_tokens.
  bool SpanToLabel(const std::pair<CodepointIndex, CodepointIndex>& span,
                   const std::vector<Token>& output_tokens, int* label) const;
 
  // Converts a token span to the corresponding label.
  int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const;
 
  // Returns the ratio of supported codepoints to total number of codepoints in
  // the given token span.
  float SupportedCodepointsRatio(const TokenSpan& token_span,
                                 const std::vector<Token>& tokens) const;
 
  void PrepareIgnoredSpanBoundaryCodepoints();
 
  // Counts the number of span boundary codepoints. If count_from_beginning is
  // True, the counting will start at the span_start iterator (inclusive) and at
  // maximum end at span_end (exclusive). If count_from_beginning is True, the
  // counting will start from span_end (exclusive) and end at span_start
  // (inclusive).
  int CountIgnoredSpanBoundaryCodepoints(
      const UnicodeText::const_iterator& span_start,
      const UnicodeText::const_iterator& span_end,
      bool count_from_beginning) const;
 
  // Finds the center token index in tokens vector, using the method defined
  // in options_.
  int FindCenterToken(CodepointSpan span,
                      const std::vector<Token>& tokens) const;
 
  // Removes all tokens from tokens that are not on a line (defined by calling
  // SplitContext on the context) to which span points.
  void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
                                 std::vector<Token>* tokens) const;
 
  // Same as above but takes UnicodeText.
  void StripTokensFromOtherLines(const UnicodeText& context_unicode,
                                 CodepointSpan span,
                                 std::vector<Token>* tokens) const;
 
  // Extracts the features of a token and appends them to the output vector.
  // Uses the embedding cache to to avoid re-extracting the re-embedding the
  // sparse features for the same token.
  bool AppendTokenFeaturesWithCache(const Token& token,
                                    CodepointSpan selection_span_for_feature,
                                    const EmbeddingExecutor* embedding_executor,
                                    EmbeddingCache* embedding_cache,
                                    std::vector<float>* output_features) const;
 
 protected:
  const TokenFeatureExtractor feature_extractor_;
 
  // Codepoint ranges that define what codepoints are supported by the model.
  // NOTE: Must be sorted.
  std::vector<CodepointRangeStruct> supported_codepoint_ranges_;
 
 private:
  // Set of codepoints that will be stripped from beginning and end of
  // predicted spans.
  std::set<int32> ignored_span_boundary_codepoints_;
 
  const FeatureProcessorOptions* const options_;
 
  // Mapping between token selection spans and labels ids.
  std::map<TokenSpan, int> selection_to_label_;
  std::vector<TokenSpan> label_to_selection_;
 
  // Mapping between collections and labels.
  std::map<std::string, int> collection_to_label_;
 
  Tokenizer tokenizer_;
};
 
}  // namespace libtextclassifier3
 
#endif  // LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_