Skip to content

Commit fef471a

Browse files
authored
Merge pull request #619 from robertknight/simd-u8
Support u8 vectors in new SIMD API
2 parents 7d4d12d + 6029ff9 commit fef471a

File tree

6 files changed

+473
-19
lines changed

6 files changed

+473
-19
lines changed

rten-simd/src/safe/arch/aarch64.rs

+87-10
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
use std::arch::aarch64::{
22
float32x4_t, int16x8_t, int32x4_t, int8x16_t, uint16x8_t, uint32x4_t, uint8x16_t, vabsq_f32,
3-
vaddq_f32, vaddq_s16, vaddq_s32, vaddq_s8, vaddq_u16, vaddvq_f32, vandq_u16, vandq_u32,
4-
vandq_u8, vbslq_f32, vbslq_s16, vbslq_s32, vbslq_s8, vbslq_u16, vceqq_f32, vceqq_s16,
5-
vceqq_s32, vceqq_s8, vceqq_u16, vcgeq_f32, vcgeq_s16, vcgeq_s32, vcgeq_s8, vcgeq_u16,
6-
vcgtq_f32, vcgtq_s16, vcgtq_s32, vcgtq_s8, vcgtq_u16, vcleq_f32, vcleq_s16, vcleq_s8,
7-
vcleq_u16, vcltq_f32, vcltq_s16, vcltq_s8, vcltq_u16, vcvtnq_s32_f32, vcvtq_s32_f32, vdivq_f32,
8-
vdupq_n_f32, vdupq_n_s16, vdupq_n_s32, vdupq_n_s8, vdupq_n_u16, vfmaq_f32, vld1q_f32,
9-
vld1q_s16, vld1q_s32, vld1q_s8, vld1q_u16, vld1q_u32, vld1q_u8, vmaxq_f32, vminq_f32,
10-
vmulq_f32, vmulq_s16, vmulq_s32, vmulq_s8, vmulq_u16, vnegq_f32, vnegq_s16, vnegq_s32,
11-
vnegq_s8, vshlq_n_s16, vshlq_n_s32, vshlq_n_s8, vst1q_f32, vst1q_s16, vst1q_s32, vst1q_s8,
12-
vst1q_u16, vsubq_f32, vsubq_s16, vsubq_s32, vsubq_s8, vsubq_u16,
3+
vaddq_f32, vaddq_s16, vaddq_s32, vaddq_s8, vaddq_u16, vaddq_u8, vaddvq_f32, vandq_u16,
4+
vandq_u32, vandq_u8, vbslq_f32, vbslq_s16, vbslq_s32, vbslq_s8, vbslq_u16, vbslq_u8, vceqq_f32,
5+
vceqq_s16, vceqq_s32, vceqq_s8, vceqq_u16, vceqq_u8, vcgeq_f32, vcgeq_s16, vcgeq_s32, vcgeq_s8,
6+
vcgeq_u16, vcgeq_u8, vcgtq_f32, vcgtq_s16, vcgtq_s32, vcgtq_s8, vcgtq_u16, vcgtq_u8, vcleq_f32,
7+
vcleq_s16, vcleq_s8, vcleq_u16, vcleq_u8, vcltq_f32, vcltq_s16, vcltq_s8, vcltq_u16, vcltq_u8,
8+
vcvtnq_s32_f32, vcvtq_s32_f32, vdivq_f32, vdupq_n_f32, vdupq_n_s16, vdupq_n_s32, vdupq_n_s8,
9+
vdupq_n_u16, vdupq_n_u8, vfmaq_f32, vld1q_f32, vld1q_s16, vld1q_s32, vld1q_s8, vld1q_u16,
10+
vld1q_u32, vld1q_u8, vmaxq_f32, vminq_f32, vmulq_f32, vmulq_s16, vmulq_s32, vmulq_s8,
11+
vmulq_u16, vmulq_u8, vnegq_f32, vnegq_s16, vnegq_s32, vnegq_s8, vshlq_n_s16, vshlq_n_s32,
12+
vshlq_n_s8, vst1q_f32, vst1q_s16, vst1q_s32, vst1q_s8, vst1q_u16, vst1q_u8, vsubq_f32,
13+
vsubq_s16, vsubq_s32, vsubq_s8, vsubq_u16, vsubq_u8,
1314
};
1415
use std::mem::transmute;
1516

