/*
|
* Copyright (C) 2018 The Android Open Source Project
|
*
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
*/
|
|
#include "actions/lua-utils.h"
|
|
namespace libtextclassifier3 {
|
namespace {
|
static constexpr const char* kTextKey = "text";
|
static constexpr const char* kTimeUsecKey = "parsed_time_ms_utc";
|
static constexpr const char* kGranularityKey = "granularity";
|
static constexpr const char* kCollectionKey = "collection";
|
static constexpr const char* kNameKey = "name";
|
static constexpr const char* kScoreKey = "score";
|
static constexpr const char* kPriorityScoreKey = "priority_score";
|
static constexpr const char* kTypeKey = "type";
|
static constexpr const char* kResponseTextKey = "response_text";
|
static constexpr const char* kAnnotationKey = "annotation";
|
static constexpr const char* kSpanKey = "span";
|
static constexpr const char* kMessageKey = "message";
|
static constexpr const char* kBeginKey = "begin";
|
static constexpr const char* kEndKey = "end";
|
static constexpr const char* kClassificationKey = "classification";
|
static constexpr const char* kSerializedEntity = "serialized_entity";
|
static constexpr const char* kEntityKey = "entity";
|
} // namespace
|
|
template <>
|
int AnnotationIterator<ClassificationResult>::Item(
|
const std::vector<ClassificationResult>* annotations, StringPiece key,
|
lua_State* state) const {
|
// Lookup annotation by collection.
|
for (const ClassificationResult& annotation : *annotations) {
|
if (key.Equals(annotation.collection)) {
|
PushAnnotation(annotation, entity_data_schema_, env_);
|
return 1;
|
}
|
}
|
TC3_LOG(ERROR) << "No annotation with collection: " << key.ToString()
|
<< " found.";
|
lua_error(state);
|
return 0;
|
}
|
|
template <>
|
int AnnotationIterator<ActionSuggestionAnnotation>::Item(
|
const std::vector<ActionSuggestionAnnotation>* annotations, StringPiece key,
|
lua_State* state) const {
|
// Lookup annotation by name.
|
for (const ActionSuggestionAnnotation& annotation : *annotations) {
|
if (key.Equals(annotation.name)) {
|
PushAnnotation(annotation, entity_data_schema_, env_);
|
return 1;
|
}
|
}
|
TC3_LOG(ERROR) << "No annotation with name: " << key.ToString() << " found.";
|
lua_error(state);
|
return 0;
|
}
|
|
void PushAnnotation(const ClassificationResult& classification,
|
const reflection::Schema* entity_data_schema,
|
LuaEnvironment* env) {
|
if (entity_data_schema == nullptr ||
|
classification.serialized_entity_data.empty()) {
|
// Empty table.
|
lua_newtable(env->state());
|
} else {
|
env->PushFlatbuffer(entity_data_schema,
|
flatbuffers::GetRoot<flatbuffers::Table>(
|
classification.serialized_entity_data.data()));
|
}
|
lua_pushinteger(env->state(),
|
classification.datetime_parse_result.time_ms_utc);
|
lua_setfield(env->state(), /*idx=*/-2, kTimeUsecKey);
|
lua_pushinteger(env->state(),
|
classification.datetime_parse_result.granularity);
|
lua_setfield(env->state(), /*idx=*/-2, kGranularityKey);
|
env->PushString(classification.collection);
|
lua_setfield(env->state(), /*idx=*/-2, kCollectionKey);
|
lua_pushnumber(env->state(), classification.score);
|
lua_setfield(env->state(), /*idx=*/-2, kScoreKey);
|
env->PushString(classification.serialized_entity_data);
|
lua_setfield(env->state(), /*idx=*/-2, kSerializedEntity);
|
}
|
|
void PushAnnotation(const ClassificationResult& classification,
|
StringPiece text,
|
const reflection::Schema* entity_data_schema,
|
LuaEnvironment* env) {
|
PushAnnotation(classification, entity_data_schema, env);
|
env->PushString(text);
|
lua_setfield(env->state(), /*idx=*/-2, kTextKey);
|
}
|
|
void PushAnnotatedSpan(
|
const AnnotatedSpan& annotated_span,
|
const AnnotationIterator<ClassificationResult>& annotation_iterator,
|
LuaEnvironment* env) {
|
lua_newtable(env->state());
|
{
|
lua_newtable(env->state());
|
lua_pushinteger(env->state(), annotated_span.span.first);
|
lua_setfield(env->state(), /*idx=*/-2, kBeginKey);
|
lua_pushinteger(env->state(), annotated_span.span.second);
|
lua_setfield(env->state(), /*idx=*/-2, kEndKey);
|
}
|
lua_setfield(env->state(), /*idx=*/-2, kSpanKey);
|
annotation_iterator.NewIterator(kClassificationKey,
|
&annotated_span.classification, env->state());
|
lua_setfield(env->state(), /*idx=*/-2, kClassificationKey);
|
}
|
|
MessageTextSpan ReadSpan(LuaEnvironment* env) {
|
MessageTextSpan span;
|
lua_pushnil(env->state());
|
while (lua_next(env->state(), /*idx=*/-2)) {
|
const StringPiece key = env->ReadString(/*index=*/-2);
|
if (key.Equals(kMessageKey)) {
|
span.message_index =
|
static_cast<int>(lua_tonumber(env->state(), /*idx=*/-1));
|
} else if (key.Equals(kBeginKey)) {
|
span.span.first =
|
static_cast<int>(lua_tonumber(env->state(), /*idx=*/-1));
|
} else if (key.Equals(kEndKey)) {
|
span.span.second =
|
static_cast<int>(lua_tonumber(env->state(), /*idx=*/-1));
|
} else if (key.Equals(kTextKey)) {
|
span.text = env->ReadString(/*index=*/-1).ToString();
|
} else {
|
TC3_LOG(INFO) << "Unknown span field: " << key.ToString();
|
}
|
lua_pop(env->state(), 1);
|
}
|
return span;
|
}
|
|
int ReadAnnotations(const reflection::Schema* entity_data_schema,
|
LuaEnvironment* env,
|
std::vector<ActionSuggestionAnnotation>* annotations) {
|
if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) {
|
TC3_LOG(ERROR) << "Expected annotations table, got: "
|
<< lua_type(env->state(), /*idx=*/-1);
|
lua_pop(env->state(), 1);
|
lua_error(env->state());
|
return LUA_ERRRUN;
|
}
|
|
// Read actions.
|
lua_pushnil(env->state());
|
while (lua_next(env->state(), /*idx=*/-2)) {
|
if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) {
|
TC3_LOG(ERROR) << "Expected annotation table, got: "
|
<< lua_type(env->state(), /*idx=*/-1);
|
lua_pop(env->state(), 1);
|
continue;
|
}
|
annotations->push_back(ReadAnnotation(entity_data_schema, env));
|
lua_pop(env->state(), 1);
|
}
|
return LUA_OK;
|
}
|
|
ActionSuggestionAnnotation ReadAnnotation(
|
const reflection::Schema* entity_data_schema, LuaEnvironment* env) {
|
ActionSuggestionAnnotation annotation;
|
lua_pushnil(env->state());
|
while (lua_next(env->state(), /*idx=*/-2)) {
|
const StringPiece key = env->ReadString(/*index=*/-2);
|
if (key.Equals(kNameKey)) {
|
annotation.name = env->ReadString(/*index=*/-1).ToString();
|
} else if (key.Equals(kSpanKey)) {
|
annotation.span = ReadSpan(env);
|
} else if (key.Equals(kEntityKey)) {
|
annotation.entity = ReadClassificationResult(entity_data_schema, env);
|
} else {
|
TC3_LOG(ERROR) << "Unknown annotation field: " << key.ToString();
|
}
|
lua_pop(env->state(), 1);
|
}
|
return annotation;
|
}
|
|
ClassificationResult ReadClassificationResult(
|
const reflection::Schema* entity_data_schema, LuaEnvironment* env) {
|
ClassificationResult classification;
|
lua_pushnil(env->state());
|
while (lua_next(env->state(), /*idx=*/-2)) {
|
const StringPiece key = env->ReadString(/*index=*/-2);
|
if (key.Equals(kCollectionKey)) {
|
classification.collection = env->ReadString(/*index=*/-1).ToString();
|
} else if (key.Equals(kScoreKey)) {
|
classification.score =
|
static_cast<float>(lua_tonumber(env->state(), /*idx=*/-1));
|
} else if (key.Equals(kTimeUsecKey)) {
|
classification.datetime_parse_result.time_ms_utc =
|
static_cast<int64>(lua_tonumber(env->state(), /*idx=*/-1));
|
} else if (key.Equals(kGranularityKey)) {
|
classification.datetime_parse_result.granularity =
|
static_cast<DatetimeGranularity>(
|
lua_tonumber(env->state(), /*idx=*/-1));
|
} else if (key.Equals(kSerializedEntity)) {
|
classification.serialized_entity_data =
|
env->ReadString(/*index=*/-1).ToString();
|
} else if (key.Equals(kEntityKey)) {
|
auto buffer = ReflectiveFlatbufferBuilder(entity_data_schema).NewRoot();
|
env->ReadFlatbuffer(buffer.get());
|
classification.serialized_entity_data = buffer->Serialize();
|
} else {
|
TC3_LOG(INFO) << "Unknown classification result field: "
|
<< key.ToString();
|
}
|
lua_pop(env->state(), 1);
|
}
|
return classification;
|
}
|
|
void PushAnnotation(const ActionSuggestionAnnotation& annotation,
|
const reflection::Schema* entity_data_schema,
|
LuaEnvironment* env) {
|
PushAnnotation(annotation.entity, annotation.span.text, entity_data_schema,
|
env);
|
env->PushString(annotation.name);
|
lua_setfield(env->state(), /*idx=*/-2, kNameKey);
|
{
|
lua_newtable(env->state());
|
lua_pushinteger(env->state(), annotation.span.message_index);
|
lua_setfield(env->state(), /*idx=*/-2, kMessageKey);
|
lua_pushinteger(env->state(), annotation.span.span.first);
|
lua_setfield(env->state(), /*idx=*/-2, kBeginKey);
|
lua_pushinteger(env->state(), annotation.span.span.second);
|
lua_setfield(env->state(), /*idx=*/-2, kEndKey);
|
}
|
lua_setfield(env->state(), /*idx=*/-2, kSpanKey);
|
}
|
|
void PushAction(
|
const ActionSuggestion& action,
|
const reflection::Schema* entity_data_schema,
|
const AnnotationIterator<ActionSuggestionAnnotation>& annotation_iterator,
|
LuaEnvironment* env) {
|
if (entity_data_schema == nullptr || action.serialized_entity_data.empty()) {
|
// Empty table.
|
lua_newtable(env->state());
|
} else {
|
env->PushFlatbuffer(entity_data_schema,
|
flatbuffers::GetRoot<flatbuffers::Table>(
|
action.serialized_entity_data.data()));
|
}
|
env->PushString(action.type);
|
lua_setfield(env->state(), /*idx=*/-2, kTypeKey);
|
env->PushString(action.response_text);
|
lua_setfield(env->state(), /*idx=*/-2, kResponseTextKey);
|
lua_pushnumber(env->state(), action.score);
|
lua_setfield(env->state(), /*idx=*/-2, kScoreKey);
|
lua_pushnumber(env->state(), action.priority_score);
|
lua_setfield(env->state(), /*idx=*/-2, kPriorityScoreKey);
|
annotation_iterator.NewIterator(kAnnotationKey, &action.annotations,
|
env->state());
|
lua_setfield(env->state(), /*idx=*/-2, kAnnotationKey);
|
}
|
|
ActionSuggestion ReadAction(
|
const reflection::Schema* actions_entity_data_schema,
|
const reflection::Schema* annotations_entity_data_schema,
|
LuaEnvironment* env) {
|
ActionSuggestion action;
|
lua_pushnil(env->state());
|
while (lua_next(env->state(), /*idx=*/-2)) {
|
const StringPiece key = env->ReadString(/*index=*/-2);
|
if (key.Equals(kResponseTextKey)) {
|
action.response_text = env->ReadString(/*index=*/-1).ToString();
|
} else if (key.Equals(kTypeKey)) {
|
action.type = env->ReadString(/*index=*/-1).ToString();
|
} else if (key.Equals(kScoreKey)) {
|
action.score = static_cast<float>(lua_tonumber(env->state(), /*idx=*/-1));
|
} else if (key.Equals(kPriorityScoreKey)) {
|
action.priority_score =
|
static_cast<float>(lua_tonumber(env->state(), /*idx=*/-1));
|
} else if (key.Equals(kAnnotationKey)) {
|
ReadAnnotations(actions_entity_data_schema, env, &action.annotations);
|
} else if (key.Equals(kEntityKey)) {
|
auto buffer =
|
ReflectiveFlatbufferBuilder(actions_entity_data_schema).NewRoot();
|
env->ReadFlatbuffer(buffer.get());
|
action.serialized_entity_data = buffer->Serialize();
|
} else {
|
TC3_LOG(INFO) << "Unknown action field: " << key.ToString();
|
}
|
lua_pop(env->state(), 1);
|
}
|
return action;
|
}
|
|
int ReadActions(const reflection::Schema* actions_entity_data_schema,
|
const reflection::Schema* annotations_entity_data_schema,
|
LuaEnvironment* env, std::vector<ActionSuggestion>* actions) {
|
if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) {
|
TC3_LOG(ERROR) << "Expected actions table, got: "
|
<< lua_type(env->state(), /*idx=*/-1);
|
lua_pop(env->state(), 1);
|
lua_error(env->state());
|
return LUA_ERRRUN;
|
}
|
|
// Read actions.
|
lua_pushnil(env->state());
|
while (lua_next(env->state(), /*idx=*/-2)) {
|
if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) {
|
TC3_LOG(ERROR) << "Expected action table, got: "
|
<< lua_type(env->state(), /*idx=*/-1);
|
lua_pop(env->state(), 1);
|
continue;
|
}
|
actions->push_back(ReadAction(actions_entity_data_schema,
|
annotations_entity_data_schema, env));
|
lua_pop(env->state(), /*n=1*/ 1);
|
}
|
lua_pop(env->state(), /*n=*/1);
|
|
return LUA_OK;
|
}
|
|
int ConversationIterator::Item(const std::vector<ConversationMessage>* messages,
|
const int64 pos, lua_State* state) const {
|
const ConversationMessage& message = (*messages)[pos];
|
lua_newtable(state);
|
lua_pushinteger(state, message.user_id);
|
lua_setfield(state, /*idx=*/-2, "user_id");
|
env_->PushString(message.text);
|
lua_setfield(state, /*idx=*/-2, "text");
|
lua_pushinteger(state, message.reference_time_ms_utc);
|
lua_setfield(state, /*idx=*/-2, "time_ms_utc");
|
env_->PushString(message.reference_timezone);
|
lua_setfield(state, /*idx=*/-2, "timezone");
|
annotated_span_iterator_.NewIterator("annotation", &message.annotations,
|
state);
|
lua_setfield(state, /*idx=*/-2, "annotation");
|
return 1;
|
}
|
|
} // namespace libtextclassifier3
|