Skip to content

Commit 8003076

Browse files
authored
Merge pull request #616 from robertknight/simd-i8
Support i8 vectors in new SIMD API
2 parents 14d429e + cb1aa4c commit 8003076

File tree

6 files changed

+625
-49
lines changed

6 files changed

+625
-49
lines changed

rten-simd/src/safe/arch/aarch64.rs

+106-8
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
use std::arch::aarch64::{
2-
float32x4_t, int16x8_t, int32x4_t, uint16x8_t, uint32x4_t, vabsq_f32, vaddq_f32, vaddq_s16,
3-
vaddq_s32, vaddvq_f32, vandq_u16, vandq_u32, vbslq_f32, vbslq_s16, vbslq_s32, vceqq_f32,
4-
vceqq_s16, vceqq_s32, vcgeq_f32, vcgeq_s16, vcgeq_s32, vcgtq_f32, vcgtq_s16, vcgtq_s32,
5-
vcleq_f32, vcleq_s16, vcltq_f32, vcltq_s16, vcvtnq_s32_f32, vcvtq_s32_f32, vdivq_f32,
6-
vdupq_n_f32, vdupq_n_s16, vdupq_n_s32, vfmaq_f32, vld1q_f32, vld1q_s16, vld1q_s32, vld1q_u16,
7-
vld1q_u32, vmaxq_f32, vminq_f32, vmulq_f32, vmulq_s16, vmulq_s32, vnegq_f32, vnegq_s16,
8-
vnegq_s32, vshlq_n_s16, vshlq_n_s32, vst1q_f32, vst1q_s16, vst1q_s32, vsubq_f32, vsubq_s16,
9-
vsubq_s32,
2+
float32x4_t, int16x8_t, int32x4_t, int8x16_t, uint16x8_t, uint32x4_t, uint8x16_t, vabsq_f32,
3+
vaddq_f32, vaddq_s16, vaddq_s32, vaddq_s8, vaddvq_f32, vandq_u16, vandq_u32, vandq_u8,
4+
vbslq_f32, vbslq_s16, vbslq_s32, vbslq_s8, vceqq_f32, vceqq_s16, vceqq_s32, vceqq_s8,
5+
vcgeq_f32, vcgeq_s16, vcgeq_s32, vcgeq_s8, vcgtq_f32, vcgtq_s16, vcgtq_s32, vcgtq_s8,
6+
vcleq_f32, vcleq_s16, vcleq_s8, vcltq_f32, vcltq_s16, vcltq_s8, vcvtnq_s32_f32, vcvtq_s32_f32,
7+
vdivq_f32, vdupq_n_f32, vdupq_n_s16, vdupq_n_s32, vdupq_n_s8, vfmaq_f32, vld1q_f32, vld1q_s16,
8+
vld1q_s32, vld1q_s8, vld1q_u16, vld1q_u32, vld1q_u8, vmaxq_f32, vminq_f32, vmulq_f32,
9+
vmulq_s16, vmulq_s32, vmulq_s8, vnegq_f32, vnegq_s16, vnegq_s32, vnegq_s8, vshlq_n_s16,
10+
vshlq_n_s32, vshlq_n_s8, vst1q_f32, vst1q_s16, vst1q_s32, vst1q_s8, vsubq_f32, vsubq_s16,
11+
vsubq_s32, vsubq_s8,
1012
};
1113
use std::mem::transmute;
1214

