/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
you may not use this file except in compliance with the License.
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
distributed under the License is distributed on an "AS IS" BASIS,
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
See the License for the specific language governing permissions and
|
limitations under the License.
|
==============================================================================*/
|
|
#include "tensorflow/core/framework/tensor_slice.h"
|
#include <vector>
|
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/strings/numbers.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/platform/logging.h"
|
|
namespace tensorflow {
|
|
TensorSlice::TensorSlice(int dim) { SetFullSlice(dim); }
|
|
TensorSlice::TensorSlice(const TensorSliceProto& proto) {
|
starts_.reserve(proto.extent_size());
|
lengths_.reserve(proto.extent_size());
|
for (const auto& e : proto.extent()) {
|
starts_.push_back(e.start());
|
lengths_.push_back(GetExtentLength(e));
|
}
|
}
|
|
TensorSlice::TensorSlice(
|
std::initializer_list<std::pair<int64, int64>> extents) {
|
starts_.reserve(extents.size());
|
lengths_.reserve(extents.size());
|
for (const auto& e : extents) {
|
starts_.push_back(e.first);
|
lengths_.push_back(e.second);
|
}
|
}
|
|
Status TensorSlice::Parse(const string& str, TensorSlice* slice) {
|
std::vector<string> items = str_util::Split(str, ':', str_util::SkipEmpty());
|
slice->starts_.reserve(items.size());
|
slice->lengths_.reserve(items.size());
|
for (const string& x : items) {
|
int64 s, l;
|
if (x == "-") {
|
// "everything"
|
s = 0;
|
l = kFullExtent;
|
} else {
|
std::vector<string> sl = str_util::Split(x, ',', str_util::SkipEmpty());
|
if (sl.size() != 2 || !strings::safe_strto64(sl[0], &s) ||
|
!strings::safe_strto64(sl[1], &l)) {
|
return errors::InvalidArgument(
|
"Expected a pair of numbers or '-' "
|
"but got '",
|
x, "': string = ", str);
|
}
|
if (s < 0 || l <= 0) {
|
return errors::InvalidArgument(
|
"Expected non-negative start and "
|
"positive length but got start = ",
|
s, ", length = ", l, ": string = ", str);
|
}
|
}
|
slice->starts_.push_back(s);
|
slice->lengths_.push_back(l);
|
}
|
|
return Status::OK();
|
}
|
|
void TensorSlice::Clear() {
|
starts_.clear();
|
lengths_.clear();
|
}
|
|
bool TensorSlice::IsFull() const {
|
for (int d = 0; d < dims(); ++d) {
|
if (!IsFullAt(d)) return false;
|
}
|
return true;
|
}
|
|
void TensorSlice::SetFullSlice(int dim) {
|
Clear();
|
starts_.reserve(dim);
|
lengths_.reserve(dim);
|
for (int d = 0; d < dim; ++d) {
|
starts_.push_back(0);
|
lengths_.push_back(kFullExtent);
|
}
|
}
|
|
void TensorSlice::Extend(int dim) {
|
int old_dim = dims();
|
DCHECK_LE(old_dim, dim);
|
starts_.resize(dim);
|
lengths_.resize(dim);
|
for (int d = old_dim; d < dim; ++d) {
|
starts_[d] = 0;
|
lengths_[d] = kFullExtent;
|
}
|
}
|
|
void TensorSlice::AsProto(TensorSliceProto* proto) const {
|
for (int d = 0; d < dims(); ++d) {
|
TensorSliceProto::Extent* e = proto->add_extent();
|
// We only need to record the explicit slice for non-full slices
|
if (!IsFullAt(d)) {
|
e->set_start(starts_[d]);
|
e->set_length(lengths_[d]);
|
}
|
}
|
}
|
|
string TensorSlice::DebugString() const {
|
string buffer;
|
bool first = true;
|
for (int d = 0; d < dims(); ++d) {
|
if (!first) {
|
buffer.append(":");
|
}
|
if (IsFullAt(d)) {
|
buffer.append("-");
|
} else {
|
strings::StrAppend(&buffer, starts_[d], ",", lengths_[d]);
|
}
|
first = false;
|
}
|
return buffer;
|
}
|
|
bool TensorSlice::Intersect(const TensorSlice& other,
|
TensorSlice* result) const {
|
// First, if two slices have different ranks, they obviously don't overlap
|
// -- in fact they are not compatible.
|
if (dims() != other.dims()) {
|
return false;
|
}
|
|
// Setting the result to the right dimension
|
if (result) {
|
result->SetFullSlice(dims());
|
}
|
// The two slices overlap if they overlap in all dimensions.
|
for (int d = 0; d < dims(); ++d) {
|
if (IsFullAt(d)) {
|
if (result) {
|
result->set_start(d, other.start(d));
|
result->set_length(d, other.length(d));
|
}
|
} else if (other.IsFullAt(d)) {
|
if (result) {
|
result->set_start(d, start(d));
|
result->set_length(d, length(d));
|
}
|
} else {
|
// If we have an intersection here, it should have a start that is the
|
// max of the two starts and an end that is the min of the two ends.
|
int64 s = std::max(start(d), other.start(d));
|
int64 l = std::min(end(d), other.end(d)) - s;
|
if (l > 0) {
|
// We have a real intersection
|
if (result) {
|
result->set_start(d, s);
|
result->set_length(d, l);
|
}
|
} else {
|
// We don't have an intersection for this dimension -- thus we don't
|
// have any intersection at all.
|
if (result) {
|
result->Clear();
|
}
|
return false;
|
}
|
}
|
}
|
// If we are here, we know there is overlap in every dimension.
|
return true;
|
}
|
|
bool TensorSlice::operator==(const TensorSlice& other) const {
|
return dims() == other.dims() && starts_ == other.starts_ &&
|
lengths_ == other.lengths_;
|
}
|
|
void TensorSlice::ComputeRelative(const TensorSlice& sub,
|
TensorSlice* relative) const {
|
DCHECK_EQ(dims(), sub.dims());
|
relative->SetFullSlice(dims());
|
for (int d = 0; d < dims(); ++d) {
|
if (IsFullAt(d)) {
|
relative->set_start(d, sub.start(d));
|
relative->set_length(d, sub.length(d));
|
} else {
|
// Otherwise the relative start is the difference between the start of
|
// sub and the start of base
|
relative->set_start(d, sub.start(d) - start(d));
|
relative->set_length(d, sub.length(d));
|
}
|
}
|
}
|
|
void TensorSlice::UpdateToCover(const TensorSlice& other) {
|
DCHECK_EQ(dims(), other.dims());
|
for (int d = 0; d < dims(); ++d) {
|
if (!IsFullAt(d)) {
|
if (other.IsFullAt(d)) {
|
starts_[d] = 0;
|
lengths_[d] = kFullExtent;
|
} else {
|
const auto new_end = std::max(end(d), other.end(d));
|
set_start(d, std::min(start(d), other.start(d)));
|
set_length(d, new_end - start(d));
|
}
|
}
|
}
|
}
|
|
// static
|
bool TensorSlice::HasExtentLength(const TensorSliceProto::Extent& extent) {
|
return extent.has_length_case() == TensorSliceProto::Extent::kLength;
|
}
|
|
// static
|
int64 TensorSlice::GetExtentLength(const TensorSliceProto::Extent& extent) {
|
if (!HasExtentLength(extent)) return -1;
|
return extent.length();
|
}
|
|
Status TensorSlice::SliceTensorShape(const TensorShape& shape,
|
TensorShape* result_shape) const {
|
result_shape->Clear();
|
// Mismatching ranks: we can't apply the slice at all.
|
if (shape.dims() != dims()) {
|
return errors::Internal("Mismatching ranks: shape = ", shape.DebugString(),
|
", slice = ", DebugString());
|
}
|
for (int d = 0; d < dims(); ++d) {
|
if (IsFullAt(d)) {
|
result_shape->AddDim(shape.dim_size(d));
|
} else {
|
// Check if the extent applies to the dimension
|
if (end(d) <= shape.dim_size(d)) {
|
// Yes: the end is within the range of the dim -- we adjust the result
|
// shape so that its size along this dimension is the length of the
|
// slice.
|
result_shape->AddDim(length(d));
|
} else {
|
// The extent doesn't apply to the dimension
|
result_shape->Clear();
|
return errors::Internal("Extent in dimension ", d,
|
" out of bounds: shape = ", shape.DebugString(),
|
", slice = ", DebugString());
|
}
|
}
|
}
|
// If we are here, we have successfully applied the shape.
|
return Status::OK();
|
}
|
|
const int64 TensorSlice::kFullExtent = -1;
|
|
} // namespace tensorflow
|