Skip to content

Commit 2b941c4

Browse files
authored
Merge branch 'develop' into sbgemv_n_neon
2 parents 35bdbca + c797e27 commit 2b941c4

File tree

9 files changed

+265
-12
lines changed

9 files changed

+265
-12
lines changed

CONTRIBUTORS.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -236,9 +236,12 @@ In chronological order:
236236
* Annop Wongwathanarat <annop.wongwathanarat@arm.com>
237237
* [2025-01-10] Add thread throttling profile for SGEMM on NEOVERSEV1
238238
* [2025-01-21] Optimize gemv_t_sve_v1x3 kernel
239+
* [2025-02-26] Add sbgemv_t_bfdot kernel
239240

240-
* Marek Michalowski <https://github.com/michalowski-arm>
241+
* Marek Michalowski <marek.michalowski@arm.com>
241242
* [2025-01-21] Add thread throttling profile for SGEMV on `NEOVERSEV1`
243+
* [2025-02-18] Add thread throttling profile for SGEMM on `NEOVERSEV2`
244+
* [2025-02-19] Add thread throttling profile for SGEMV on `NEOVERSEV2`
242245

243246
* Ye Tao <ye.tao@arm.com>
244247
* [2025-02-03] Optimize SBGEMM kernel on NEOVERSEV1

driver/others/dynamic_arm64.c

+1-1
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ extern gotoblas_t gotoblas_ARMV9SME;
162162

163163
extern gotoblas_t gotoblas_THUNDERX3T110;
164164
#endif
165-
#define gotoblas_NEOVERSEV2 gotoblas_NEOVERSEV1
165+
#define gotoblas_NEOVERSEV2 gotoblas_NEOVERSEN2
166166

167167
extern void openblas_warning(int verbose, const char * msg);
168168
#define FALLBACK_VERBOSE 1

interface/gemm.c

+25
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ static int init_amxtile_permission() {
177177
}
178178
#endif
179179

180+
#ifdef SMP
180181
#ifdef DYNAMIC_ARCH
181182
extern char* gotoblas_corename(void);
182183
#endif
@@ -198,14 +199,37 @@ static inline int get_gemm_optimal_nthreads_neoversev1(double MNK, int ncpu) {
198199
}
199200
#endif
200201

202+
#if defined(DYNAMIC_ARCH) || defined(NEOVERSEV2)
203+
static inline int get_gemm_optimal_nthreads_neoversev2(double MNK, int ncpu) {
204+
return
205+
MNK < 125000L ? 1
206+
: MNK < 1092727L ? MIN(ncpu, 6)
207+
: MNK < 2628072L ? MIN(ncpu, 8)
208+
: MNK < 8000000L ? MIN(ncpu, 12)
209+
: MNK < 20346417L ? MIN(ncpu, 16)
210+
: MNK < 57066625L ? MIN(ncpu, 24)
211+
: MNK < 91125000L ? MIN(ncpu, 28)
212+
: MNK < 238328000L ? MIN(ncpu, 40)
213+
: MNK < 454756609L ? MIN(ncpu, 48)
214+
: MNK < 857375000L ? MIN(ncpu, 56)
215+
: MNK < 1073741824L ? MIN(ncpu, 64)
216+
: ncpu;
217+
}
218+
#endif
219+
201220
static inline int get_gemm_optimal_nthreads(double MNK) {
202221
int ncpu = num_cpu_avail(3);
203222
#if defined(NEOVERSEV1) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16)
204223
return get_gemm_optimal_nthreads_neoversev1(MNK, ncpu);
224+
#elif defined(NEOVERSEV2) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16)
225+
return get_gemm_optimal_nthreads_neoversev2(MNK, ncpu);
205226
#elif defined(DYNAMIC_ARCH) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16)
206227
if (strcmp(gotoblas_corename(), "neoversev1") == 0) {
207228
return get_gemm_optimal_nthreads_neoversev1(MNK, ncpu);
208229
}
230+
if (strcmp(gotoblas_corename(), "neoversev2") == 0) {
231+
return get_gemm_optimal_nthreads_neoversev2(MNK, ncpu);
232+
}
209233
#endif
210234
if ( MNK <= (SMP_THRESHOLD_MIN * (double) GEMM_MULTITHREAD_THRESHOLD) ) {
211235
return 1;
@@ -219,6 +243,7 @@ static inline int get_gemm_optimal_nthreads(double MNK) {
219243
}
220244
}
221245
}
246+
#endif
222247

223248
#ifndef CBLAS
224249

interface/gemv.c

+18-7
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ static int (*gemv_thread[])(BLASLONG, BLASLONG, FLOAT, FLOAT *, BLASLONG, FLOAT
6363
};
6464
#endif
6565

