Skip to content

Commit 217324d

Browse files
authored
Merge pull request #5162 from taoye9/add_sbgemv_tests
add beta and alpha testcase for sbgemv
2 parents e4630ed + 4346b91 commit 217324d

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

test/compare_sgemm_sbgemm.c

+14-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/***************************************************************************
2-
Copyright (c) 2020, The OpenBLAS Project
2+
Copyright (c) 2020,2025 The OpenBLAS Project
33
All rights reserved.
44
Redistribution and use in source and binary forms, with or without
55
modification, are permitted provided that the following conditions are
@@ -202,6 +202,8 @@ main (int argc, char *argv[])
202202
return ret;
203203
}
204204

205+
for (beta = 0; beta < 3; beta += 1) {
206+
for (alpha = 0; alpha < 3; alpha += 1) {
205207
for (l = 0; l < 2; l++) { // l = 1 to test inc_x & inc_y not equal to one.
206208
for (x = 1; x <= loop; x++)
207209
{
@@ -230,7 +232,10 @@ main (int argc, char *argv[])
230232
B[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
231233
sbstobf16_(&one, &B[j << l], &one, &btmp, &one);
232234
BB[j << l].v = btmp;
235+
236+
CC[j << l] = C[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
233237
}
238+
234239
for (y = 0; y < 2; y++)
235240
{
236241
if (y == 0) {
@@ -246,12 +251,14 @@ main (int argc, char *argv[])
246251
SGEMV (&transA, &x, &x, &alpha, A, &x, B, &k, &beta, C, &k);
247252
SBGEMV (&transA, &x, &x, &alpha, (bfloat16*) AA, &x, (bfloat16*) BB, &k, &beta, CC, &k);
248253

254+
for (int i = 0; i < x; i ++) DD[i] *= beta;
255+
249256
for (j = 0; j < x; j++)
250257
for (i = 0; i < x; i++)
251258
if (transA == 'N') {
252-
DD[i] += float16to32 (AA[j * x + i]) * float16to32 (BB[j << l]);
259+
DD[i] += alpha * float16to32 (AA[j * x + i]) * float16to32 (BB[j << l]);
253260
} else if (transA == 'T') {
254-
DD[j] += float16to32 (AA[j * x + i]) * float16to32 (BB[i << l]);
261+
DD[j] += alpha * float16to32 (AA[j * x + i]) * float16to32 (BB[i << l]);
255262
}
256263

257264
for (j = 0; j < x; j++) {
@@ -268,8 +275,10 @@ main (int argc, char *argv[])
268275
free(BB);
269276
free(DD);
270277
free(CC);
271-
}
272-
}
278+
} // x
279+
} // l
280+
} // alpha
281+
} // beta
273282

274283
if (ret != 0)
275284
fprintf (stderr, "FATAL ERROR SBGEMV - Return code: %d\n", ret);

0 commit comments

Comments
 (0)