@@ -32,6 +33,7 @@ unsafe impl Isa for ArmNeonIsa {
3233
type I32 = int32x4_t;
3334
type I16 = int16x8_t;
3435
type I8 = int8x16_t;
36+
type U8 = uint8x16_t;
3537
type U16 = uint16x8_t;
3638
type Bits = int32x4_t;
3739

@@ -51,6 +53,10 @@ unsafe impl Isa for ArmNeonIsa {
5153
self
5254
}
5355

56+
fn u8(self) -> impl SimdOps<Self::U8> {
57+
self
58+
}
59+
5460
fn u16(self) -> impl SimdOps<Self::U16> {
5561
self
5662
}
@@ -458,6 +464,76 @@ impl SimdIntOps<int8x16_t> for ArmNeonIsa {
458464
}
459465
}
460466

467+
unsafe impl SimdOps<uint8x16_t> for ArmNeonIsa {
468+
simd_ops_common!(uint8x16_t, uint8x16_t);
469+
470+
#[inline]
471+
fn add(self, x: uint8x16_t, y: uint8x16_t) -> uint8x16_t {
472+
unsafe { vaddq_u8(x, y) }
473+
}
474+
475+
#[inline]
476+
fn sub(self, x: uint8x16_t, y: uint8x16_t) -> uint8x16_t {
477+
unsafe { vsubq_u8(x, y) }
478+
}
479+
480+
#[inline]
481+
fn mul(self, x: uint8x16_t, y: uint8x16_t) -> uint8x16_t {
482+
unsafe { vmulq_u8(x, y) }
483+
}
484+
485+
#[inline]
486+
fn splat(self, x: u8) -> uint8x16_t {
487+
unsafe { vdupq_n_u8(x) }
488+
}
489+
490+
#[inline]
491+
fn lt(self, x: uint8x16_t, y: uint8x16_t) -> uint8x16_t {
492+
unsafe { vcltq_u8(x, y) }
493+
}
494+
495+
#[inline]
496+
fn le(self, x: uint8x16_t, y: uint8x16_t) -> uint8x16_t {
497+
unsafe { vcleq_u8(x, y) }
498+
}
499+
500+
#[inline]
501+
fn eq(self, x: uint8x16_t, y: uint8x16_t) -> uint8x16_t {
502+
unsafe { vceqq_u8(x, y) }
503+
}
504+
505+
#[inline]
506+
fn ge(self, x: uint8x16_t, y: uint8x16_t) -> uint8x16_t {
507+
unsafe { vcgeq_u8(x, y) }
508+
}
509+
510+
#[inline]
511+
fn gt(self, x: uint8x16_t, y: uint8x16_t) -> uint8x16_t {
512+
unsafe { vcgtq_u8(x, y) }
513+
}
514+
515+
#[inline]
516+
unsafe fn load_ptr(self, ptr: *const u8) -> uint8x16_t {
517+
unsafe { vld1q_u8(ptr) }
518+
}
519+
520+
#[inline]
521+
fn first_n_mask(self, n: usize) -> uint8x16_t {
522+
let mask: [u8; 16] = std::array::from_fn(|i| if i < n { u8::MAX } else { 0 });
523+
unsafe { vld1q_u8(mask.as_ptr()) }
524+
}
525+
526+
#[inline]
527+
fn select(self, x: uint8x16_t, y: uint8x16_t, mask: <uint8x16_t as Simd>::Mask) -> uint8x16_t {
528+
unsafe { vbslq_u8(mask, x, y) }
529+
}
530+
531+
#[inline]
532+
unsafe fn store_ptr(self, x: uint8x16_t, ptr: *mut u8) {
533+
unsafe { vst1q_u8(ptr, x) }
534+
}
535+
}
536+
461537
unsafe impl SimdOps<uint16x8_t> for ArmNeonIsa {
462538
simd_ops_common!(uint16x8_t, uint16x8_t);
463539

@@ -610,4 +686,5 @@ impl_simd!(float32x4_t, f32, 4, uint32x4_t);
610686
impl_simd!(int32x4_t, i32, 4, uint32x4_t);
611687
impl_simd!(int16x8_t, i16, 8, uint16x8_t);
612688
impl_simd!(int8x16_t, i8, 16, uint8x16_t);
689+
impl_simd!(uint8x16_t, u8, 16, uint8x16_t);
613690
impl_simd!(uint16x8_t, u16, 8, uint16x8_t);

rten-simd/src/safe/arch/generic.rs

+8
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ simd_type!(F32x4, f32, LEN_X32);
1919
simd_type!(I32x4, i32, LEN_X32);
2020
simd_type!(I16x8, i16, LEN_X32 * 2);
2121
simd_type!(I8x16, i8, LEN_X32 * 4);
22+
simd_type!(U8x16, u8, LEN_X32 * 4);
2223
simd_type!(U16x8, u16, LEN_X32 * 2);
2324

2425
// Define mask vector types. `Mn` is a mask for a vector with n-bit lanes.
@@ -49,6 +50,7 @@ unsafe impl Isa for GenericIsa {
4950
type I32 = I32x4;
5051
type I16 = I16x8;
5152
type I8 = I8x16;
53+
type U8 = U8x16;
5254
type U16 = U16x8;
5355
type Bits = I32x4;
5456

@@ -68,6 +70,10 @@ unsafe impl Isa for GenericIsa {
6870
self
6971
}
7072

73+
fn u8(self) -> impl SimdOps<Self::U8> {
74+
self
75+
}
76+
7177
fn u16(self) -> impl SimdOps<Self::U16> {
7278
self
7379
}
@@ -274,6 +280,7 @@ macro_rules! impl_simd_unsigned_int_ops {
274280
}
275281
};
276282
}
283+
impl_simd_unsigned_int_ops!(U8x16, u8, 16, M8);
277284
impl_simd_unsigned_int_ops!(U16x8, u16, 8, M16);
278285

