/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
you may not use this file except in compliance with the License.
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
distributed under the License is distributed on an "AS IS" BASIS,
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
See the License for the specific language governing permissions and
|
limitations under the License.
|
==============================================================================*/
|
|
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/shape_inference.h"
|
|
using namespace tensorflow; // NOLINT(build/namespaces)
|
|
REGISTER_OP("ZeroOut")
|
.Attr("preserve_index: int = 0")
|
.Input("to_zero: int32")
|
.Output("zeroed: int32")
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
c->set_output(0, c->input(0));
|
return Status::OK();
|
});
|
|
class ZeroOutOp : public OpKernel {
|
public:
|
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {
|
// Get the index of the value to preserve
|
OP_REQUIRES_OK(context,
|
context->GetAttr("preserve_index", &preserve_index_));
|
// Check that preserve\_index is positive
|
OP_REQUIRES(context, preserve_index_ >= 0,
|
errors::InvalidArgument("Need preserve_index >= 0, got ",
|
preserve_index_));
|
}
|
|
void Compute(OpKernelContext* context) override {
|
// Grab the input tensor
|
const Tensor& input_tensor = context->input(0);
|
auto input = input_tensor.flat<int32>();
|
|
// Check that preserve_index is in range
|
OP_REQUIRES(context, preserve_index_ < input.dimension(0),
|
errors::InvalidArgument("preserve_index out of range"));
|
|
// Create an output tensor
|
Tensor* output_tensor = nullptr;
|
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
|
&output_tensor));
|
auto output = output_tensor->template flat<int32>();
|
|
// Set all the elements of the output tensor to 0
|
const int N = input.size();
|
for (int i = 0; i < N; i++) {
|
output(i) = 0;
|
}
|
|
// Preserve the requested input value
|
output(preserve_index_) = input(preserve_index_);
|
}
|
|
private:
|
int preserve_index_;
|
};
|
|
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
|