/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
you may not use this file except in compliance with the License.
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
distributed under the License is distributed on an "AS IS" BASIS,
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
See the License for the specific language governing permissions and
|
limitations under the License.
|
==============================================================================*/
|
|
#ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_ENCODE_DECODE_H_
|
#define TENSORFLOW_CORE_FRAMEWORK_VARIANT_ENCODE_DECODE_H_
|
|
#include <iostream>
|
#include <type_traits>
|
#include <utility>
|
#include <vector>
|
|
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/type_index.h"
|
#include "tensorflow/core/framework/variant_tensor_data.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/platform/abi.h"
|
#include "tensorflow/core/platform/protobuf.h"
|
|
namespace tensorflow {
|
|
// Type used for tag-dispatch of the Encode/Decode Variant implementations. This
|
// template can determine whether the first type parameter `T` is one of the
|
// following:
|
//
|
// * A POD type (TypeResolver<T, true>)
|
// * A tensorflow::Tensor (TypeResolver<T, false, true>)
|
// * A protocol buffer (TypeResolver<T, false, false, true>)
|
// * None of the above (TypeResolver<T, false, false, false>)
|
//
|
template <typename T, bool = std::is_pod<typename std::decay<T>::type>::value,
|
bool = std::is_same<typename std::decay<T>::type,
|
::tensorflow::Tensor>::value,
|
bool = std::is_base_of<protobuf::MessageLite,
|
typename std::decay<T>::type>::value>
|
struct TypeResolver {};
|
|
// Specialization for POD type
|
template <typename T>
|
void EncodeVariantImpl(const T& value, TypeResolver<T, true /* is_pod */>,
|
VariantTensorData* data) {
|
data->set_metadata(value);
|
}
|
|
// Specialization for tensorflow::Tensor
|
template <typename T>
|
void EncodeVariantImpl(const T& value,
|
TypeResolver<T, false /* is_pod */, true /* Tensor */>,
|
VariantTensorData* data) {
|
data->tensors_.clear();
|
data->tensors_.push_back(value);
|
}
|
|
// Specialization for protobuf
|
template <typename T>
|
void EncodeVariantImpl(const T& value,
|
TypeResolver<T, false /* is_pod */, false /* Tensor */,
|
true /* protobuf */>,
|
VariantTensorData* data) {
|
value.SerializeToString(&data->metadata_);
|
}
|
|
// Specialization for other types
|
template <typename T>
|
void EncodeVariantImpl(const T& value,
|
TypeResolver<T, false /* is_pod */, false /* Tensor */,
|
false /* protobuf */>,
|
VariantTensorData* data) {
|
value.Encode(data);
|
}
|
|
// Specialization for POD type
|
template <typename T>
|
bool DecodeVariantImpl(VariantTensorData data,
|
TypeResolver<T, true /* is_pod */, false /* Tensor */,
|
false /* protobuf */>,
|
T* value) {
|
return data.get_metadata(value);
|
}
|
|
// Specialization for tensorflow::Tensor
|
template <typename T>
|
bool DecodeVariantImpl(VariantTensorData data,
|
TypeResolver<T, false /* is_pod */, true /* Tensor */,
|
false /* protobuf */>,
|
T* value) {
|
*value = data.tensors(0);
|
return true;
|
}
|
|
// Specialization for protobuf
|
template <typename T>
|
bool DecodeVariantImpl(VariantTensorData data,
|
TypeResolver<T, false /* is_pod */, false /* Tensor */,
|
true /* protobuf */>,
|
T* value) {
|
string metadata;
|
data.get_metadata(&metadata);
|
return value->ParseFromString(std::move(metadata));
|
}
|
|
// Specialization for other types
|
template <typename T>
|
bool DecodeVariantImpl(VariantTensorData data,
|
TypeResolver<T, false /* is_pod */, false /* Tensor */,
|
false /* protobuf */>,
|
T* value) {
|
return value->Decode(std::move(data));
|
}
|
|
template <typename C, typename = void>
|
struct has_type_name : std::false_type {};
|
|
template <typename C>
|
struct has_type_name<
|
C, typename std::enable_if<std::is_same<
|
decltype(std::declval<C>().TypeName()), string>::value>::type>
|
: std::true_type {};
|
|
template <typename T, bool = has_type_name<typename std::decay<T>::type>::value,
|
bool = std::is_same<typename std::decay<T>::type,
|
::tensorflow::Tensor>::value,
|
bool = std::is_base_of<protobuf::MessageLite,
|
typename std::decay<T>::type>::value>
|
struct TypeNameResolver {};
|
|
template <typename T>
|
string TypeNameVariantImpl(const T& value,
|
TypeNameResolver<T, true /* has_type_name */>) {
|
return value.TypeName();
|
}
|
|
template <typename T>
|
string TypeNameVariantImpl(
|
const T& value,
|
TypeNameResolver<T, false /* has_type_name */, true /* Tensor */>) {
|
return "tensorflow::Tensor";
|
}
|
|
template <typename T>
|
string TypeNameVariantImpl(
|
const T& value, TypeNameResolver<T, false /* has_type_name */,
|
false /* Tensor */, true /* protobuf */>) {
|
return value.GetTypeName();
|
}
|
|
template <typename T>
|
string TypeNameVariantImpl(
|
const T& value,
|
TypeNameResolver<T, false /* has_type_name */, false /* Tensor */,
|
false /* protobuf */>) {
|
return port::MaybeAbiDemangle(MakeTypeIndex<T>().name());
|
}
|
|
template <typename T>
|
string TypeNameVariant(const T& value) {
|
return TypeNameVariantImpl(value, TypeNameResolver<T>());
|
}
|
|
template <typename C, typename = void>
|
struct has_debug_string : std::false_type {};
|
|
template <typename C>
|
struct has_debug_string<
|
C, typename std::enable_if<std::is_same<
|
decltype(std::declval<C>().DebugString()), string>::value>::type>
|
: std::true_type {};
|
|
template <typename C, typename = void>
|
struct can_strcat : std::false_type {};
|
|
template <typename C>
|
struct can_strcat<
|
C, typename std::enable_if<std::is_same<
|
decltype(strings::StrCat(std::declval<C>())), string>::value>::type>
|
: std::true_type {};
|
|
template <typename T,
|
bool = has_debug_string<typename std::decay<T>::type>::value,
|
bool = can_strcat<typename std::decay<T>::type>::value>
|
struct DebugStringResolver {};
|
|
// TODO(ebrevdo): Expand DebugStringResolver to return TypeString if
|
// there is no StrCat<T>() constructor.
|
template <typename T>
|
string DebugStringVariantImpl(
|
const T& value, DebugStringResolver<T, true /* has_debug_string */>) {
|
return value.DebugString();
|
}
|
|
template <typename T>
|
string DebugStringVariantImpl(
|
const T& value, DebugStringResolver<T, false /* has_debug_string */,
|
true /* can_strcat */>) {
|
return strings::StrCat(value);
|
}
|
|
template <typename T>
|
string DebugStringVariantImpl(
|
const T& value, DebugStringResolver<T, false /* has_debug_string */,
|
false /* can_strcat */>) {
|
return "?";
|
}
|
|
template <typename T>
|
string DebugStringVariant(const T& value) {
|
return DebugStringVariantImpl(value, DebugStringResolver<T>());
|
}
|
|
template <typename T>
|
void EncodeVariant(const T& value, VariantTensorData* data) {
|
EncodeVariantImpl(value, TypeResolver<T>(), data);
|
data->set_type_name(TypeNameVariant(value));
|
}
|
|
template <typename T>
|
bool DecodeVariant(VariantTensorData* data, T* value) {
|
return DecodeVariantImpl(std::move(*data), TypeResolver<T>(), value);
|
}
|
|
template <typename T>
|
void EncodeVariant(const T& value, string* buf) {
|
VariantTensorData data;
|
EncodeVariantImpl(value, TypeResolver<T>(), &data);
|
data.set_type_name(TypeNameVariant(value));
|
DCHECK(buf != nullptr);
|
data.SerializeToString(buf);
|
}
|
|
template <typename T>
|
bool DecodeVariant(string* buf, T* value) {
|
VariantTensorData data;
|
if (!data.ParseFromString(*buf)) return false;
|
if (!DecodeVariantImpl(std::move(data), TypeResolver<T>(), value)) {
|
return false;
|
}
|
return true;
|
}
|
|
// Specializations for VariantTensorDataProto
|
template <>
|
string TypeNameVariant(const VariantTensorDataProto& value);
|
|
template <>
|
void EncodeVariant(const VariantTensorDataProto& value,
|
VariantTensorData* data);
|
|
template <>
|
bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value);
|
|
template <>
|
void EncodeVariant(const VariantTensorDataProto& value, string* buf);
|
|
template <>
|
bool DecodeVariant(string* buf, VariantTensorDataProto* value);
|
|
// Encodes an array of Variant objects in to the given StringListEncoder.
|
// `variant_array` is assumed to point to an array of `n` Variant objects.
|
void EncodeVariantList(const Variant* variant_array, int64 n,
|
std::unique_ptr<port::StringListEncoder> e);
|
|
// Decodes an array of Variant objects from the given StringListDecoder.
|
// `variant_array` is assumed to point to an array of `n` Variant objects.
|
bool DecodeVariantList(std::unique_ptr<port::StringListDecoder> d,
|
Variant* variant_array, int64 n);
|
|
} // end namespace tensorflow
|
|
#endif // TENSORFLOW_CORE_FRAMEWORK_VARIANT_ENCODE_DECODE_H_
|