66+
#ifdef SMP
6667
#ifdef DYNAMIC_ARCH
6768
extern char* gotoblas_corename(void);
6869
#endif
@@ -77,21 +78,38 @@ static inline int get_gemv_optimal_nthreads_neoversev1(BLASLONG MN, int ncpu) {
7778
}
7879
#endif
7980

81+
#if defined(DYNAMIC_ARCH) || defined(NEOVERSEV2)
82+
static inline int get_gemv_optimal_nthreads_neoversev2(BLASLONG MN, int ncpu) {
83+
return
84+
MN < 24964L ? 1
85+
: MN < 65536L ? MIN(ncpu, 8)
86+
: MN < 262144L ? MIN(ncpu, 32)
87+
: MN < 1638400L ? MIN(ncpu, 64)
88+
: ncpu;
89+
}
90+
#endif
91+
8092
static inline int get_gemv_optimal_nthreads(BLASLONG MN) {
8193
int ncpu = num_cpu_avail(3);
8294
#if defined(NEOVERSEV1) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16)
8395
return get_gemv_optimal_nthreads_neoversev1(MN, ncpu);
96+
#elif defined(NEOVERSEV2) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16)
97+
return get_gemv_optimal_nthreads_neoversev2(MN, ncpu);
8498
#elif defined(DYNAMIC_ARCH) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16)
8599
if (strcmp(gotoblas_corename(), "neoversev1") == 0) {
86100
return get_gemv_optimal_nthreads_neoversev1(MN, ncpu);
87101
}
102+
if (strcmp(gotoblas_corename(), "neoversev2") == 0) {
103+
return get_gemv_optimal_nthreads_neoversev2(MN, ncpu);
104+
}
88105
#endif
89106

90107
if ( MN < 115200L * GEMM_MULTITHREAD_THRESHOLD )
91108
return 1;
92109
else
93110
return num_cpu_avail(2);
94111
}
112+
#endif
95113

96114
#ifndef CBLAS
97115

