/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
you may not use this file except in compliance with the License.
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
distributed under the License is distributed on an "AS IS" BASIS,
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
See the License for the specific language governing permissions and
|
limitations under the License.
|
==============================================================================*/
|
|
#ifndef TENSORFLOW_STREAM_EXECUTOR_HOST_OR_DEVICE_SCALAR_H_
|
#define TENSORFLOW_STREAM_EXECUTOR_HOST_OR_DEVICE_SCALAR_H_
|
|
#include "tensorflow/stream_executor/device_memory.h"
|
#include "tensorflow/stream_executor/platform/logging.h"
|
|
namespace stream_executor {
|
|
// Allows to represent a value that is either a host scalar or a scalar stored
|
// on the GPU device.
|
template <typename ElemT>
|
class HostOrDeviceScalar {
|
public:
|
// Not marked as explicit because when using this constructor, we usually want
|
// to set this to a compile-time constant.
|
HostOrDeviceScalar(ElemT value) : value_(value), is_pointer_(false) {}
|
explicit HostOrDeviceScalar(const DeviceMemory<ElemT>& pointer)
|
: pointer_(pointer), is_pointer_(true) {
|
CHECK_EQ(1, pointer.ElementCount());
|
}
|
|
bool is_pointer() const { return is_pointer_; }
|
const DeviceMemory<ElemT>& pointer() const {
|
CHECK(is_pointer());
|
return pointer_;
|
}
|
const ElemT& value() const {
|
CHECK(!is_pointer());
|
return value_;
|
}
|
|
private:
|
union {
|
ElemT value_;
|
DeviceMemory<ElemT> pointer_;
|
};
|
bool is_pointer_;
|
};
|
|
} // namespace stream_executor
|
#endif // TENSORFLOW_STREAM_EXECUTOR_HOST_OR_DEVICE_SCALAR_H_
|