Skip to content

Commit 8075463

Browse files
authored
[SYCL] Add support for bfloat16 conversion (#4213)
Signed-off-by: Alexey Sotkin <alexey.sotkin@intel.com>
1 parent 46a6889 commit 8075463

File tree

4 files changed

+203
-0
lines changed

4 files changed

+203
-0
lines changed

sycl/include/CL/__spirv/spirv_ops.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,9 @@ extern SYCL_EXTERNAL void
594594
__spirv_ocl_prefetch(const __attribute__((opencl_global)) char *Ptr,
595595
size_t NumBytes) noexcept;
596596

597+
extern SYCL_EXTERNAL uint16_t __spirv_ConvertFToBF16INTEL(float) noexcept;
598+
extern SYCL_EXTERNAL float __spirv_ConvertBF16ToFINTEL(uint16_t) noexcept;
599+
597600
#else // if !__SYCL_DEVICE_ONLY__
598601

599602
template <typename dataT>

sycl/include/CL/sycl/feature_test.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ namespace sycl {
2323
#ifndef SYCL_EXT_ONEAPI_MATRIX
2424
#define SYCL_EXT_ONEAPI_MATRIX 2
2525
#endif
26+
#define SYCL_EXT_INTEL_BF16_CONVERSION 1
2627

2728
} // namespace sycl
2829
} // __SYCL_INLINE_NAMESPACE(cl)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
//==--------- bfloat16.hpp ------- SYCL bfloat16 conversion ----------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#pragma once
10+
11+
#include <CL/__spirv/spirv_ops.hpp>
12+
13+
__SYCL_INLINE_NAMESPACE(cl) {
14+
namespace sycl {
15+
namespace ext {
16+
namespace intel {
17+
namespace experimental {
18+
19+
class [[sycl_detail::uses_aspects(ext_intel_bf16_conversion)]] bfloat16 {
20+
using storage_t = uint16_t;
21+
storage_t value;
22+
23+
public:
24+
bfloat16() = default;
25+
bfloat16(const bfloat16 &) = default;
26+
~bfloat16() = default;
27+
28+
// Explicit conversion functions
29+
static storage_t from_float(const float &a) {
30+
#if defined(__SYCL_DEVICE_ONLY__)
31+
return __spirv_ConvertFToBF16INTEL(a);
32+
#else
33+
throw exception{errc::feature_not_supported,
34+
"Bfloat16 conversion is not supported on host device"};
35+
#endif
36+
}
37+
static float to_float(const storage_t &a) {
38+
#if defined(__SYCL_DEVICE_ONLY__)
39+
return __spirv_ConvertBF16ToFINTEL(a);
40+
#else
41+
throw exception{errc::feature_not_supported,
42+
"Bfloat16 conversion is not supported on host device"};
43+
#endif
44+
}
45+
46+
// Direct initialization
47+
bfloat16(const storage_t &a) : value(a) {}
48+
49+
// Implicit conversion from float to bfloat16
50+
bfloat16(const float &a) { value = from_float(a); }
51+
52+
bfloat16 &operator=(const float &rhs) {
53+
value = from_float(rhs);
54+
return *this;
55+
}
56+
57+
// Implicit conversion from bfloat16 to float
58+
operator float() const { return to_float(value); }
59+
60+
// Get raw bits representation of bfloat16
61+
operator storage_t() const { return value; }
62+
63+
// Logical operators (!,||,&&) are covered if we can cast to bool
64+
explicit operator bool() { return to_float(value) != 0.0f; }
65+
66+
// Unary minus operator overloading
67+
friend bfloat16 operator-(bfloat16 &lhs) {
68+
return bfloat16{-to_float(lhs.value)};
69+
}
70+
71+
// Increment and decrement operators overloading
72+
#define OP(op) \
73+
friend bfloat16 &operator op(bfloat16 &lhs) { \
74+
float f = to_float(lhs.value); \
75+
lhs.value = from_float(op f); \
76+
return lhs; \
77+
} \
78+
friend bfloat16 operator op(bfloat16 &lhs, int) { \
79+
bfloat16 old = lhs; \
80+
operator op(lhs); \
81+
return old; \
82+
}
83+
OP(++)
84+
OP(--)
85+
#undef OP
86+
87+
// Assignment operators overloading
88+
#define OP(op) \
89+
friend bfloat16 &operator op(bfloat16 &lhs, const bfloat16 &rhs) { \
90+
float f = static_cast<float>(lhs); \
91+
f op static_cast<float>(rhs); \
92+
return lhs = f; \
93+
} \
94+
template <typename T> \
95+
friend bfloat16 &operator op(bfloat16 &lhs, const T &rhs) { \
96+
float f = static_cast<float>(lhs); \
97+
f op static_cast<float>(rhs); \
98+
return lhs = f; \
99+
} \
100+
template <typename T> friend T &operator op(T &lhs, const bfloat16 &rhs) { \
101+
float f = static_cast<float>(lhs); \
102+
f op static_cast<float>(rhs); \
103+
return lhs = f; \
104+
}
105+
OP(+=)
106+
OP(-=)
107+
OP(*=)
108+
OP(/=)
109+
#undef OP
110+
111+
// Binary operators overloading
112+
#define OP(type, op) \
113+
friend type operator op(const bfloat16 &lhs, const bfloat16 &rhs) { \
114+
return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
115+
} \
116+
template <typename T> \
117+
friend type operator op(const bfloat16 &lhs, const T &rhs) { \
118+
return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
119+
} \
120+
template <typename T> \
121+
friend type operator op(const T &lhs, const bfloat16 &rhs) { \
122+
return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
123+
}
124+
OP(bfloat16, +)
125+
OP(bfloat16, -)
126+
OP(bfloat16, *)
127+
OP(bfloat16, /)
128+
OP(bool, ==)
129+
OP(bool, !=)
130+
OP(bool, <)
131+
OP(bool, >)
132+
OP(bool, <=)
133+
OP(bool, >=)
134+
#undef OP
135+
136+
// Bitwise(|,&,~,^), modulo(%) and shift(<<,>>) operations are not supported
137+
// for floating-point types.
138+
};
139+
140+
} // namespace experimental
141+
} // namespace intel
142+
} // namespace ext
143+
144+
namespace __SYCL2020_DEPRECATED("use 'ext::intel' instead") INTEL {
145+
using namespace ext::intel;
146+
}
147+
} // namespace sycl
148+
} // __SYCL_INLINE_NAMESPACE(cl)

