/*
|
* Copyright (C) 2018 The Android Open Source Project
|
*
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
*/
|
|
// JNI wrapper for actions.
|
|
#include "actions/actions_jni.h"
|
|
#include <jni.h>
|
#include <map>
|
#include <type_traits>
|
#include <vector>
|
|
#include "actions/actions-suggestions.h"
|
#include "annotator/annotator.h"
|
#include "annotator/annotator_jni_common.h"
|
#include "utils/base/integral_types.h"
|
#include "utils/intents/intent-generator.h"
|
#include "utils/intents/jni.h"
|
#include "utils/java/jni-cache.h"
|
#include "utils/java/scoped_local_ref.h"
|
#include "utils/java/string_utils.h"
|
#include "utils/memory/mmap.h"
|
|
using libtextclassifier3::ActionsSuggestions;
|
using libtextclassifier3::ActionsSuggestionsResponse;
|
using libtextclassifier3::ActionSuggestion;
|
using libtextclassifier3::ActionSuggestionOptions;
|
using libtextclassifier3::Annotator;
|
using libtextclassifier3::Conversation;
|
using libtextclassifier3::IntentGenerator;
|
using libtextclassifier3::ScopedLocalRef;
|
using libtextclassifier3::ToStlString;
|
|
// When using the Java's ICU, UniLib needs to be instantiated with a JavaVM
|
// pointer from JNI. When using a standard ICU the pointer is not needed and the
|
// objects are instantiated implicitly.
|
#ifdef TC3_UNILIB_JAVAICU
|
using libtextclassifier3::UniLib;
|
#endif
|
|
namespace libtextclassifier3 {
|
|
namespace {
|
|
// Cached state for model inference.
|
// Keeps a jni cache, intent generator and model instance so that they don't
|
// have to be recreated for each call.
|
class ActionsSuggestionsJniContext {
|
public:
|
static ActionsSuggestionsJniContext* Create(
|
const std::shared_ptr<libtextclassifier3::JniCache>& jni_cache,
|
std::unique_ptr<ActionsSuggestions> model) {
|
if (jni_cache == nullptr || model == nullptr) {
|
return nullptr;
|
}
|
std::unique_ptr<IntentGenerator> intent_generator =
|
IntentGenerator::Create(model->model()->android_intent_options(),
|
model->model()->resources(), jni_cache);
|
std::unique_ptr<RemoteActionTemplatesHandler> template_handler =
|
libtextclassifier3::RemoteActionTemplatesHandler::Create(jni_cache);
|
|
if (intent_generator == nullptr || template_handler == nullptr) {
|
return nullptr;
|
}
|
|
return new ActionsSuggestionsJniContext(jni_cache, std::move(model),
|
std::move(intent_generator),
|
std::move(template_handler));
|
}
|
|
std::shared_ptr<libtextclassifier3::JniCache> jni_cache() const {
|
return jni_cache_;
|
}
|
|
ActionsSuggestions* model() const { return model_.get(); }
|
|
IntentGenerator* intent_generator() const { return intent_generator_.get(); }
|
|
RemoteActionTemplatesHandler* template_handler() const {
|
return template_handler_.get();
|
}
|
|
private:
|
ActionsSuggestionsJniContext(
|
const std::shared_ptr<libtextclassifier3::JniCache>& jni_cache,
|
std::unique_ptr<ActionsSuggestions> model,
|
std::unique_ptr<IntentGenerator> intent_generator,
|
std::unique_ptr<RemoteActionTemplatesHandler> template_handler)
|
: jni_cache_(jni_cache),
|
model_(std::move(model)),
|
intent_generator_(std::move(intent_generator)),
|
template_handler_(std::move(template_handler)) {}
|
|
std::shared_ptr<libtextclassifier3::JniCache> jni_cache_;
|
std::unique_ptr<ActionsSuggestions> model_;
|
std::unique_ptr<IntentGenerator> intent_generator_;
|
std::unique_ptr<RemoteActionTemplatesHandler> template_handler_;
|
};
|
|
ActionSuggestionOptions FromJavaActionSuggestionOptions(JNIEnv* env,
|
jobject joptions) {
|
ActionSuggestionOptions options = ActionSuggestionOptions::Default();
|
return options;
|
}
|
|
jobjectArray ActionSuggestionsToJObjectArray(
|
JNIEnv* env, const ActionsSuggestionsJniContext* context,
|
jobject app_context,
|
const reflection::Schema* annotations_entity_data_schema,
|
const std::vector<ActionSuggestion>& action_result,
|
const Conversation& conversation, const jstring device_locales,
|
const bool generate_intents) {
|
const ScopedLocalRef<jclass> result_class(
|
env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
|
"$ActionSuggestion"),
|
env);
|
if (!result_class) {
|
TC3_LOG(ERROR) << "Couldn't find ActionSuggestion class.";
|
return nullptr;
|
}
|
|
const jmethodID result_class_constructor = env->GetMethodID(
|
result_class.get(), "<init>",
|
"(Ljava/lang/String;Ljava/lang/String;F[L" TC3_PACKAGE_PATH
|
TC3_NAMED_VARIANT_CLASS_NAME_STR
|
";[B[L" TC3_PACKAGE_PATH TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR ";)V");
|
const jobjectArray results =
|
env->NewObjectArray(action_result.size(), result_class.get(), nullptr);
|
for (int i = 0; i < action_result.size(); i++) {
|
jobject extras = nullptr;
|
|
const reflection::Schema* actions_entity_data_schema =
|
context->model()->entity_data_schema();
|
if (actions_entity_data_schema != nullptr &&
|
!action_result[i].serialized_entity_data.empty()) {
|
extras = context->template_handler()->EntityDataAsNamedVariantArray(
|
actions_entity_data_schema, action_result[i].serialized_entity_data);
|
}
|
|
jbyteArray serialized_entity_data = nullptr;
|
if (!action_result[i].serialized_entity_data.empty()) {
|
serialized_entity_data =
|
env->NewByteArray(action_result[i].serialized_entity_data.size());
|
env->SetByteArrayRegion(
|
serialized_entity_data, 0,
|
action_result[i].serialized_entity_data.size(),
|
reinterpret_cast<const jbyte*>(
|
action_result[i].serialized_entity_data.data()));
|
}
|
|
jobject remote_action_templates_result = nullptr;
|
if (generate_intents) {
|
std::vector<RemoteActionTemplate> remote_action_templates;
|
if (context->intent_generator()->GenerateIntents(
|
device_locales, action_result[i], conversation, app_context,
|
actions_entity_data_schema, annotations_entity_data_schema,
|
&remote_action_templates)) {
|
remote_action_templates_result =
|
context->template_handler()->RemoteActionTemplatesToJObjectArray(
|
remote_action_templates);
|
}
|
}
|
|
ScopedLocalRef<jstring> reply = context->jni_cache()->ConvertToJavaString(
|
action_result[i].response_text);
|
|
ScopedLocalRef<jobject> result(env->NewObject(
|
result_class.get(), result_class_constructor, reply.get(),
|
env->NewStringUTF(action_result[i].type.c_str()),
|
static_cast<jfloat>(action_result[i].score), extras,
|
serialized_entity_data, remote_action_templates_result));
|
env->SetObjectArrayElement(results, i, result.get());
|
}
|
return results;
|
}
|
|
ConversationMessage FromJavaConversationMessage(JNIEnv* env, jobject jmessage) {
|
if (!jmessage) {
|
return {};
|
}
|
|
const ScopedLocalRef<jclass> message_class(
|
env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
|
"$ConversationMessage"),
|
env);
|
const std::pair<bool, jobject> status_or_text = CallJniMethod0<jobject>(
|
env, jmessage, message_class.get(), &JNIEnv::CallObjectMethod, "getText",
|
"Ljava/lang/String;");
|
const std::pair<bool, int32> status_or_user_id =
|
CallJniMethod0<int32>(env, jmessage, message_class.get(),
|
&JNIEnv::CallIntMethod, "getUserId", "I");
|
const std::pair<bool, int64> status_or_reference_time = CallJniMethod0<int64>(
|
env, jmessage, message_class.get(), &JNIEnv::CallLongMethod,
|
"getReferenceTimeMsUtc", "J");
|
const std::pair<bool, jobject> status_or_reference_timezone =
|
CallJniMethod0<jobject>(env, jmessage, message_class.get(),
|
&JNIEnv::CallObjectMethod, "getReferenceTimezone",
|
"Ljava/lang/String;");
|
const std::pair<bool, jobject> status_or_detected_text_language_tags =
|
CallJniMethod0<jobject>(
|
env, jmessage, message_class.get(), &JNIEnv::CallObjectMethod,
|
"getDetectedTextLanguageTags", "Ljava/lang/String;");
|
if (!status_or_text.first || !status_or_user_id.first ||
|
!status_or_detected_text_language_tags.first ||
|
!status_or_reference_time.first || !status_or_reference_timezone.first) {
|
return {};
|
}
|
|
ConversationMessage message;
|
message.text = ToStlString(env, static_cast<jstring>(status_or_text.second));
|
message.user_id = status_or_user_id.second;
|
message.reference_time_ms_utc = status_or_reference_time.second;
|
message.reference_timezone = ToStlString(
|
env, static_cast<jstring>(status_or_reference_timezone.second));
|
message.detected_text_language_tags = ToStlString(
|
env, static_cast<jstring>(status_or_detected_text_language_tags.second));
|
return message;
|
}
|
|
Conversation FromJavaConversation(JNIEnv* env, jobject jconversation) {
|
if (!jconversation) {
|
return {};
|
}
|
|
const ScopedLocalRef<jclass> conversation_class(
|
env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
|
"$Conversation"),
|
env);
|
|
const std::pair<bool, jobject> status_or_messages = CallJniMethod0<jobject>(
|
env, jconversation, conversation_class.get(), &JNIEnv::CallObjectMethod,
|
"getConversationMessages",
|
"[L" TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR "$ConversationMessage;");
|
|
if (!status_or_messages.first) {
|
return {};
|
}
|
|
const jobjectArray jmessages =
|
reinterpret_cast<jobjectArray>(status_or_messages.second);
|
|
const int size = env->GetArrayLength(jmessages);
|
|
std::vector<ConversationMessage> messages;
|
for (int i = 0; i < size; i++) {
|
jobject jmessage = env->GetObjectArrayElement(jmessages, i);
|
ConversationMessage message = FromJavaConversationMessage(env, jmessage);
|
messages.push_back(message);
|
}
|
Conversation conversation;
|
conversation.messages = messages;
|
return conversation;
|
}
|
|
jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
|
if (!mmap->handle().ok()) {
|
return env->NewStringUTF("");
|
}
|
const ActionsModel* model = libtextclassifier3::ViewActionsModel(
|
mmap->handle().start(), mmap->handle().num_bytes());
|
if (!model || !model->locales()) {
|
return env->NewStringUTF("");
|
}
|
return env->NewStringUTF(model->locales()->c_str());
|
}
|
|
jint GetVersionFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
|
if (!mmap->handle().ok()) {
|
return 0;
|
}
|
const ActionsModel* model = libtextclassifier3::ViewActionsModel(
|
mmap->handle().start(), mmap->handle().num_bytes());
|
if (!model) {
|
return 0;
|
}
|
return model->version();
|
}
|
|
jstring GetNameFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
|
if (!mmap->handle().ok()) {
|
return env->NewStringUTF("");
|
}
|
const ActionsModel* model = libtextclassifier3::ViewActionsModel(
|
mmap->handle().start(), mmap->handle().num_bytes());
|
if (!model || !model->name()) {
|
return env->NewStringUTF("");
|
}
|
return env->NewStringUTF(model->name()->c_str());
|
}
|
} // namespace
|
} // namespace libtextclassifier3
|
|
using libtextclassifier3::ActionsSuggestionsJniContext;
|
using libtextclassifier3::ActionSuggestionsToJObjectArray;
|
using libtextclassifier3::FromJavaActionSuggestionOptions;
|
using libtextclassifier3::FromJavaConversation;
|
|
TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModel)
|
(JNIEnv* env, jobject thiz, jint fd, jbyteArray serialized_preconditions) {
|
std::shared_ptr<libtextclassifier3::JniCache> jni_cache =
|
libtextclassifier3::JniCache::Create(env);
|
std::string preconditions;
|
if (serialized_preconditions != nullptr &&
|
!libtextclassifier3::JByteArrayToString(env, serialized_preconditions,
|
&preconditions)) {
|
TC3_LOG(ERROR) << "Could not convert serialized preconditions.";
|
return 0;
|
}
|
#ifdef TC3_UNILIB_JAVAICU
|
return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
|
jni_cache,
|
ActionsSuggestions::FromFileDescriptor(
|
fd, std::unique_ptr<UniLib>(new UniLib(jni_cache)), preconditions)));
|
#else
|
return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
|
jni_cache, ActionsSuggestions::FromFileDescriptor(fd, /*unilib=*/nullptr,
|
preconditions)));
|
#endif // TC3_UNILIB_JAVAICU
|
}
|
|
TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelFromPath)
|
(JNIEnv* env, jobject thiz, jstring path, jbyteArray serialized_preconditions) {
|
std::shared_ptr<libtextclassifier3::JniCache> jni_cache =
|
libtextclassifier3::JniCache::Create(env);
|
const std::string path_str = ToStlString(env, path);
|
std::string preconditions;
|
if (serialized_preconditions != nullptr &&
|
!libtextclassifier3::JByteArrayToString(env, serialized_preconditions,
|
&preconditions)) {
|
TC3_LOG(ERROR) << "Could not convert serialized preconditions.";
|
return 0;
|
}
|
#ifdef TC3_UNILIB_JAVAICU
|
return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
|
jni_cache, ActionsSuggestions::FromPath(
|
path_str, std::unique_ptr<UniLib>(new UniLib(jni_cache)),
|
preconditions)));
|
#else
|
return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
|
jni_cache, ActionsSuggestions::FromPath(path_str, /*unilib=*/nullptr,
|
preconditions)));
|
#endif // TC3_UNILIB_JAVAICU
|
}
|
|
TC3_JNI_METHOD(jobjectArray, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions)
|
(JNIEnv* env, jobject clazz, jlong ptr, jobject jconversation, jobject joptions,
|
jlong annotatorPtr, jobject app_context, jstring device_locales,
|
jboolean generate_intents) {
|
if (!ptr) {
|
return nullptr;
|
}
|
const Conversation conversation = FromJavaConversation(env, jconversation);
|
const ActionSuggestionOptions options =
|
FromJavaActionSuggestionOptions(env, joptions);
|
const ActionsSuggestionsJniContext* context =
|
reinterpret_cast<ActionsSuggestionsJniContext*>(ptr);
|
const Annotator* annotator = reinterpret_cast<Annotator*>(annotatorPtr);
|
|
const ActionsSuggestionsResponse response =
|
context->model()->SuggestActions(conversation, annotator, options);
|
|
const reflection::Schema* anntotations_entity_data_schema =
|
annotator ? annotator->entity_data_schema() : nullptr;
|
return ActionSuggestionsToJObjectArray(
|
env, context, app_context, anntotations_entity_data_schema,
|
response.actions, conversation, device_locales, generate_intents);
|
}
|
|
TC3_JNI_METHOD(void, TC3_ACTIONS_CLASS_NAME, nativeCloseActionsModel)
|
(JNIEnv* env, jobject clazz, jlong model_ptr) {
|
const ActionsSuggestionsJniContext* context =
|
reinterpret_cast<ActionsSuggestionsJniContext*>(model_ptr);
|
delete context;
|
}
|
|
TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetLocales)
|
(JNIEnv* env, jobject clazz, jint fd) {
|
const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
|
new libtextclassifier3::ScopedMmap(fd));
|
return libtextclassifier3::GetLocalesFromMmap(env, mmap.get());
|
}
|
|
TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetName)
|
(JNIEnv* env, jobject clazz, jint fd) {
|
const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
|
new libtextclassifier3::ScopedMmap(fd));
|
return libtextclassifier3::GetNameFromMmap(env, mmap.get());
|
}
|
|
TC3_JNI_METHOD(jint, TC3_ACTIONS_CLASS_NAME, nativeGetVersion)
|
(JNIEnv* env, jobject clazz, jint fd) {
|
const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
|
new libtextclassifier3::ScopedMmap(fd));
|
return libtextclassifier3::GetVersionFromMmap(env, mmap.get());
|
}
|