/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
you may not use this file except in compliance with the License.
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
distributed under the License is distributed on an "AS IS" BASIS,
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
See the License for the specific language governing permissions and
|
limitations under the License.
|
==============================================================================*/
|
|
#ifndef TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_RNG_H_
|
#define TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_RNG_H_
|
|
#include "tensorflow/stream_executor/platform/mutex.h"
|
#include "tensorflow/stream_executor/platform/port.h"
|
#include "tensorflow/stream_executor/platform/thread_annotations.h"
|
#include "tensorflow/stream_executor/plugin_registry.h"
|
#include "tensorflow/stream_executor/rng.h"
|
|
#include "tensorflow/stream_executor/gpu/gpu_types.h"
|
|
namespace stream_executor {
|
|
class Stream;
|
template <typename ElemT>
|
class DeviceMemory;
|
|
namespace gpu {
|
|
// Opaque and unique identifier for the GPU RNG plugin.
|
extern const PluginId kGpuRandPlugin;
|
|
class GpuExecutor;
|
|
// GPU-platform implementation of the random number generation support
|
// interface.
|
//
|
// Thread-safe post-initialization.
|
class GpuRng : public rng::RngSupport {
|
public:
|
explicit GpuRng(GpuExecutor* parent);
|
|
// Retrieves a gpu rng library generator handle. This is necessary for
|
// enqueuing random number generation work onto the device.
|
// TODO(leary) provide a way for users to select the RNG algorithm.
|
bool Init();
|
|
// Releases a gpu rng library generator handle, if one was acquired.
|
~GpuRng() override;
|
|
// See rng::RngSupport for details on the following overrides.
|
bool DoPopulateRandUniform(Stream* stream, DeviceMemory<float>* v) override;
|
bool DoPopulateRandUniform(Stream* stream, DeviceMemory<double>* v) override;
|
bool DoPopulateRandUniform(Stream* stream,
|
DeviceMemory<std::complex<float>>* v) override;
|
bool DoPopulateRandUniform(Stream* stream,
|
DeviceMemory<std::complex<double>>* v) override;
|
bool DoPopulateRandGaussian(Stream* stream, float mean, float stddev,
|
DeviceMemory<float>* v) override;
|
bool DoPopulateRandGaussian(Stream* stream, double mean, double stddev,
|
DeviceMemory<double>* v) override;
|
|
bool SetSeed(Stream* stream, const uint8* seed, uint64 seed_bytes) override;
|
|
private:
|
// Actually performs the work of generating random numbers - the public
|
// methods are thin wrappers to this interface.
|
template <typename T>
|
bool DoPopulateRandUniformInternal(Stream* stream, DeviceMemory<T>* v);
|
template <typename ElemT, typename FuncT>
|
bool DoPopulateRandGaussianInternal(Stream* stream, ElemT mean, ElemT stddev,
|
DeviceMemory<ElemT>* v, FuncT func);
|
|
// Sets the stream for the internal gpu rng generator.
|
//
|
// This is a stateful operation, as the handle can only have one stream set at
|
// a given time, so it is usually performed right before enqueuing work to do
|
// with random number generation.
|
bool SetStream(Stream* stream) EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
|
// mutex that guards the gpu rng library handle for this device.
|
mutex mu_;
|
|
// GpuExecutor which instantiated this GpuRng.
|
// Immutable post-initialization.
|
GpuExecutor* parent_;
|
|
// gpu rng library handle on the device.
|
GpuRngHandle rng_ GUARDED_BY(mu_);
|
|
SE_DISALLOW_COPY_AND_ASSIGN(GpuRng);
|
};
|
|
template <typename T>
|
string TypeString();
|
|
template <>
|
string TypeString<float>() {
|
return "float";
|
}
|
|
template <>
|
string TypeString<double>() {
|
return "double";
|
}
|
|
template <>
|
string TypeString<std::complex<float>>() {
|
return "std::complex<float>";
|
}
|
|
template <>
|
string TypeString<std::complex<double>>() {
|
return "std::complex<double>";
|
}
|
|
} // namespace gpu
|
} // namespace stream_executor
|
|
#endif // TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_RNG_H_
|