sycl/test/extensions/bfloat16.cpp

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// RUN: %clangxx -fsycl-device-only -S -Xclang -emit-llvm %s -o - | FileCheck %s
2+
3+
#include <sycl/sycl.hpp>
4+
#include <sycl/ext/intel/experimental/bfloat16.hpp>
5+
6+
using sycl::ext::intel::experimental::bfloat16;
7+
8+
SYCL_EXTERNAL uint16_t some_bf16_intrinsic(uint16_t x, uint16_t y);
9+
10+
__attribute__((noinline))
11+
float op(float a, float b) {
12+
bfloat16 A {a};
13+
// CHECK: [[A:%.*]] = tail call spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float %a)
14+
// CHECK-NOT: fptoui
15+
16+
bfloat16 B {b};
17+
// CHECK: [[B:%.*]] = tail call spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float %b)
18+
// CHECK-NOT: fptoui
19+
20+
bfloat16 C = A + B;
21+
// CHECK: [[A_float:%.*]] = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELt(i16 zeroext [[A]])
22+
// CHECK: [[B_float:%.*]] = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELt(i16 zeroext [[B]])
23+
// CHECK: [[Add:%.*]] = fadd float [[A_float]], [[B_float]]
24+
// CHECK: [[C:%.*]] = tail call spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float [[Add]])
25+
// CHECK-NOT: uitofp
26+
// CHECK-NOT: fptoui
27+
28+
bfloat16 D = some_bf16_intrinsic(A, C);
29+
// CHECK: [[D:%.*]] = tail call spir_func zeroext i16 @_Z19some_bf16_intrinsictt(i16 zeroext [[A]], i16 zeroext [[C]])
30+
// CHECK-NOT: uitofp
31+
// CHECK-NOT: fptoui
32+
33+
return D;
34+
// CHECK: [[RetVal:%.*]] = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELt(i16 zeroext [[D]])
35+
// CHECK: ret float [[RetVal]]
36+
// CHECK-NOT: uitofp
37+
// CHECK-NOT: fptoui
38+
}
39+
40+
int main(int argc, char *argv[]) {
41+
float data[3] = {7.0, 8.1, 0.0};
42+
cl::sycl::queue deviceQueue;
43+
cl::sycl::buffer<float, 1> buf{data, cl::sycl::range<1>{3}};
44+
45+
deviceQueue.submit([&](cl::sycl::handler &cgh) {
46+
auto numbers = buf.get_access<cl::sycl::access::mode::read_write>(cgh);
47+
cgh.single_task<class simple_kernel>(
48+
[=]() { numbers[2] = op(numbers[0], numbers[1]); });
49+
});
50+
return 0;
51+
}

0 commit comments

Comments
 (0)