Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ struct Expr : public Internal::IRHandle {
Expr(bfloat16_t x)
: IRHandle(Internal::FloatImm::make(BFloat(16), (double)x)) {
}
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
#if HALIDE_CPP_COMPILER_HAS_FLOAT16
explicit Expr(_Float16 x)
: IRHandle(Internal::FloatImm::make(Float(16), (double)x)) {
}
Expand Down
4 changes: 2 additions & 2 deletions src/Float16.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ struct float16_t {
* positive zero.*/
float16_t() = default;

#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
#if HALIDE_CPP_COMPILER_HAS_FLOAT16
/** Construct a float16_t from compiler's built-in _Float16 type. */
explicit float16_t(_Float16 value) {
memcpy(&data, &value, sizeof(_Float16));
Expand All @@ -57,7 +57,7 @@ struct float16_t {
/** Cast to int */
explicit operator int() const;

#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
#if HALIDE_CPP_COMPILER_HAS_FLOAT16
/** Cast to compiler's built-in _Float16 type. */
explicit operator _Float16() const {
_Float16 result;
Expand Down
2 changes: 1 addition & 1 deletion src/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(Halide::float16_t);
HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(Halide::bfloat16_t);
HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(halide_task_t);
HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(halide_loop_task_t);
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
#if HALIDE_CPP_COMPILER_HAS_FLOAT16
HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(_Float16);
#endif
HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(float);
Expand Down
14 changes: 10 additions & 4 deletions src/runtime/HalideRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,21 @@ extern "C" {
#define HALIDE_RUNTIME_ASAN_DETECTED
#endif

#if !defined(HALIDE_CPP_COMPILER_HAS_FLOAT16)
#define HALIDE_CPP_COMPILER_HAS_FLOAT16 0

#if !defined(HALIDE_RUNTIME_ASAN_DETECTED)

// clang had _Float16 added as a reserved name in clang 8, but
// doesn't actually support it on most platforms until clang 15.
// Ideally there would be a better way to detect if the type
// is supported, even in a compiler independent fashion, but
// coming up with one has proven elusive.
#if defined(__clang__) && (__clang_major__ >= 15) && !defined(__EMSCRIPTEN__) && !defined(__i386__)
#if defined(__clang__) && (__clang_major__ >= 15) && !defined(__EMSCRIPTEN__) && !defined(__i386__) && !defined(__wasm__)
#if defined(__is_identifier)
#if !__is_identifier(_Float16)
#define HALIDE_CPP_COMPILER_HAS_FLOAT16
#undef HALIDE_CPP_COMPILER_HAS_FLOAT16
#define HALIDE_CPP_COMPILER_HAS_FLOAT16 1
#endif
#endif
#endif
Expand All @@ -121,11 +125,13 @@ extern "C" {
(defined(__i386__) && (__GNUC__ >= 14) && defined(__SSE2__)) || \
(defined(__arm__) && (__GNUC__ >= 13) && __ARM_FP16_FORMAT_IEEE) || \
(defined(__aarch64__) && (__GNUC__ >= 13))
#define HALIDE_CPP_COMPILER_HAS_FLOAT16
#undef HALIDE_CPP_COMPILER_HAS_FLOAT16
#define HALIDE_CPP_COMPILER_HAS_FLOAT16 1
#endif
#endif

#endif // !HALIDE_RUNTIME_ASAN_DETECTED
#endif // !defined(HALIDE_CPP_COMPILER_HAS_FLOAT16)

#endif // !COMPILING_HALIDE_RUNTIME

Expand Down Expand Up @@ -2152,7 +2158,7 @@ HALIDE_ALWAYS_INLINE constexpr halide_type_t halide_type_of() {
return halide_type_t(halide_type_handle, 64);
}

#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
#if HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
HALIDE_ALWAYS_INLINE constexpr halide_type_t halide_type_of<_Float16>() {
return halide_type_t(halide_type_float, 16);
Expand Down
4 changes: 2 additions & 2 deletions test/correctness/float16_t.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ int run_test() {
std::abs(halfway_plus_eps - (double)to_even));

assert(float(halfway_plus_eps) == halfway);
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
#if HALIDE_CPP_COMPILER_HAS_FLOAT16
assert(_Float16(halfway_plus_eps) == _Float16(float(to_odd)));
#endif
assert(float16_t(halfway_plus_eps) == to_odd);
Expand Down Expand Up @@ -484,7 +484,7 @@ int main(int argc, char **argv) {
}

printf("Testing _Float16...\n");
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
#if HALIDE_CPP_COMPILER_HAS_FLOAT16
if (run_test<_Float16>() != 0) {
fprintf(stderr, "_Float16 test failed!\n");
return 1;
Expand Down
2 changes: 1 addition & 1 deletion test/correctness/image_io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ int main(int argc, char **argv) {
do_test<uint32_t>();
do_test<uint64_t>();
do_test<float>();
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
#if HALIDE_CPP_COMPILER_HAS_FLOAT16
do_test<_Float16>();
#endif
do_test<double>();
Expand Down
28 changes: 14 additions & 14 deletions tools/halide_image_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ template<>
inline bool convert(const int64_t &in) {
return in != 0;
}
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
#if HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline bool convert(const _Float16 &in) {
return (float)in != 0;
Expand Down Expand Up @@ -171,7 +171,7 @@ template<>
inline uint8_t convert(const int64_t &in) {
return convert<uint8_t, uint64_t>(in);
}
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
#if HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline uint8_t convert(const _Float16 &in) {
return (uint8_t)std::lround((float)in * 255.0f);
Expand Down Expand Up @@ -223,7 +223,7 @@ template<>
inline uint16_t convert(const int64_t &in) {
return convert<uint16_t, uint64_t>(in);
}
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
#if HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline uint16_t convert(const _Float16 &in) {
return (uint16_t)std::lround((float)in * 65535.0f);
Expand Down Expand Up @@ -275,7 +275,7 @@ template<>
inline uint32_t convert(const int64_t &in) {
return convert<uint32_t, uint64_t>(in);
}
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
#if HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline uint32_t convert(const _Float16 &in) {
return (uint32_t)std::llround((float)in * 4294967295.0);
Expand Down Expand Up @@ -327,7 +327,7 @@ template<>
inline uint64_t convert(const int64_t &in) {
return convert<uint64_t, uint64_t>(in);
}
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
#if HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline uint64_t convert(const _Float16 &in) {
return convert<uint64_t, uint32_t>((uint32_t)std::llround((float)in * 4294967295.0));
Expand Down Expand Up @@ -379,7 +379,7 @@ template<>
inline int8_t convert(const int64_t &in) {
return convert<uint8_t, int64_t>(in);
}
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
#if HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline int8_t convert(const _Float16 &in) {
return convert<uint8_t, float>((float)in);
Expand Down Expand Up @@ -431,7 +431,7 @@ template<>
inline int16_t convert(const int64_t &in) {
return convert<uint16_t, int64_t>(in);
}
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
#if HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline int16_t convert(const _Float16 &in) {
return convert<uint16_t, float>((float)in);
Expand Down Expand Up @@ -483,7 +483,7 @@ template<>
inline int32_t convert(const int64_t &in) {
return convert<uint32_t, int64_t>(in);
}
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
#if HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline int32_t convert(const _Float16 &in) {
return convert<uint32_t, float>((float)in);
Expand Down Expand Up @@ -535,7 +535,7 @@ template<>
inline int64_t convert(const int64_t &in) {
return convert<uint64_t, int64_t>(in);
}
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
#if HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline int64_t convert(const _Float16 &in) {
return convert<uint64_t, float>((float)in);
Expand All @@ -550,7 +550,7 @@ inline int64_t convert(const double &in) {
return convert<uint64_t, double>(in);
}

#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
#if HALIDE_CPP_COMPILER_HAS_FLOAT16
// Convert to f16
template<>
inline _Float16 convert(const bool &in) {
Expand Down Expand Up @@ -639,7 +639,7 @@ template<>
inline float convert(const int64_t &in) {
return convert<float, uint64_t>(in);
}
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
#if HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline float convert(const _Float16 &in) {
return (float)in;
Expand Down Expand Up @@ -691,7 +691,7 @@ template<>
inline double convert(const int64_t &in) {
return convert<double, uint64_t>(in);
}
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
#if HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline double convert(const _Float16 &in) {
return (double)in;
Expand Down Expand Up @@ -2496,7 +2496,7 @@ struct ImageTypeConversion {

const halide_type_t src_type = src.type();
switch (src_type.element_of().as_u32()) {
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
#if HALIDE_CPP_COMPILER_HAS_FLOAT16
case halide_type_t(halide_type_float, 16).as_u32():
return convert_image<DstElemType>(src.template as<_Float16, AnyDims>());
#endif
Expand Down Expand Up @@ -2545,7 +2545,7 @@ struct ImageTypeConversion {
// Call the appropriate static-to-static conversion routine
// based on the desired dst type.
switch (dst_type.element_of().as_u32()) {
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
#if HALIDE_CPP_COMPILER_HAS_FLOAT16
case halide_type_t(halide_type_float, 16).as_u32():
return convert_image<_Float16>(src);
#endif
Expand Down
Loading