/*
|
* Copyright (C) 2018 The Android Open Source Project
|
*
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
*/
|
|
#include "utils/flatbuffers.h"
|
|
#include <vector>
|
#include "utils/strings/numbers.h"
|
#include "utils/variant.h"
|
|
namespace libtextclassifier3 {
|
namespace {
|
bool CreateRepeatedField(
|
const reflection::Schema* schema, const reflection::Type* type,
|
std::unique_ptr<ReflectiveFlatbuffer::RepeatedField>* repeated_field) {
|
switch (type->element()) {
|
case reflection::Bool:
|
repeated_field->reset(new ReflectiveFlatbuffer::TypedRepeatedField<bool>);
|
return true;
|
case reflection::Int:
|
repeated_field->reset(new ReflectiveFlatbuffer::TypedRepeatedField<int>);
|
return true;
|
case reflection::Long:
|
repeated_field->reset(
|
new ReflectiveFlatbuffer::TypedRepeatedField<int64>);
|
return true;
|
case reflection::Float:
|
repeated_field->reset(
|
new ReflectiveFlatbuffer::TypedRepeatedField<float>);
|
return true;
|
case reflection::Double:
|
repeated_field->reset(
|
new ReflectiveFlatbuffer::TypedRepeatedField<double>);
|
return true;
|
case reflection::String:
|
repeated_field->reset(
|
new ReflectiveFlatbuffer::TypedRepeatedField<std::string>);
|
return true;
|
case reflection::Obj:
|
repeated_field->reset(
|
new ReflectiveFlatbuffer::TypedRepeatedField<ReflectiveFlatbuffer>(
|
schema, type));
|
return true;
|
default:
|
TC3_LOG(ERROR) << "Unsupported type: " << type->element();
|
return false;
|
}
|
}
|
} // namespace
|
|
template <>
|
const char* FlatbufferFileIdentifier<Model>() {
|
return ModelIdentifier();
|
}
|
|
std::unique_ptr<ReflectiveFlatbuffer> ReflectiveFlatbufferBuilder::NewRoot()
|
const {
|
if (!schema_->root_table()) {
|
TC3_LOG(ERROR) << "No root table specified.";
|
return nullptr;
|
}
|
return std::unique_ptr<ReflectiveFlatbuffer>(
|
new ReflectiveFlatbuffer(schema_, schema_->root_table()));
|
}
|
|
std::unique_ptr<ReflectiveFlatbuffer> ReflectiveFlatbufferBuilder::NewTable(
|
StringPiece table_name) const {
|
for (const reflection::Object* object : *schema_->objects()) {
|
if (table_name.Equals(object->name()->str())) {
|
return std::unique_ptr<ReflectiveFlatbuffer>(
|
new ReflectiveFlatbuffer(schema_, object));
|
}
|
}
|
return nullptr;
|
}
|
|
const reflection::Field* ReflectiveFlatbuffer::GetFieldOrNull(
|
const StringPiece field_name) const {
|
return type_->fields()->LookupByKey(field_name.data());
|
}
|
|
const reflection::Field* ReflectiveFlatbuffer::GetFieldOrNull(
|
const FlatbufferField* field) const {
|
// Lookup by name might be faster as the fields are sorted by name in the
|
// schema data, so try that first.
|
if (field->field_name() != nullptr) {
|
return GetFieldOrNull(field->field_name()->str());
|
}
|
return GetFieldByOffsetOrNull(field->field_offset());
|
}
|
|
bool ReflectiveFlatbuffer::GetFieldWithParent(
|
const FlatbufferFieldPath* field_path, ReflectiveFlatbuffer** parent,
|
reflection::Field const** field) {
|
const auto* path = field_path->field();
|
if (path == nullptr || path->size() == 0) {
|
return false;
|
}
|
|
for (int i = 0; i < path->size(); i++) {
|
*parent = (i == 0 ? this : (*parent)->Mutable(*field));
|
if (*parent == nullptr) {
|
return false;
|
}
|
*field = (*parent)->GetFieldOrNull(path->Get(i));
|
if (*field == nullptr) {
|
return false;
|
}
|
}
|
|
return true;
|
}
|
|
const reflection::Field* ReflectiveFlatbuffer::GetFieldByOffsetOrNull(
|
const int field_offset) const {
|
if (type_->fields() == nullptr) {
|
return nullptr;
|
}
|
for (const reflection::Field* field : *type_->fields()) {
|
if (field->offset() == field_offset) {
|
return field;
|
}
|
}
|
return nullptr;
|
}
|
|
bool ReflectiveFlatbuffer::IsMatchingType(const reflection::Field* field,
|
const Variant& value) const {
|
switch (field->type()->base_type()) {
|
case reflection::Bool:
|
return value.HasBool();
|
case reflection::Int:
|
return value.HasInt();
|
case reflection::Long:
|
return value.HasInt64();
|
case reflection::Float:
|
return value.HasFloat();
|
case reflection::Double:
|
return value.HasDouble();
|
case reflection::String:
|
return value.HasString();
|
default:
|
return false;
|
}
|
}
|
|
bool ReflectiveFlatbuffer::ParseAndSet(const reflection::Field* field,
|
const std::string& value) {
|
switch (field->type()->base_type()) {
|
case reflection::String:
|
return Set(field, value);
|
case reflection::Int: {
|
int32 int_value;
|
if (!ParseInt32(value.data(), &int_value)) {
|
TC3_LOG(ERROR) << "Could not parse '" << value << "' as int32.";
|
return false;
|
}
|
return Set(field, int_value);
|
}
|
case reflection::Long: {
|
int64 int_value;
|
if (!ParseInt64(value.data(), &int_value)) {
|
TC3_LOG(ERROR) << "Could not parse '" << value << "' as int64.";
|
return false;
|
}
|
return Set(field, int_value);
|
}
|
case reflection::Float: {
|
double double_value;
|
if (!ParseDouble(value.data(), &double_value)) {
|
TC3_LOG(ERROR) << "Could not parse '" << value << "' as float.";
|
return false;
|
}
|
return Set(field, static_cast<float>(double_value));
|
}
|
case reflection::Double: {
|
double double_value;
|
if (!ParseDouble(value.data(), &double_value)) {
|
TC3_LOG(ERROR) << "Could not parse '" << value << "' as double.";
|
return false;
|
}
|
return Set(field, double_value);
|
}
|
default:
|
TC3_LOG(ERROR) << "Unhandled field type: " << field->type()->base_type();
|
return false;
|
}
|
}
|
|
bool ReflectiveFlatbuffer::ParseAndSet(const FlatbufferFieldPath* path,
|
const std::string& value) {
|
ReflectiveFlatbuffer* parent;
|
const reflection::Field* field;
|
if (!GetFieldWithParent(path, &parent, &field)) {
|
return false;
|
}
|
return parent->ParseAndSet(field, value);
|
}
|
|
ReflectiveFlatbuffer* ReflectiveFlatbuffer::Mutable(
|
const StringPiece field_name) {
|
if (const reflection::Field* field = GetFieldOrNull(field_name)) {
|
return Mutable(field);
|
}
|
TC3_LOG(ERROR) << "Unknown field: " << field_name.ToString();
|
return nullptr;
|
}
|
|
ReflectiveFlatbuffer* ReflectiveFlatbuffer::Mutable(
|
const reflection::Field* field) {
|
if (field->type()->base_type() != reflection::Obj) {
|
TC3_LOG(ERROR) << "Field is not of type Object.";
|
return nullptr;
|
}
|
const auto entry = children_.find(field);
|
if (entry != children_.end()) {
|
return entry->second.get();
|
}
|
const auto it = children_.insert(
|
/*hint=*/entry,
|
std::make_pair(
|
field,
|
std::unique_ptr<ReflectiveFlatbuffer>(new ReflectiveFlatbuffer(
|
schema_, schema_->objects()->Get(field->type()->index())))));
|
return it->second.get();
|
}
|
|
ReflectiveFlatbuffer::RepeatedField* ReflectiveFlatbuffer::Repeated(
|
StringPiece field_name) {
|
if (const reflection::Field* field = GetFieldOrNull(field_name)) {
|
return Repeated(field);
|
}
|
TC3_LOG(ERROR) << "Unknown field: " << field_name.ToString();
|
return nullptr;
|
}
|
|
ReflectiveFlatbuffer::RepeatedField* ReflectiveFlatbuffer::Repeated(
|
const reflection::Field* field) {
|
if (field->type()->base_type() != reflection::Vector) {
|
TC3_LOG(ERROR) << "Field is not of type Vector.";
|
return nullptr;
|
}
|
|
// If the repeated field was already set, return its instance.
|
const auto entry = repeated_fields_.find(field);
|
if (entry != repeated_fields_.end()) {
|
return entry->second.get();
|
}
|
|
// Otherwise, create a new instance and store it.
|
std::unique_ptr<RepeatedField> repeated_field;
|
if (!CreateRepeatedField(schema_, field->type(), &repeated_field)) {
|
TC3_LOG(ERROR) << "Could not create repeated field.";
|
return nullptr;
|
}
|
const auto it = repeated_fields_.insert(
|
/*hint=*/entry, std::make_pair(field, std::move(repeated_field)));
|
return it->second.get();
|
}
|
|
flatbuffers::uoffset_t ReflectiveFlatbuffer::Serialize(
|
flatbuffers::FlatBufferBuilder* builder) const {
|
// Build all children before we can start with this table.
|
std::vector<
|
std::pair</* field vtable offset */ int,
|
/* field data offset in buffer */ flatbuffers::uoffset_t>>
|
offsets;
|
offsets.reserve(children_.size() + repeated_fields_.size());
|
for (const auto& it : children_) {
|
offsets.push_back({it.first->offset(), it.second->Serialize(builder)});
|
}
|
|
// Create strings.
|
for (const auto& it : fields_) {
|
if (it.second.HasString()) {
|
offsets.push_back({it.first->offset(),
|
builder->CreateString(it.second.StringValue()).o});
|
}
|
}
|
|
// Build the repeated fields.
|
for (const auto& it : repeated_fields_) {
|
offsets.push_back({it.first->offset(), it.second->Serialize(builder)});
|
}
|
|
// Build the table now.
|
const flatbuffers::uoffset_t table_start = builder->StartTable();
|
|
// Add scalar fields.
|
for (const auto& it : fields_) {
|
switch (it.second.GetType()) {
|
case Variant::TYPE_BOOL_VALUE:
|
builder->AddElement<uint8_t>(
|
it.first->offset(), static_cast<uint8_t>(it.second.BoolValue()),
|
static_cast<uint8_t>(it.first->default_integer()));
|
continue;
|
case Variant::TYPE_INT_VALUE:
|
builder->AddElement<int32>(
|
it.first->offset(), it.second.IntValue(),
|
static_cast<int32>(it.first->default_integer()));
|
continue;
|
case Variant::TYPE_INT64_VALUE:
|
builder->AddElement<int64>(it.first->offset(), it.second.Int64Value(),
|
it.first->default_integer());
|
continue;
|
case Variant::TYPE_FLOAT_VALUE:
|
builder->AddElement<float>(
|
it.first->offset(), it.second.FloatValue(),
|
static_cast<float>(it.first->default_real()));
|
continue;
|
case Variant::TYPE_DOUBLE_VALUE:
|
builder->AddElement<double>(it.first->offset(), it.second.DoubleValue(),
|
it.first->default_real());
|
continue;
|
default:
|
continue;
|
}
|
}
|
|
// Add strings, subtables and repeated fields.
|
for (const auto& it : offsets) {
|
builder->AddOffset(it.first, flatbuffers::Offset<void>(it.second));
|
}
|
|
return builder->EndTable(table_start);
|
}
|
|
std::string ReflectiveFlatbuffer::Serialize() const {
|
flatbuffers::FlatBufferBuilder builder;
|
builder.Finish(flatbuffers::Offset<void>(Serialize(&builder)));
|
return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
|
builder.GetSize());
|
}
|
|
bool ReflectiveFlatbuffer::MergeFrom(const flatbuffers::Table* from) {
|
// No fields to set.
|
if (type_->fields() == nullptr) {
|
return true;
|
}
|
|
for (const reflection::Field* field : *type_->fields()) {
|
// Skip fields that are not explicitly set.
|
if (!from->CheckField(field->offset())) {
|
continue;
|
}
|
const reflection::BaseType type = field->type()->base_type();
|
switch (type) {
|
case reflection::Bool:
|
Set<bool>(field, from->GetField<uint8_t>(field->offset(),
|
field->default_integer()));
|
break;
|
case reflection::Int:
|
Set<int32>(field, from->GetField<int32>(field->offset(),
|
field->default_integer()));
|
break;
|
case reflection::Long:
|
Set<int64>(field, from->GetField<int64>(field->offset(),
|
field->default_integer()));
|
break;
|
case reflection::Float:
|
Set<float>(field, from->GetField<float>(field->offset(),
|
field->default_real()));
|
break;
|
case reflection::Double:
|
Set<double>(field, from->GetField<double>(field->offset(),
|
field->default_real()));
|
break;
|
case reflection::String:
|
Set<std::string>(
|
field, from->GetPointer<const flatbuffers::String*>(field->offset())
|
->str());
|
break;
|
case reflection::Obj:
|
if (!Mutable(field)->MergeFrom(
|
from->GetPointer<const flatbuffers::Table* const>(
|
field->offset()))) {
|
return false;
|
}
|
break;
|
default:
|
TC3_LOG(ERROR) << "Unsupported type: " << type;
|
return false;
|
}
|
}
|
return true;
|
}
|
|
bool ReflectiveFlatbuffer::MergeFromSerializedFlatbuffer(StringPiece from) {
|
return MergeFrom(flatbuffers::GetAnyRoot(
|
reinterpret_cast<const unsigned char*>(from.data())));
|
}
|
|
void ReflectiveFlatbuffer::AsFlatMap(
|
const std::string& key_separator, const std::string& key_prefix,
|
std::map<std::string, Variant>* result) const {
|
// Add direct fields.
|
for (auto it : fields_) {
|
(*result)[key_prefix + it.first->name()->str()] = it.second;
|
}
|
|
// Add nested messages.
|
for (auto& it : children_) {
|
it.second->AsFlatMap(key_separator,
|
key_prefix + it.first->name()->str() + key_separator,
|
result);
|
}
|
}
|
|
} // namespace libtextclassifier3
|