/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
you may not use this file except in compliance with the License.
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
distributed under the License is distributed on an "AS IS" BASIS,
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
See the License for the specific language governing permissions and
|
limitations under the License.
|
==============================================================================*/
|
|
#include "tensorflow/core/framework/bfloat16.h"
|
|
#include "absl/base/casts.h"
|
#include "tensorflow/core/framework/numeric_types.h"
|
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test_benchmark.h"
|
|
namespace tensorflow {
|
namespace {
|
|
TEST(Bfloat16Test, DefaultValueIsZero) {
|
EXPECT_EQ(0.0f, static_cast<float>(bfloat16()));
|
}
|
|
TEST(Bfloat16Test, RepresentableFloatsRoundTripViaBfloat16) {
|
const std::vector<float> values = {
|
-std::numeric_limits<float>::infinity(), -1.0, -0.5, -0.0, 0.0, 0.5, 1.0,
|
std::numeric_limits<float>::infinity(),
|
};
|
for (float v : values) {
|
EXPECT_EQ(v, static_cast<float>(static_cast<bfloat16>(v)));
|
}
|
}
|
|
TEST(Bfloat16Test, Simple) {
|
bfloat16 a(12);
|
// Floating point representation of 12: 0x41400000
|
EXPECT_EQ(0x4140, a.value);
|
}
|
|
float BinaryToFloat(uint32_t sign, uint32_t exponent, uint32_t high_mantissa,
|
uint32_t low_mantissa) {
|
return absl::bit_cast<float>((sign << 31) + (exponent << 23) +
|
(high_mantissa << 16) + low_mantissa);
|
}
|
|
struct Bfloat16TestParam {
|
float input;
|
float expected_truncation;
|
float expected_rounding;
|
};
|
|
class Bfloat16Test : public ::testing::Test,
|
public ::testing::WithParamInterface<Bfloat16TestParam> {};
|
|
TEST_P(Bfloat16Test, TruncateTest) {
|
bfloat16 truncated = bfloat16::truncate_to_bfloat16((GetParam().input));
|
|
if (std::isnan(GetParam().input)) {
|
EXPECT_TRUE(std::isnan(float(truncated)) || std::isinf(float(truncated)));
|
return;
|
}
|
EXPECT_EQ(GetParam().expected_truncation, float(truncated));
|
|
bfloat16 rounded = bfloat16::round_to_bfloat16((GetParam().input));
|
if (std::isnan(GetParam().input)) {
|
EXPECT_TRUE(std::isnan(float(rounded)) || std::isinf(float(rounded)));
|
return;
|
}
|
EXPECT_EQ(GetParam().expected_rounding, float(rounded));
|
}
|
|
INSTANTIATE_TEST_SUITE_P(
|
Bfloat16Test_Instantiation, Bfloat16Test,
|
::testing::Values(
|
Bfloat16TestParam{
|
BinaryToFloat(0, 0b10000000, 0b1001000, 0b1111010111000011),
|
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000),
|
BinaryToFloat(0, 0b10000000, 0b1001001, 0b0000000000000000)},
|
Bfloat16TestParam{
|
BinaryToFloat(1, 0b10000000, 0b1001000, 0b1111010111000011),
|
BinaryToFloat(1, 0b10000000, 0b1001000, 0b0000000000000000),
|
BinaryToFloat(1, 0b10000000, 0b1001001, 0b0000000000000000)},
|
Bfloat16TestParam{
|
BinaryToFloat(0, 0b10000000, 0b1001000, 0b1000000000000000),
|
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000),
|
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
|
Bfloat16TestParam{
|
BinaryToFloat(0, 0b11111111, 0b0000000, 0b0000000000000001),
|
BinaryToFloat(0, 0b11111111, 0b0000000, 0b0000000000000000),
|
BinaryToFloat(0, 0b11111111, 0b1000000, 0b0000000000000000)},
|
Bfloat16TestParam{
|
BinaryToFloat(0, 0b11111111, 0b1111111, 0b1111111111111111),
|
BinaryToFloat(0, 0b11111111, 0b1111111, 0b0000000000000000),
|
BinaryToFloat(0, 0b11111111, 0b1000000, 0b0000000000000000)},
|
Bfloat16TestParam{
|
BinaryToFloat(1, 0b10000000, 0b1001000, 0b1100000000000000),
|
BinaryToFloat(1, 0b10000000, 0b1001000, 0b0000000000000000),
|
BinaryToFloat(1, 0b10000000, 0b1001001, 0b0000000000000000)},
|
Bfloat16TestParam{
|
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000),
|
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000),
|
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
|
Bfloat16TestParam{
|
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0100000000000000),
|
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000),
|
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
|
Bfloat16TestParam{
|
BinaryToFloat(0, 0b10000000, 0b1001000, 0b1000000000000000),
|
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000),
|
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
|
Bfloat16TestParam{
|
BinaryToFloat(0, 0b00000000, 0b1001000, 0b1000000000000000),
|
BinaryToFloat(0, 0b00000000, 0b1001000, 0b0000000000000000),
|
BinaryToFloat(0, 0b00000000, 0b1001000, 0b0000000000000000)},
|
Bfloat16TestParam{
|
BinaryToFloat(0, 0b00000000, 0b1111111, 0b1100000000000000),
|
BinaryToFloat(0, 0b00000000, 0b1111111, 0b0000000000000000),
|
BinaryToFloat(0, 0b00000001, 0b0000000, 0b0000000000000000)}));
|
|
TEST(Bfloat16Test, Conversion) {
|
float a[100];
|
for (int i = 0; i < 100; ++i) {
|
a[i] = i + 1.25;
|
}
|
bfloat16 b[100];
|
float c[100];
|
FloatToBFloat16(a, b, 100);
|
BFloat16ToFloat(b, c, 100);
|
for (int i = 0; i < 100; ++i) {
|
// The relative error should be less than 1/(2^7) since bfloat16
|
// has 7 bits mantissa.
|
EXPECT_LE(fabs(c[i] - a[i]) / a[i], 1.0 / 128);
|
}
|
}
|
|
TEST(Bfloat16Test, Epsilon) {
|
EXPECT_LT(1.0f, static_cast<float>(bfloat16::epsilon() + bfloat16(1.0f)));
|
EXPECT_EQ(1.0f, static_cast<float>((bfloat16::epsilon() / bfloat16(2.0f)) +
|
bfloat16(1.0f)));
|
}
|
|
TEST(Bfloat16Test, Negate) {
|
EXPECT_EQ(-3.0f, static_cast<float>(-bfloat16(3.0f)));
|
EXPECT_EQ(4.5f, static_cast<float>(-bfloat16(-4.5f)));
|
}
|
|
static void BM_FloatToBFloat16(int iters) {
|
testing::StopTiming();
|
static const int N = 32 << 20;
|
const int64 tot = static_cast<int64>(iters) * N;
|
testing::ItemsProcessed(tot);
|
testing::BytesProcessed(tot * (sizeof(float) + sizeof(bfloat16)));
|
|
float* inp = new float[N];
|
bfloat16* out = new bfloat16[N];
|
|
testing::StartTiming();
|
while (iters--) {
|
FloatToBFloat16(inp, out, N);
|
}
|
delete[] inp;
|
delete[] out;
|
}
|
BENCHMARK(BM_FloatToBFloat16);
|
|
static void BM_BFloat16ToFloat(int iters) {
|
testing::StopTiming();
|
static const int N = 32 << 20;
|
const int64 tot = static_cast<int64>(iters) * N;
|
testing::ItemsProcessed(tot);
|
testing::BytesProcessed(tot * (sizeof(float) + sizeof(bfloat16)));
|
|
bfloat16* inp = new bfloat16[N];
|
float* out = new float[N];
|
|
testing::StartTiming();
|
while (iters--) {
|
BFloat16ToFloat(inp, out, N);
|
}
|
delete[] inp;
|
delete[] out;
|
}
|
BENCHMARK(BM_BFloat16ToFloat);
|
|
} // namespace
|
} // namespace tensorflow
|