@@ -554,13 +554,62 @@ float bf16_to_f32(uint16_t bfloat16) {
554
554
return *reinterpret_cast <float *>(&val_bits);
555
555
}
556
556
557
+ uint16_t f8_e4m3_to_f16 (uint8_t f8) {
558
+ // do we need to support uz?
559
+
560
+ const uint32_t exponent_bias = 7 ;
561
+ if (f8 == 0xff ) {
562
+ return ggml_fp32_to_fp16 (-NAN);
563
+ } else if (f8 == 0x7f ) {
564
+ return ggml_fp32_to_fp16 (NAN);
565
+ }
566
+
567
+ uint32_t sign = f8 & 0x80 ;
568
+ uint32_t exponent = (f8 & 0x78 ) >> 3 ;
569
+ uint32_t mantissa = f8 & 0x07 ;
570
+ uint32_t result = sign << 24 ;
571
+ if (exponent == 0 ) {
572
+ if (mantissa > 0 ) {
573
+ exponent = 0x7f - exponent_bias;
574
+
575
+ // yes, 2 times
576
+ if ((mantissa & 0x04 ) == 0 ) {
577
+ mantissa &= 0x03 ;
578
+ mantissa <<= 1 ;
579
+ exponent -= 1 ;
580
+ }
581
+ if ((mantissa & 0x04 ) == 0 ) {
582
+ mantissa &= 0x03 ;
583
+ mantissa <<= 1 ;
584
+ exponent -= 1 ;
585
+ }
586
+
587
+ result |= (mantissa & 0x03 ) << 21 ;
588
+ result |= exponent << 23 ;
589
+ }
590
+ } else {
591
+ result |= mantissa << 20 ;
592
+ exponent += 0x7f - exponent_bias;
593
+ result |= exponent << 23 ;
594
+ }
595
+
596
+ return ggml_fp32_to_fp16 (*reinterpret_cast <const float *>(&result));
597
+ }
598
+
557
599
void bf16_to_f32_vec (uint16_t * src, float * dst, int64_t n) {
558
600
// support inplace op
559
601
for (int64_t i = n - 1 ; i >= 0 ; i--) {
560
602
dst[i] = bf16_to_f32 (src[i]);
561
603
}
562
604
}
563
605
606
+ void f8_e4m3_to_f16_vec (uint8_t * src, uint16_t * dst, int64_t n) {
607
+ // support inplace op
608
+ for (int64_t i = n - 1 ; i >= 0 ; i--) {
609
+ dst[i] = f8_e4m3_to_f16 (src[i]);
610
+ }
611
+ }
612
+
564
613
void convert_tensor (void * src,
565
614
ggml_type src_type,
566
615
void * dst,
@@ -794,6 +843,8 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
794
843
ttype = GGML_TYPE_F32;
795
844
} else if (dtype == " F32" ) {
796
845
ttype = GGML_TYPE_F32;
846
+ } else if (dtype == " F8_E4M3" ) {
847
+ ttype = GGML_TYPE_F16;
797
848
}
798
849
return ttype;
799
850
}
@@ -866,7 +917,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
866
917
867
918
ggml_type type = str_to_ggml_type (dtype);
868
919
if (type == GGML_TYPE_COUNT) {
869
- LOG_ERROR (" unsupported dtype '%s'" , dtype.c_str ());
920
+ LOG_ERROR (" unsupported dtype '%s' (tensor '%s') " , dtype. c_str (), name .c_str ());
870
921
return false ;
871
922
}
872
923
@@ -903,6 +954,10 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
903
954
if (dtype == " BF16" ) {
904
955
tensor_storage.is_bf16 = true ;
905
956
GGML_ASSERT (tensor_storage.nbytes () == tensor_data_size * 2 );
957
+ } else if (dtype == " F8_E4M3" ) {
958
+ tensor_storage.is_f8_e4m3 = true ;
959
+ // f8 -> f16
960
+ GGML_ASSERT (tensor_storage.nbytes () == tensor_data_size * 2 );
906
961
} else {
907
962
GGML_ASSERT (tensor_storage.nbytes () == tensor_data_size);
908
963
}
@@ -1537,6 +1592,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
1537
1592
if (tensor_storage.is_bf16 ) {
1538
1593
// inplace op
1539
1594
bf16_to_f32_vec ((uint16_t *)dst_tensor->data , (float *)dst_tensor->data , tensor_storage.nelements ());
1595
+ } else if (tensor_storage.is_f8_e4m3 ) {
1596
+ // inplace op
1597
+ f8_e4m3_to_f16_vec ((uint8_t *)dst_tensor->data , (uint16_t *)dst_tensor->data , tensor_storage.nelements ());
1540
1598
}
1541
1599
} else {
1542
1600
read_buffer.resize (tensor_storage.nbytes ());
@@ -1545,6 +1603,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
1545
1603
if (tensor_storage.is_bf16 ) {
1546
1604
// inplace op
1547
1605
bf16_to_f32_vec ((uint16_t *)read_buffer.data (), (float *)read_buffer.data (), tensor_storage.nelements ());
1606
+ } else if (tensor_storage.is_f8_e4m3 ) {
1607
+ // inplace op
1608
+ f8_e4m3_to_f16_vec ((uint8_t *)read_buffer.data (), (uint16_t *)read_buffer.data (), tensor_storage.nelements ());
1548
1609
}
1549
1610
1550
1611
convert_tensor ((void *)read_buffer.data (), tensor_storage.type , dst_tensor->data ,
@@ -1557,6 +1618,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
1557
1618
if (tensor_storage.is_bf16 ) {
1558
1619
// inplace op
1559
1620
bf16_to_f32_vec ((uint16_t *)read_buffer.data (), (float *)read_buffer.data (), tensor_storage.nelements ());
1621
+ } else if (tensor_storage.is_f8_e4m3 ) {
1622
+ // inplace op
1623
+ f8_e4m3_to_f16_vec ((uint8_t *)read_buffer.data (), (uint16_t *)read_buffer.data (), tensor_storage.nelements ());
1560
1624
}
1561
1625
1562
1626
if (tensor_storage.type == dst_tensor->type ) {
0 commit comments