279286
macro_rules! impl_mask {
@@ -334,4 +341,5 @@ impl_simd!(F32x4, f32, M32, 4);
334341
impl_simd!(I32x4, i32, M32, 4);
335342
impl_simd!(I16x8, i16, M16, 8);
336343
impl_simd!(I8x16, i8, M8, 16);
344+
impl_simd!(U8x16, u8, M8, 16);
337345
impl_simd!(U16x8, u16, M16, 8);

rten-simd/src/safe/arch/wasm32.rs

+58-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@ use std::arch::wasm32::{
55
i16x8_mul, i16x8_neg, i16x8_shl, i16x8_splat, i16x8_sub, i32x4_add, i32x4_eq, i32x4_ge,
66
i32x4_gt, i32x4_mul, i32x4_neg, i32x4_shl, i32x4_shuffle, i32x4_splat, i32x4_sub,
77
i32x4_trunc_sat_f32x4, i8x16_add, i8x16_eq, i8x16_ge, i8x16_gt, i8x16_neg, i8x16_shl,
8-
i8x16_shuffle, i8x16_splat, i8x16_sub, u16x8_add, u16x8_eq, u16x8_ge, u16x8_gt, u16x8_mul,
9-
u16x8_splat, u16x8_sub, v128, v128_and, v128_bitselect, v128_load, v128_store,
8+
i8x16_shuffle, i8x16_splat, i8x16_sub, u16x8_add, u16x8_eq, u16x8_extmul_high_u8x16,
9+
u16x8_extmul_low_u8x16, u16x8_ge, u16x8_gt, u16x8_mul, u16x8_splat, u16x8_sub, u8x16_add,
10+
u8x16_eq, u8x16_ge, u8x16_gt, u8x16_shuffle, u8x16_splat, u8x16_sub, v128, v128_and,
11+
v128_bitselect, v128_load, v128_store,
1012
};
1113
use std::mem::transmute;
1214

@@ -17,6 +19,7 @@ simd_type!(F32x4, v128, f32, M32, Wasm32Isa);
1719
simd_type!(I32x4, v128, i32, M32, Wasm32Isa);
1820
simd_type!(I16x8, v128, i16, M16, Wasm32Isa);
1921
simd_type!(I8x16, v128, i8, M8, Wasm32Isa);
22+
simd_type!(U8x16, v128, u8, M8, Wasm32Isa);
2023
simd_type!(U16x8, v128, u16, M16, Wasm32Isa);
2124

2225
#[derive(Copy, Clone)]
@@ -37,6 +40,7 @@ unsafe impl Isa for Wasm32Isa {
3740
type I32 = I32x4;
3841
type I16 = I16x8;
3942
type I8 = I8x16;
43+
type U8 = U8x16;
4044
type U16 = U16x8;
4145
type Bits = I32x4;
4246

@@ -56,6 +60,10 @@ unsafe impl Isa for Wasm32Isa {
5660
self
5761
}
5862

63+
fn u8(self) -> impl SimdOps<Self::U8> {
64+
self
65+
}
66+
5967
fn u16(self) -> impl SimdOps<Self::U16> {
6068
self
6169
}
@@ -399,6 +407,54 @@ impl SimdIntOps<I8x16> for Wasm32Isa {
399407
}
400408
}
401409

410+
unsafe impl SimdOps<U8x16> for Wasm32Isa {
411+
simd_ops_common!(U8x16, M8, i8);
412+
413+
#[inline]
414+
fn add(self, x: U8x16, y: U8x16) -> U8x16 {
415+
U8x16(u8x16_add(x.0, y.0))
416+
}
417+
418+
#[inline]
419+
fn sub(self, x: U8x16, y: U8x16) -> U8x16 {
420+
U8x16(u8x16_sub(x.0, y.0))
421+
}
422+
423+
#[inline]
424+
fn mul(self, x: U8x16, y: U8x16) -> U8x16 {
425+
let prod_low = u16x8_extmul_low_u8x16(x.0, y.0);
426+
let prod_high = u16x8_extmul_high_u8x16(x.0, y.0);
427+
428+
// Select even bytes from low and high products. This obtains the
429+
// u8 truncated product.
430+
let prod_u8 = u8x16_shuffle::<0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30>(
431+
prod_low, prod_high,
432+
);
433+
434+
U8x16(prod_u8)
435+
}
436+
437+
#[inline]
438+
fn splat(self, x: u8) -> U8x16 {
439+
U8x16(u8x16_splat(x))
440+
}
441+
442+
#[inline]
443+
fn eq(self, x: U8x16, y: U8x16) -> M8 {
444+
M8(u8x16_eq(x.0, y.0))
445+
}
446+
447+
#[inline]
448+
fn ge(self, x: U8x16, y: U8x16) -> M8 {
449+
M8(u8x16_ge(x.0, y.0))
450+
}
451+
452+
#[inline]
453+
fn gt(self, x: U8x16, y: U8x16) -> M8 {
454+
M8(u8x16_gt(x.0, y.0))
455+
}
456+
}
457+
402458
unsafe impl SimdOps<U16x8> for Wasm32Isa {
403459
simd_ops_common!(U16x8, M16, u16);
404460

0 commit comments

Comments
 (0)