// fp8.hpp #pragma once #include #include #include #include #ifdef __CUDACC__ #include #include #endif class fp8_e4m3 { private: uint8_t data; public: // Constructors __host__ __device__ fp8_e4m3() : data(0) {} __host__ __device__ explicit fp8_e4m3(uint8_t val) : data(val) {} // Conversion from float32 __host__ __device__ explicit fp8_e4m3(float f) { #ifdef __CUDACC__ data = float_to_fp8(f); #else data = cpu_float_to_fp8(f); #endif } // Conversion from float16 (CUDA only) #ifdef __CUDACC__ __host__ __device__ explicit fp8_e4m3(__half h) { data = half_to_fp8(h); } #endif // Conversion to float32 __host__ __device__ operator float() const { #ifdef __CUDACC__ return fp8_to_float(data); #else return cpu_fp8_to_float(data); #endif } // Arithmetic operators __host__ __device__ fp8_e4m3 operator+(const fp8_e4m3& other) const { return fp8_e4m3(float(*this) + float(other)); } __host__ __device__ fp8_e4m3 operator-(const fp8_e4m3& other) const { return fp8_e4m3(float(*this) - float(other)); } __host__ __device__ fp8_e4m3 operator*(const fp8_e4m3& other) const { return fp8_e4m3(float(*this) * float(other)); } __host__ __device__ fp8_e4m3 operator/(const fp8_e4m3& other) const { return fp8_e4m3(float(*this) / float(other)); } // Compound assignment operators __host__ __device__ fp8_e4m3& operator+=(const fp8_e4m3& other) { *this = fp8_e4m3(float(*this) + float(other)); return *this; } __host__ __device__ fp8_e4m3& operator-=(const fp8_e4m3& other) { *this = fp8_e4m3(float(*this) - float(other)); return *this; } __host__ __device__ fp8_e4m3& operator*=(const fp8_e4m3& other) { *this = fp8_e4m3(float(*this) * float(other)); return *this; } __host__ __device__ fp8_e4m3& operator/=(const fp8_e4m3& other) { *this = fp8_e4m3(float(*this) / float(other)); return *this; } // Comparison operators __host__ __device__ bool operator==(const fp8_e4m3& other) const { // Handle NaN and ±0.0 cases if ((data & 0x7F) == 0x7F) return false; // NaN if (data == other.data) return true; return false; } __host__ __device__ bool operator!=(const fp8_e4m3& other) const { return !(*this == other); } __host__ __device__ bool operator<(const fp8_e4m3& other) const { return float(*this) < float(other); } __host__ __device__ bool operator>(const fp8_e4m3& other) const { return float(*this) > float(other); } __host__ __device__ bool operator<=(const fp8_e4m3& other) const { return float(*this) <= float(other); } __host__ __device__ bool operator>=(const fp8_e4m3& other) const { return float(*this) >= float(other); } // Get raw data __host__ __device__ uint8_t get_raw() const { return data; } // Special values __host__ __device__ static fp8_e4m3 zero() { return fp8_e4m3(0x00); } __host__ __device__ static fp8_e4m3 one() { return fp8_e4m3(0x3C); } // 1.0 __host__ __device__ static fp8_e4m3 nan() { return fp8_e4m3(0x7F); } __host__ __device__ static fp8_e4m3 inf() { return fp8_e4m3(0x78); } // +inf __host__ __device__ static fp8_e4m3 neg_inf() { return fp8_e4m3(0xF8); } // -inf // Memory operations __host__ __device__ static void memcpy(void* dst, const void* src, size_t count) { ::memcpy(dst, src, count); } __host__ __device__ static void memset(void* ptr, int value, size_t count) { ::memset(ptr, value, count); } private: // CPU implementation (fast bit manipulation) __host__ __device__ static uint8_t cpu_float_to_fp8(float f) { uint32_t f_bits; memcpy(&f_bits, &f, sizeof(float)); uint32_t sign = (f_bits >> 31) & 0x1; int32_t exp = ((f_bits >> 23) & 0xFF) - 127; uint32_t mantissa = f_bits & 0x7FFFFF; // Handle special cases if (exp == 128) { // NaN or Inf return (sign << 7) | 0x7F; // Preserve sign for NaN/Inf } // Denormal handling if (exp < -6) { return sign << 7; // Underflow to zero } // Clamp exponent to e4m3 range [-6, 7] if (exp > 7) { return (sign << 7) | 0x78; // Overflow to inf } // Convert to fp8 format uint32_t fp8_exp = (exp + 6) & 0xF; // Bias: -6 -> 0, 7 -> 13 uint32_t fp8_mant = mantissa >> 20; // Keep top 3 bits // Round to nearest even uint32_t rounding_bit = (mantissa >> 19) & 1; uint32_t sticky_bits = (mantissa & 0x7FFFF) ? 1 : 0; if (rounding_bit && (fp8_mant & 1 || sticky_bits)) { fp8_mant++; if (fp8_mant > 0x7) { // Mantissa overflow fp8_mant = 0; fp8_exp++; if (fp8_exp > 0xF) { // Exponent overflow return (sign << 7) | 0x78; // Infinity } } } return (sign << 7) | (fp8_exp << 3) | (fp8_mant & 0x7); } __host__ __device__ static float cpu_fp8_to_float(uint8_t fp8) { uint32_t sign = (fp8 >> 7) & 0x1; uint32_t exp = (fp8 >> 3) & 0xF; uint32_t mant = fp8 & 0x7; // Handle special cases if (exp == 0xF) { // NaN or Inf uint32_t f_bits = (sign << 31) | (0xFF << 23) | (mant << 20); float result; memcpy(&result, &f_bits, sizeof(float)); return result; } if (exp == 0) { // Denormal/subnormal if (mant == 0) return sign ? -0.0f : 0.0f; // Convert denormal exp = -6; mant = mant << 1; } else { exp -= 6; // Remove bias } // Convert to float32 uint32_t f_exp = (exp + 127) & 0xFF; uint32_t f_mant = mant << 20; uint32_t f_bits = (sign << 31) | (f_exp << 23) | f_mant; float result; memcpy(&result, &f_bits, sizeof(float)); return result; } // CUDA implementation (using intrinsics when available) #ifdef __CUDACC__ __device__ static uint8_t float_to_fp8(float f) { #if __CUDA_ARCH__ >= 890 // Hopper+ has native FP8 support return __float_to_fp8_rn(f); #else return cpu_float_to_fp8(f); #endif } __device__ static float fp8_to_float(uint8_t fp8) { #if __CUDA_ARCH__ >= 890 return __fp8_to_float(fp8); #else return cpu_fp8_to_float(fp8); #endif } __device__ static uint8_t half_to_fp8(__half h) { return float_to_fp8(__half2float(h)); } #else // For non-CUDA, use CPU versions __host__ __device__ static uint8_t float_to_fp8(float f) { return cpu_float_to_fp8(f); } __host__ __device__ static float fp8_to_float(uint8_t fp8) { return cpu_fp8_to_float(fp8); } #endif }; // Vectorized operations for performance namespace fp8_ops { // Convert array of floats to fp8 (efficient batch conversion) static void convert_float_to_fp8(uint8_t* dst, const float* src, size_t count) { #pragma omp parallel for simd if(count > 1024) for (size_t i = 0; i < count; ++i) { dst[i] = fp8_e4m3(src[i]).get_raw(); } } // Convert array of fp8 to floats static void convert_fp8_to_float(float* dst, const uint8_t* src, size_t count) { #pragma omp parallel for simd if(count > 1024) for (size_t i = 0; i < count; ++i) { dst[i] = fp8_e4m3(src[i]); } } // Direct memory operations static void memset_fp8(void* ptr, fp8_e4m3 value, size_t count) { uint8_t val = value.get_raw(); ::memset(ptr, val, count); } }