Skip to content

Commit 7d4d12d

Browse files
authored
Merge pull request #618 from robertknight/simd-u16
Support u16 vectors in new SIMD API
2 parents e572952 + 560f7e9 commit 7d4d12d

File tree

6 files changed

+376
-37
lines changed

6 files changed

+376
-37
lines changed

rten-simd/src/safe/arch/aarch64.rs

+86-9
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
use std::arch::aarch64::{
22
float32x4_t, int16x8_t, int32x4_t, int8x16_t, uint16x8_t, uint32x4_t, uint8x16_t, vabsq_f32,
3-
vaddq_f32, vaddq_s16, vaddq_s32, vaddq_s8, vaddvq_f32, vandq_u16, vandq_u32, vandq_u8,
4-
vbslq_f32, vbslq_s16, vbslq_s32, vbslq_s8, vceqq_f32, vceqq_s16, vceqq_s32, vceqq_s8,
5-
vcgeq_f32, vcgeq_s16, vcgeq_s32, vcgeq_s8, vcgtq_f32, vcgtq_s16, vcgtq_s32, vcgtq_s8,
6-
vcleq_f32, vcleq_s16, vcleq_s8, vcltq_f32, vcltq_s16, vcltq_s8, vcvtnq_s32_f32, vcvtq_s32_f32,
7-
vdivq_f32, vdupq_n_f32, vdupq_n_s16, vdupq_n_s32, vdupq_n_s8, vfmaq_f32, vld1q_f32, vld1q_s16,
8-
vld1q_s32, vld1q_s8, vld1q_u16, vld1q_u32, vld1q_u8, vmaxq_f32, vminq_f32, vmulq_f32,
9-
vmulq_s16, vmulq_s32, vmulq_s8, vnegq_f32, vnegq_s16, vnegq_s32, vnegq_s8, vshlq_n_s16,
10-
vshlq_n_s32, vshlq_n_s8, vst1q_f32, vst1q_s16, vst1q_s32, vst1q_s8, vsubq_f32, vsubq_s16,
11-
vsubq_s32, vsubq_s8,
3+
vaddq_f32, vaddq_s16, vaddq_s32, vaddq_s8, vaddq_u16, vaddvq_f32, vandq_u16, vandq_u32,
4+
vandq_u8, vbslq_f32, vbslq_s16, vbslq_s32, vbslq_s8, vbslq_u16, vceqq_f32, vceqq_s16,
5+
vceqq_s32, vceqq_s8, vceqq_u16, vcgeq_f32, vcgeq_s16, vcgeq_s32, vcgeq_s8, vcgeq_u16,
6+
vcgtq_f32, vcgtq_s16, vcgtq_s32, vcgtq_s8, vcgtq_u16, vcleq_f32, vcleq_s16, vcleq_s8,
7+
vcleq_u16, vcltq_f32, vcltq_s16, vcltq_s8, vcltq_u16, vcvtnq_s32_f32, vcvtq_s32_f32, vdivq_f32,
8+
vdupq_n_f32, vdupq_n_s16, vdupq_n_s32, vdupq_n_s8, vdupq_n_u16, vfmaq_f32, vld1q_f32,
9+
vld1q_s16, vld1q_s32, vld1q_s8, vld1q_u16, vld1q_u32, vld1q_u8, vmaxq_f32, vminq_f32,
10+
vmulq_f32, vmulq_s16, vmulq_s32, vmulq_s8, vmulq_u16, vnegq_f32, vnegq_s16, vnegq_s32,
11+
vnegq_s8, vshlq_n_s16, vshlq_n_s32, vshlq_n_s8, vst1q_f32, vst1q_s16, vst1q_s32, vst1q_s8,
12+
vst1q_u16, vsubq_f32, vsubq_s16, vsubq_s32, vsubq_s8, vsubq_u16,
1213
};
1314
use std::mem::transmute;
1415

