/*
|
* Copyright (C) 2018 The Android Open Source Project
|
*
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
*/
|
|
#include "actions/actions-suggestions.h"
|
|
#include <memory>
|
|
#include "actions/lua-actions.h"
|
#include "actions/types.h"
|
#include "actions/zlib-utils.h"
|
#include "utils/base/logging.h"
|
#include "utils/flatbuffers.h"
|
#include "utils/lua-utils.h"
|
#include "utils/regex-match.h"
|
#include "utils/strings/split.h"
|
#include "utils/strings/stringpiece.h"
|
#include "utils/utf8/unicodetext.h"
|
#include "utils/zlib/zlib_regex.h"
|
#include "tensorflow/lite/string_util.h"
|
|
namespace libtextclassifier3 {
|
|
const std::string& ActionsSuggestions::kViewCalendarType =
|
*[]() { return new std::string("view_calendar"); }();
|
const std::string& ActionsSuggestions::kViewMapType =
|
*[]() { return new std::string("view_map"); }();
|
const std::string& ActionsSuggestions::kTrackFlightType =
|
*[]() { return new std::string("track_flight"); }();
|
const std::string& ActionsSuggestions::kOpenUrlType =
|
*[]() { return new std::string("open_url"); }();
|
const std::string& ActionsSuggestions::kSendSmsType =
|
*[]() { return new std::string("send_sms"); }();
|
const std::string& ActionsSuggestions::kCallPhoneType =
|
*[]() { return new std::string("call_phone"); }();
|
const std::string& ActionsSuggestions::kSendEmailType =
|
*[]() { return new std::string("send_email"); }();
|
const std::string& ActionsSuggestions::kShareLocation =
|
*[]() { return new std::string("share_location"); }();
|
|
namespace {
|
|
const ActionsModel* LoadAndVerifyModel(const uint8_t* addr, int size) {
|
flatbuffers::Verifier verifier(addr, size);
|
if (VerifyActionsModelBuffer(verifier)) {
|
return GetActionsModel(addr);
|
} else {
|
return nullptr;
|
}
|
}
|
|
template <typename T>
|
T ValueOrDefault(const flatbuffers::Table* values, const int32 field_offset,
|
const T default_value) {
|
if (values == nullptr) {
|
return default_value;
|
}
|
return values->GetField<T>(field_offset, default_value);
|
}
|
|
// Returns number of (tail) messages of a conversation to consider.
|
int NumMessagesToConsider(const Conversation& conversation,
|
const int max_conversation_history_length) {
|
return ((max_conversation_history_length < 0 ||
|
conversation.messages.size() < max_conversation_history_length)
|
? conversation.messages.size()
|
: max_conversation_history_length);
|
}
|
|
} // namespace
|
|
std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromUnownedBuffer(
|
const uint8_t* buffer, const int size, const UniLib* unilib,
|
const std::string& triggering_preconditions_overlay) {
|
auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
|
const ActionsModel* model = LoadAndVerifyModel(buffer, size);
|
if (model == nullptr) {
|
return nullptr;
|
}
|
actions->model_ = model;
|
actions->SetOrCreateUnilib(unilib);
|
actions->triggering_preconditions_overlay_buffer_ =
|
triggering_preconditions_overlay;
|
if (!actions->ValidateAndInitialize()) {
|
return nullptr;
|
}
|
return actions;
|
}
|
|
std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
|
std::unique_ptr<libtextclassifier3::ScopedMmap> mmap, const UniLib* unilib,
|
const std::string& triggering_preconditions_overlay) {
|
if (!mmap->handle().ok()) {
|
TC3_VLOG(1) << "Mmap failed.";
|
return nullptr;
|
}
|
const ActionsModel* model = LoadAndVerifyModel(
|
reinterpret_cast<const uint8_t*>(mmap->handle().start()),
|
mmap->handle().num_bytes());
|
if (!model) {
|
TC3_LOG(ERROR) << "Model verification failed.";
|
return nullptr;
|
}
|
auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
|
actions->model_ = model;
|
actions->mmap_ = std::move(mmap);
|
actions->SetOrCreateUnilib(unilib);
|
actions->triggering_preconditions_overlay_buffer_ =
|
triggering_preconditions_overlay;
|
if (!actions->ValidateAndInitialize()) {
|
return nullptr;
|
}
|
return actions;
|
}
|
|
std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
|
std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
|
std::unique_ptr<UniLib> unilib,
|
const std::string& triggering_preconditions_overlay) {
|
if (!mmap->handle().ok()) {
|
TC3_VLOG(1) << "Mmap failed.";
|
return nullptr;
|
}
|
const ActionsModel* model = LoadAndVerifyModel(
|
reinterpret_cast<const uint8_t*>(mmap->handle().start()),
|
mmap->handle().num_bytes());
|
if (!model) {
|
TC3_LOG(ERROR) << "Model verification failed.";
|
return nullptr;
|
}
|
auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
|
actions->model_ = model;
|
actions->mmap_ = std::move(mmap);
|
actions->owned_unilib_ = std::move(unilib);
|
actions->unilib_ = actions->owned_unilib_.get();
|
actions->triggering_preconditions_overlay_buffer_ =
|
triggering_preconditions_overlay;
|
if (!actions->ValidateAndInitialize()) {
|
return nullptr;
|
}
|
return actions;
|
}
|
|
std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
|
const int fd, const int offset, const int size, const UniLib* unilib,
|
const std::string& triggering_preconditions_overlay) {
|
std::unique_ptr<libtextclassifier3::ScopedMmap> mmap;
|
if (offset >= 0 && size >= 0) {
|
mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size));
|
} else {
|
mmap.reset(new libtextclassifier3::ScopedMmap(fd));
|
}
|
return FromScopedMmap(std::move(mmap), unilib,
|
triggering_preconditions_overlay);
|
}
|
|
std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
|
const int fd, const int offset, const int size,
|
std::unique_ptr<UniLib> unilib,
|
const std::string& triggering_preconditions_overlay) {
|
std::unique_ptr<libtextclassifier3::ScopedMmap> mmap;
|
if (offset >= 0 && size >= 0) {
|
mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size));
|
} else {
|
mmap.reset(new libtextclassifier3::ScopedMmap(fd));
|
}
|
return FromScopedMmap(std::move(mmap), std::move(unilib),
|
triggering_preconditions_overlay);
|
}
|
|
std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
|
const int fd, const UniLib* unilib,
|
const std::string& triggering_preconditions_overlay) {
|
std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
|
new libtextclassifier3::ScopedMmap(fd));
|
return FromScopedMmap(std::move(mmap), unilib,
|
triggering_preconditions_overlay);
|
}
|
|
std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
|
const int fd, std::unique_ptr<UniLib> unilib,
|
const std::string& triggering_preconditions_overlay) {
|
std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
|
new libtextclassifier3::ScopedMmap(fd));
|
return FromScopedMmap(std::move(mmap), std::move(unilib),
|
triggering_preconditions_overlay);
|
}
|
|
std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
|
const std::string& path, const UniLib* unilib,
|
const std::string& triggering_preconditions_overlay) {
|
std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
|
new libtextclassifier3::ScopedMmap(path));
|
return FromScopedMmap(std::move(mmap), unilib,
|
triggering_preconditions_overlay);
|
}
|
|
std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
|
const std::string& path, std::unique_ptr<UniLib> unilib,
|
const std::string& triggering_preconditions_overlay) {
|
std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
|
new libtextclassifier3::ScopedMmap(path));
|
return FromScopedMmap(std::move(mmap), std::move(unilib),
|
triggering_preconditions_overlay);
|
}
|
|
void ActionsSuggestions::SetOrCreateUnilib(const UniLib* unilib) {
|
if (unilib != nullptr) {
|
unilib_ = unilib;
|
} else {
|
owned_unilib_.reset(new UniLib);
|
unilib_ = owned_unilib_.get();
|
}
|
}
|
|
bool ActionsSuggestions::ValidateAndInitialize() {
|
if (model_ == nullptr) {
|
TC3_LOG(ERROR) << "No model specified.";
|
return false;
|
}
|
|
if (model_->smart_reply_action_type() == nullptr) {
|
TC3_LOG(ERROR) << "No smart reply action type specified.";
|
return false;
|
}
|
|
if (!InitializeTriggeringPreconditions()) {
|
TC3_LOG(ERROR) << "Could not initialize preconditions.";
|
return false;
|
}
|
|
if (model_->locales() &&
|
!ParseLocales(model_->locales()->c_str(), &locales_)) {
|
TC3_LOG(ERROR) << "Could not parse model supported locales.";
|
return false;
|
}
|
|
if (model_->tflite_model_spec() != nullptr) {
|
model_executor_ = TfLiteModelExecutor::FromBuffer(
|
model_->tflite_model_spec()->tflite_model());
|
if (!model_executor_) {
|
TC3_LOG(ERROR) << "Could not initialize model executor.";
|
return false;
|
}
|
}
|
|
if (model_->annotation_actions_spec() != nullptr &&
|
model_->annotation_actions_spec()->annotation_mapping() != nullptr) {
|
for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
|
*model_->annotation_actions_spec()->annotation_mapping()) {
|
annotation_entity_types_.insert(mapping->annotation_collection()->str());
|
}
|
}
|
|
std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
|
if (!InitializeRules(decompressor.get())) {
|
TC3_LOG(ERROR) << "Could not initialize rules.";
|
return false;
|
}
|
|
if (model_->actions_entity_data_schema() != nullptr) {
|
entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
|
model_->actions_entity_data_schema()->Data(),
|
model_->actions_entity_data_schema()->size());
|
if (entity_data_schema_ == nullptr) {
|
TC3_LOG(ERROR) << "Could not load entity data schema data.";
|
return false;
|
}
|
|
entity_data_builder_.reset(
|
new ReflectiveFlatbufferBuilder(entity_data_schema_));
|
} else {
|
entity_data_schema_ = nullptr;
|
}
|
|
std::string actions_script;
|
if (GetUncompressedString(model_->lua_actions_script(),
|
model_->compressed_lua_actions_script(),
|
decompressor.get(), &actions_script) &&
|
!actions_script.empty()) {
|
if (!Compile(actions_script, &lua_bytecode_)) {
|
TC3_LOG(ERROR) << "Could not precompile lua actions snippet.";
|
return false;
|
}
|
}
|
|
if (!(ranker_ = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
|
model_->ranking_options(), decompressor.get(),
|
model_->smart_reply_action_type()->str()))) {
|
TC3_LOG(ERROR) << "Could not create an action suggestions ranker.";
|
return false;
|
}
|
|
// Create feature processor if specified.
|
const ActionsTokenFeatureProcessorOptions* options =
|
model_->feature_processor_options();
|
if (options != nullptr) {
|
if (options->tokenizer_options() == nullptr) {
|
TC3_LOG(ERROR) << "No tokenizer options specified.";
|
return false;
|
}
|
|
feature_processor_.reset(new ActionsFeatureProcessor(options, unilib_));
|
embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
|
options->embedding_model(), options->embedding_size(),
|
options->embedding_quantization_bits());
|
|
if (embedding_executor_ == nullptr) {
|
TC3_LOG(ERROR) << "Could not initialize embedding executor.";
|
return false;
|
}
|
|
// Cache embedding of padding, start and end token.
|
if (!EmbedTokenId(options->padding_token_id(), &embedded_padding_token_) ||
|
!EmbedTokenId(options->start_token_id(), &embedded_start_token_) ||
|
!EmbedTokenId(options->end_token_id(), &embedded_end_token_)) {
|
TC3_LOG(ERROR) << "Could not precompute token embeddings.";
|
return false;
|
}
|
token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize();
|
}
|
|
// Create low confidence model if specified.
|
if (model_->low_confidence_ngram_model() != nullptr) {
|
ngram_model_ = NGramModel::Create(model_->low_confidence_ngram_model(),
|
feature_processor_ == nullptr
|
? nullptr
|
: feature_processor_->tokenizer(),
|
unilib_);
|
if (ngram_model_ == nullptr) {
|
TC3_LOG(ERROR) << "Could not create ngram linear regression model.";
|
return false;
|
}
|
}
|
|
return true;
|
}
|
|
bool ActionsSuggestions::InitializeTriggeringPreconditions() {
|
triggering_preconditions_overlay_ =
|
LoadAndVerifyFlatbuffer<TriggeringPreconditions>(
|
triggering_preconditions_overlay_buffer_);
|
|
if (triggering_preconditions_overlay_ == nullptr &&
|
!triggering_preconditions_overlay_buffer_.empty()) {
|
TC3_LOG(ERROR) << "Could not load triggering preconditions overwrites.";
|
return false;
|
}
|
const flatbuffers::Table* overlay =
|
reinterpret_cast<const flatbuffers::Table*>(
|
triggering_preconditions_overlay_);
|
const TriggeringPreconditions* defaults = model_->preconditions();
|
if (defaults == nullptr) {
|
TC3_LOG(ERROR) << "No triggering conditions specified.";
|
return false;
|
}
|
|
preconditions_.min_smart_reply_triggering_score = ValueOrDefault(
|
overlay, TriggeringPreconditions::VT_MIN_SMART_REPLY_TRIGGERING_SCORE,
|
defaults->min_smart_reply_triggering_score());
|
preconditions_.max_sensitive_topic_score = ValueOrDefault(
|
overlay, TriggeringPreconditions::VT_MAX_SENSITIVE_TOPIC_SCORE,
|
defaults->max_sensitive_topic_score());
|
preconditions_.suppress_on_sensitive_topic = ValueOrDefault(
|
overlay, TriggeringPreconditions::VT_SUPPRESS_ON_SENSITIVE_TOPIC,
|
defaults->suppress_on_sensitive_topic());
|
preconditions_.min_input_length =
|
ValueOrDefault(overlay, TriggeringPreconditions::VT_MIN_INPUT_LENGTH,
|
defaults->min_input_length());
|
preconditions_.max_input_length =
|
ValueOrDefault(overlay, TriggeringPreconditions::VT_MAX_INPUT_LENGTH,
|
defaults->max_input_length());
|
preconditions_.min_locale_match_fraction = ValueOrDefault(
|
overlay, TriggeringPreconditions::VT_MIN_LOCALE_MATCH_FRACTION,
|
defaults->min_locale_match_fraction());
|
preconditions_.handle_missing_locale_as_supported = ValueOrDefault(
|
overlay, TriggeringPreconditions::VT_HANDLE_MISSING_LOCALE_AS_SUPPORTED,
|
defaults->handle_missing_locale_as_supported());
|
preconditions_.handle_unknown_locale_as_supported = ValueOrDefault(
|
overlay, TriggeringPreconditions::VT_HANDLE_UNKNOWN_LOCALE_AS_SUPPORTED,
|
defaults->handle_unknown_locale_as_supported());
|
preconditions_.suppress_on_low_confidence_input = ValueOrDefault(
|
overlay, TriggeringPreconditions::VT_SUPPRESS_ON_LOW_CONFIDENCE_INPUT,
|
defaults->suppress_on_low_confidence_input());
|
preconditions_.diversification_distance_threshold = ValueOrDefault(
|
overlay, TriggeringPreconditions::VT_DIVERSIFICATION_DISTANCE_THRESHOLD,
|
defaults->diversification_distance_threshold());
|
preconditions_.confidence_threshold =
|
ValueOrDefault(overlay, TriggeringPreconditions::VT_CONFIDENCE_THRESHOLD,
|
defaults->confidence_threshold());
|
preconditions_.empirical_probability_factor = ValueOrDefault(
|
overlay, TriggeringPreconditions::VT_EMPIRICAL_PROBABILITY_FACTOR,
|
defaults->empirical_probability_factor());
|
preconditions_.min_reply_score_threshold = ValueOrDefault(
|
overlay, TriggeringPreconditions::VT_MIN_REPLY_SCORE_THRESHOLD,
|
defaults->min_reply_score_threshold());
|
|
return true;
|
}
|
|
bool ActionsSuggestions::EmbedTokenId(const int32 token_id,
|
std::vector<float>* embedding) const {
|
return feature_processor_->AppendFeatures(
|
{token_id},
|
/*dense_features=*/{}, embedding_executor_.get(), embedding);
|
}
|
|
bool ActionsSuggestions::InitializeRules(ZlibDecompressor* decompressor) {
|
if (model_->rules() != nullptr) {
|
if (!InitializeRules(decompressor, model_->rules(), &rules_)) {
|
TC3_LOG(ERROR) << "Could not initialize action rules.";
|
return false;
|
}
|
}
|
|
if (model_->low_confidence_rules() != nullptr) {
|
if (!InitializeRules(decompressor, model_->low_confidence_rules(),
|
&low_confidence_rules_)) {
|
TC3_LOG(ERROR) << "Could not initialize low confidence rules.";
|
return false;
|
}
|
}
|
|
// Extend by rules provided by the overwrite.
|
// NOTE: The rules from the original models are *not* cleared.
|
if (triggering_preconditions_overlay_ != nullptr &&
|
triggering_preconditions_overlay_->low_confidence_rules() != nullptr) {
|
// These rules are optionally compressed, but separately.
|
std::unique_ptr<ZlibDecompressor> overwrite_decompressor =
|
ZlibDecompressor::Instance();
|
if (overwrite_decompressor == nullptr) {
|
TC3_LOG(ERROR) << "Could not initialze decompressor for overwrite rules.";
|
return false;
|
}
|
if (!InitializeRules(
|
overwrite_decompressor.get(),
|
triggering_preconditions_overlay_->low_confidence_rules(),
|
&low_confidence_rules_)) {
|
TC3_LOG(ERROR)
|
<< "Could not initialize low confidence rules from overwrite.";
|
return false;
|
}
|
}
|
|
return true;
|
}
|
|
bool ActionsSuggestions::InitializeRules(
|
ZlibDecompressor* decompressor, const RulesModel* rules,
|
std::vector<CompiledRule>* compiled_rules) const {
|
for (const RulesModel_::Rule* rule : *rules->rule()) {
|
std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
|
UncompressMakeRegexPattern(
|
*unilib_, rule->pattern(), rule->compressed_pattern(),
|
rules->lazy_regex_compilation(), decompressor);
|
if (compiled_pattern == nullptr) {
|
TC3_LOG(ERROR) << "Failed to load rule pattern.";
|
return false;
|
}
|
|
// Check whether there is a check on the output.
|
std::unique_ptr<UniLib::RegexPattern> compiled_output_pattern;
|
if (rule->output_pattern() != nullptr ||
|
rule->compressed_output_pattern() != nullptr) {
|
compiled_output_pattern = UncompressMakeRegexPattern(
|
*unilib_, rule->output_pattern(), rule->compressed_output_pattern(),
|
rules->lazy_regex_compilation(), decompressor);
|
if (compiled_output_pattern == nullptr) {
|
TC3_LOG(ERROR) << "Failed to load rule output pattern.";
|
return false;
|
}
|
}
|
|
compiled_rules->emplace_back(rule, std::move(compiled_pattern),
|
std::move(compiled_output_pattern));
|
}
|
|
return true;
|
}
|
|
bool ActionsSuggestions::IsLowConfidenceInput(
|
const Conversation& conversation, const int num_messages,
|
std::vector<int>* post_check_rules) const {
|
for (int i = 1; i <= num_messages; i++) {
|
const std::string& message =
|
conversation.messages[conversation.messages.size() - i].text;
|
const UnicodeText message_unicode(
|
UTF8ToUnicodeText(message, /*do_copy=*/false));
|
|
// Run ngram linear regression model.
|
if (ngram_model_ != nullptr) {
|
if (ngram_model_->Eval(message_unicode)) {
|
return true;
|
}
|
}
|
|
// Run the regex based rules.
|
for (int low_confidence_rule = 0;
|
low_confidence_rule < low_confidence_rules_.size();
|
low_confidence_rule++) {
|
const CompiledRule& rule = low_confidence_rules_[low_confidence_rule];
|
const std::unique_ptr<UniLib::RegexMatcher> matcher =
|
rule.pattern->Matcher(message_unicode);
|
int status = UniLib::RegexMatcher::kNoError;
|
if (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
|
// Rule only applies to input-output pairs, so defer the check.
|
if (rule.output_pattern != nullptr) {
|
post_check_rules->push_back(low_confidence_rule);
|
continue;
|
}
|
return true;
|
}
|
}
|
}
|
return false;
|
}
|
|
bool ActionsSuggestions::FilterConfidenceOutput(
|
const std::vector<int>& post_check_rules,
|
std::vector<ActionSuggestion>* actions) const {
|
if (post_check_rules.empty() || actions->empty()) {
|
return true;
|
}
|
std::vector<ActionSuggestion> filtered_text_replies;
|
for (const ActionSuggestion& action : *actions) {
|
if (action.response_text.empty()) {
|
filtered_text_replies.push_back(action);
|
continue;
|
}
|
bool passes_post_check = true;
|
const UnicodeText text_reply_unicode(
|
UTF8ToUnicodeText(action.response_text, /*do_copy=*/false));
|
for (const int rule_id : post_check_rules) {
|
const std::unique_ptr<UniLib::RegexMatcher> matcher =
|
low_confidence_rules_[rule_id].output_pattern->Matcher(
|
text_reply_unicode);
|
if (matcher == nullptr) {
|
TC3_LOG(ERROR) << "Could not create matcher for post check rule.";
|
return false;
|
}
|
int status = UniLib::RegexMatcher::kNoError;
|
if (matcher->Find(&status) || status != UniLib::RegexMatcher::kNoError) {
|
passes_post_check = false;
|
break;
|
}
|
}
|
if (passes_post_check) {
|
filtered_text_replies.push_back(action);
|
}
|
}
|
*actions = std::move(filtered_text_replies);
|
return true;
|
}
|
|
ActionSuggestion ActionsSuggestions::SuggestionFromSpec(
|
const ActionSuggestionSpec* action, const std::string& default_type,
|
const std::string& default_response_text,
|
const std::string& default_serialized_entity_data,
|
const float default_score, const float default_priority_score) const {
|
ActionSuggestion suggestion;
|
suggestion.score = action != nullptr ? action->score() : default_score;
|
suggestion.priority_score =
|
action != nullptr ? action->priority_score() : default_priority_score;
|
suggestion.type = action != nullptr && action->type() != nullptr
|
? action->type()->str()
|
: default_type;
|
suggestion.response_text =
|
action != nullptr && action->response_text() != nullptr
|
? action->response_text()->str()
|
: default_response_text;
|
suggestion.serialized_entity_data =
|
action != nullptr && action->serialized_entity_data() != nullptr
|
? action->serialized_entity_data()->str()
|
: default_serialized_entity_data;
|
return suggestion;
|
}
|
|
std::vector<std::vector<Token>> ActionsSuggestions::Tokenize(
|
const std::vector<std::string>& context) const {
|
std::vector<std::vector<Token>> tokens;
|
tokens.reserve(context.size());
|
for (const std::string& message : context) {
|
tokens.push_back(feature_processor_->tokenizer()->Tokenize(message));
|
}
|
return tokens;
|
}
|
|
bool ActionsSuggestions::EmbedTokensPerMessage(
|
const std::vector<std::vector<Token>>& tokens,
|
std::vector<float>* embeddings, int* max_num_tokens_per_message) const {
|
const int num_messages = tokens.size();
|
*max_num_tokens_per_message = 0;
|
for (int i = 0; i < num_messages; i++) {
|
const int num_message_tokens = tokens[i].size();
|
if (num_message_tokens > *max_num_tokens_per_message) {
|
*max_num_tokens_per_message = num_message_tokens;
|
}
|
}
|
|
if (model_->feature_processor_options()->min_num_tokens_per_message() >
|
*max_num_tokens_per_message) {
|
*max_num_tokens_per_message =
|
model_->feature_processor_options()->min_num_tokens_per_message();
|
}
|
if (model_->feature_processor_options()->max_num_tokens_per_message() > 0 &&
|
*max_num_tokens_per_message >
|
model_->feature_processor_options()->max_num_tokens_per_message()) {
|
*max_num_tokens_per_message =
|
model_->feature_processor_options()->max_num_tokens_per_message();
|
}
|
|
// Embed all tokens and add paddings to pad tokens of each message to the
|
// maximum number of tokens in a message of the conversation.
|
// If a number of tokens is specified in the model config, tokens at the
|
// beginning of a message are dropped if they don't fit in the limit.
|
for (int i = 0; i < num_messages; i++) {
|
const int start =
|
std::max<int>(tokens[i].size() - *max_num_tokens_per_message, 0);
|
for (int pos = start; pos < tokens[i].size(); pos++) {
|
if (!feature_processor_->AppendTokenFeatures(
|
tokens[i][pos], embedding_executor_.get(), embeddings)) {
|
TC3_LOG(ERROR) << "Could not run token feature extractor.";
|
return false;
|
}
|
}
|
// Add padding.
|
for (int k = tokens[i].size(); k < *max_num_tokens_per_message; k++) {
|
embeddings->insert(embeddings->end(), embedded_padding_token_.begin(),
|
embedded_padding_token_.end());
|
}
|
}
|
|
return true;
|
}
|
|
bool ActionsSuggestions::EmbedAndFlattenTokens(
|
const std::vector<std::vector<Token>> tokens,
|
std::vector<float>* embeddings, int* total_token_count) const {
|
const int num_messages = tokens.size();
|
int start_message = 0;
|
int message_token_offset = 0;
|
|
// If a maximum model input length is specified, we need to check how
|
// much we need to trim at the start.
|
const int max_num_total_tokens =
|
model_->feature_processor_options()->max_num_total_tokens();
|
if (max_num_total_tokens > 0) {
|
int total_tokens = 0;
|
start_message = num_messages - 1;
|
for (; start_message >= 0; start_message--) {
|
// Tokens of the message + start and end token.
|
const int num_message_tokens = tokens[start_message].size() + 2;
|
total_tokens += num_message_tokens;
|
|
// Check whether we exhausted the budget.
|
if (total_tokens >= max_num_total_tokens) {
|
message_token_offset = total_tokens - max_num_total_tokens;
|
break;
|
}
|
}
|
}
|
|
// Add embeddings.
|
*total_token_count = 0;
|
for (int i = start_message; i < num_messages; i++) {
|
if (message_token_offset == 0) {
|
++(*total_token_count);
|
// Add `start message` token.
|
embeddings->insert(embeddings->end(), embedded_start_token_.begin(),
|
embedded_start_token_.end());
|
}
|
|
for (int pos = std::max(0, message_token_offset - 1);
|
pos < tokens[i].size(); pos++) {
|
++(*total_token_count);
|
if (!feature_processor_->AppendTokenFeatures(
|
tokens[i][pos], embedding_executor_.get(), embeddings)) {
|
TC3_LOG(ERROR) << "Could not run token feature extractor.";
|
return false;
|
}
|
}
|
|
// Add `end message` token.
|
++(*total_token_count);
|
embeddings->insert(embeddings->end(), embedded_end_token_.begin(),
|
embedded_end_token_.end());
|
|
// Reset for the subsequent messages.
|
message_token_offset = 0;
|
}
|
|
// Add optional padding.
|
const int min_num_total_tokens =
|
model_->feature_processor_options()->min_num_total_tokens();
|
for (; *total_token_count < min_num_total_tokens; ++(*total_token_count)) {
|
embeddings->insert(embeddings->end(), embedded_padding_token_.begin(),
|
embedded_padding_token_.end());
|
}
|
|
return true;
|
}
|
|
bool ActionsSuggestions::AllocateInput(const int conversation_length,
|
const int max_tokens,
|
const int total_token_count,
|
tflite::Interpreter* interpreter) const {
|
if (model_->tflite_model_spec()->resize_inputs()) {
|
if (model_->tflite_model_spec()->input_context() >= 0) {
|
interpreter->ResizeInputTensor(
|
interpreter->inputs()[model_->tflite_model_spec()->input_context()],
|
{1, conversation_length});
|
}
|
if (model_->tflite_model_spec()->input_user_id() >= 0) {
|
interpreter->ResizeInputTensor(
|
interpreter->inputs()[model_->tflite_model_spec()->input_user_id()],
|
{1, conversation_length});
|
}
|
if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
|
interpreter->ResizeInputTensor(
|
interpreter
|
->inputs()[model_->tflite_model_spec()->input_time_diffs()],
|
{1, conversation_length});
|
}
|
if (model_->tflite_model_spec()->input_num_tokens() >= 0) {
|
interpreter->ResizeInputTensor(
|
interpreter
|
->inputs()[model_->tflite_model_spec()->input_num_tokens()],
|
{conversation_length, 1});
|
}
|
if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
|
interpreter->ResizeInputTensor(
|
interpreter
|
->inputs()[model_->tflite_model_spec()->input_token_embeddings()],
|
{conversation_length, max_tokens, token_embedding_size_});
|
}
|
if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
|
interpreter->ResizeInputTensor(
|
interpreter->inputs()[model_->tflite_model_spec()
|
->input_flattened_token_embeddings()],
|
{1, total_token_count});
|
}
|
}
|
|
return interpreter->AllocateTensors() == kTfLiteOk;
|
}
|
|
bool ActionsSuggestions::SetupModelInput(
|
const std::vector<std::string>& context, const std::vector<int>& user_ids,
|
const std::vector<float>& time_diffs, const int num_suggestions,
|
const float confidence_threshold, const float diversification_distance,
|
const float empirical_probability_factor,
|
tflite::Interpreter* interpreter) const {
|
// Compute token embeddings.
|
std::vector<std::vector<Token>> tokens;
|
std::vector<float> token_embeddings;
|
std::vector<float> flattened_token_embeddings;
|
int max_tokens = 0;
|
int total_token_count = 0;
|
if (model_->tflite_model_spec()->input_num_tokens() >= 0 ||
|
model_->tflite_model_spec()->input_token_embeddings() >= 0 ||
|
model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
|
if (feature_processor_ == nullptr) {
|
TC3_LOG(ERROR) << "No feature processor specified.";
|
return false;
|
}
|
|
// Tokenize the messages in the conversation.
|
tokens = Tokenize(context);
|
if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
|
if (!EmbedTokensPerMessage(tokens, &token_embeddings, &max_tokens)) {
|
TC3_LOG(ERROR) << "Could not extract token features.";
|
return false;
|
}
|
}
|
if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
|
if (!EmbedAndFlattenTokens(tokens, &flattened_token_embeddings,
|
&total_token_count)) {
|
TC3_LOG(ERROR) << "Could not extract token features.";
|
return false;
|
}
|
}
|
}
|
|
if (!AllocateInput(context.size(), max_tokens, total_token_count,
|
interpreter)) {
|
TC3_LOG(ERROR) << "TensorFlow Lite model allocation failed.";
|
return false;
|
}
|
if (model_->tflite_model_spec()->input_context() >= 0) {
|
model_executor_->SetInput<std::string>(
|
model_->tflite_model_spec()->input_context(), context, interpreter);
|
}
|
if (model_->tflite_model_spec()->input_context_length() >= 0) {
|
model_executor_->SetInput<int>(
|
model_->tflite_model_spec()->input_context_length(), context.size(),
|
interpreter);
|
}
|
if (model_->tflite_model_spec()->input_user_id() >= 0) {
|
model_executor_->SetInput<int>(model_->tflite_model_spec()->input_user_id(),
|
user_ids, interpreter);
|
}
|
if (model_->tflite_model_spec()->input_num_suggestions() >= 0) {
|
model_executor_->SetInput<int>(
|
model_->tflite_model_spec()->input_num_suggestions(), num_suggestions,
|
interpreter);
|
}
|
if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
|
model_executor_->SetInput<float>(
|
model_->tflite_model_spec()->input_time_diffs(), time_diffs,
|
interpreter);
|
}
|
if (model_->tflite_model_spec()->input_diversification_distance() >= 0) {
|
model_executor_->SetInput<float>(
|
model_->tflite_model_spec()->input_diversification_distance(),
|
diversification_distance, interpreter);
|
}
|
if (model_->tflite_model_spec()->input_confidence_threshold() >= 0) {
|
model_executor_->SetInput<float>(
|
model_->tflite_model_spec()->input_confidence_threshold(),
|
confidence_threshold, interpreter);
|
}
|
if (model_->tflite_model_spec()->input_empirical_probability_factor() >= 0) {
|
model_executor_->SetInput<float>(
|
model_->tflite_model_spec()->input_empirical_probability_factor(),
|
confidence_threshold, interpreter);
|
}
|
if (model_->tflite_model_spec()->input_num_tokens() >= 0) {
|
std::vector<int> num_tokens_per_message(tokens.size());
|
for (int i = 0; i < tokens.size(); i++) {
|
num_tokens_per_message[i] = tokens[i].size();
|
}
|
model_executor_->SetInput<int>(
|
model_->tflite_model_spec()->input_num_tokens(), num_tokens_per_message,
|
interpreter);
|
}
|
if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
|
model_executor_->SetInput<float>(
|
model_->tflite_model_spec()->input_token_embeddings(), token_embeddings,
|
interpreter);
|
}
|
if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
|
model_executor_->SetInput<float>(
|
model_->tflite_model_spec()->input_flattened_token_embeddings(),
|
flattened_token_embeddings, interpreter);
|
}
|
return true;
|
}
|
|
bool ActionsSuggestions::ReadModelOutput(
|
tflite::Interpreter* interpreter, const ActionSuggestionOptions& options,
|
ActionsSuggestionsResponse* response) const {
|
// Read sensitivity and triggering score predictions.
|
if (model_->tflite_model_spec()->output_triggering_score() >= 0) {
|
const TensorView<float>& triggering_score =
|
model_executor_->OutputView<float>(
|
model_->tflite_model_spec()->output_triggering_score(),
|
interpreter);
|
if (!triggering_score.is_valid() || triggering_score.size() == 0) {
|
TC3_LOG(ERROR) << "Could not compute triggering score.";
|
return false;
|
}
|
response->triggering_score = triggering_score.data()[0];
|
response->output_filtered_min_triggering_score =
|
(response->triggering_score <
|
preconditions_.min_smart_reply_triggering_score);
|
}
|
if (model_->tflite_model_spec()->output_sensitive_topic_score() >= 0) {
|
const TensorView<float>& sensitive_topic_score =
|
model_executor_->OutputView<float>(
|
model_->tflite_model_spec()->output_sensitive_topic_score(),
|
interpreter);
|
if (!sensitive_topic_score.is_valid() ||
|
sensitive_topic_score.dim(0) != 1) {
|
TC3_LOG(ERROR) << "Could not compute sensitive topic score.";
|
return false;
|
}
|
response->sensitivity_score = sensitive_topic_score.data()[0];
|
response->output_filtered_sensitivity =
|
(response->sensitivity_score >
|
preconditions_.max_sensitive_topic_score);
|
}
|
|
// Suppress model outputs.
|
if (response->output_filtered_sensitivity) {
|
return true;
|
}
|
|
// Read smart reply predictions.
|
std::vector<ActionSuggestion> text_replies;
|
if (!response->output_filtered_min_triggering_score &&
|
model_->tflite_model_spec()->output_replies() >= 0) {
|
const std::vector<tflite::StringRef> replies =
|
model_executor_->Output<tflite::StringRef>(
|
model_->tflite_model_spec()->output_replies(), interpreter);
|
TensorView<float> scores = model_executor_->OutputView<float>(
|
model_->tflite_model_spec()->output_replies_scores(), interpreter);
|
for (int i = 0; i < replies.size(); i++) {
|
if (replies[i].len == 0) continue;
|
const float score = scores.data()[i];
|
if (score < preconditions_.min_reply_score_threshold) {
|
continue;
|
}
|
response->actions.push_back({std::string(replies[i].str, replies[i].len),
|
model_->smart_reply_action_type()->str(),
|
score});
|
}
|
}
|
|
// Read actions suggestions.
|
if (model_->tflite_model_spec()->output_actions_scores() >= 0) {
|
const TensorView<float> actions_scores = model_executor_->OutputView<float>(
|
model_->tflite_model_spec()->output_actions_scores(), interpreter);
|
for (int i = 0; i < model_->action_type()->Length(); i++) {
|
const ActionTypeOptions* action_type = model_->action_type()->Get(i);
|
// Skip disabled action classes, such as the default other category.
|
if (!action_type->enabled()) {
|
continue;
|
}
|
const float score = actions_scores.data()[i];
|
if (score < action_type->min_triggering_score()) {
|
continue;
|
}
|
ActionSuggestion suggestion =
|
SuggestionFromSpec(action_type->action(),
|
/*default_type=*/action_type->name()->str());
|
suggestion.score = score;
|
response->actions.push_back(suggestion);
|
}
|
}
|
|
return true;
|
}
|
|
bool ActionsSuggestions::SuggestActionsFromModel(
|
const Conversation& conversation, const int num_messages,
|
const ActionSuggestionOptions& options,
|
ActionsSuggestionsResponse* response,
|
std::unique_ptr<tflite::Interpreter>* interpreter) const {
|
TC3_CHECK_LE(num_messages, conversation.messages.size());
|
|
if (!model_executor_) {
|
return true;
|
}
|
*interpreter = model_executor_->CreateInterpreter();
|
|
if (!*interpreter) {
|
TC3_LOG(ERROR) << "Could not build TensorFlow Lite interpreter for the "
|
"actions suggestions model.";
|
return false;
|
}
|
|
std::vector<std::string> context;
|
std::vector<int> user_ids;
|
std::vector<float> time_diffs;
|
context.reserve(num_messages);
|
user_ids.reserve(num_messages);
|
time_diffs.reserve(num_messages);
|
|
// Gather last `num_messages` messages from the conversation.
|
int64 last_message_reference_time_ms_utc = 0;
|
const float second_in_ms = 1000;
|
for (int i = conversation.messages.size() - num_messages;
|
i < conversation.messages.size(); i++) {
|
const ConversationMessage& message = conversation.messages[i];
|
context.push_back(message.text);
|
user_ids.push_back(message.user_id);
|
|
float time_diff_secs = 0;
|
if (message.reference_time_ms_utc != 0 &&
|
last_message_reference_time_ms_utc != 0) {
|
time_diff_secs = std::max(0.0f, (message.reference_time_ms_utc -
|
last_message_reference_time_ms_utc) /
|
second_in_ms);
|
}
|
if (message.reference_time_ms_utc != 0) {
|
last_message_reference_time_ms_utc = message.reference_time_ms_utc;
|
}
|
time_diffs.push_back(time_diff_secs);
|
}
|
|
if (!SetupModelInput(context, user_ids, time_diffs,
|
/*num_suggestions=*/model_->num_smart_replies(),
|
preconditions_.confidence_threshold,
|
preconditions_.diversification_distance_threshold,
|
preconditions_.empirical_probability_factor,
|
interpreter->get())) {
|
TC3_LOG(ERROR) << "Failed to setup input for TensorFlow Lite model.";
|
return false;
|
}
|
|
if ((*interpreter)->Invoke() != kTfLiteOk) {
|
TC3_LOG(ERROR) << "Failed to invoke TensorFlow Lite interpreter.";
|
return false;
|
}
|
|
return ReadModelOutput(interpreter->get(), options, response);
|
}
|
|
AnnotationOptions ActionsSuggestions::AnnotationOptionsForMessage(
|
const ConversationMessage& message) const {
|
AnnotationOptions options;
|
options.detected_text_language_tags = message.detected_text_language_tags;
|
options.reference_time_ms_utc = message.reference_time_ms_utc;
|
options.reference_timezone = message.reference_timezone;
|
options.annotation_usecase =
|
model_->annotation_actions_spec()->annotation_usecase();
|
options.is_serialized_entity_data_enabled =
|
model_->annotation_actions_spec()->is_serialized_entity_data_enabled();
|
options.entity_types = annotation_entity_types_;
|
return options;
|
}
|
|
void ActionsSuggestions::SuggestActionsFromAnnotations(
|
const Conversation& conversation, const ActionSuggestionOptions& options,
|
const Annotator* annotator, std::vector<ActionSuggestion>* actions) const {
|
if (model_->annotation_actions_spec() == nullptr ||
|
model_->annotation_actions_spec()->annotation_mapping() == nullptr ||
|
model_->annotation_actions_spec()->annotation_mapping()->size() == 0) {
|
return;
|
}
|
|
// Create actions based on the annotations.
|
const int max_from_any_person =
|
model_->annotation_actions_spec()->max_history_from_any_person();
|
const int max_from_last_person =
|
model_->annotation_actions_spec()->max_history_from_last_person();
|
const int last_person = conversation.messages.back().user_id;
|
|
int num_messages_last_person = 0;
|
int num_messages_any_person = 0;
|
bool all_from_last_person = true;
|
for (int message_index = conversation.messages.size() - 1; message_index >= 0;
|
message_index--) {
|
const ConversationMessage& message = conversation.messages[message_index];
|
std::vector<AnnotatedSpan> annotations = message.annotations;
|
|
// Update how many messages we have processed from the last person in the
|
// conversation and from any person in the conversation.
|
num_messages_any_person++;
|
if (all_from_last_person && message.user_id == last_person) {
|
num_messages_last_person++;
|
} else {
|
all_from_last_person = false;
|
}
|
|
if (num_messages_any_person > max_from_any_person &&
|
(!all_from_last_person ||
|
num_messages_last_person > max_from_last_person)) {
|
break;
|
}
|
|
if (message.user_id == kLocalUserId) {
|
if (model_->annotation_actions_spec()->only_until_last_sent()) {
|
break;
|
}
|
if (!model_->annotation_actions_spec()->include_local_user_messages()) {
|
continue;
|
}
|
}
|
|
if (annotations.empty() && annotator != nullptr) {
|
annotations = annotator->Annotate(message.text,
|
AnnotationOptionsForMessage(message));
|
}
|
std::vector<ActionSuggestionAnnotation> action_annotations;
|
action_annotations.reserve(annotations.size());
|
for (const AnnotatedSpan& annotation : annotations) {
|
if (annotation.classification.empty()) {
|
continue;
|
}
|
|
const ClassificationResult& classification_result =
|
annotation.classification[0];
|
|
ActionSuggestionAnnotation action_annotation;
|
action_annotation.span = {
|
message_index, annotation.span,
|
UTF8ToUnicodeText(message.text, /*do_copy=*/false)
|
.UTF8Substring(annotation.span.first, annotation.span.second)};
|
action_annotation.entity = classification_result;
|
action_annotation.name = classification_result.collection;
|
action_annotations.push_back(action_annotation);
|
}
|
|
if (model_->annotation_actions_spec()->deduplicate_annotations()) {
|
// Create actions only for deduplicated annotations.
|
for (const int annotation_id :
|
DeduplicateAnnotations(action_annotations)) {
|
SuggestActionsFromAnnotation(
|
message_index, action_annotations[annotation_id], actions);
|
}
|
} else {
|
// Create actions for all annotations.
|
for (const ActionSuggestionAnnotation& annotation : action_annotations) {
|
SuggestActionsFromAnnotation(message_index, annotation, actions);
|
}
|
}
|
}
|
}
|
|
void ActionsSuggestions::SuggestActionsFromAnnotation(
|
const int message_index, const ActionSuggestionAnnotation& annotation,
|
std::vector<ActionSuggestion>* actions) const {
|
for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
|
*model_->annotation_actions_spec()->annotation_mapping()) {
|
if (annotation.entity.collection ==
|
mapping->annotation_collection()->str()) {
|
if (annotation.entity.score < mapping->min_annotation_score()) {
|
continue;
|
}
|
ActionSuggestion suggestion = SuggestionFromSpec(mapping->action());
|
if (mapping->use_annotation_score()) {
|
suggestion.score = annotation.entity.score;
|
}
|
|
// Set annotation text as (additional) entity data field.
|
if (mapping->entity_field() != nullptr) {
|
std::unique_ptr<ReflectiveFlatbuffer> entity_data =
|
entity_data_builder_->NewRoot();
|
TC3_CHECK(entity_data != nullptr);
|
|
// Merge existing static entity data.
|
if (!suggestion.serialized_entity_data.empty()) {
|
entity_data->MergeFromSerializedFlatbuffer(
|
StringPiece(suggestion.serialized_entity_data.c_str(),
|
suggestion.serialized_entity_data.size()));
|
}
|
|
entity_data->ParseAndSet(mapping->entity_field(), annotation.span.text);
|
suggestion.serialized_entity_data = entity_data->Serialize();
|
}
|
|
suggestion.annotations = {annotation};
|
actions->push_back(suggestion);
|
}
|
}
|
}
|
|
std::vector<int> ActionsSuggestions::DeduplicateAnnotations(
|
const std::vector<ActionSuggestionAnnotation>& annotations) const {
|
std::map<std::pair<std::string, std::string>, int> deduplicated_annotations;
|
|
for (int i = 0; i < annotations.size(); i++) {
|
const std::pair<std::string, std::string> key = {annotations[i].name,
|
annotations[i].span.text};
|
auto entry = deduplicated_annotations.find(key);
|
if (entry != deduplicated_annotations.end()) {
|
// Kepp the annotation with the higher score.
|
if (annotations[entry->second].entity.score <
|
annotations[i].entity.score) {
|
entry->second = i;
|
}
|
continue;
|
}
|
deduplicated_annotations.insert(entry, {key, i});
|
}
|
|
std::vector<int> result;
|
result.reserve(deduplicated_annotations.size());
|
for (const auto& key_and_annotation : deduplicated_annotations) {
|
result.push_back(key_and_annotation.second);
|
}
|
return result;
|
}
|
|
bool ActionsSuggestions::FillAnnotationFromMatchGroup(
|
const UniLib::RegexMatcher* matcher,
|
const RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroup* group,
|
const int message_index, ActionSuggestionAnnotation* annotation) const {
|
if (group->annotation_name() != nullptr ||
|
group->annotation_type() != nullptr) {
|
int status = UniLib::RegexMatcher::kNoError;
|
const CodepointSpan span = {matcher->Start(group->group_id(), &status),
|
matcher->End(group->group_id(), &status)};
|
std::string text =
|
matcher->Group(group->group_id(), &status).ToUTF8String();
|
if (status != UniLib::RegexMatcher::kNoError) {
|
TC3_LOG(ERROR) << "Could not extract span from rule capturing group.";
|
return false;
|
}
|
|
// The capturing group was not part of the match.
|
if (span.first == kInvalidIndex || span.second == kInvalidIndex) {
|
return false;
|
}
|
annotation->span.span = span;
|
annotation->span.message_index = message_index;
|
annotation->span.text = text;
|
if (group->annotation_name() != nullptr) {
|
annotation->name = group->annotation_name()->str();
|
}
|
if (group->annotation_type() != nullptr) {
|
annotation->entity.collection = group->annotation_type()->str();
|
}
|
}
|
return true;
|
}
|
|
bool ActionsSuggestions::SuggestActionsFromRules(
|
const Conversation& conversation,
|
std::vector<ActionSuggestion>* actions) const {
|
// Create actions based on rules checking the last message.
|
const int message_index = conversation.messages.size() - 1;
|
const std::string& message = conversation.messages.back().text;
|
const UnicodeText message_unicode(
|
UTF8ToUnicodeText(message, /*do_copy=*/false));
|
for (const CompiledRule& rule : rules_) {
|
const std::unique_ptr<UniLib::RegexMatcher> matcher =
|
rule.pattern->Matcher(message_unicode);
|
int status = UniLib::RegexMatcher::kNoError;
|
while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
|
for (const RulesModel_::Rule_::RuleActionSpec* rule_action :
|
*rule.rule->actions()) {
|
const ActionSuggestionSpec* action = rule_action->action();
|
std::vector<ActionSuggestionAnnotation> annotations;
|
|
bool sets_entity_data = false;
|
std::unique_ptr<ReflectiveFlatbuffer> entity_data =
|
entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
|
: nullptr;
|
|
// Set static entity data.
|
if (action != nullptr && action->serialized_entity_data() != nullptr) {
|
TC3_CHECK(entity_data != nullptr);
|
sets_entity_data = true;
|
entity_data->MergeFromSerializedFlatbuffer(
|
StringPiece(action->serialized_entity_data()->c_str(),
|
action->serialized_entity_data()->size()));
|
}
|
|
// Add entity data from rule capturing groups.
|
if (rule_action->capturing_group() != nullptr) {
|
for (const RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroup*
|
group : *rule_action->capturing_group()) {
|
if (group->entity_field() != nullptr) {
|
TC3_CHECK(entity_data != nullptr);
|
sets_entity_data = true;
|
if (!SetFieldFromCapturingGroup(
|
group->group_id(), group->entity_field(), matcher.get(),
|
entity_data.get())) {
|
TC3_LOG(ERROR)
|
<< "Could not set entity data from rule capturing group.";
|
return false;
|
}
|
}
|
|
// Create a text annotation for the group span.
|
ActionSuggestionAnnotation annotation;
|
if (FillAnnotationFromMatchGroup(matcher.get(), group,
|
message_index, &annotation)) {
|
annotations.push_back(annotation);
|
}
|
|
// Create text reply.
|
if (group->text_reply() != nullptr) {
|
int status = UniLib::RegexMatcher::kNoError;
|
const std::string group_text =
|
matcher->Group(group->group_id(), &status).ToUTF8String();
|
if (status != UniLib::RegexMatcher::kNoError) {
|
TC3_LOG(ERROR) << "Could get text from capturing group.";
|
return false;
|
}
|
if (group_text.empty()) {
|
// The group was not part of the match, ignore and continue.
|
continue;
|
}
|
actions->push_back(SuggestionFromSpec(
|
group->text_reply(),
|
/*default_type=*/model_->smart_reply_action_type()->str(),
|
/*default_response_text=*/group_text));
|
}
|
}
|
}
|
|
if (action != nullptr) {
|
ActionSuggestion suggestion = SuggestionFromSpec(action);
|
suggestion.annotations = annotations;
|
if (sets_entity_data) {
|
suggestion.serialized_entity_data = entity_data->Serialize();
|
}
|
actions->push_back(suggestion);
|
}
|
}
|
}
|
}
|
return true;
|
}
|
|
bool ActionsSuggestions::SuggestActionsFromLua(
|
const Conversation& conversation, const TfLiteModelExecutor* model_executor,
|
const tflite::Interpreter* interpreter,
|
const reflection::Schema* annotation_entity_data_schema,
|
std::vector<ActionSuggestion>* actions) const {
|
if (lua_bytecode_.empty()) {
|
return true;
|
}
|
|
auto lua_actions = LuaActionsSuggestions::CreateLuaActionsSuggestions(
|
lua_bytecode_, conversation, model_executor, model_->tflite_model_spec(),
|
interpreter, entity_data_schema_, annotation_entity_data_schema);
|
if (lua_actions == nullptr) {
|
TC3_LOG(ERROR) << "Could not create lua actions.";
|
return false;
|
}
|
return lua_actions->SuggestActions(actions);
|
}
|
|
bool ActionsSuggestions::GatherActionsSuggestions(
|
const Conversation& conversation, const Annotator* annotator,
|
const ActionSuggestionOptions& options,
|
ActionsSuggestionsResponse* response) const {
|
if (conversation.messages.empty()) {
|
return true;
|
}
|
|
const int num_messages = NumMessagesToConsider(
|
conversation, model_->max_conversation_history_length());
|
|
if (num_messages <= 0) {
|
TC3_LOG(INFO) << "No messages provided for actions suggestions.";
|
return false;
|
}
|
|
SuggestActionsFromAnnotations(conversation, options, annotator,
|
&response->actions);
|
|
int input_text_length = 0;
|
int num_matching_locales = 0;
|
for (int i = conversation.messages.size() - num_messages;
|
i < conversation.messages.size(); i++) {
|
input_text_length += conversation.messages[i].text.length();
|
std::vector<Locale> message_languages;
|
if (!ParseLocales(conversation.messages[i].detected_text_language_tags,
|
&message_languages)) {
|
continue;
|
}
|
if (Locale::IsAnyLocaleSupported(
|
message_languages, locales_,
|
preconditions_.handle_unknown_locale_as_supported)) {
|
++num_matching_locales;
|
}
|
}
|
|
// Bail out if we are provided with too few or too much input.
|
if (input_text_length < preconditions_.min_input_length ||
|
(preconditions_.max_input_length >= 0 &&
|
input_text_length > preconditions_.max_input_length)) {
|
TC3_LOG(INFO) << "Too much or not enough input for inference.";
|
return response;
|
}
|
|
// Bail out if the text does not look like it can be handled by the model.
|
const float matching_fraction =
|
static_cast<float>(num_matching_locales) / num_messages;
|
if (matching_fraction < preconditions_.min_locale_match_fraction) {
|
TC3_LOG(INFO) << "Not enough locale matches.";
|
response->output_filtered_locale_mismatch = true;
|
return true;
|
}
|
|
std::vector<int> post_check_rules;
|
if (preconditions_.suppress_on_low_confidence_input &&
|
IsLowConfidenceInput(conversation, num_messages, &post_check_rules)) {
|
response->output_filtered_low_confidence = true;
|
return true;
|
}
|
|
std::unique_ptr<tflite::Interpreter> interpreter;
|
if (!SuggestActionsFromModel(conversation, num_messages, options, response,
|
&interpreter)) {
|
TC3_LOG(ERROR) << "Could not run model.";
|
return false;
|
}
|
|
// Suppress all predictions if the conversation was deemed sensitive.
|
if (preconditions_.suppress_on_sensitive_topic &&
|
response->output_filtered_sensitivity) {
|
return true;
|
}
|
|
if (!SuggestActionsFromLua(
|
conversation, model_executor_.get(), interpreter.get(),
|
annotator != nullptr ? annotator->entity_data_schema() : nullptr,
|
&response->actions)) {
|
TC3_LOG(ERROR) << "Could not suggest actions from script.";
|
return false;
|
}
|
|
if (!SuggestActionsFromRules(conversation, &response->actions)) {
|
TC3_LOG(ERROR) << "Could not suggest actions from rules.";
|
return false;
|
}
|
|
if (preconditions_.suppress_on_low_confidence_input &&
|
!FilterConfidenceOutput(post_check_rules, &response->actions)) {
|
TC3_LOG(ERROR) << "Could not post-check actions.";
|
return false;
|
}
|
|
return true;
|
}
|
|
ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
|
const Conversation& conversation, const Annotator* annotator,
|
const ActionSuggestionOptions& options) const {
|
ActionsSuggestionsResponse response;
|
if (!GatherActionsSuggestions(conversation, annotator, options, &response)) {
|
TC3_LOG(ERROR) << "Could not gather actions suggestions.";
|
response.actions.clear();
|
} else if (!ranker_->RankActions(conversation, &response, entity_data_schema_,
|
annotator != nullptr
|
? annotator->entity_data_schema()
|
: nullptr)) {
|
TC3_LOG(ERROR) << "Could not rank actions.";
|
response.actions.clear();
|
}
|
return response;
|
}
|
|
ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
|
const Conversation& conversation,
|
const ActionSuggestionOptions& options) const {
|
return SuggestActions(conversation, /*annotator=*/nullptr, options);
|
}
|
|
const ActionsModel* ActionsSuggestions::model() const { return model_; }
|
const reflection::Schema* ActionsSuggestions::entity_data_schema() const {
|
return entity_data_schema_;
|
}
|
|
const ActionsModel* ViewActionsModel(const void* buffer, int size) {
|
if (buffer == nullptr) {
|
return nullptr;
|
}
|
return LoadAndVerifyModel(reinterpret_cast<const uint8_t*>(buffer), size);
|
}
|
|
} // namespace libtextclassifier3
|