/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
you may not use this file except in compliance with the License.
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
distributed under the License is distributed on an "AS IS" BASIS,
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
See the License for the specific language governing permissions and
|
limitations under the License.
|
==============================================================================*/
|
|
#ifndef TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_H_
|
#define TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_H_
|
|
// Functor definitions for Aggregate ops, must be compilable by nvcc.
|
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "tensorflow/core/framework/tensor_types.h"
|
|
namespace tensorflow {
|
namespace functor {
|
|
template <typename Device, typename T>
|
struct Add2Functor {
|
void operator()(const Device& d, typename TTypes<T>::Flat out,
|
typename TTypes<T>::ConstFlat in1,
|
typename TTypes<T>::ConstFlat in2);
|
};
|
|
template <typename Device, typename T>
|
struct Add2EigenImpl {
|
static void Compute(const Device& d, typename TTypes<T>::Flat out,
|
typename TTypes<T>::ConstFlat in1,
|
typename TTypes<T>::ConstFlat in2) {
|
out.device(d) = in1 + in2;
|
}
|
};
|
|
template <typename Device, typename T>
|
struct Add3Functor {
|
void operator()(const Device& d, typename TTypes<T>::Flat out,
|
typename TTypes<T>::ConstFlat in1,
|
typename TTypes<T>::ConstFlat in2,
|
typename TTypes<T>::ConstFlat in3);
|
};
|
|
template <typename Device, typename T>
|
struct Add3EigenImpl {
|
static void Compute(const Device& d, typename TTypes<T>::Flat out,
|
typename TTypes<T>::ConstFlat in1,
|
typename TTypes<T>::ConstFlat in2,
|
typename TTypes<T>::ConstFlat in3) {
|
out.device(d) = in1 + in2 + in3;
|
}
|
};
|
|
template <typename Device, typename T>
|
struct Add4Functor {
|
void operator()(const Device& d, typename TTypes<T>::Flat out,
|
typename TTypes<T>::ConstFlat in1,
|
typename TTypes<T>::ConstFlat in2,
|
typename TTypes<T>::ConstFlat in3,
|
typename TTypes<T>::ConstFlat in4);
|
};
|
|
template <typename Device, typename T>
|
struct Add4EigenImpl {
|
static void Compute(const Device& d, typename TTypes<T>::Flat out,
|
typename TTypes<T>::ConstFlat in1,
|
typename TTypes<T>::ConstFlat in2,
|
typename TTypes<T>::ConstFlat in3,
|
typename TTypes<T>::ConstFlat in4) {
|
out.device(d) = in1 + in2 + in3 + in4;
|
}
|
};
|
|
template <typename Device, typename T>
|
struct Add5Functor {
|
void operator()(const Device& d, typename TTypes<T>::Flat out,
|
typename TTypes<T>::ConstFlat in1,
|
typename TTypes<T>::ConstFlat in2,
|
typename TTypes<T>::ConstFlat in3,
|
typename TTypes<T>::ConstFlat in4,
|
typename TTypes<T>::ConstFlat in5);
|
};
|
|
template <typename Device, typename T>
|
struct Add5EigenImpl {
|
static void Compute(const Device& d, typename TTypes<T>::Flat out,
|
typename TTypes<T>::ConstFlat in1,
|
typename TTypes<T>::ConstFlat in2,
|
typename TTypes<T>::ConstFlat in3,
|
typename TTypes<T>::ConstFlat in4,
|
typename TTypes<T>::ConstFlat in5) {
|
out.device(d) = in1 + in2 + in3 + in4 + in5;
|
}
|
};
|
|
template <typename Device, typename T>
|
struct Add6Functor {
|
void operator()(const Device& d, typename TTypes<T>::Flat out,
|
typename TTypes<T>::ConstFlat in1,
|
typename TTypes<T>::ConstFlat in2,
|
typename TTypes<T>::ConstFlat in3,
|
typename TTypes<T>::ConstFlat in4,
|
typename TTypes<T>::ConstFlat in5,
|
typename TTypes<T>::ConstFlat in6);
|
};
|
|
template <typename Device, typename T>
|
struct Add6EigenImpl {
|
static void Compute(const Device& d, typename TTypes<T>::Flat out,
|
typename TTypes<T>::ConstFlat in1,
|
typename TTypes<T>::ConstFlat in2,
|
typename TTypes<T>::ConstFlat in3,
|
typename TTypes<T>::ConstFlat in4,
|
typename TTypes<T>::ConstFlat in5,
|
typename TTypes<T>::ConstFlat in6) {
|
out.device(d) = in1 + in2 + in3 + in4 + in5 + in6;
|
}
|
};
|
|
template <typename Device, typename T>
|
struct Add7Functor {
|
void operator()(const Device& d, typename TTypes<T>::Flat out,
|
typename TTypes<T>::ConstFlat in1,
|
typename TTypes<T>::ConstFlat in2,
|
typename TTypes<T>::ConstFlat in3,
|
typename TTypes<T>::ConstFlat in4,
|
typename TTypes<T>::ConstFlat in5,
|
typename TTypes<T>::ConstFlat in6,
|
typename TTypes<T>::ConstFlat in7);
|
};
|
|
template <typename Device, typename T>
|
struct Add7EigenImpl {
|
static void Compute(const Device& d, typename TTypes<T>::Flat out,
|
typename TTypes<T>::ConstFlat in1,
|
typename TTypes<T>::ConstFlat in2,
|
typename TTypes<T>::ConstFlat in3,
|
typename TTypes<T>::ConstFlat in4,
|
typename TTypes<T>::ConstFlat in5,
|
typename TTypes<T>::ConstFlat in6,
|
typename TTypes<T>::ConstFlat in7) {
|
out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7;
|
}
|
};
|
|
template <typename Device, typename T>
|
struct Add8Functor {
|
void operator()(
|
const Device& d, typename TTypes<T>::Flat out,
|
typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
|
typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
|
typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
|
typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8);
|
};
|
|
template <typename Device, typename T>
|
struct Add8EigenImpl {
|
static void Compute(
|
const Device& d, typename TTypes<T>::Flat out,
|
typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
|
typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
|
typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
|
typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
|
out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8;
|
}
|
};
|
|
// Add8p is like Add8 except the underlying implementation should +=
|
// rather than assign to the output.
|
template <typename Device, typename T>
|
struct Add8pFunctor {
|
void operator()(
|
const Device& d, typename TTypes<T>::Flat out,
|
typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
|
typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
|
typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
|
typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8);
|
};
|
|
template <typename Device, typename T>
|
struct Add8pEigenImpl {
|
static void Compute(
|
const Device& d, typename TTypes<T>::Flat out,
|
typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
|
typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
|
typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
|
typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
|
out.device(d) += in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8;
|
}
|
};
|
|
template <typename Device, typename T>
|
struct Add9Functor {
|
void operator()(
|
const Device& d, typename TTypes<T>::Flat out,
|
typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
|
typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
|
typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
|
typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8,
|
typename TTypes<T>::ConstFlat in9);
|
};
|
|
template <typename Device, typename T>
|
struct Add9EigenImpl {
|
static void Compute(
|
const Device& d, typename TTypes<T>::Flat out,
|
typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
|
typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
|
typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
|
typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8,
|
typename TTypes<T>::ConstFlat in9) {
|
out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8 + in9;
|
}
|
};
|
|
} // namespace functor
|
} // namespace tensorflow
|
|
#endif // TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_H_
|