@@ -31,6 +32,7 @@ unsafe impl Isa for ArmNeonIsa {
3132
type I32 = int32x4_t;
3233
type I16 = int16x8_t;
3334
type I8 = int8x16_t;
35+
type U16 = uint16x8_t;
3436
type Bits = int32x4_t;
3537

3638
fn f32(self) -> impl SimdFloatOps<Self::F32, Int = Self::I32> {
@@ -48,6 +50,10 @@ unsafe impl Isa for ArmNeonIsa {
4850
fn i8(self) -> impl SimdIntOps<Self::I8> {
4951
self
5052
}
53+
54+
fn u16(self) -> impl SimdOps<Self::U16> {
55+
self
56+
}
5157
}
5258

5359
macro_rules! simd_ops_common {
@@ -452,6 +458,76 @@ impl SimdIntOps<int8x16_t> for ArmNeonIsa {
452458
}
453459
}
454460

461+
unsafe impl SimdOps<uint16x8_t> for ArmNeonIsa {
462+
simd_ops_common!(uint16x8_t, uint16x8_t);
463+
464+
#[inline]
465+
fn add(self, x: uint16x8_t, y: uint16x8_t) -> uint16x8_t {
466+
unsafe { vaddq_u16(x, y) }
467+
}
468+
469+
#[inline]
470+
fn sub(self, x: uint16x8_t, y: uint16x8_t) -> uint16x8_t {
471+
unsafe { vsubq_u16(x, y) }
472+
}
473+
474+
#[inline]
475+
fn mul(self, x: uint16x8_t, y: uint16x8_t) -> uint16x8_t {
476+
unsafe { vmulq_u16(x, y) }
477+
}
478+
479+
#[inline]
480+
fn splat(self, x: u16) -> uint16x8_t {
481+
unsafe { vdupq_n_u16(x) }
482+
}
483+
484+
#[inline]
485+
fn lt(self, x: uint16x8_t, y: uint16x8_t) -> uint16x8_t {
486+
unsafe { vcltq_u16(x, y) }
487+
}
488+
489+
#[inline]
490+
fn le(self, x: uint16x8_t, y: uint16x8_t) -> uint16x8_t {
491+
unsafe { vcleq_u16(x, y) }
492+
}
493+
494+
#[inline]
495+
fn eq(self, x: uint16x8_t, y: uint16x8_t) -> uint16x8_t {
496+
unsafe { vceqq_u16(x, y) }
497+
}
498+
499+
#[inline]
500+
fn ge(self, x: uint16x8_t, y: uint16x8_t) -> uint16x8_t {
501+
unsafe { vcgeq_u16(x, y) }
502+
}
503+
504+
#[inline]
505+
fn gt(self, x: uint16x8_t, y: uint16x8_t) -> uint16x8_t {
506+
unsafe { vcgtq_u16(x, y) }
507+
}
508+
509+
#[inline]
510+
unsafe fn load_ptr(self, ptr: *const u16) -> uint16x8_t {
511+
unsafe { vld1q_u16(ptr) }
512+
}
513+
514+
#[inline]
515+
fn first_n_mask(self, n: usize) -> uint16x8_t {
516+
let mask: [u16; 8] = std::array::from_fn(|i| if i < n { u16::MAX } else { 0 });
517+
unsafe { vld1q_u16(mask.as_ptr()) }
518+
}
519+
520+
#[inline]
521+
fn select(self, x: uint16x8_t, y: uint16x8_t, mask: <uint16x8_t as Simd>::Mask) -> uint16x8_t {
522+
unsafe { vbslq_u16(mask, x, y) }
523+
}
524+
525+
#[inline]
526+
unsafe fn store_ptr(self, x: uint16x8_t, ptr: *mut u16) {
527+
unsafe { vst1q_u16(ptr, x) }
528+
}
529+
}
530+
455531
macro_rules! impl_mask {
456532
($mask:ty, $elem:ty, $len:expr) => {
457533
impl Mask for $mask {
@@ -534,3 +610,4 @@ impl_simd!(float32x4_t, f32, 4, uint32x4_t);
534610
impl_simd!(int32x4_t, i32, 4, uint32x4_t);
535611
impl_simd!(int16x8_t, i16, 8, uint16x8_t);
536612
impl_simd!(int8x16_t, i8, 16, uint8x16_t);
613+
impl_simd!(uint16x8_t, u16, 8, uint16x8_t);

rten-simd/src/safe/arch/generic.rs

+20-4
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ simd_type!(F32x4, f32, LEN_X32);
1919
simd_type!(I32x4, i32, LEN_X32);
2020
simd_type!(I16x8, i16, LEN_X32 * 2);
2121
simd_type!(I8x16, i8, LEN_X32 * 4);
22+
simd_type!(U16x8, u16, LEN_X32 * 2);
2223

2324
// Define mask vector types. `Mn` is a mask for a vector with n-bit lanes.
2425
simd_type!(M32, i32, LEN_X32);
@@ -48,6 +49,7 @@ unsafe impl Isa for GenericIsa {
4849
type I32 = I32x4;
4950
type I16 = I16x8;
5051
type I8 = I8x16;
52+
type U16 = U16x8;
5153
type Bits = I32x4;
5254

5355
fn f32(self) -> impl SimdFloatOps<Self::F32, Int = Self::I32> {
@@ -65,6 +67,10 @@ unsafe impl Isa for GenericIsa {
6567
fn i8(self) -> impl SimdIntOps<Self::I8> {
6668
self
6769
}
70+
71+
fn u16(self) -> impl SimdOps<Self::U16> {
72+
self
73+
}
6874
}
6975

7076
macro_rules! simd_ops_common {
@@ -235,7 +241,7 @@ impl SimdFloatOps<F32x4> for GenericIsa {
235241
}
236242
}
237243

238-
macro_rules! impl_simd_int_ops {
244+
macro_rules! impl_simd_signed_int_ops {
239245
($simd:ident, $elem:ty, $len:expr, $mask:ident) => {
240246
unsafe impl SimdOps<$simd> for GenericIsa {
241247
simd_ops_common!($simd, $elem, $len, $mask);
@@ -257,9 +263,18 @@ macro_rules! impl_simd_int_ops {
257263
};
258264
}
259265

260-
impl_simd_int_ops!(I32x4, i32, 4, M32);
261-
impl_simd_int_ops!(I16x8, i16, 8, M16);
262-
impl_simd_int_ops!(I8x16, i8, 16, M8);
266+
impl_simd_signed_int_ops!(I32x4, i32, 4, M32);
267+
impl_simd_signed_int_ops!(I16x8, i16, 8, M16);
268+
impl_simd_signed_int_ops!(I8x16, i8, 16, M8);
269+
270+
macro_rules! impl_simd_unsigned_int_ops {
271+
($simd:ident, $elem:ty, $len:expr, $mask:ident) => {
272+
unsafe impl SimdOps<$simd> for GenericIsa {
273+
simd_ops_common!($simd, $elem, $len, $mask);
274+
}
275+
};
276+
}
277+
impl_simd_unsigned_int_ops!(U16x8, u16, 8, M16);
263278

264279
macro_rules! impl_mask {
265280
($mask:ident, $len:expr) => {
@@ -319,3 +334,4 @@ impl_simd!(F32x4, f32, M32, 4);
319334
impl_simd!(I32x4, i32, M32, 4);
320335
impl_simd!(I16x8, i16, M16, 8);
321336
impl_simd!(I8x16, i8, M8, 16);
337+
impl_simd!(U16x8, u16, M16, 8);

rten-simd/src/safe/arch/wasm32.rs

+48-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ use std::arch::wasm32::{
55
i16x8_mul, i16x8_neg, i16x8_shl, i16x8_splat, i16x8_sub, i32x4_add, i32x4_eq, i32x4_ge,
66
i32x4_gt, i32x4_mul, i32x4_neg, i32x4_shl, i32x4_shuffle, i32x4_splat, i32x4_sub,
77
i32x4_trunc_sat_f32x4, i8x16_add, i8x16_eq, i8x16_ge, i8x16_gt, i8x16_neg, i8x16_shl,
8-
i8x16_shuffle, i8x16_splat, i8x16_sub, v128, v128_and, v128_bitselect, v128_load, v128_store,
8+
i8x16_shuffle, i8x16_splat, i8x16_sub, u16x8_add, u16x8_eq, u16x8_ge, u16x8_gt, u16x8_mul,
9+
u16x8_splat, u16x8_sub, v128, v128_and, v128_bitselect, v128_load, v128_store,
910
};
1011
use std::mem::transmute;
1112

@@ -16,6 +17,7 @@ simd_type!(F32x4, v128, f32, M32, Wasm32Isa);
1617
simd_type!(I32x4, v128, i32, M32, Wasm32Isa);
1718
simd_type!(I16x8, v128, i16, M16, Wasm32Isa);
1819
simd_type!(I8x16, v128, i8, M8, Wasm32Isa);
20+
simd_type!(U16x8, v128, u16, M16, Wasm32Isa);
1921

2022
#[derive(Copy, Clone)]
2123
pub struct Wasm32Isa {
@@ -35,6 +37,7 @@ unsafe impl Isa for Wasm32Isa {
3537
type I32 = I32x4;
3638
type I16 = I16x8;
3739
type I8 = I8x16;
40+
type U16 = U16x8;
3841
type Bits = I32x4;
3942

4043
fn f32(self) -> impl SimdFloatOps<Self::F32, Int = Self::I32> {
@@ -52,6 +55,10 @@ unsafe impl Isa for Wasm32Isa {
5255
fn i8(self) -> impl SimdIntOps<Self::I8> {
5356
self
5457
}
58+
59+
fn u16(self) -> impl SimdOps<Self::U16> {
60+
self
61+
}
5562
}
5663

5764
macro_rules! simd_ops_common {
@@ -69,7 +76,7 @@ macro_rules! simd_ops_common {
6976
#[inline]
7077
fn first_n_mask(self, n: usize) -> $mask {
7178
let mask: [$mask_elem; lanes::<$simd>()] =
72-
std::array::from_fn(|i| if i < n { -1 } else { 0 });
79+
std::array::from_fn(|i| if i < n { !0 } else { 0 });
7380
$mask(unsafe { v128_load(mask.as_ptr() as *const v128) })
7481
}
7582

@@ -392,6 +399,45 @@ impl SimdIntOps<I8x16> for Wasm32Isa {
392399
}
393400
}
394401

402+
unsafe impl SimdOps<U16x8> for Wasm32Isa {
403+
simd_ops_common!(U16x8, M16, u16);
404+
405+
#[inline]
406+
fn add(self, x: U16x8, y: U16x8) -> U16x8 {
407+
U16x8(u16x8_add(x.0, y.0))
408+
}
409+
410+
#[inline]
411+
fn sub(self, x: U16x8, y: U16x8) -> U16x8 {
412+
U16x8(u16x8_sub(x.0, y.0))
413+
}
414+
415+
#[inline]
416+
fn mul(self, x: U16x8, y: U16x8) -> U16x8 {
417+
U16x8(u16x8_mul(x.0, y.0))
418+
}
419+
420+
#[inline]
421+
fn splat(self, x: u16) -> U16x8 {
422+
U16x8(u16x8_splat(x))
423+
}
424+
425+
#[inline]
426+
fn eq(self, x: U16x8, y: U16x8) -> M16 {
427+
M16(u16x8_eq(x.0, y.0))
428+
}
429+
430+
#[inline]
431+
fn ge(self, x: U16x8, y: U16x8) -> M16 {
432+
M16(u16x8_ge(x.0, y.0))
433+
}
434+
435+
#[inline]
436+
fn gt(self, x: U16x8, y: U16x8) -> M16 {
437+
M16(u16x8_gt(x.0, y.0))
438+
}
439+
}
440+
395441
macro_rules! mask_type {
396442
($mask:ident, $elem:ty, $len: expr) => {
397443
#[derive(Copy, Clone, Debug)]

0 commit comments

Comments
 (0)