|
48 | 48 | #define SWITCH_RATIO 2
|
49 | 49 | #endif
|
50 | 50 |
|
| 51 | +#ifndef GEMM_PREFERED_SIZE |
| 52 | +#define GEMM_PREFERED_SIZE 1 |
| 53 | +#endif |
| 54 | + |
51 | 55 | //The array of job_t may overflow the stack.
|
52 | 56 | //Instead, use malloc to alloc job_t.
|
53 | 57 | #if MAX_CPU_NUMBER > BLAS3_MEM_ALLOC_THRESHOLD
|
@@ -510,6 +514,16 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
|
510 | 514 | return 0;
|
511 | 515 | }
|
512 | 516 |
|
| 517 | +static int round_up(int remainder, int width, int multiple) |
| 518 | +{ |
| 519 | + if (multiple > remainder || width <= multiple) |
| 520 | + return width; |
| 521 | + width = (width + multiple - 1) / multiple; |
| 522 | + width = width * multiple; |
| 523 | + return width; |
| 524 | +} |
| 525 | + |
| 526 | + |
513 | 527 | static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
|
514 | 528 | *range_n, FLOAT *sa, FLOAT *sb,
|
515 | 529 | BLASLONG nthreads_m, BLASLONG nthreads_n) {
|
@@ -601,9 +615,14 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
|
601 | 615 | num_parts = 0;
|
602 | 616 | while (m > 0){
|
603 | 617 | width = blas_quickdivide(m + nthreads_m - num_parts - 1, nthreads_m - num_parts);
|
| 618 | + |
| 619 | + width = round_up(m, width, GEMM_PREFERED_SIZE); |
| 620 | + |
604 | 621 | m -= width;
|
| 622 | + |
605 | 623 | if (m < 0) width = width + m;
|
606 | 624 | range_M[num_parts + 1] = range_M[num_parts] + width;
|
| 625 | + |
607 | 626 | num_parts ++;
|
608 | 627 | }
|
609 | 628 | for (i = num_parts; i < MAX_CPU_NUMBER; i++) {
|
@@ -645,9 +664,12 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
|
645 | 664 | if (width < SWITCH_RATIO) {
|
646 | 665 | width = SWITCH_RATIO;
|
647 | 666 | }
|
| 667 | + width = round_up(n, width, GEMM_PREFERED_SIZE); |
| 668 | + |
648 | 669 | n -= width;
|
649 | 670 | if (n < 0) width = width + n;
|
650 | 671 | range_N[num_parts + 1] = range_N[num_parts] + width;
|
| 672 | + |
651 | 673 | num_parts ++;
|
652 | 674 | }
|
653 | 675 | for (j = num_parts; j < MAX_CPU_NUMBER; j++) {
|
|
0 commit comments