lin
2025-07-30 fcd736bf35fd93b563e9bbf594f2aa7b62028cc9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
 
#include "actions/ngram-model.h"
 
#include <algorithm>
 
#include "actions/feature-processor.h"
#include "utils/hash/farmhash.h"
#include "utils/strings/stringpiece.h"
 
namespace libtextclassifier3 {
namespace {
 
// An iterator to iterate over the initial tokens of the n-grams of a model.
class FirstTokenIterator
    : public std::iterator<std::random_access_iterator_tag,
                           /*value_type=*/uint32, /*difference_type=*/ptrdiff_t,
                           /*pointer=*/const uint32*,
                           /*reference=*/uint32&> {
 public:
  explicit FirstTokenIterator(const NGramLinearRegressionModel* model,
                              int index)
      : model_(model), index_(index) {}
 
  FirstTokenIterator& operator++() {
    index_++;
    return *this;
  }
  FirstTokenIterator& operator+=(ptrdiff_t dist) {
    index_ += dist;
    return *this;
  }
  ptrdiff_t operator-(const FirstTokenIterator& other_it) const {
    return index_ - other_it.index_;
  }
  uint32 operator*() const {
    const uint32 token_offset = (*model_->ngram_start_offsets())[index_];
    return (*model_->hashed_ngram_tokens())[token_offset];
  }
  int index() const { return index_; }
 
 private:
  const NGramLinearRegressionModel* model_;
  int index_;
};
 
}  // anonymous namespace
 
std::unique_ptr<NGramModel> NGramModel::Create(
    const NGramLinearRegressionModel* model, const Tokenizer* tokenizer,
    const UniLib* unilib) {
  if (model == nullptr) {
    return nullptr;
  }
  if (tokenizer == nullptr && model->tokenizer_options() == nullptr) {
    TC3_LOG(ERROR) << "No tokenizer options specified.";
    return nullptr;
  }
  return std::unique_ptr<NGramModel>(new NGramModel(model, tokenizer, unilib));
}
 
NGramModel::NGramModel(const NGramLinearRegressionModel* model,
                       const Tokenizer* tokenizer, const UniLib* unilib)
    : model_(model) {
  // Create new tokenizer if options are specified, reuse feature processor
  // tokenizer otherwise.
  if (model->tokenizer_options() != nullptr) {
    owned_tokenizer_ = CreateTokenizer(model->tokenizer_options(), unilib);
    tokenizer_ = owned_tokenizer_.get();
  } else {
    tokenizer_ = tokenizer;
  }
}
 
// Returns whether a given n-gram matches the token stream.
bool NGramModel::IsNGramMatch(const uint32* tokens, size_t num_tokens,
                              const uint32* ngram_tokens,
                              size_t num_ngram_tokens, int max_skips) const {
  int token_idx = 0, ngram_token_idx = 0, skip_remain = 0;
  for (; token_idx < num_tokens && ngram_token_idx < num_ngram_tokens;) {
    if (tokens[token_idx] == ngram_tokens[ngram_token_idx]) {
      // Token matches. Advance both and reset the skip budget.
      ++token_idx;
      ++ngram_token_idx;
      skip_remain = max_skips;
    } else if (skip_remain > 0) {
      // No match, but we have skips left, so just advance over the token.
      ++token_idx;
      skip_remain--;
    } else {
      // No match and we're out of skips. Reject.
      return false;
    }
  }
  return ngram_token_idx == num_ngram_tokens;
}
 
// Calculates the total number of skip-grams that can be created for a stream
// with the given number of tokens.
uint64 NGramModel::GetNumSkipGrams(int num_tokens, int max_ngram_length,
                                   int max_skips) {
  // Start with unigrams.
  uint64 total = num_tokens;
  for (int ngram_len = 2;
       ngram_len <= max_ngram_length && ngram_len <= num_tokens; ++ngram_len) {
    // We can easily compute the expected length of the n-gram (with skips),
    // but it doesn't account for the fact that they may be longer than the
    // input and should be pruned.
    // Instead, we iterate over the distribution of effective n-gram lengths
    // and add each length individually.
    const int num_gaps = ngram_len - 1;
    const int len_min = ngram_len;
    const int len_max = ngram_len + num_gaps * max_skips;
    const int len_mid = (len_max + len_min) / 2;
    for (int len_i = len_min; len_i <= len_max; ++len_i) {
      if (len_i > num_tokens) continue;
      const int num_configs_of_len_i =
          len_i <= len_mid ? len_i - len_min + 1 : len_max - len_i + 1;
      const int num_start_offsets = num_tokens - len_i + 1;
      total += num_configs_of_len_i * num_start_offsets;
    }
  }
  return total;
}
 
std::pair<int, int> NGramModel::GetFirstTokenMatches(uint32 token_hash) const {
  const int num_ngrams = model_->ngram_weights()->size();
  const auto start_it = FirstTokenIterator(model_, 0);
  const auto end_it = FirstTokenIterator(model_, num_ngrams);
  const int start = std::lower_bound(start_it, end_it, token_hash).index();
  const int end = std::upper_bound(start_it, end_it, token_hash).index();
  return std::make_pair(start, end);
}
 
bool NGramModel::Eval(const UnicodeText& text, float* score) const {
  const std::vector<Token> raw_tokens = tokenizer_->Tokenize(text);
 
  // If we have no tokens, then just bail early.
  if (raw_tokens.empty()) {
    if (score != nullptr) {
      *score = model_->default_token_weight();
    }
    return false;
  }
 
  // Hash the tokens.
  std::vector<uint32> tokens;
  tokens.reserve(raw_tokens.size());
  for (const Token& raw_token : raw_tokens) {
    tokens.push_back(tc3farmhash::Fingerprint32(raw_token.value.data(),
                                                raw_token.value.length()));
  }
 
  // Calculate the total number of skip-grams that can be generated for the
  // input text.
  const uint64 num_candidates = GetNumSkipGrams(
      tokens.size(), model_->max_denom_ngram_length(), model_->max_skips());
 
  // For each token, see whether it denotes the start of an n-gram in the model.
  int num_matches = 0;
  float weight_matches = 0.f;
  for (size_t start_i = 0; start_i < tokens.size(); ++start_i) {
    const std::pair<int, int> ngram_range =
        GetFirstTokenMatches(tokens[start_i]);
    for (int ngram_idx = ngram_range.first; ngram_idx < ngram_range.second;
         ++ngram_idx) {
      const uint16 ngram_tokens_begin =
          (*model_->ngram_start_offsets())[ngram_idx];
      const uint16 ngram_tokens_end =
          (*model_->ngram_start_offsets())[ngram_idx + 1];
      if (IsNGramMatch(
              /*tokens=*/tokens.data() + start_i,
              /*num_tokens=*/tokens.size() - start_i,
              /*ngram_tokens=*/model_->hashed_ngram_tokens()->data() +
                  ngram_tokens_begin,
              /*num_ngram_tokens=*/ngram_tokens_end - ngram_tokens_begin,
              /*max_skips=*/model_->max_skips())) {
        ++num_matches;
        weight_matches += (*model_->ngram_weights())[ngram_idx];
      }
    }
  }
 
  // Calculate the score.
  const int num_misses = num_candidates - num_matches;
  const float internal_score =
      (weight_matches + (model_->default_token_weight() * num_misses)) /
      num_candidates;
  if (score != nullptr) {
    *score = internal_score;
  }
  return internal_score > model_->threshold();
}
 
}  // namespace libtextclassifier3