/*
|
* Copyright (C) 2018 The Android Open Source Project
|
*
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
*/
|
|
#include "actions/actions-suggestions.h"
|
|
#include <fstream>
|
#include <iterator>
|
#include <memory>
|
|
#include "actions/actions_model_generated.h"
|
#include "actions/test_utils.h"
|
#include "actions/zlib-utils.h"
|
#include "annotator/collections.h"
|
#include "annotator/types.h"
|
#include "utils/flatbuffers.h"
|
#include "utils/flatbuffers_generated.h"
|
#include "utils/hash/farmhash.h"
|
#include "gmock/gmock.h"
|
#include "gtest/gtest.h"
|
#include "flatbuffers/flatbuffers.h"
|
#include "flatbuffers/reflection.h"
|
|
namespace libtextclassifier3 {
|
namespace {
|
using testing::_;
|
|
constexpr char kModelFileName[] = "actions_suggestions_test.model";
|
constexpr char kHashGramModelFileName[] =
|
"actions_suggestions_test.hashgram.model";
|
|
std::string ReadFile(const std::string& file_name) {
|
std::ifstream file_stream(file_name);
|
return std::string(std::istreambuf_iterator<char>(file_stream), {});
|
}
|
|
std::string GetModelPath() {
|
return "";
|
}
|
|
class ActionsSuggestionsTest : public testing::Test {
|
protected:
|
ActionsSuggestionsTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
|
std::unique_ptr<ActionsSuggestions> LoadTestModel() {
|
return ActionsSuggestions::FromPath(GetModelPath() + kModelFileName,
|
&unilib_);
|
}
|
std::unique_ptr<ActionsSuggestions> LoadHashGramTestModel() {
|
return ActionsSuggestions::FromPath(GetModelPath() + kHashGramModelFileName,
|
&unilib_);
|
}
|
UniLib unilib_;
|
};
|
|
TEST_F(ActionsSuggestionsTest, InstantiateActionSuggestions) {
|
EXPECT_THAT(LoadTestModel(), testing::NotNull());
|
}
|
|
TEST_F(ActionsSuggestionsTest, SuggestActions) {
|
std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
|
const ActionsSuggestionsResponse& response =
|
actions_suggestions->SuggestActions(
|
{{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
|
/*reference_timezone=*/"Europe/Zurich",
|
/*annotations=*/{}, /*locales=*/"en"}}});
|
EXPECT_EQ(response.actions.size(), 3 /* share_location + 2 smart replies*/);
|
}
|
|
TEST_F(ActionsSuggestionsTest, SuggestNoActionsForUnknownLocale) {
|
std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
|
const ActionsSuggestionsResponse& response =
|
actions_suggestions->SuggestActions(
|
{{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
|
/*reference_timezone=*/"Europe/Zurich",
|
/*annotations=*/{}, /*locales=*/"zz"}}});
|
EXPECT_THAT(response.actions, testing::IsEmpty());
|
}
|
|
TEST_F(ActionsSuggestionsTest, SuggestActionsFromAnnotations) {
|
std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
|
AnnotatedSpan annotation;
|
annotation.span = {11, 15};
|
annotation.classification = {ClassificationResult("address", 1.0)};
|
const ActionsSuggestionsResponse& response =
|
actions_suggestions->SuggestActions(
|
{{{/*user_id=*/1, "are you at home?",
|
/*reference_time_ms_utc=*/0,
|
/*reference_timezone=*/"Europe/Zurich",
|
/*annotations=*/{annotation},
|
/*locales=*/"en"}}});
|
ASSERT_GE(response.actions.size(), 1);
|
EXPECT_EQ(response.actions.front().type, "view_map");
|
EXPECT_EQ(response.actions.front().score, 1.0);
|
}
|
|
TEST_F(ActionsSuggestionsTest, SuggestActionsFromAnnotationsWithEntityData) {
|
const std::string actions_model_string =
|
ReadFile(GetModelPath() + kModelFileName);
|
std::unique_ptr<ActionsModelT> actions_model =
|
UnPackActionsModel(actions_model_string.c_str());
|
SetTestEntityDataSchema(actions_model.get());
|
|
// Set custom actions from annotations config.
|
actions_model->annotation_actions_spec->annotation_mapping.clear();
|
actions_model->annotation_actions_spec->annotation_mapping.emplace_back(
|
new AnnotationActionsSpec_::AnnotationMappingT);
|
AnnotationActionsSpec_::AnnotationMappingT* mapping =
|
actions_model->annotation_actions_spec->annotation_mapping.back().get();
|
mapping->annotation_collection = "address";
|
mapping->action.reset(new ActionSuggestionSpecT);
|
mapping->action->type = "save_location";
|
mapping->action->score = 1.0;
|
mapping->action->priority_score = 2.0;
|
mapping->entity_field.reset(new FlatbufferFieldPathT);
|
mapping->entity_field->field.emplace_back(new FlatbufferFieldT);
|
mapping->entity_field->field.back()->field_name = "location";
|
|
flatbuffers::FlatBufferBuilder builder;
|
FinishActionsModelBuffer(builder,
|
ActionsModel::Pack(builder, actions_model.get()));
|
std::unique_ptr<ActionsSuggestions> actions_suggestions =
|
ActionsSuggestions::FromUnownedBuffer(
|
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
|
builder.GetSize(), &unilib_);
|
|
AnnotatedSpan annotation;
|
annotation.span = {11, 15};
|
annotation.classification = {ClassificationResult("address", 1.0)};
|
const ActionsSuggestionsResponse& response =
|
actions_suggestions->SuggestActions(
|
{{{/*user_id=*/1, "are you at home?",
|
/*reference_time_ms_utc=*/0,
|
/*reference_timezone=*/"Europe/Zurich",
|
/*annotations=*/{annotation},
|
/*locales=*/"en"}}});
|
ASSERT_GE(response.actions.size(), 1);
|
EXPECT_EQ(response.actions.front().type, "save_location");
|
EXPECT_EQ(response.actions.front().score, 1.0);
|
|
// Check that the `location` entity field holds the text from the address
|
// annotation.
|
const flatbuffers::Table* entity =
|
flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
|
response.actions.front().serialized_entity_data.data()));
|
EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
|
"home");
|
}
|
|
TEST_F(ActionsSuggestionsTest, SuggestActionsFromDuplicatedAnnotations) {
|
std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
|
AnnotatedSpan flight_annotation;
|
flight_annotation.span = {11, 15};
|
flight_annotation.classification = {ClassificationResult("flight", 2.5)};
|
AnnotatedSpan flight_annotation2;
|
flight_annotation2.span = {35, 39};
|
flight_annotation2.classification = {ClassificationResult("flight", 3.0)};
|
AnnotatedSpan email_annotation;
|
email_annotation.span = {55, 68};
|
email_annotation.classification = {ClassificationResult("email", 2.0)};
|
|
const ActionsSuggestionsResponse& response =
|
actions_suggestions->SuggestActions(
|
{{{/*user_id=*/1,
|
"call me at LX38 or send message to LX38 or test@test.com.",
|
/*reference_time_ms_utc=*/0,
|
/*reference_timezone=*/"Europe/Zurich",
|
/*annotations=*/
|
{flight_annotation, flight_annotation2, email_annotation},
|
/*locales=*/"en"}}});
|
|
ASSERT_GE(response.actions.size(), 2);
|
EXPECT_EQ(response.actions[0].type, "track_flight");
|
EXPECT_EQ(response.actions[0].score, 3.0);
|
EXPECT_EQ(response.actions[1].type, "send_email");
|
EXPECT_EQ(response.actions[1].score, 2.0);
|
}
|
|
TEST_F(ActionsSuggestionsTest, SuggestActionsAnnotationsNoDeduplication) {
|
const std::string actions_model_string =
|
ReadFile(GetModelPath() + kModelFileName);
|
std::unique_ptr<ActionsModelT> actions_model =
|
UnPackActionsModel(actions_model_string.c_str());
|
// Disable deduplication.
|
actions_model->annotation_actions_spec->deduplicate_annotations = false;
|
flatbuffers::FlatBufferBuilder builder;
|
FinishActionsModelBuffer(builder,
|
ActionsModel::Pack(builder, actions_model.get()));
|
std::unique_ptr<ActionsSuggestions> actions_suggestions =
|
ActionsSuggestions::FromUnownedBuffer(
|
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
|
builder.GetSize(), &unilib_);
|
AnnotatedSpan flight_annotation;
|
flight_annotation.span = {11, 15};
|
flight_annotation.classification = {ClassificationResult("flight", 2.5)};
|
AnnotatedSpan flight_annotation2;
|
flight_annotation2.span = {35, 39};
|
flight_annotation2.classification = {ClassificationResult("flight", 3.0)};
|
AnnotatedSpan email_annotation;
|
email_annotation.span = {55, 68};
|
email_annotation.classification = {ClassificationResult("email", 2.0)};
|
|
const ActionsSuggestionsResponse& response =
|
actions_suggestions->SuggestActions(
|
{{{/*user_id=*/1,
|
"call me at LX38 or send message to LX38 or test@test.com.",
|
/*reference_time_ms_utc=*/0,
|
/*reference_timezone=*/"Europe/Zurich",
|
/*annotations=*/
|
{flight_annotation, flight_annotation2, email_annotation},
|
/*locales=*/"en"}}});
|
|
ASSERT_GE(response.actions.size(), 3);
|
EXPECT_EQ(response.actions[0].type, "track_flight");
|
EXPECT_EQ(response.actions[0].score, 3.0);
|
EXPECT_EQ(response.actions[1].type, "track_flight");
|
EXPECT_EQ(response.actions[1].score, 2.5);
|
EXPECT_EQ(response.actions[2].type, "send_email");
|
EXPECT_EQ(response.actions[2].score, 2.0);
|
}
|
|
ActionsSuggestionsResponse TestSuggestActionsFromAnnotations(
|
const std::function<void(ActionsModelT*)>& set_config_fn,
|
const UniLib* unilib = nullptr) {
|
const std::string actions_model_string =
|
ReadFile(GetModelPath() + kModelFileName);
|
std::unique_ptr<ActionsModelT> actions_model =
|
UnPackActionsModel(actions_model_string.c_str());
|
|
// Set custom config.
|
set_config_fn(actions_model.get());
|
|
// Disable smart reply for easier testing.
|
actions_model->preconditions->min_smart_reply_triggering_score = 1.0;
|
|
flatbuffers::FlatBufferBuilder builder;
|
FinishActionsModelBuffer(builder,
|
ActionsModel::Pack(builder, actions_model.get()));
|
std::unique_ptr<ActionsSuggestions> actions_suggestions =
|
ActionsSuggestions::FromUnownedBuffer(
|
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
|
builder.GetSize(), unilib);
|
|
AnnotatedSpan flight_annotation;
|
flight_annotation.span = {15, 19};
|
flight_annotation.classification = {ClassificationResult("flight", 2.0)};
|
AnnotatedSpan email_annotation;
|
email_annotation.span = {0, 16};
|
email_annotation.classification = {ClassificationResult("email", 1.0)};
|
|
return actions_suggestions->SuggestActions(
|
{{{/*user_id=*/ActionsSuggestions::kLocalUserId,
|
"hehe@android.com",
|
/*reference_time_ms_utc=*/0,
|
/*reference_timezone=*/"Europe/Zurich",
|
/*annotations=*/
|
{email_annotation},
|
/*locales=*/"en"},
|
{/*user_id=*/2,
|
"yoyo@android.com",
|
/*reference_time_ms_utc=*/0,
|
/*reference_timezone=*/"Europe/Zurich",
|
/*annotations=*/
|
{email_annotation},
|
/*locales=*/"en"},
|
{/*user_id=*/1,
|
"test@android.com",
|
/*reference_time_ms_utc=*/0,
|
/*reference_timezone=*/"Europe/Zurich",
|
/*annotations=*/
|
{email_annotation},
|
/*locales=*/"en"},
|
{/*user_id=*/1,
|
"I am on flight LX38.",
|
/*reference_time_ms_utc=*/0,
|
/*reference_timezone=*/"Europe/Zurich",
|
/*annotations=*/
|
{flight_annotation},
|
/*locales=*/"en"}}});
|
}
|
|
TEST_F(ActionsSuggestionsTest, SuggestActionsWithAnnotationsOnlyLastMessage) {
|
const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
|
[](ActionsModelT* actions_model) {
|
actions_model->annotation_actions_spec->include_local_user_messages =
|
false;
|
actions_model->annotation_actions_spec->only_until_last_sent = true;
|
actions_model->annotation_actions_spec->max_history_from_any_person = 1;
|
actions_model->annotation_actions_spec->max_history_from_last_person =
|
1;
|
},
|
&unilib_);
|
EXPECT_EQ(response.actions.size(), 1);
|
EXPECT_EQ(response.actions[0].type, "track_flight");
|
}
|
|
TEST_F(ActionsSuggestionsTest, SuggestActionsWithAnnotationsOnlyLastPerson) {
|
const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
|
[](ActionsModelT* actions_model) {
|
actions_model->annotation_actions_spec->include_local_user_messages =
|
false;
|
actions_model->annotation_actions_spec->only_until_last_sent = true;
|
actions_model->annotation_actions_spec->max_history_from_any_person = 1;
|
actions_model->annotation_actions_spec->max_history_from_last_person =
|
3;
|
},
|
&unilib_);
|
EXPECT_EQ(response.actions.size(), 2);
|
EXPECT_EQ(response.actions[0].type, "track_flight");
|
EXPECT_EQ(response.actions[1].type, "send_email");
|
}
|
|
TEST_F(ActionsSuggestionsTest, SuggestActionsWithAnnotationsFromAny) {
|
const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
|
[](ActionsModelT* actions_model) {
|
actions_model->annotation_actions_spec->include_local_user_messages =
|
false;
|
actions_model->annotation_actions_spec->only_until_last_sent = true;
|
actions_model->annotation_actions_spec->max_history_from_any_person = 2;
|
actions_model->annotation_actions_spec->max_history_from_last_person =
|
1;
|
},
|
&unilib_);
|
EXPECT_EQ(response.actions.size(), 2);
|
EXPECT_EQ(response.actions[0].type, "track_flight");
|
EXPECT_EQ(response.actions[1].type, "send_email");
|
}
|
|
TEST_F(ActionsSuggestionsTest,
|
SuggestActionsWithAnnotationsFromAnyManyMessages) {
|
const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
|
[](ActionsModelT* actions_model) {
|
actions_model->annotation_actions_spec->include_local_user_messages =
|
false;
|
actions_model->annotation_actions_spec->only_until_last_sent = true;
|
actions_model->annotation_actions_spec->max_history_from_any_person = 3;
|
actions_model->annotation_actions_spec->max_history_from_last_person =
|
1;
|
},
|
&unilib_);
|
EXPECT_EQ(response.actions.size(), 3);
|
EXPECT_EQ(response.actions[0].type, "track_flight");
|
EXPECT_EQ(response.actions[1].type, "send_email");
|
EXPECT_EQ(response.actions[2].type, "send_email");
|
}
|
|
TEST_F(ActionsSuggestionsTest,
|
SuggestActionsWithAnnotationsFromAnyManyMessagesButNotLocalUser) {
|
const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
|
[](ActionsModelT* actions_model) {
|
actions_model->annotation_actions_spec->include_local_user_messages =
|
false;
|
actions_model->annotation_actions_spec->only_until_last_sent = true;
|
actions_model->annotation_actions_spec->max_history_from_any_person = 5;
|
actions_model->annotation_actions_spec->max_history_from_last_person =
|
1;
|
},
|
&unilib_);
|
EXPECT_EQ(response.actions.size(), 3);
|
EXPECT_EQ(response.actions[0].type, "track_flight");
|
EXPECT_EQ(response.actions[1].type, "send_email");
|
EXPECT_EQ(response.actions[2].type, "send_email");
|
}
|
|
TEST_F(ActionsSuggestionsTest,
|
SuggestActionsWithAnnotationsFromAnyManyMessagesAlsoFromLocalUser) {
|
const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
|
[](ActionsModelT* actions_model) {
|
actions_model->annotation_actions_spec->include_local_user_messages =
|
true;
|
actions_model->annotation_actions_spec->only_until_last_sent = false;
|
actions_model->annotation_actions_spec->max_history_from_any_person = 5;
|
actions_model->annotation_actions_spec->max_history_from_last_person =
|
1;
|
},
|
&unilib_);
|
EXPECT_EQ(response.actions.size(), 4);
|
EXPECT_EQ(response.actions[0].type, "track_flight");
|
EXPECT_EQ(response.actions[1].type, "send_email");
|
EXPECT_EQ(response.actions[2].type, "send_email");
|
EXPECT_EQ(response.actions[3].type, "send_email");
|
}
|
|
void TestSuggestActionsWithThreshold(
|
const std::function<void(ActionsModelT*)>& set_value_fn,
|
const UniLib* unilib = nullptr, const int expected_size = 0,
|
const std::string& preconditions_overwrite = "") {
|
const std::string actions_model_string =
|
ReadFile(GetModelPath() + kModelFileName);
|
std::unique_ptr<ActionsModelT> actions_model =
|
UnPackActionsModel(actions_model_string.c_str());
|
set_value_fn(actions_model.get());
|
flatbuffers::FlatBufferBuilder builder;
|
FinishActionsModelBuffer(builder,
|
ActionsModel::Pack(builder, actions_model.get()));
|
std::unique_ptr<ActionsSuggestions> actions_suggestions =
|
ActionsSuggestions::FromUnownedBuffer(
|
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
|
builder.GetSize(), unilib, preconditions_overwrite);
|
ASSERT_TRUE(actions_suggestions);
|
const ActionsSuggestionsResponse& response =
|
actions_suggestions->SuggestActions(
|
{{{/*user_id=*/1, "I have the low-ground. Where are you?",
|
/*reference_time_ms_utc=*/0,
|
/*reference_timezone=*/"Europe/Zurich",
|
/*annotations=*/{}, /*locales=*/"en"}}});
|
EXPECT_LE(response.actions.size(), expected_size);
|
}
|
|
TEST_F(ActionsSuggestionsTest, SuggestActionsWithTriggeringScore) {
|
TestSuggestActionsWithThreshold(
|
[](ActionsModelT* actions_model) {
|
actions_model->preconditions->min_smart_reply_triggering_score = 1.0;
|
},
|
&unilib_,
|
/*expected_size=*/1 /*no smart reply, only actions*/
|
);
|
}
|
|
TEST_F(ActionsSuggestionsTest, SuggestActionsWithMinReplyScore) {
|
TestSuggestActionsWithThreshold(
|
[](ActionsModelT* actions_model) {
|
actions_model->preconditions->min_reply_score_threshold = 1.0;
|
},
|
&unilib_,
|
/*expected_size=*/1 /*no smart reply, only actions*/
|
);
|
}
|
|
TEST_F(ActionsSuggestionsTest, SuggestActionsWithSensitiveTopicScore) {
|
TestSuggestActionsWithThreshold(
|
[](ActionsModelT* actions_model) {
|
actions_model->preconditions->max_sensitive_topic_score = 0.0;
|
},
|
&unilib_,
|
/*expected_size=*/4 /* no sensitive prediction in test model*/);
|
}
|
|
TEST_F(ActionsSuggestionsTest, SuggestActionsWithMaxInputLength) {
|
TestSuggestActionsWithThreshold(
|
[](ActionsModelT* actions_model) {
|
actions_model->preconditions->max_input_length = 0;
|
},
|
&unilib_);
|
}
|
|
TEST_F(ActionsSuggestionsTest, SuggestActionsWithMinInputLength) {
|
TestSuggestActionsWithThreshold(
|
[](ActionsModelT* actions_model) {
|
actions_model->preconditions->min_input_length = 100;
|
},
|
&unilib_);
|
}
|
|
TEST_F(ActionsSuggestionsTest, SuggestActionsWithPreconditionsOverwrite) {
|
TriggeringPreconditionsT preconditions_overwrite;
|
preconditions_overwrite.max_input_length = 0;
|
flatbuffers::FlatBufferBuilder builder;
|
builder.Finish(
|
TriggeringPreconditions::Pack(builder, &preconditions_overwrite));
|
TestSuggestActionsWithThreshold(
|
// Keep model untouched.
|
[](ActionsModelT* actions_model) {}, &unilib_,
|
/*expected_size=*/0,
|
std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
|
builder.GetSize()));
|
}
|
|
#ifdef TC3_UNILIB_ICU
|
TEST_F(ActionsSuggestionsTest, SuggestActionsLowConfidence) {
|
TestSuggestActionsWithThreshold(
|
[](ActionsModelT* actions_model) {
|
actions_model->preconditions->suppress_on_low_confidence_input = true;
|
actions_model->low_confidence_rules.reset(new RulesModelT);
|
actions_model->low_confidence_rules->rule.emplace_back(
|
new RulesModel_::RuleT);
|
actions_model->low_confidence_rules->rule.back()->pattern =
|
"low-ground";
|
},
|
&unilib_);
|
}
|
|
TEST_F(ActionsSuggestionsTest, SuggestActionsLowConfidenceInputOutput) {
|
const std::string actions_model_string =
|
ReadFile(GetModelPath() + kModelFileName);
|
std::unique_ptr<ActionsModelT> actions_model =
|
UnPackActionsModel(actions_model_string.c_str());
|
// Add custom triggering rule.
|
actions_model->rules.reset(new RulesModelT());
|
actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
|
RulesModel_::RuleT* rule = actions_model->rules->rule.back().get();
|
rule->pattern = "^(?i:hello\\s(there))$";
|
{
|
std::unique_ptr<RulesModel_::Rule_::RuleActionSpecT> rule_action(
|
new RulesModel_::Rule_::RuleActionSpecT);
|
rule_action->action.reset(new ActionSuggestionSpecT);
|
rule_action->action->type = "text_reply";
|
rule_action->action->response_text = "General Desaster!";
|
rule_action->action->score = 1.0f;
|
rule_action->action->priority_score = 1.0f;
|
rule->actions.push_back(std::move(rule_action));
|
}
|
{
|
std::unique_ptr<RulesModel_::Rule_::RuleActionSpecT> rule_action(
|
new RulesModel_::Rule_::RuleActionSpecT);
|
rule_action->action.reset(new ActionSuggestionSpecT);
|
rule_action->action->type = "text_reply";
|
rule_action->action->response_text = "General Kenobi!";
|
rule_action->action->score = 1.0f;
|
rule_action->action->priority_score = 1.0f;
|
rule->actions.push_back(std::move(rule_action));
|
}
|
|
// Add input-output low confidence rule.
|
actions_model->preconditions->suppress_on_low_confidence_input = true;
|
actions_model->low_confidence_rules.reset(new RulesModelT);
|
actions_model->low_confidence_rules->rule.emplace_back(
|
new RulesModel_::RuleT);
|
actions_model->low_confidence_rules->rule.back()->pattern = "hello";
|
actions_model->low_confidence_rules->rule.back()->output_pattern =
|
"(?i:desaster)";
|
|
flatbuffers::FlatBufferBuilder builder;
|
FinishActionsModelBuffer(builder,
|
ActionsModel::Pack(builder, actions_model.get()));
|
std::unique_ptr<ActionsSuggestions> actions_suggestions =
|
ActionsSuggestions::FromUnownedBuffer(
|
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
|
builder.GetSize(), &unilib_);
|
ASSERT_TRUE(actions_suggestions);
|
const ActionsSuggestionsResponse& response =
|
actions_suggestions->SuggestActions(
|
{{{/*user_id=*/1, "hello there",
|
/*reference_time_ms_utc=*/0,
|
/*reference_timezone=*/"Europe/Zurich",
|
/*annotations=*/{}, /*locales=*/"en"}}});
|
ASSERT_GE(response.actions.size(), 1);
|
EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
|
}
|
|
TEST_F(ActionsSuggestionsTest,
|
SuggestActionsLowConfidenceInputOutputOverwrite) {
|
const std::string actions_model_string =
|
ReadFile(GetModelPath() + kModelFileName);
|
std::unique_ptr<ActionsModelT> actions_model =
|
UnPackActionsModel(actions_model_string.c_str());
|
actions_model->low_confidence_rules.reset();
|
|
// Add custom triggering rule.
|
actions_model->rules.reset(new RulesModelT());
|
actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
|
RulesModel_::RuleT* rule = actions_model->rules->rule.back().get();
|
rule->pattern = "^(?i:hello\\s(there))$";
|
{
|
std::unique_ptr<RulesModel_::Rule_::RuleActionSpecT> rule_action(
|
new RulesModel_::Rule_::RuleActionSpecT);
|
rule_action->action.reset(new ActionSuggestionSpecT);
|
rule_action->action->type = "text_reply";
|
rule_action->action->response_text = "General Desaster!";
|
rule_action->action->score = 1.0f;
|
rule_action->action->priority_score = 1.0f;
|
rule->actions.push_back(std::move(rule_action));
|
}
|
{
|
std::unique_ptr<RulesModel_::Rule_::RuleActionSpecT> rule_action(
|
new RulesModel_::Rule_::RuleActionSpecT);
|
rule_action->action.reset(new ActionSuggestionSpecT);
|
rule_action->action->type = "text_reply";
|
rule_action->action->response_text = "General Kenobi!";
|
rule_action->action->score = 1.0f;
|
rule_action->action->priority_score = 1.0f;
|
rule->actions.push_back(std::move(rule_action));
|
}
|
|
// Add custom triggering rule via overwrite.
|
actions_model->preconditions->low_confidence_rules.reset();
|
TriggeringPreconditionsT preconditions;
|
preconditions.suppress_on_low_confidence_input = true;
|
preconditions.low_confidence_rules.reset(new RulesModelT);
|
preconditions.low_confidence_rules->rule.emplace_back(new RulesModel_::RuleT);
|
preconditions.low_confidence_rules->rule.back()->pattern = "hello";
|
preconditions.low_confidence_rules->rule.back()->output_pattern =
|
"(?i:desaster)";
|
flatbuffers::FlatBufferBuilder preconditions_builder;
|
preconditions_builder.Finish(
|
TriggeringPreconditions::Pack(preconditions_builder, &preconditions));
|
std::string serialize_preconditions = std::string(
|
reinterpret_cast<const char*>(preconditions_builder.GetBufferPointer()),
|
preconditions_builder.GetSize());
|
|
flatbuffers::FlatBufferBuilder builder;
|
FinishActionsModelBuffer(builder,
|
ActionsModel::Pack(builder, actions_model.get()));
|
std::unique_ptr<ActionsSuggestions> actions_suggestions =
|
ActionsSuggestions::FromUnownedBuffer(
|
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
|
builder.GetSize(), &unilib_, serialize_preconditions);
|
|
ASSERT_TRUE(actions_suggestions);
|
const ActionsSuggestionsResponse& response =
|
actions_suggestions->SuggestActions(
|
{{{/*user_id=*/1, "hello there",
|
/*reference_time_ms_utc=*/0,
|
/*reference_timezone=*/"Europe/Zurich",
|
/*annotations=*/{}, /*locales=*/"en"}}});
|
ASSERT_GE(response.actions.size(), 1);
|
EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
|
}
|
#endif
|
|
TEST_F(ActionsSuggestionsTest, SuppressActionsFromAnnotationsOnSensitiveTopic) {
|
const std::string actions_model_string =
|
ReadFile(GetModelPath() + kModelFileName);
|
std::unique_ptr<ActionsModelT> actions_model =
|
UnPackActionsModel(actions_model_string.c_str());
|
|
// Don't test if no sensitivity score is produced
|
if (actions_model->tflite_model_spec->output_sensitive_topic_score < 0) {
|
return;
|
}
|
|
actions_model->preconditions->max_sensitive_topic_score = 0.0;
|
actions_model->preconditions->suppress_on_sensitive_topic = true;
|
flatbuffers::FlatBufferBuilder builder;
|
FinishActionsModelBuffer(builder,
|
ActionsModel::Pack(builder, actions_model.get()));
|
std::unique_ptr<ActionsSuggestions> actions_suggestions =
|
ActionsSuggestions::FromUnownedBuffer(
|
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
|
builder.GetSize(), &unilib_);
|
AnnotatedSpan annotation;
|
annotation.span = {11, 15};
|
annotation.classification = {
|
ClassificationResult(Collections::Address(), 1.0)};
|
const ActionsSuggestionsResponse& response =
|
actions_suggestions->SuggestActions(
|
{{{/*user_id=*/1, "are you at home?",
|
/*reference_time_ms_utc=*/0,
|
/*reference_timezone=*/"Europe/Zurich",
|
/*annotations=*/{annotation},
|
/*locales=*/"en"}}});
|
EXPECT_THAT(response.actions, testing::IsEmpty());
|
}
|
|
TEST_F(ActionsSuggestionsTest, SuggestActionsWithLongerConversation) {
|
const std::string actions_model_string =
|
ReadFile(GetModelPath() + kModelFileName);
|
std::unique_ptr<ActionsModelT> actions_model =
|
UnPackActionsModel(actions_model_string.c_str());
|
|
// Allow a larger conversation context.
|
actions_model->max_conversation_history_length = 10;
|
|
flatbuffers::FlatBufferBuilder builder;
|
FinishActionsModelBuffer(builder,
|
ActionsModel::Pack(builder, actions_model.get()));
|
std::unique_ptr<ActionsSuggestions> actions_suggestions =
|
ActionsSuggestions::FromUnownedBuffer(
|
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
|
builder.GetSize(), &unilib_);
|
AnnotatedSpan annotation;
|
annotation.span = {11, 15};
|
annotation.classification = {
|
ClassificationResult(Collections::Address(), 1.0)};
|
const ActionsSuggestionsResponse& response =
|
actions_suggestions->SuggestActions(
|
{{{/*user_id=*/ActionsSuggestions::kLocalUserId, "hi, how are you?",
|
/*reference_time_ms_utc=*/10000,
|
/*reference_timezone=*/"Europe/Zurich",
|
/*annotations=*/{}, /*locales=*/"en"},
|
{/*user_id=*/1, "good! are you at home?",
|
/*reference_time_ms_utc=*/15000,
|
/*reference_timezone=*/"Europe/Zurich",
|
/*annotations=*/{annotation},
|
/*locales=*/"en"}}});
|
ASSERT_GE(response.actions.size(), 1);
|
EXPECT_EQ(response.actions[0].type, "view_map");
|
EXPECT_EQ(response.actions[0].score, 1.0);
|
}
|
|
TEST_F(ActionsSuggestionsTest, CreateActionsFromClassificationResult) {
|
std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
|
AnnotatedSpan annotation;
|
annotation.span = {8, 12};
|
annotation.classification = {
|
ClassificationResult(Collections::Flight(), 1.0)};
|
|
const ActionsSuggestionsResponse& response =
|
actions_suggestions->SuggestActions(
|
{{{/*user_id=*/1, "I'm on LX38?",
|
/*reference_time_ms_utc=*/0,
|
/*reference_timezone=*/"Europe/Zurich",
|
/*annotations=*/{annotation},
|
/*locales=*/"en"}}});
|
|
ASSERT_GE(response.actions.size(), 2);
|
EXPECT_EQ(response.actions[0].type, "track_flight");
|
EXPECT_EQ(response.actions[0].score, 1.0);
|
EXPECT_EQ(response.actions[0].annotations.size(), 1);
|
EXPECT_EQ(response.actions[0].annotations[0].span.message_index, 0);
|
EXPECT_EQ(response.actions[0].annotations[0].span.span, annotation.span);
|
}
|
|
#ifdef TC3_UNILIB_ICU
|
TEST_F(ActionsSuggestionsTest, CreateActionsFromRules) {
|
const std::string actions_model_string =
|
ReadFile(GetModelPath() + kModelFileName);
|
std::unique_ptr<ActionsModelT> actions_model =
|
UnPackActionsModel(actions_model_string.c_str());
|
ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
|
|
actions_model->rules.reset(new RulesModelT());
|
actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
|
RulesModel_::RuleT* rule = actions_model->rules->rule.back().get();
|
rule->pattern = "^(?i:hello\\s(there))$";
|
rule->actions.emplace_back(new RulesModel_::Rule_::RuleActionSpecT);
|
rule->actions.back()->action.reset(new ActionSuggestionSpecT);
|
ActionSuggestionSpecT* action = rule->actions.back()->action.get();
|
action->type = "text_reply";
|
action->response_text = "General Kenobi!";
|
action->score = 1.0f;
|
action->priority_score = 1.0f;
|
|
// Set capturing groups for entity data.
|
rule->actions.back()->capturing_group.emplace_back(
|
new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT);
|
RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* greeting_group =
|
rule->actions.back()->capturing_group.back().get();
|
greeting_group->group_id = 0;
|
greeting_group->entity_field.reset(new FlatbufferFieldPathT);
|
greeting_group->entity_field->field.emplace_back(new FlatbufferFieldT);
|
greeting_group->entity_field->field.back()->field_name = "greeting";
|
rule->actions.back()->capturing_group.emplace_back(
|
new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT);
|
RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* location_group =
|
rule->actions.back()->capturing_group.back().get();
|
location_group->group_id = 1;
|
location_group->entity_field.reset(new FlatbufferFieldPathT);
|
location_group->entity_field->field.emplace_back(new FlatbufferFieldT);
|
location_group->entity_field->field.back()->field_name = "location";
|
|
// Set test entity data schema.
|
SetTestEntityDataSchema(actions_model.get());
|
|
// Use meta data to generate custom serialized entity data.
|
ReflectiveFlatbufferBuilder entity_data_builder(
|
flatbuffers::GetRoot<reflection::Schema>(
|
actions_model->actions_entity_data_schema.data()));
|
std::unique_ptr<ReflectiveFlatbuffer> entity_data =
|
entity_data_builder.NewRoot();
|
entity_data->Set("person", "Kenobi");
|
action->serialized_entity_data = entity_data->Serialize();
|
|
flatbuffers::FlatBufferBuilder builder;
|
FinishActionsModelBuffer(builder,
|
ActionsModel::Pack(builder, actions_model.get()));
|
std::unique_ptr<ActionsSuggestions> actions_suggestions =
|
ActionsSuggestions::FromUnownedBuffer(
|
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
|
builder.GetSize(), &unilib_);
|
|
const ActionsSuggestionsResponse& response =
|
actions_suggestions->SuggestActions(
|
{{{/*user_id=*/1, "hello there", /*reference_time_ms_utc=*/0,
|
/*reference_timezone=*/"Europe/Zurich",
|
/*annotations=*/{}, /*locales=*/"en"}}});
|
EXPECT_GE(response.actions.size(), 1);
|
EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
|
|
// Check entity data.
|
const flatbuffers::Table* entity =
|
flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
|
response.actions[0].serialized_entity_data.data()));
|
EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
|
"hello there");
|
EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
|
"there");
|
EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
|
"Kenobi");
|
}
|
|
TEST_F(ActionsSuggestionsTest, CreatesTextRepliesFromRules) {
|
const std::string actions_model_string =
|
ReadFile(GetModelPath() + kModelFileName);
|
std::unique_ptr<ActionsModelT> actions_model =
|
UnPackActionsModel(actions_model_string.c_str());
|
ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
|
|
actions_model->rules.reset(new RulesModelT());
|
actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
|
RulesModel_::RuleT* rule = actions_model->rules->rule.back().get();
|
rule->pattern = "(?i:reply (stop|quit|end) (?:to|for) )";
|
rule->actions.emplace_back(new RulesModel_::Rule_::RuleActionSpecT);
|
|
// Set capturing groups for entity data.
|
rule->actions.back()->capturing_group.emplace_back(
|
new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT);
|
RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* code_group =
|
rule->actions.back()->capturing_group.back().get();
|
code_group->group_id = 1;
|
code_group->text_reply.reset(new ActionSuggestionSpecT);
|
code_group->text_reply->score = 1.0f;
|
code_group->text_reply->priority_score = 1.0f;
|
|
flatbuffers::FlatBufferBuilder builder;
|
FinishActionsModelBuffer(builder,
|
ActionsModel::Pack(builder, actions_model.get()));
|
std::unique_ptr<ActionsSuggestions> actions_suggestions =
|
ActionsSuggestions::FromUnownedBuffer(
|
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
|
builder.GetSize(), &unilib_);
|
|
const ActionsSuggestionsResponse& response =
|
actions_suggestions->SuggestActions(
|
{{{/*user_id=*/1,
|
"visit test.com or reply STOP to cancel your subscription",
|
/*reference_time_ms_utc=*/0,
|
/*reference_timezone=*/"Europe/Zurich",
|
/*annotations=*/{}, /*locales=*/"en"}}});
|
EXPECT_GE(response.actions.size(), 1);
|
EXPECT_EQ(response.actions[0].response_text, "STOP");
|
}
|
|
TEST_F(ActionsSuggestionsTest, DeduplicateActions) {
|
std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
|
ActionsSuggestionsResponse response = actions_suggestions->SuggestActions(
|
{{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
|
/*reference_timezone=*/"Europe/Zurich",
|
/*annotations=*/{}, /*locales=*/"en"}}});
|
|
// Check that the location sharing model triggered.
|
bool has_location_sharing_action = false;
|
for (const ActionSuggestion action : response.actions) {
|
if (action.type == ActionsSuggestions::kShareLocation) {
|
has_location_sharing_action = true;
|
break;
|
}
|
}
|
EXPECT_TRUE(has_location_sharing_action);
|
const int num_actions = response.actions.size();
|
|
// Add custom rule for location sharing.
|
const std::string actions_model_string =
|
ReadFile(GetModelPath() + kModelFileName);
|
std::unique_ptr<ActionsModelT> actions_model =
|
UnPackActionsModel(actions_model_string.c_str());
|
ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
|
|
actions_model->rules.reset(new RulesModelT());
|
actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
|
actions_model->rules->rule.back()->pattern = "^(?i:where are you[.?]?)$";
|
actions_model->rules->rule.back()->actions.emplace_back(
|
new RulesModel_::Rule_::RuleActionSpecT);
|
actions_model->rules->rule.back()->actions.back()->action.reset(
|
new ActionSuggestionSpecT);
|
ActionSuggestionSpecT* action =
|
actions_model->rules->rule.back()->actions.back()->action.get();
|
action->score = 1.0f;
|
action->type = ActionsSuggestions::kShareLocation;
|
|
flatbuffers::FlatBufferBuilder builder;
|
FinishActionsModelBuffer(builder,
|
ActionsModel::Pack(builder, actions_model.get()));
|
actions_suggestions = ActionsSuggestions::FromUnownedBuffer(
|
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
|
builder.GetSize(), &unilib_);
|
|
response = actions_suggestions->SuggestActions(
|
{{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
|
/*reference_timezone=*/"Europe/Zurich",
|
/*annotations=*/{}, /*locales=*/"en"}}});
|
EXPECT_EQ(response.actions.size(), num_actions);
|
}
|
|
TEST_F(ActionsSuggestionsTest, DeduplicateConflictingActions) {
|
std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
|
AnnotatedSpan annotation;
|
annotation.span = {7, 11};
|
annotation.classification = {
|
ClassificationResult(Collections::Flight(), 1.0)};
|
ActionsSuggestionsResponse response = actions_suggestions->SuggestActions(
|
{{{/*user_id=*/1, "I'm on LX38",
|
/*reference_time_ms_utc=*/0,
|
/*reference_timezone=*/"Europe/Zurich",
|
/*annotations=*/{annotation},
|
/*locales=*/"en"}}});
|
|
// Check that the phone actions are present.
|
EXPECT_GE(response.actions.size(), 1);
|
EXPECT_EQ(response.actions[0].type, "track_flight");
|
|
// Add custom rule.
|
const std::string actions_model_string =
|
ReadFile(GetModelPath() + kModelFileName);
|
std::unique_ptr<ActionsModelT> actions_model =
|
UnPackActionsModel(actions_model_string.c_str());
|
ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
|
|
actions_model->rules.reset(new RulesModelT());
|
actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
|
RulesModel_::RuleT* rule = actions_model->rules->rule.back().get();
|
rule->pattern = "^(?i:I'm on ([a-z0-9]+))$";
|
rule->actions.emplace_back(new RulesModel_::Rule_::RuleActionSpecT);
|
rule->actions.back()->action.reset(new ActionSuggestionSpecT);
|
ActionSuggestionSpecT* action = rule->actions.back()->action.get();
|
action->score = 1.0f;
|
action->priority_score = 2.0f;
|
action->type = "test_code";
|
rule->actions.back()->capturing_group.emplace_back(
|
new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT);
|
RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* code_group =
|
rule->actions.back()->capturing_group.back().get();
|
code_group->group_id = 1;
|
code_group->annotation_name = "code";
|
code_group->annotation_type = "code";
|
|
flatbuffers::FlatBufferBuilder builder;
|
FinishActionsModelBuffer(builder,
|
ActionsModel::Pack(builder, actions_model.get()));
|
actions_suggestions = ActionsSuggestions::FromUnownedBuffer(
|
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
|
builder.GetSize(), &unilib_);
|
|
response = actions_suggestions->SuggestActions(
|
{{{/*user_id=*/1, "I'm on LX38",
|
/*reference_time_ms_utc=*/0,
|
/*reference_timezone=*/"Europe/Zurich",
|
/*annotations=*/{annotation},
|
/*locales=*/"en"}}});
|
EXPECT_GE(response.actions.size(), 1);
|
EXPECT_EQ(response.actions[0].type, "test_code");
|
}
|
#endif
|
|
TEST_F(ActionsSuggestionsTest, SuggestActionsRanking) {
|
std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
|
std::vector<AnnotatedSpan> annotations(2);
|
annotations[0].span = {11, 15};
|
annotations[0].classification = {ClassificationResult("address", 1.0)};
|
annotations[1].span = {19, 23};
|
annotations[1].classification = {ClassificationResult("address", 2.0)};
|
const ActionsSuggestionsResponse& response =
|
actions_suggestions->SuggestActions(
|
{{{/*user_id=*/1, "are you at home or work?",
|
/*reference_time_ms_utc=*/0,
|
/*reference_timezone=*/"Europe/Zurich",
|
/*annotations=*/annotations,
|
/*locales=*/"en"}}});
|
EXPECT_GE(response.actions.size(), 2);
|
EXPECT_EQ(response.actions[0].type, "view_map");
|
EXPECT_EQ(response.actions[0].score, 2.0);
|
EXPECT_EQ(response.actions[1].type, "view_map");
|
EXPECT_EQ(response.actions[1].score, 1.0);
|
}
|
|
TEST_F(ActionsSuggestionsTest, VisitActionsModel) {
|
EXPECT_TRUE(VisitActionsModel<bool>(GetModelPath() + kModelFileName,
|
[](const ActionsModel* model) {
|
if (model == nullptr) {
|
return false;
|
}
|
return true;
|
}));
|
EXPECT_FALSE(VisitActionsModel<bool>(GetModelPath() + "non_existing_model.fb",
|
[](const ActionsModel* model) {
|
if (model == nullptr) {
|
return false;
|
}
|
return true;
|
}));
|
}
|
|
TEST_F(ActionsSuggestionsTest, SuggestActionsWithHashGramModel) {
|
std::unique_ptr<ActionsSuggestions> actions_suggestions =
|
LoadHashGramTestModel();
|
ASSERT_TRUE(actions_suggestions != nullptr);
|
{
|
const ActionsSuggestionsResponse& response =
|
actions_suggestions->SuggestActions(
|
{{{/*user_id=*/1, "hello",
|
/*reference_time_ms_utc=*/0,
|
/*reference_timezone=*/"Europe/Zurich",
|
/*annotations=*/{},
|
/*locales=*/"en"}}});
|
EXPECT_THAT(response.actions, testing::IsEmpty());
|
}
|
{
|
const ActionsSuggestionsResponse& response =
|
actions_suggestions->SuggestActions(
|
{{{/*user_id=*/1, "where are you",
|
/*reference_time_ms_utc=*/0,
|
/*reference_timezone=*/"Europe/Zurich",
|
/*annotations=*/{},
|
/*locales=*/"en"}}});
|
EXPECT_THAT(
|
response.actions,
|
ElementsAre(testing::Field(&ActionSuggestion::type, "share_location")));
|
}
|
{
|
const ActionsSuggestionsResponse& response =
|
actions_suggestions->SuggestActions(
|
{{{/*user_id=*/1, "do you know johns number",
|
/*reference_time_ms_utc=*/0,
|
/*reference_timezone=*/"Europe/Zurich",
|
/*annotations=*/{},
|
/*locales=*/"en"}}});
|
EXPECT_THAT(
|
response.actions,
|
ElementsAre(testing::Field(&ActionSuggestion::type, "share_contact")));
|
}
|
}
|
|
// Test class to expose token embedding methods for testing.
|
class TestingMessageEmbedder : private ActionsSuggestions {
|
public:
|
explicit TestingMessageEmbedder(const ActionsModel* model);
|
|
using ActionsSuggestions::EmbedAndFlattenTokens;
|
using ActionsSuggestions::EmbedTokensPerMessage;
|
|
protected:
|
// EmbeddingExecutor that always returns features based on
|
// the id of the sparse features.
|
class FakeEmbeddingExecutor : public EmbeddingExecutor {
|
public:
|
bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
|
const int dest_size) const override {
|
TC3_CHECK_GE(dest_size, 1);
|
EXPECT_EQ(sparse_features.size(), 1);
|
dest[0] = sparse_features.data()[0];
|
return true;
|
}
|
};
|
};
|
|
TestingMessageEmbedder::TestingMessageEmbedder(const ActionsModel* model) {
|
model_ = model;
|
const ActionsTokenFeatureProcessorOptions* options =
|
model->feature_processor_options();
|
feature_processor_.reset(
|
new ActionsFeatureProcessor(options, /*unilib=*/nullptr));
|
embedding_executor_.reset(new FakeEmbeddingExecutor());
|
EXPECT_TRUE(
|
EmbedTokenId(options->padding_token_id(), &embedded_padding_token_));
|
EXPECT_TRUE(EmbedTokenId(options->start_token_id(), &embedded_start_token_));
|
EXPECT_TRUE(EmbedTokenId(options->end_token_id(), &embedded_end_token_));
|
token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize();
|
EXPECT_EQ(token_embedding_size_, 1);
|
}
|
|
class EmbeddingTest : public testing::Test {
|
protected:
|
EmbeddingTest() {
|
model_.feature_processor_options.reset(
|
new ActionsTokenFeatureProcessorOptionsT);
|
options_ = model_.feature_processor_options.get();
|
options_->chargram_orders = {1};
|
options_->num_buckets = 1000;
|
options_->embedding_size = 1;
|
options_->start_token_id = 0;
|
options_->end_token_id = 1;
|
options_->padding_token_id = 2;
|
options_->tokenizer_options.reset(new ActionsTokenizerOptionsT);
|
}
|
|
TestingMessageEmbedder CreateTestingMessageEmbedder() {
|
flatbuffers::FlatBufferBuilder builder;
|
FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, &model_));
|
buffer_ = builder.ReleaseBufferPointer();
|
return TestingMessageEmbedder(
|
flatbuffers::GetRoot<ActionsModel>(buffer_.data()));
|
}
|
|
flatbuffers::DetachedBuffer buffer_;
|
ActionsModelT model_;
|
ActionsTokenFeatureProcessorOptionsT* options_;
|
};
|
|
TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithNoBounds) {
|
const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
|
std::vector<std::vector<Token>> tokens = {
|
{Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
|
std::vector<float> embeddings;
|
int max_num_tokens_per_message = 0;
|
|
EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
|
&max_num_tokens_per_message));
|
|
EXPECT_EQ(max_num_tokens_per_message, 3);
|
EXPECT_EQ(embeddings.size(), 3);
|
EXPECT_THAT(embeddings[0],
|
testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
|
options_->num_buckets));
|
EXPECT_THAT(embeddings[1],
|
testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
|
options_->num_buckets));
|
EXPECT_THAT(embeddings[2],
|
testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
|
options_->num_buckets));
|
}
|
|
TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithPadding) {
|
options_->min_num_tokens_per_message = 5;
|
const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
|
std::vector<std::vector<Token>> tokens = {
|
{Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
|
std::vector<float> embeddings;
|
int max_num_tokens_per_message = 0;
|
|
EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
|
&max_num_tokens_per_message));
|
|
EXPECT_EQ(max_num_tokens_per_message, 5);
|
EXPECT_EQ(embeddings.size(), 5);
|
EXPECT_THAT(embeddings[0],
|
testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
|
options_->num_buckets));
|
EXPECT_THAT(embeddings[1],
|
testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
|
options_->num_buckets));
|
EXPECT_THAT(embeddings[2],
|
testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
|
options_->num_buckets));
|
EXPECT_THAT(embeddings[3], testing::FloatEq(options_->padding_token_id));
|
EXPECT_THAT(embeddings[4], testing::FloatEq(options_->padding_token_id));
|
}
|
|
TEST_F(EmbeddingTest, EmbedsTokensPerMessageDropsAtBeginning) {
|
options_->max_num_tokens_per_message = 2;
|
const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
|
std::vector<std::vector<Token>> tokens = {
|
{Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
|
std::vector<float> embeddings;
|
int max_num_tokens_per_message = 0;
|
|
EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
|
&max_num_tokens_per_message));
|
|
EXPECT_EQ(max_num_tokens_per_message, 2);
|
EXPECT_EQ(embeddings.size(), 2);
|
EXPECT_THAT(embeddings[0],
|
testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
|
options_->num_buckets));
|
EXPECT_THAT(embeddings[1],
|
testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
|
options_->num_buckets));
|
}
|
|
TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithMultipleMessagesNoBounds) {
|
const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
|
std::vector<std::vector<Token>> tokens = {
|
{Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
|
{Token("d", 0, 1), Token("e", 2, 3)}};
|
std::vector<float> embeddings;
|
int max_num_tokens_per_message = 0;
|
|
EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
|
&max_num_tokens_per_message));
|
|
EXPECT_EQ(max_num_tokens_per_message, 3);
|
EXPECT_THAT(embeddings[0],
|
testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
|
options_->num_buckets));
|
EXPECT_THAT(embeddings[1],
|
testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
|
options_->num_buckets));
|
EXPECT_THAT(embeddings[2],
|
testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
|
options_->num_buckets));
|
EXPECT_THAT(embeddings[3],
|
testing::FloatEq(tc3farmhash::Fingerprint64("d", 1) %
|
options_->num_buckets));
|
EXPECT_THAT(embeddings[4],
|
testing::FloatEq(tc3farmhash::Fingerprint64("e", 1) %
|
options_->num_buckets));
|
EXPECT_THAT(embeddings[5], testing::FloatEq(options_->padding_token_id));
|
}
|
|
TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithNoBounds) {
|
const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
|
std::vector<std::vector<Token>> tokens = {
|
{Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
|
std::vector<float> embeddings;
|
int total_token_count = 0;
|
|
EXPECT_TRUE(
|
embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
|
|
EXPECT_EQ(total_token_count, 5);
|
EXPECT_EQ(embeddings.size(), 5);
|
EXPECT_THAT(embeddings[0], testing::FloatEq(options_->start_token_id));
|
EXPECT_THAT(embeddings[1],
|
testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
|
options_->num_buckets));
|
EXPECT_THAT(embeddings[2],
|
testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
|
options_->num_buckets));
|
EXPECT_THAT(embeddings[3],
|
testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
|
options_->num_buckets));
|
EXPECT_THAT(embeddings[4], testing::FloatEq(options_->end_token_id));
|
}
|
|
TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithPadding) {
|
options_->min_num_total_tokens = 7;
|
const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
|
std::vector<std::vector<Token>> tokens = {
|
{Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
|
std::vector<float> embeddings;
|
int total_token_count = 0;
|
|
EXPECT_TRUE(
|
embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
|
|
EXPECT_EQ(total_token_count, 7);
|
EXPECT_EQ(embeddings.size(), 7);
|
EXPECT_THAT(embeddings[0], testing::FloatEq(options_->start_token_id));
|
EXPECT_THAT(embeddings[1],
|
testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
|
options_->num_buckets));
|
EXPECT_THAT(embeddings[2],
|
testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
|
options_->num_buckets));
|
EXPECT_THAT(embeddings[3],
|
testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
|
options_->num_buckets));
|
EXPECT_THAT(embeddings[4], testing::FloatEq(options_->end_token_id));
|
EXPECT_THAT(embeddings[5], testing::FloatEq(options_->padding_token_id));
|
EXPECT_THAT(embeddings[6], testing::FloatEq(options_->padding_token_id));
|
}
|
|
TEST_F(EmbeddingTest, EmbedsFlattenedTokensDropsAtBeginning) {
|
options_->max_num_total_tokens = 3;
|
const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
|
std::vector<std::vector<Token>> tokens = {
|
{Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
|
std::vector<float> embeddings;
|
int total_token_count = 0;
|
|
EXPECT_TRUE(
|
embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
|
|
EXPECT_EQ(total_token_count, 3);
|
EXPECT_EQ(embeddings.size(), 3);
|
EXPECT_THAT(embeddings[0],
|
testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
|
options_->num_buckets));
|
EXPECT_THAT(embeddings[1],
|
testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
|
options_->num_buckets));
|
EXPECT_THAT(embeddings[2], testing::FloatEq(options_->end_token_id));
|
}
|
|
TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithMultipleMessagesNoBounds) {
|
const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
|
std::vector<std::vector<Token>> tokens = {
|
{Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
|
{Token("d", 0, 1), Token("e", 2, 3)}};
|
std::vector<float> embeddings;
|
int total_token_count = 0;
|
|
EXPECT_TRUE(
|
embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
|
|
EXPECT_EQ(total_token_count, 9);
|
EXPECT_EQ(embeddings.size(), 9);
|
EXPECT_THAT(embeddings[0], testing::FloatEq(options_->start_token_id));
|
EXPECT_THAT(embeddings[1],
|
testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
|
options_->num_buckets));
|
EXPECT_THAT(embeddings[2],
|
testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
|
options_->num_buckets));
|
EXPECT_THAT(embeddings[3],
|
testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
|
options_->num_buckets));
|
EXPECT_THAT(embeddings[4], testing::FloatEq(options_->end_token_id));
|
EXPECT_THAT(embeddings[5], testing::FloatEq(options_->start_token_id));
|
EXPECT_THAT(embeddings[6],
|
testing::FloatEq(tc3farmhash::Fingerprint64("d", 1) %
|
options_->num_buckets));
|
EXPECT_THAT(embeddings[7],
|
testing::FloatEq(tc3farmhash::Fingerprint64("e", 1) %
|
options_->num_buckets));
|
EXPECT_THAT(embeddings[8], testing::FloatEq(options_->end_token_id));
|
}
|
|
TEST_F(EmbeddingTest,
|
EmbedsFlattenedTokensWithMultipleMessagesDropsAtBeginning) {
|
options_->max_num_total_tokens = 7;
|
const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
|
std::vector<std::vector<Token>> tokens = {
|
{Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
|
{Token("d", 0, 1), Token("e", 2, 3), Token("f", 4, 5)}};
|
std::vector<float> embeddings;
|
int total_token_count = 0;
|
|
EXPECT_TRUE(
|
embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
|
|
EXPECT_EQ(total_token_count, 7);
|
EXPECT_EQ(embeddings.size(), 7);
|
EXPECT_THAT(embeddings[0],
|
testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
|
options_->num_buckets));
|
EXPECT_THAT(embeddings[1], testing::FloatEq(options_->end_token_id));
|
EXPECT_THAT(embeddings[2], testing::FloatEq(options_->start_token_id));
|
EXPECT_THAT(embeddings[3],
|
testing::FloatEq(tc3farmhash::Fingerprint64("d", 1) %
|
options_->num_buckets));
|
EXPECT_THAT(embeddings[4],
|
testing::FloatEq(tc3farmhash::Fingerprint64("e", 1) %
|
options_->num_buckets));
|
EXPECT_THAT(embeddings[5],
|
testing::FloatEq(tc3farmhash::Fingerprint64("f", 1) %
|
options_->num_buckets));
|
EXPECT_THAT(embeddings[6], testing::FloatEq(options_->end_token_id));
|
}
|
|
} // namespace
|
} // namespace libtextclassifier3
|