@@ -232,13 +250,6 @@ void CNAME(enum CBLAS_ORDER order,
232250

233251
if (alpha == ZERO) return;
234252

235-
#if 0
236-
/* this optimization causes stack corruption on x86_64 under OSX, Windows and FreeBSD */
237-
if (trans == 0 && incx == 1 && incy == 1 && m*n < 2304 *GEMM_MULTITHREAD_THRESHOLD) {
238-
GEMV_N(m, n, 0, alpha, a, lda, x, incx, y, incy, NULL);
239-
return;
240-
}
241-
#endif
242253
IDEBUG_START;
243254

244255
FUNCTION_PROFILE_START();

kernel/arm64/KERNEL.NEOVERSEN2

+1
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,4 @@ SBGEMMINCOPYOBJ = sbgemm_incopy$(TSUFFIX).$(SUFFIX)
198198
SBGEMMITCOPYOBJ = sbgemm_itcopy$(TSUFFIX).$(SUFFIX)
199199
SBGEMMONCOPYOBJ = sbgemm_oncopy$(TSUFFIX).$(SUFFIX)
200200
SBGEMMOTCOPYOBJ = sbgemm_otcopy$(TSUFFIX).$(SUFFIX)
201+
SBGEMVTKERNEL = sbgemv_t_bfdot.c

kernel/arm64/KERNEL.NEOVERSEV1

+2
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,6 @@ SBGEMMONCOPYOBJ = sbgemm_oncopy$(TSUFFIX).$(SUFFIX)
1717
SBGEMMOTCOPYOBJ = sbgemm_otcopy$(TSUFFIX).$(SUFFIX)
1818

1919
SBGEMVNKERNEL = sbgemv_n_neon.c
20+
SBGEMVTKERNEL = sbgemv_t_bfdot.c
21+
2022
endif

kernel/arm64/KERNEL.NEOVERSEV2

+4
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
11
include $(KERNELDIR)/KERNEL.ARMV8SVE
2+
3+
ifeq ($(BUILD_BFLOAT16), 1)
4+
SBGEMVTKERNEL = sbgemv_t_bfdot.c
5+
endif

kernel/arm64/sbgemv_t_bfdot.c

+207
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
/***************************************************************************
2+
Copyright (c) 2025, The OpenBLAS Project
3+
All rights reserved.
4+
5+
Redistribution and use in source and binary forms, with or without
6+
modification, are permitted provided that the following conditions are
7+
met:
8+
9+
1. Redistributions of source code must retain the above copyright
10+
notice, this list of conditions and the following disclaimer.
11+
12+
2. Redistributions in binary form must reproduce the above copyright
13+
notice, this list of conditions and the following disclaimer in
14+
the documentation and/or other materials provided with the
15+
distribution.
16+
3. Neither the name of the OpenBLAS project nor the names of
17+
its contributors may be used to endorse or promote products
18+
derived from this software without specific prior written
19+
permission.
20+
21+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
24+
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
25+
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
30+
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31+
*****************************************************************************/
32+
33+
#include <arm_neon.h>
34+
#include "common.h"
35+
36+
static inline float bf16_to_fp32(bfloat16 bf16) {
37+
uint32_t fp32 = (uint32_t)bf16 << 16;
38+
return *((float*)&fp32);
39+
}
40+
41+
int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, BLASLONG incx, float beta, float *y, BLASLONG incy)
42+
{
43+
if (m < 1 || n < 1) return(0);
44+
BLASLONG i;
45+
BLASLONG ix,iy;
46+
BLASLONG j;
47+
bfloat16_t *a_ptr;
48+
bfloat16_t *x_ptr;
49+
float *y_ptr;
50+
float temp;
51+
52+
iy = 0;
53+
a_ptr = (bfloat16_t*)(a);
54+
x_ptr = (bfloat16_t*)(x);
55+
56+
if (incx == 1) {
57+
BLASLONG width = n / 4;
58+
59+
bfloat16_t *a0_ptr = a_ptr + lda * width * 0;
60+
bfloat16_t *a1_ptr = a_ptr + lda * width * 1;
61+
bfloat16_t *a2_ptr = a_ptr + lda * width * 2;
62+
bfloat16_t *a3_ptr = a_ptr + lda * width * 3;
63+
64+
float *y0_ptr = y + incy * width * 0;
65+
float *y1_ptr = y + incy * width * 1;
66+
float *y2_ptr = y + incy * width * 2;
67+
float *y3_ptr = y + incy * width * 3;
68+
69+
for (j = 0; j < width; j++) {
70+
float32x4_t temp0_vec = vdupq_n_f32(0.0f);
71+
float32x4_t temp1_vec = vdupq_n_f32(0.0f);
72+
float32x4_t temp2_vec = vdupq_n_f32(0.0f);
73+
float32x4_t temp3_vec = vdupq_n_f32(0.0f);
74+
75+
i = 0;
76+
while (i + 7 < m) {
77+
bfloat16x8_t x_vec = vld1q_bf16(x_ptr + i);
78+
79+
bfloat16x8_t a0_vec = vld1q_bf16(a0_ptr + i);
80+
bfloat16x8_t a1_vec = vld1q_bf16(a1_ptr + i);
81+
bfloat16x8_t a2_vec = vld1q_bf16(a2_ptr + i);
82+
bfloat16x8_t a3_vec = vld1q_bf16(a3_ptr + i);
83+
84+
temp0_vec = vbfdotq_f32(temp0_vec, a0_vec, x_vec);
85+
temp1_vec = vbfdotq_f32(temp1_vec, a1_vec, x_vec);
86+
temp2_vec = vbfdotq_f32(temp2_vec, a2_vec, x_vec);
87+
temp3_vec = vbfdotq_f32(temp3_vec, a3_vec, x_vec);
88+
89+
i += 8;
90+
}
91+
if (i + 3 < m) {
92+
float32x2_t t0 = vdup_n_f32(0.0f);
93+
float32x2_t t1 = vdup_n_f32(0.0f);
94+
float32x2_t t2 = vdup_n_f32(0.0f);
95+
float32x2_t t3 = vdup_n_f32(0.0f);
96+
97+
bfloat16x4_t x_vec = vld1_bf16(x_ptr + i);
98+
99+
bfloat16x4_t a0_vec = vld1_bf16(a0_ptr + i);
100+
bfloat16x4_t a1_vec = vld1_bf16(a1_ptr + i);
101+
bfloat16x4_t a2_vec = vld1_bf16(a2_ptr + i);
102+
bfloat16x4_t a3_vec = vld1_bf16(a3_ptr + i);
103+
104+
t0 = vbfdot_f32(t0, a0_vec, x_vec);
105+
t1 = vbfdot_f32(t1, a1_vec, x_vec);
106+
t2 = vbfdot_f32(t2, a2_vec, x_vec);
107+
t3 = vbfdot_f32(t3, a3_vec, x_vec);
108+
109+
float32x2_t temp0_vec_low = vget_low_f32(temp0_vec);
110+
float32x2_t temp1_vec_low = vget_low_f32(temp1_vec);
111+
float32x2_t temp2_vec_low = vget_low_f32(temp2_vec);
112+
float32x2_t temp3_vec_low = vget_low_f32(temp3_vec);
113+
114+
temp0_vec = vcombine_f32(vadd_f32(t0, temp0_vec_low), vget_high_f32(temp0_vec));
115+
temp1_vec = vcombine_f32(vadd_f32(t1, temp1_vec_low), vget_high_f32(temp1_vec));
116+
temp2_vec = vcombine_f32(vadd_f32(t2, temp2_vec_low), vget_high_f32(temp2_vec));
117+
temp3_vec = vcombine_f32(vadd_f32(t3, temp3_vec_low), vget_high_f32(temp3_vec));
118+
119+
i += 4;
120+
}
121+
if (beta == 0.0f) {
122+
y0_ptr[iy] = alpha * vaddvq_f32(temp0_vec);
123+
y1_ptr[iy] = alpha * vaddvq_f32(temp1_vec);
124+
y2_ptr[iy] = alpha * vaddvq_f32(temp2_vec);
125+
y3_ptr[iy] = alpha * vaddvq_f32(temp3_vec);
126+
}
127+
else {
128+
y0_ptr[iy] = alpha * vaddvq_f32(temp0_vec) + beta * y0_ptr[iy];
129+
y1_ptr[iy] = alpha * vaddvq_f32(temp1_vec) + beta * y1_ptr[iy];
130+
y2_ptr[iy] = alpha * vaddvq_f32(temp2_vec) + beta * y2_ptr[iy];
131+
y3_ptr[iy] = alpha * vaddvq_f32(temp3_vec) + beta * y3_ptr[iy];
132+
}
133+
134+
for (; i < m; ++i) {
135+
y0_ptr[iy] += alpha * a0_ptr[i] * x_ptr[i];
136+
y1_ptr[iy] += alpha * a1_ptr[i] * x_ptr[i];
137+
y2_ptr[iy] += alpha * a2_ptr[i] * x_ptr[i];
138+
y3_ptr[iy] += alpha * a3_ptr[i] * x_ptr[i];
139+
}
140+
141+
iy += incy;
142+
143+
a0_ptr += lda;
144+
a1_ptr += lda;
145+
a2_ptr += lda;
146+
a3_ptr += lda;
147+
}
148+
149+
a_ptr = a3_ptr;
150+
y_ptr = y3_ptr;
151+
for (j = width * 4; j < n; j++) {
152+
float32x4_t temp0_vec = vdupq_n_f32(0.0f);
153+
i = 0;
154+
while (i + 7 < m) {
155+
bfloat16x8_t x_vec = vld1q_bf16(x_ptr + i);
156+
bfloat16x8_t a0_vec = vld1q_bf16(a_ptr + i);
157+
temp0_vec = vbfdotq_f32(temp0_vec, a0_vec, x_vec);
158+
159+
i += 8;
160+
}
161+
if (i + 3 < m) {
162+
float32x2_t t0 = vdup_n_f32(0.0f);
163+
bfloat16x4_t x_vec = vld1_bf16(x_ptr + i);
164+
bfloat16x4_t a0_vec = vld1_bf16(a_ptr + i);
165+
166+
t0 = vbfdot_f32(t0, a0_vec, x_vec);
167+
float32x2_t temp0_vec_low = vget_low_f32(temp0_vec);
168+
temp0_vec = vcombine_f32(vadd_f32(t0, temp0_vec_low), vget_high_f32(temp0_vec));
169+
170+
i += 4;
171+
}
172+
if (beta == 0.0f) {
173+
y_ptr[iy] = alpha * vaddvq_f32(temp0_vec);
174+
}
175+
else {
176+
y_ptr[iy] = alpha * vaddvq_f32(temp0_vec) + beta * y_ptr[iy];
177+
}
178+
179+
for (; i < m; ++i) {
180+
y_ptr[iy] += alpha * a_ptr[i] * x_ptr[i];
181+
}
182+
183+
iy += incy;
184+
185+
a_ptr += lda;
186+
}
187+
return(0);
188+
}
189+
190+
for (j = 0; j < n; j++) {
191+
temp = 0.0;
192+
ix = 0;
193+
for (i = 0; i < m; i++) {
194+
temp += bf16_to_fp32(a[i]) * bf16_to_fp32(x[ix]);
195+
ix += incx;
196+
}
197+
if (beta == 0.0f) {
198+
y[iy] = alpha * temp;
199+
}
200+
else {
201+
y[iy] = alpha * temp + beta * y[iy];
202+
}
203+
iy += incy;
204+
a += lda;
205+
}
206+
return (0);
207+
}

kernel/power/scal.S

+3-3
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
#else
5252
#define X r7
5353
#define INCX r8
54-
#define FLAG r12
54+
#define FLAG r11
5555
#endif
5656
#endif
5757

@@ -63,7 +63,7 @@
6363
#else
6464
#define X r7
6565
#define INCX r8
66-
#define FLAG r12
66+
#define FLAG r11
6767
#endif
6868
#endif
6969

@@ -91,7 +91,7 @@
9191
fcmpu cr0, FZERO, ALPHA
9292
bne- cr0, LL(A1I1)
9393

94-
LDLONG FLAG, 48+64+8(SP)
94+
LDLONG FLAG, 104(SP)
9595
cmpwi cr0, FLAG, 1
9696
beq- cr0, LL(A1I1)
9797

0 commit comments

Comments
 (0)