@@ -28,6 +30,7 @@ unsafe impl Isa for ArmNeonIsa {
2830
type F32 = float32x4_t;
2931
type I32 = int32x4_t;
3032
type I16 = int16x8_t;
33+
type I8 = int8x16_t;
3134
type Bits = int32x4_t;
3235

3336
fn f32(self) -> impl SimdFloatOps<Self::F32, Int = Self::I32> {
@@ -41,6 +44,10 @@ unsafe impl Isa for ArmNeonIsa {
4144
fn i16(self) -> impl SimdIntOps<Self::I16> {
4245
self
4346
}
47+
48+
fn i8(self) -> impl SimdIntOps<Self::I8> {
49+
self
50+
}
4451
}
4552

4653
macro_rules! simd_ops_common {
@@ -363,6 +370,88 @@ impl SimdIntOps<int16x8_t> for ArmNeonIsa {
363370
}
364371
}
365372

373+
unsafe impl SimdOps<int8x16_t> for ArmNeonIsa {
374+
simd_ops_common!(int8x16_t, uint8x16_t);
375+
376+
#[inline]
377+
fn add(self, x: int8x16_t, y: int8x16_t) -> int8x16_t {
378+
unsafe { vaddq_s8(x, y) }
379+
}
380+
381+
#[inline]
382+
fn sub(self, x: int8x16_t, y: int8x16_t) -> int8x16_t {
383+
unsafe { vsubq_s8(x, y) }
384+
}
385+
386+
#[inline]
387+
fn mul(self, x: int8x16_t, y: int8x16_t) -> int8x16_t {
388+
unsafe { vmulq_s8(x, y) }
389+
}
390+
391+
#[inline]
392+
fn splat(self, x: i8) -> int8x16_t {
393+
unsafe { vdupq_n_s8(x) }
394+
}
395+
396+
#[inline]
397+
fn lt(self, x: int8x16_t, y: int8x16_t) -> uint8x16_t {
398+
unsafe { vcltq_s8(x, y) }
399+
}
400+
401+
#[inline]
402+
fn le(self, x: int8x16_t, y: int8x16_t) -> uint8x16_t {
403+
unsafe { vcleq_s8(x, y) }
404+
}
405+
406+
#[inline]
407+
fn eq(self, x: int8x16_t, y: int8x16_t) -> uint8x16_t {
408+
unsafe { vceqq_s8(x, y) }
409+
}
410+
411+
#[inline]
412+
fn ge(self, x: int8x16_t, y: int8x16_t) -> uint8x16_t {
413+
unsafe { vcgeq_s8(x, y) }
414+
}
415+
416+
#[inline]
417+
fn gt(self, x: int8x16_t, y: int8x16_t) -> uint8x16_t {
418+
unsafe { vcgtq_s8(x, y) }
419+
}
420+
421+
#[inline]
422+
unsafe fn load_ptr(self, ptr: *const i8) -> int8x16_t {
423+
unsafe { vld1q_s8(ptr) }
424+
}
425+
426+
#[inline]
427+
fn first_n_mask(self, n: usize) -> uint8x16_t {
428+
let mask: [u8; 16] = std::array::from_fn(|i| if i < n { u8::MAX } else { 0 });
429+
unsafe { vld1q_u8(mask.as_ptr()) }
430+
}
431+
432+
#[inline]
433+
fn select(self, x: int8x16_t, y: int8x16_t, mask: <int8x16_t as Simd>::Mask) -> int8x16_t {
434+
unsafe { vbslq_s8(mask, x, y) }
435+
}
436+
437+
#[inline]
438+
unsafe fn store_ptr(self, x: int8x16_t, ptr: *mut i8) {
439+
unsafe { vst1q_s8(ptr, x) }
440+
}
441+
}
442+
443+
impl SimdIntOps<int8x16_t> for ArmNeonIsa {
444+
#[inline]
445+
fn neg(self, x: int8x16_t) -> int8x16_t {
446+
unsafe { vnegq_s8(x) }
447+
}
448+
449+
#[inline]
450+
fn shift_left<const SHIFT: i32>(self, x: int8x16_t) -> int8x16_t {
451+
unsafe { vshlq_n_s8::<SHIFT>(x) }
452+
}
453+
}
454+
366455
macro_rules! impl_mask {
367456
($mask:ty, $elem:ty, $len:expr) => {
368457
impl Mask for $mask {
@@ -379,6 +468,7 @@ macro_rules! impl_mask {
379468

380469
impl_mask!(uint32x4_t, u32, 4);
381470
impl_mask!(uint16x8_t, u16, 8);
471+
impl_mask!(uint8x16_t, u8, 16);
382472

383473
unsafe impl MaskOps<uint32x4_t> for ArmNeonIsa {
384474
#[inline]
@@ -394,6 +484,13 @@ unsafe impl MaskOps<uint16x8_t> for ArmNeonIsa {
394484
}
395485
}
396486

487+
unsafe impl MaskOps<uint8x16_t> for ArmNeonIsa {
488+
#[inline]
489+
fn and(self, x: uint8x16_t, y: uint8x16_t) -> uint8x16_t {
490+
unsafe { vandq_u8(x, y) }
491+
}
492+
}
493+
397494
macro_rules! simd_common {
398495
($mask:ty, $len:expr) => {
399496
type Array = [Self::Elem; $len];
@@ -436,3 +533,4 @@ macro_rules! impl_simd {
436533
impl_simd!(float32x4_t, f32, 4, uint32x4_t);
437534
impl_simd!(int32x4_t, i32, 4, uint32x4_t);
438535
impl_simd!(int16x8_t, i16, 8, uint16x8_t);
536+
impl_simd!(int8x16_t, i8, 16, uint8x16_t);

rten-simd/src/safe/arch/generic.rs

+12
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ pub struct I32x4([i32; LEN_X32]);
1818
#[derive(Copy, Clone, Debug)]
1919
pub struct I16x8([i16; LEN_X32 * 2]);
2020

21+
#[repr(align(16))]
22+
#[derive(Copy, Clone, Debug)]
23+
pub struct I8x16([i8; LEN_X32 * 4]);
24+
2125
#[derive(Copy, Clone)]
2226
pub struct GenericIsa {
2327
_private: (),
@@ -40,6 +44,7 @@ unsafe impl Isa for GenericIsa {
4044
type F32 = F32x4;
4145
type I32 = I32x4;
4246
type I16 = I16x8;
47+
type I8 = I8x16;
4348
type Bits = I32x4;
4449

4550
fn f32(self) -> impl SimdFloatOps<Self::F32, Int = Self::I32> {
@@ -53,6 +58,10 @@ unsafe impl Isa for GenericIsa {
5358
fn i16(self) -> impl SimdIntOps<Self::I16> {
5459
self
5560
}
61+
62+
fn i8(self) -> impl SimdIntOps<Self::I8> {
63+
self
64+
}
5665
}
5766

5867
macro_rules! simd_ops_common {
@@ -247,6 +256,7 @@ macro_rules! impl_simd_int_ops {
247256

248257
impl_simd_int_ops!(I32x4, i32, 4, I32x4);
249258
impl_simd_int_ops!(I16x8, i16, 8, I16x8);
259+
impl_simd_int_ops!(I8x16, i8, 16, I8x16);
250260

251261
macro_rules! impl_mask {
252262
($mask:ident, $len:expr) => {
@@ -272,6 +282,7 @@ macro_rules! impl_mask {
272282

273283
impl_mask!(I32x4, LEN_X32);
274284
impl_mask!(I16x8, LEN_X32 * 2);
285+
impl_mask!(I8x16, LEN_X32 * 4);
275286

276287
macro_rules! impl_simd {
277288
($simd:ty, $elem:ty, $mask:ty, $len:expr) => {
@@ -304,3 +315,4 @@ macro_rules! impl_simd {
304315
impl_simd!(F32x4, f32, I32x4, 4);
305316
impl_simd!(I32x4, i32, I32x4, 4);
306317
impl_simd!(I16x8, i16, I16x8, 8);
318+
impl_simd!(I8x16, i8, I8x16, 16);

rten-simd/src/safe/arch/wasm32.rs

+72-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
use std::arch::wasm32::{
22
f32x4_abs, f32x4_add, f32x4_div, f32x4_eq, f32x4_extract_lane, f32x4_ge, f32x4_gt, f32x4_le,
33
f32x4_lt, f32x4_max, f32x4_min, f32x4_mul, f32x4_nearest, f32x4_neg, f32x4_splat, f32x4_sub,
4-
i16x8_add, i16x8_eq, i16x8_ge, i16x8_gt, i16x8_mul, i16x8_neg, i16x8_shl, i16x8_splat,
5-
i16x8_sub, i32x4_add, i32x4_eq, i32x4_ge, i32x4_gt, i32x4_mul, i32x4_neg, i32x4_shl,
6-
i32x4_shuffle, i32x4_splat, i32x4_sub, i32x4_trunc_sat_f32x4, v128, v128_and, v128_bitselect,
7-
v128_load, v128_store,
4+
i16x8_add, i16x8_eq, i16x8_extmul_high_i8x16, i16x8_extmul_low_i8x16, i16x8_ge, i16x8_gt,
5+
i16x8_mul, i16x8_neg, i16x8_shl, i16x8_splat, i16x8_sub, i32x4_add, i32x4_eq, i32x4_ge,
6+
i32x4_gt, i32x4_mul, i32x4_neg, i32x4_shl, i32x4_shuffle, i32x4_splat, i32x4_sub,
7+
i32x4_trunc_sat_f32x4, i8x16_add, i8x16_eq, i8x16_ge, i8x16_gt, i8x16_neg, i8x16_shl,
8+
i8x16_shuffle, i8x16_splat, i8x16_sub, v128, v128_and, v128_bitselect, v128_load, v128_store,
89
};
910
use std::mem::transmute;
1011

@@ -14,6 +15,7 @@ use crate::safe::{Isa, Mask, MaskOps, Simd, SimdFloatOps, SimdIntOps, SimdOps};
1415
simd_type!(F32x4, v128, f32, I32x4, Wasm32Isa);
1516
simd_type!(I32x4, v128, i32, I32x4, Wasm32Isa);
1617
simd_type!(I16x8, v128, i16, I16x8, Wasm32Isa);
18+
simd_type!(I8x16, v128, i8, I8x16, Wasm32Isa);
1719

1820
#[derive(Copy, Clone)]
1921
pub struct Wasm32Isa {
@@ -32,6 +34,7 @@ unsafe impl Isa for Wasm32Isa {
3234
type F32 = F32x4;
3335
type I32 = I32x4;
3436
type I16 = I16x8;
37+
type I8 = I8x16;
3538
type Bits = I32x4;
3639

3740
fn f32(self) -> impl SimdFloatOps<Self::F32, Int = Self::I32> {
@@ -45,6 +48,10 @@ unsafe impl Isa for Wasm32Isa {
4548
fn i16(self) -> impl SimdIntOps<Self::I16> {
4649
self
4750
}
51+
52+
fn i8(self) -> impl SimdIntOps<Self::I8> {
53+
self
54+
}
4855
}
4956

5057
macro_rules! simd_ops_common {
@@ -325,6 +332,66 @@ impl SimdIntOps<I16x8> for Wasm32Isa {
325332
}
326333
}
327334

335+
unsafe impl SimdOps<I8x16> for Wasm32Isa {
336+
simd_ops_common!(I8x16, I8x16, i8);
337+
338+
#[inline]
339+
fn add(self, x: I8x16, y: I8x16) -> I8x16 {
340+
I8x16(i8x16_add(x.0, y.0))
341+
}
342+
343+
#[inline]
344+
fn sub(self, x: I8x16, y: I8x16) -> I8x16 {
345+
I8x16(i8x16_sub(x.0, y.0))
346+
}
347+
348+
#[inline]
349+
fn mul(self, x: I8x16, y: I8x16) -> I8x16 {
350+
let prod_low = i16x8_extmul_low_i8x16(x.0, y.0);
351+
let prod_high = i16x8_extmul_high_i8x16(x.0, y.0);
352+
353+
// Select even bytes from low and high products. This obtains the
354+
// i8 truncated product.
355+
let prod_i8 = i8x16_shuffle::<0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30>(
356+
prod_low, prod_high,
357+
);
358+
359+
I8x16(prod_i8)
360+
}
361+
362+
#[inline]
363+
fn splat(self, x: i8) -> I8x16 {
364+
I8x16(i8x16_splat(x))
365+
}
366+
367+
#[inline]
368+
fn eq(self, x: I8x16, y: I8x16) -> I8x16 {
369+
I8x16(i8x16_eq(x.0, y.0))
370+
}
371+
372+
#[inline]
373+
fn ge(self, x: I8x16, y: I8x16) -> I8x16 {
374+
I8x16(i8x16_ge(x.0, y.0))
375+
}
376+
377+
#[inline]
378+
fn gt(self, x: I8x16, y: I8x16) -> I8x16 {
379+
I8x16(i8x16_gt(x.0, y.0))
380+
}
381+
}
382+
383+
impl SimdIntOps<I8x16> for Wasm32Isa {
384+
#[inline]
385+
fn neg(self, x: I8x16) -> I8x16 {
386+
I8x16(i8x16_neg(x.0))
387+
}
388+
389+
#[inline]
390+
fn shift_left<const SHIFT: i32>(self, x: I8x16) -> I8x16 {
391+
I8x16(i8x16_shl(x.0, SHIFT as u32))
392+
}
393+
}
394+
328395
macro_rules! mask_type {
329396
($mask:ident, $elem:ty, $len: expr) => {
330397
impl Mask for $mask {
@@ -348,3 +415,4 @@ macro_rules! mask_type {
348415

349416
mask_type!(I32x4, i32, 4);
350417
mask_type!(I16x8, i16, 8);
418+
mask_type!(I8x16, i8, 16);

0 commit comments

Comments
 (0)