Skip to content

Commit 3b0c33c

Browse files
Green-SkySkutteOleg
authored andcommitted
add inplace conversion support for f8_e4m3 (leejet#359)
in the same way it is done for bf16 like how bf16 converts losslessly to fp32, f8_e4m3 converts losslessly to fp16
1 parent a86f694 commit 3b0c33c

File tree

2 files changed

+72
-3
lines changed

2 files changed

+72
-3
lines changed

model.cpp

+65-1
Original file line numberDiff line numberDiff line change
@@ -554,13 +554,62 @@ float bf16_to_f32(uint16_t bfloat16) {
554554
return *reinterpret_cast<float*>(&val_bits);
555555
}
556556

557+
uint16_t f8_e4m3_to_f16(uint8_t f8) {
558+
// do we need to support uz?
559+
560+
const uint32_t exponent_bias = 7;
561+
if (f8 == 0xff) {
562+
return ggml_fp32_to_fp16(-NAN);
563+
} else if (f8 == 0x7f) {
564+
return ggml_fp32_to_fp16(NAN);
565+
}
566+
567+
uint32_t sign = f8 & 0x80;
568+
uint32_t exponent = (f8 & 0x78) >> 3;
569+
uint32_t mantissa = f8 & 0x07;
570+
uint32_t result = sign << 24;
571+
if (exponent == 0) {
572+
if (mantissa > 0) {
573+
exponent = 0x7f - exponent_bias;
574+
575+
// yes, 2 times
576+
if ((mantissa & 0x04) == 0) {
577+
mantissa &= 0x03;
578+
mantissa <<= 1;
579+
exponent -= 1;
580+
}
581+
if ((mantissa & 0x04) == 0) {
582+
mantissa &= 0x03;
583+
mantissa <<= 1;
584+
exponent -= 1;
585+
}
586+
587+
result |= (mantissa & 0x03) << 21;
588+
result |= exponent << 23;
589+
}
590+
} else {
591+
result |= mantissa << 20;
592+
exponent += 0x7f - exponent_bias;
593+
result |= exponent << 23;
594+
}
595+
596+
return ggml_fp32_to_fp16(*reinterpret_cast<const float*>(&result));
597+
}
598+
557599
void bf16_to_f32_vec(uint16_t* src, float* dst, int64_t n) {
558600
// support inplace op
559601
for (int64_t i = n - 1; i >= 0; i--) {
560602
dst[i] = bf16_to_f32(src[i]);
561603
}
562604
}
563605

606+
void f8_e4m3_to_f16_vec(uint8_t* src, uint16_t* dst, int64_t n) {
607+
// support inplace op
608+
for (int64_t i = n - 1; i >= 0; i--) {
609+
dst[i] = f8_e4m3_to_f16(src[i]);
610+
}
611+
}
612+
564613
void convert_tensor(void* src,
565614
ggml_type src_type,
566615
void* dst,
@@ -794,6 +843,8 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
794843
ttype = GGML_TYPE_F32;
795844
} else if (dtype == "F32") {
796845
ttype = GGML_TYPE_F32;
846+
} else if (dtype == "F8_E4M3") {
847+
ttype = GGML_TYPE_F16;
797848
}
798849
return ttype;
799850
}
@@ -866,7 +917,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
866917

867918
ggml_type type = str_to_ggml_type(dtype);
868919
if (type == GGML_TYPE_COUNT) {
869-
LOG_ERROR("unsupported dtype '%s'", dtype.c_str());
920+
LOG_ERROR("unsupported dtype '%s' (tensor '%s')", dtype.c_str(), name.c_str());
870921
return false;
871922
}
872923

@@ -903,6 +954,10 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
903954
if (dtype == "BF16") {
904955
tensor_storage.is_bf16 = true;
905956
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
957+
} else if (dtype == "F8_E4M3") {
958+
tensor_storage.is_f8_e4m3 = true;
959+
// f8 -> f16
960+
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
906961
} else {
907962
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size);
908963
}
@@ -1537,6 +1592,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
15371592
if (tensor_storage.is_bf16) {
15381593
// inplace op
15391594
bf16_to_f32_vec((uint16_t*)dst_tensor->data, (float*)dst_tensor->data, tensor_storage.nelements());
1595+
} else if (tensor_storage.is_f8_e4m3) {
1596+
// inplace op
1597+
f8_e4m3_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements());
15401598
}
15411599
} else {
15421600
read_buffer.resize(tensor_storage.nbytes());
@@ -1545,6 +1603,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
15451603
if (tensor_storage.is_bf16) {
15461604
// inplace op
15471605
bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
1606+
} else if (tensor_storage.is_f8_e4m3) {
1607+
// inplace op
1608+
f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
15481609
}
15491610

15501611
convert_tensor((void*)read_buffer.data(), tensor_storage.type, dst_tensor->data,
@@ -1557,6 +1618,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
15571618
if (tensor_storage.is_bf16) {
15581619
// inplace op
15591620
bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
1621+
} else if (tensor_storage.is_f8_e4m3) {
1622+
// inplace op
1623+
f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
15601624
}
15611625

15621626
if (tensor_storage.type == dst_tensor->type) {

model.h

+7-2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ struct TensorStorage {
3232
std::string name;
3333
ggml_type type = GGML_TYPE_F32;
3434
bool is_bf16 = false;
35+
bool is_f8_e4m3 = false;
3536
int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
3637
int n_dims = 0;
3738

@@ -61,7 +62,7 @@ struct TensorStorage {
6162
}
6263

6364
int64_t nbytes_to_read() const {
64-
if (is_bf16) {
65+
if (is_bf16 || is_f8_e4m3) {
6566
return nbytes() / 2;
6667
} else {
6768
return nbytes();
@@ -109,6 +110,8 @@ struct TensorStorage {
109110
const char* type_name = ggml_type_name(type);
110111
if (is_bf16) {
111112
type_name = "bf16";
113+
} else if (is_f8_e4m3) {
114+
type_name = "f8_e4m3";
112115
}
113116
ss << name << " | " << type_name << " | ";
114117
ss << n_dims << " [";
@@ -160,4 +163,6 @@ class ModelLoader {
160163
static std::string load_merges();
161164
static std::string load_t5_tokenizer_json();
162165
};
163-
#endif // __MODEL_H__
166+
167+
#endif // __MODEL_H__
168+

0 commit comments

Comments
 (0)