Files
stupidsimcpp/util/fp8.hpp
2025-12-05 16:01:00 -05:00

268 lines
8.1 KiB
C++

// fp8.hpp
#pragma once
#include <cstdint>
#include <cstring>
#include <cmath>
#include <type_traits>
#ifdef __CUDACC__
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#endif
class fp8_e4m3 {
private:
uint8_t data;
public:
// Constructors
__host__ __device__ fp8_e4m3() : data(0) {}
__host__ __device__ explicit fp8_e4m3(uint8_t val) : data(val) {}
// Conversion from float32
__host__ __device__ explicit fp8_e4m3(float f) {
#ifdef __CUDACC__
data = float_to_fp8(f);
#else
data = cpu_float_to_fp8(f);
#endif
}
// Conversion from float16 (CUDA only)
#ifdef __CUDACC__
__host__ __device__ explicit fp8_e4m3(__half h) {
data = half_to_fp8(h);
}
#endif
// Conversion to float32
__host__ __device__ operator float() const {
#ifdef __CUDACC__
return fp8_to_float(data);
#else
return cpu_fp8_to_float(data);
#endif
}
// Arithmetic operators
__host__ __device__ fp8_e4m3 operator+(const fp8_e4m3& other) const {
return fp8_e4m3(float(*this) + float(other));
}
__host__ __device__ fp8_e4m3 operator-(const fp8_e4m3& other) const {
return fp8_e4m3(float(*this) - float(other));
}
__host__ __device__ fp8_e4m3 operator*(const fp8_e4m3& other) const {
return fp8_e4m3(float(*this) * float(other));
}
__host__ __device__ fp8_e4m3 operator/(const fp8_e4m3& other) const {
return fp8_e4m3(float(*this) / float(other));
}
// Compound assignment operators
__host__ __device__ fp8_e4m3& operator+=(const fp8_e4m3& other) {
*this = fp8_e4m3(float(*this) + float(other));
return *this;
}
__host__ __device__ fp8_e4m3& operator-=(const fp8_e4m3& other) {
*this = fp8_e4m3(float(*this) - float(other));
return *this;
}
__host__ __device__ fp8_e4m3& operator*=(const fp8_e4m3& other) {
*this = fp8_e4m3(float(*this) * float(other));
return *this;
}
__host__ __device__ fp8_e4m3& operator/=(const fp8_e4m3& other) {
*this = fp8_e4m3(float(*this) / float(other));
return *this;
}
// Comparison operators
__host__ __device__ bool operator==(const fp8_e4m3& other) const {
// Handle NaN and ±0.0 cases
if ((data & 0x7F) == 0x7F) return false; // NaN
if (data == other.data) return true;
return false;
}
__host__ __device__ bool operator!=(const fp8_e4m3& other) const {
return !(*this == other);
}
__host__ __device__ bool operator<(const fp8_e4m3& other) const {
return float(*this) < float(other);
}
__host__ __device__ bool operator>(const fp8_e4m3& other) const {
return float(*this) > float(other);
}
__host__ __device__ bool operator<=(const fp8_e4m3& other) const {
return float(*this) <= float(other);
}
__host__ __device__ bool operator>=(const fp8_e4m3& other) const {
return float(*this) >= float(other);
}
// Get raw data
__host__ __device__ uint8_t get_raw() const { return data; }
// Special values
__host__ __device__ static fp8_e4m3 zero() { return fp8_e4m3(0x00); }
__host__ __device__ static fp8_e4m3 one() { return fp8_e4m3(0x3C); } // 1.0
__host__ __device__ static fp8_e4m3 nan() { return fp8_e4m3(0x7F); }
__host__ __device__ static fp8_e4m3 inf() { return fp8_e4m3(0x78); } // +inf
__host__ __device__ static fp8_e4m3 neg_inf() { return fp8_e4m3(0xF8); } // -inf
// Memory operations
__host__ __device__ static void memcpy(void* dst, const void* src, size_t count) {
::memcpy(dst, src, count);
}
__host__ __device__ static void memset(void* ptr, int value, size_t count) {
::memset(ptr, value, count);
}
private:
// CPU implementation (fast bit manipulation)
__host__ __device__ static uint8_t cpu_float_to_fp8(float f) {
uint32_t f_bits;
memcpy(&f_bits, &f, sizeof(float));
uint32_t sign = (f_bits >> 31) & 0x1;
int32_t exp = ((f_bits >> 23) & 0xFF) - 127;
uint32_t mantissa = f_bits & 0x7FFFFF;
// Handle special cases
if (exp == 128) { // NaN or Inf
return (sign << 7) | 0x7F; // Preserve sign for NaN/Inf
}
// Denormal handling
if (exp < -6) {
return sign << 7; // Underflow to zero
}
// Clamp exponent to e4m3 range [-6, 7]
if (exp > 7) {
return (sign << 7) | 0x78; // Overflow to inf
}
// Convert to fp8 format
uint32_t fp8_exp = (exp + 6) & 0xF; // Bias: -6 -> 0, 7 -> 13
uint32_t fp8_mant = mantissa >> 20; // Keep top 3 bits
// Round to nearest even
uint32_t rounding_bit = (mantissa >> 19) & 1;
uint32_t sticky_bits = (mantissa & 0x7FFFF) ? 1 : 0;
if (rounding_bit && (fp8_mant & 1 || sticky_bits)) {
fp8_mant++;
if (fp8_mant > 0x7) { // Mantissa overflow
fp8_mant = 0;
fp8_exp++;
if (fp8_exp > 0xF) { // Exponent overflow
return (sign << 7) | 0x78; // Infinity
}
}
}
return (sign << 7) | (fp8_exp << 3) | (fp8_mant & 0x7);
}
__host__ __device__ static float cpu_fp8_to_float(uint8_t fp8) {
uint32_t sign = (fp8 >> 7) & 0x1;
uint32_t exp = (fp8 >> 3) & 0xF;
uint32_t mant = fp8 & 0x7;
// Handle special cases
if (exp == 0xF) { // NaN or Inf
uint32_t f_bits = (sign << 31) | (0xFF << 23) | (mant << 20);
float result;
memcpy(&result, &f_bits, sizeof(float));
return result;
}
if (exp == 0) {
// Denormal/subnormal
if (mant == 0) return sign ? -0.0f : 0.0f;
// Convert denormal
exp = -6;
mant = mant << 1;
} else {
exp -= 6; // Remove bias
}
// Convert to float32
uint32_t f_exp = (exp + 127) & 0xFF;
uint32_t f_mant = mant << 20;
uint32_t f_bits = (sign << 31) | (f_exp << 23) | f_mant;
float result;
memcpy(&result, &f_bits, sizeof(float));
return result;
}
// CUDA implementation (using intrinsics when available)
#ifdef __CUDACC__
__device__ static uint8_t float_to_fp8(float f) {
#if __CUDA_ARCH__ >= 890 // Hopper+ has native FP8 support
return __float_to_fp8_rn(f);
#else
return cpu_float_to_fp8(f);
#endif
}
__device__ static float fp8_to_float(uint8_t fp8) {
#if __CUDA_ARCH__ >= 890
return __fp8_to_float(fp8);
#else
return cpu_fp8_to_float(fp8);
#endif
}
__device__ static uint8_t half_to_fp8(__half h) {
return float_to_fp8(__half2float(h));
}
#else
// For non-CUDA, use CPU versions
__host__ __device__ static uint8_t float_to_fp8(float f) {
return cpu_float_to_fp8(f);
}
__host__ __device__ static float fp8_to_float(uint8_t fp8) {
return cpu_fp8_to_float(fp8);
}
#endif
};
// Vectorized operations for performance
namespace fp8_ops {
// Convert array of floats to fp8 (efficient batch conversion)
static void convert_float_to_fp8(uint8_t* dst, const float* src, size_t count) {
#pragma omp parallel for simd if(count > 1024)
for (size_t i = 0; i < count; ++i) {
dst[i] = fp8_e4m3(src[i]).get_raw();
}
}
// Convert array of fp8 to floats
static void convert_fp8_to_float(float* dst, const uint8_t* src, size_t count) {
#pragma omp parallel for simd if(count > 1024)
for (size_t i = 0; i < count; ++i) {
dst[i] = fp8_e4m3(src[i]);
}
}
// Direct memory operations
static void memset_fp8(void* ptr, fp8_e4m3 value, size_t count) {
uint8_t val = value.get_raw();
::memset(ptr, val, count);
}
}