Skip to content

Commit 747ec6d

Browse files
committed
add beta and alpha testcase for sbgemv
1 parent bb540dc commit 747ec6d

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

test/compare_sgemm_sbgemm.c

+14-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/***************************************************************************
2-
Copyright (c) 2020, The OpenBLAS Project
2+
Copyright (c) 2020,2025 The OpenBLAS Project
33
All rights reserved.
44
Redistribution and use in source and binary forms, with or without
55
modification, are permitted provided that the following conditions are
@@ -202,6 +202,8 @@ main (int argc, char *argv[])
202202
return ret;
203203
}
204204

205+
for (beta = 0; beta < 3; beta += 1) {
206+
for (alpha = 0; alpha < 3; alpha += 1) {
205207
for (l = 0; l < 2; l++) { // l = 1 to test inc_x & inc_y not equal to one.
206208
for (x = 1; x <= loop; x++)
207209
{
@@ -230,7 +232,10 @@ main (int argc, char *argv[])
230232
B[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
231233
sbstobf16_(&one, &B[j << l], &one, &btmp, &one);
232234
BB[j << l].v = btmp;
235+
236+
CC[j << l] = C[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
233237
}
238+
234239
for (y = 0; y < 2; y++)
235240
{
236241
if (y == 0) {
@@ -246,12 +251,14 @@ main (int argc, char *argv[])
246251
SGEMV (&transA, &x, &x, &alpha, A, &x, B, &k, &beta, C, &k);
247252
SBGEMV (&transA, &x, &x, &alpha, (bfloat16*) AA, &x, (bfloat16*) BB, &k, &beta, CC, &k);
248253

254+
for (int i = 0; i < x; i ++) DD[i] *= beta;
255+
249256
for (j = 0; j < x; j++)
250257
for (i = 0; i < x; i++)
251258
if (transA == 'N') {
252-
DD[i] += float16to32 (AA[j * x + i]) * float16to32 (BB[j << l]);
259+
DD[i] += alpha * float16to32 (AA[j * x + i]) * float16to32 (BB[j << l]);
253260
} else if (transA == 'T') {
254-
DD[j] += float16to32 (AA[j * x + i]) * float16to32 (BB[i << l]);
261+
DD[j] += alpha * float16to32 (AA[j * x + i]) * float16to32 (BB[i << l]);
255262
}
256263

257264
for (j = 0; j < x; j++) {
@@ -268,8 +275,10 @@ main (int argc, char *argv[])
268275
free(BB);
269276
free(DD);
270277
free(CC);
271-
}
272-
}
278+
} // x
279+
} // l
280+
} // alpha
281+
} // beta
273282

274283
if (ret != 0)
275284
fprintf (stderr, "FATAL ERROR SBGEMV - Return code: %d\n", ret);

0 commit comments

Comments
 (0)