/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
you may not use this file except in compliance with the License.
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
distributed under the License is distributed on an "AS IS" BASIS,
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
See the License for the specific language governing permissions and
|
limitations under the License.
|
==============================================================================*/
|
|
#include "tensorflow/lite/interpreter.h"
|
|
#include <cassert>
|
#include <cstdarg>
|
#include <cstdint>
|
#include <cstring>
|
#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
|
|
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/context_util.h"
|
#include "tensorflow/lite/core/api/error_reporter.h"
|
#include "tensorflow/lite/graph_info.h"
|
#include "tensorflow/lite/memory_planner.h"
|
#include "tensorflow/lite/minimal_logging.h"
|
#include "tensorflow/lite/nnapi_delegate.h"
|
#include "tensorflow/lite/profiling/profiler.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/util.h"
|
|
namespace tflite {
|
|
namespace {
|
|
// Gets the current TfLiteQuantization from the legacy fLiteQuantizationParams.
|
TfLiteQuantization GetQuantizationFromLegacy(
|
const TfLiteQuantizationParams& legacy_quantization) {
|
TfLiteQuantization quantization;
|
quantization.type = kTfLiteAffineQuantization;
|
auto* affine_quantization = reinterpret_cast<TfLiteAffineQuantization*>(
|
malloc(sizeof(TfLiteAffineQuantization)));
|
affine_quantization->scale = TfLiteFloatArrayCreate(1);
|
affine_quantization->zero_point = TfLiteIntArrayCreate(1);
|
affine_quantization->scale->data[0] = legacy_quantization.scale;
|
affine_quantization->zero_point->data[0] = legacy_quantization.zero_point;
|
quantization.params = affine_quantization;
|
|
return quantization;
|
}
|
|
} // namespace
|
|
Interpreter::Interpreter(ErrorReporter* error_reporter)
|
: error_reporter_(error_reporter ? error_reporter
|
: DefaultErrorReporter()) {
|
// Only log initialization once per-process to avoid log spam.
|
static std::once_flag init_log_once_flag;
|
std::call_once(init_log_once_flag, []() {
|
// TODO(b/128420794): Include the TFLite runtime version in the log.
|
TFLITE_LOG_PROD(TFLITE_LOG_INFO, "Initialized TensorFlow Lite runtime.");
|
});
|
|
// There's always at least 1 subgraph which is the primary subgraph.
|
AddSubgraphs(1);
|
context_ = primary_subgraph().context();
|
|
// Reserve some space for the tensors to avoid excessive resizing.
|
for (int i = 0; i < kTfLiteMaxExternalContexts; ++i) {
|
external_contexts_[i] = nullptr;
|
}
|
|
UseNNAPI(false);
|
}
|
|
Interpreter::~Interpreter() {}
|
|
void Interpreter::SetExternalContext(TfLiteExternalContextType type,
|
TfLiteExternalContext* ctx) {
|
primary_subgraph().SetExternalContext(type, ctx);
|
}
|
|
TfLiteStatus Interpreter::SetInputs(std::vector<int> inputs) {
|
return primary_subgraph().SetInputs(inputs);
|
}
|
|
TfLiteStatus Interpreter::SetOutputs(std::vector<int> outputs) {
|
return primary_subgraph().SetOutputs(outputs);
|
}
|
|
TfLiteStatus Interpreter::SetVariables(std::vector<int> variables) {
|
return primary_subgraph().SetVariables(variables);
|
}
|
|
TfLiteStatus Interpreter::AllocateTensors() {
|
return primary_subgraph().AllocateTensors();
|
}
|
|
void Interpreter::ReserveNodes(int count) {
|
primary_subgraph().ReserveNodes(count);
|
}
|
|
void Interpreter::AddSubgraphs(int subgraphs_to_add,
|
int* first_new_subgraph_index) {
|
const size_t base_index = subgraphs_.size();
|
if (first_new_subgraph_index) *first_new_subgraph_index = base_index;
|
|
subgraphs_.reserve(base_index + subgraphs_to_add);
|
for (int i = 0; i < subgraphs_to_add; ++i) {
|
Subgraph* subgraph =
|
new Subgraph(error_reporter_, external_contexts_, &subgraphs_);
|
subgraphs_.emplace_back(subgraph);
|
}
|
}
|
|
TfLiteStatus Interpreter::AddNodeWithParameters(
|
const std::vector<int>& inputs, const std::vector<int>& outputs,
|
const char* init_data, size_t init_data_size, void* builtin_data,
|
const TfLiteRegistration* registration, int* node_index) {
|
return primary_subgraph().AddNodeWithParameters(inputs, outputs, init_data,
|
init_data_size, builtin_data,
|
registration, node_index);
|
}
|
|
TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index,
|
const std::vector<int>& dims) {
|
return primary_subgraph().ResizeInputTensor(tensor_index, dims);
|
}
|
|
TfLiteStatus Interpreter::Invoke() {
|
TF_LITE_ENSURE_STATUS(primary_subgraph().Invoke());
|
|
if (!allow_buffer_handle_output_) {
|
for (int tensor_index : outputs()) {
|
TF_LITE_ENSURE_STATUS(
|
primary_subgraph().EnsureTensorDataIsReadable(tensor_index));
|
}
|
}
|
|
return kTfLiteOk;
|
}
|
|
TfLiteStatus Interpreter::AddTensors(int tensors_to_add,
|
int* first_new_tensor_index) {
|
return primary_subgraph().AddTensors(tensors_to_add, first_new_tensor_index);
|
}
|
|
TfLiteStatus Interpreter::ResetVariableTensors() {
|
return primary_subgraph().ResetVariableTensors();
|
}
|
|
TfLiteStatus Interpreter::SetTensorParametersReadOnly(
|
int tensor_index, TfLiteType type, const char* name,
|
const std::vector<int>& dims, TfLiteQuantization quantization,
|
const char* buffer, size_t bytes, const Allocation* allocation) {
|
return primary_subgraph().SetTensorParametersReadOnly(
|
tensor_index, type, name, dims.size(), dims.data(), quantization, buffer,
|
bytes, allocation);
|
}
|
|
TfLiteStatus Interpreter::SetTensorParametersReadWrite(
|
int tensor_index, TfLiteType type, const char* name,
|
const std::vector<int>& dims, TfLiteQuantization quantization,
|
bool is_variable) {
|
return primary_subgraph().SetTensorParametersReadWrite(
|
tensor_index, type, name, dims.size(), dims.data(), quantization,
|
is_variable);
|
}
|
|
TfLiteStatus Interpreter::SetTensorParametersReadOnly(
|
int tensor_index, TfLiteType type, const char* name, const size_t rank,
|
const int* dims, TfLiteQuantizationParams quantization, const char* buffer,
|
size_t bytes, const Allocation* allocation) {
|
TfLiteQuantization new_quantization = GetQuantizationFromLegacy(quantization);
|
if (primary_subgraph().SetTensorParametersReadOnly(
|
tensor_index, type, name, rank, dims, new_quantization, buffer, bytes,
|
allocation) != kTfLiteOk) {
|
TfLiteQuantizationFree(&new_quantization);
|
return kTfLiteError;
|
}
|
return kTfLiteOk;
|
}
|
|
TfLiteStatus Interpreter::SetTensorParametersReadWrite(
|
int tensor_index, TfLiteType type, const char* name, const size_t rank,
|
const int* dims, TfLiteQuantizationParams quantization, bool is_variable) {
|
TfLiteQuantization new_quantization = GetQuantizationFromLegacy(quantization);
|
if (primary_subgraph().SetTensorParametersReadWrite(
|
tensor_index, type, name, rank, dims, new_quantization,
|
is_variable) != kTfLiteOk) {
|
TfLiteQuantizationFree(&new_quantization);
|
return kTfLiteError;
|
}
|
return kTfLiteOk;
|
}
|
|
TfLiteStatus Interpreter::SetExecutionPlan(const std::vector<int>& new_plan) {
|
return primary_subgraph().SetExecutionPlan(new_plan);
|
}
|
|
void Interpreter::UseNNAPI(bool enable) { primary_subgraph().UseNNAPI(enable); }
|
|
void Interpreter::SetNumThreads(int num_threads) {
|
for (auto& subgraph : subgraphs_) {
|
subgraph->context()->recommended_num_threads = num_threads;
|
}
|
|
for (int i = 0; i < kTfLiteMaxExternalContexts; ++i) {
|
auto* c = external_contexts_[i];
|
if (c && c->Refresh) {
|
c->Refresh(context_);
|
}
|
}
|
}
|
|
void Interpreter::SetAllowFp16PrecisionForFp32(bool allow) {
|
for (auto& subgraph : subgraphs_) {
|
subgraph->context()->allow_fp32_relax_to_fp16 = allow;
|
}
|
}
|
|
// TODO(b/121264966): Subgraphs added after cancellation is set will not get the
|
// cancellation function added to their context.
|
void Interpreter::SetCancellationFunction(void* data,
|
bool (*check_cancelled_func)(void*)) {
|
for (auto& subgraph : subgraphs_) {
|
subgraph->SetCancellationFunction(data, check_cancelled_func);
|
}
|
}
|
|
TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
|
for (auto& subgraph : subgraphs_) {
|
TF_LITE_ENSURE_OK(context_, subgraph->ModifyGraphWithDelegate(delegate));
|
}
|
return kTfLiteOk;
|
}
|
|
TfLiteStatus Interpreter::SetBufferHandle(int tensor_index,
|
TfLiteBufferHandle buffer_handle,
|
TfLiteDelegate* delegate) {
|
TF_LITE_ENSURE(context_, tensor_index < tensors_size());
|
std::vector<TfLiteTensor>& tensors = primary_subgraph().tensors();
|
TfLiteTensor* tensor = &tensors[tensor_index];
|
|
TF_LITE_ENSURE(context_,
|
tensor->delegate == nullptr || tensor->delegate == delegate);
|
tensor->delegate = delegate;
|
if (tensor->buffer_handle != kTfLiteNullBufferHandle) {
|
TF_LITE_ENSURE(context_, tensor->delegate->FreeBufferHandle != nullptr);
|
tensor->delegate->FreeBufferHandle(context_, tensor->delegate,
|
&tensor->buffer_handle);
|
}
|
tensor->buffer_handle = buffer_handle;
|
|
return kTfLiteOk;
|
}
|
|
TfLiteStatus Interpreter::GetBufferHandle(int tensor_index,
|
TfLiteBufferHandle* buffer_handle,
|
TfLiteDelegate** delegate) {
|
TF_LITE_ENSURE(context_, tensor_index < tensors_size());
|
std::vector<TfLiteTensor>& tensors = primary_subgraph().tensors();
|
TfLiteTensor* tensor = &tensors[tensor_index];
|
|
*delegate = tensor->delegate;
|
*buffer_handle = tensor->buffer_handle;
|
|
return kTfLiteOk;
|
}
|
|
void Interpreter::SetProfiler(profiling::Profiler* profiler) {
|
for (auto& subgraph : subgraphs_) subgraph->SetProfiler(profiler);
|
}
|
|
profiling::Profiler* Interpreter::GetProfiler() {
|
return primary_subgraph().GetProfiler();
|
}
|
|
} // namespace tflite
|