From ed7780d65925ee329d687e9dfb57a5dd547e99e6 Mon Sep 17 00:00:00 2001 From: Meghana Date: Wed, 18 Mar 2020 16:07:47 +0530 Subject: [PATCH 1/5] Made some critical changes to small_gemm kernels Details: - In case of GEMM, whenever beta is zero, we need to perform C = alpha *(A * B) instead of C = beta * C + alpha * (A * B) Added conditions to check the value of beta at different levels inside small_gemm kernels and decide whether to perform scaling C with beta or not. -Modified small_gemm kernels to use BLIS specific functions to retrieve different fields of objects. -Calling bli_gemm_check before entering bli_gemm_small to facilitate early return in case of invalid inputs. -For corner cases inside small_gemm kernels, a buffer called f_temp is used to load and store data to and from registers. populating the buffer with zeroes before use. -In bli_gemm_front, datatypes of status and return value from bli_gemm_small are not matching. Corrected the datatype of the variable 'status' inside bli_gemm_front to err_t. Change-Id: I8b52ad55008f028d6c8b7e0d20f746a869d9daea Signed-off-by: Meghana Vankadari AMD-Internal: [CPUPL-689,SWLCSG-104] --- frame/3/gemm/bli_gemm_front.c | 3 +-- kernels/zen/3/bli_gemm_small.c | 16 +++++++--------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/frame/3/gemm/bli_gemm_front.c b/frame/3/gemm/bli_gemm_front.c index bd815a4c82..882c076f5b 100644 --- a/frame/3/gemm/bli_gemm_front.c +++ b/frame/3/gemm/bli_gemm_front.c @@ -73,7 +73,7 @@ void bli_gemm_front return; } -#if 0 + #ifdef BLIS_ENABLE_SMALL_MATRIX // Only handle small problems separately for homogeneous datatypes. if ( bli_obj_dt( a ) == bli_obj_dt( b ) && @@ -83,7 +83,6 @@ void bli_gemm_front err_t status = bli_gemm_small( alpha, a, b, beta, c, cntx, cntl ); if ( status == BLIS_SUCCESS ) return; } -#endif #endif // Alias A, B, and C in case we need to apply transformations. diff --git a/kernels/zen/3/bli_gemm_small.c b/kernels/zen/3/bli_gemm_small.c index b04ffea580..eb2536e914 100644 --- a/kernels/zen/3/bli_gemm_small.c +++ b/kernels/zen/3/bli_gemm_small.c @@ -194,7 +194,6 @@ static err_t bli_sgemm_small beta, c ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); return BLIS_SUCCESS; } @@ -1747,18 +1746,14 @@ static err_t bli_dgemm_small beta, c ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); return BLIS_SUCCESS; } if (N<3) //Implemenation assumes that N is atleast 3. - { - AOCL_DTL_TRACE_EXIT_ERR( - AOCL_DTL_LEVEL_INFO, - "N < 3, cannot be processed by small gemm" - ); + { return BLIS_NOT_YET_IMPLEMENTED; - } + } + #ifdef BLIS_ENABLE_SMALL_MATRIX_ROME if( (L && K) && ((K < D_BLIS_SMALL_MATRIX_K_THRES_ROME) || ((N < BLIS_SMALL_MATRIX_THRES_ROME) && (K < BLIS_SMALL_MATRIX_THRES_ROME)))) @@ -1810,7 +1805,8 @@ static err_t bli_dgemm_small //if true, we should perform C=alpha * A*B operation //instead of C = beta * C + alpha * (A * B) bool is_beta_non_zero = 0; - if(!bli_obj_equals(beta, &BLIS_ZERO)) + + if(!bli_obj_equals(beta, &BLIS_ZERO)) is_beta_non_zero = 1; /* @@ -3847,7 +3843,9 @@ static err_t bli_dgemm_small_atbn //check if beta is zero //if true, we need to perform C = alpha * (A * B) //instead of C = beta * C + alpha * (A * B) + bool is_beta_non_zero = 0; + if(!bli_obj_equals(beta,&BLIS_ZERO)) is_beta_non_zero = 1; From 1c6d455d51c4100b92d0120885be10a999953b4e Mon Sep 17 00:00:00 2001 From: Meghana Vankadari Date: Thu, 18 Feb 2021 13:50:36 +0530 Subject: [PATCH 2/5] Implemented 16x3 based gemm kernel for the case where A has transpose Details: - This implementation does a transpose operation while packing 16xk of A buffer and passes it to 16x3-nn kernel. - The same implementation works for the case where B has transpose. AMD-Internal: [CPUPL-1376] Change-Id: I81f74deb609926598f62c30f5bd6fc80fb1b9a17 --- frame/compat/bla_gemm.c | 1 - kernels/zen/3/bli_gemm_small.c | 1488 +++++++++++++++++++++++++++++++- kernels/zen/bli_kernels_zen.h | 15 + 3 files changed, 1497 insertions(+), 7 deletions(-) diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index e04e48cf50..36617bc4cd 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -218,7 +218,6 @@ void PASTEF77(ch,blasname) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ } - #endif #ifdef BLIS_ENABLE_BLAS diff --git a/kernels/zen/3/bli_gemm_small.c b/kernels/zen/3/bli_gemm_small.c index eb2536e914..4ba9d08b1b 100644 --- a/kernels/zen/3/bli_gemm_small.c +++ b/kernels/zen/3/bli_gemm_small.c @@ -136,16 +136,25 @@ err_t bli_gemm_small if (bli_obj_has_trans( a )) { - if (bli_obj_has_notrans( b )) + if (dt == BLIS_DOUBLE) + { +#ifndef BLIS_ENABLE_MULTITHREADING + // bli_dgemm_small_At is called directly from blas interface for + // sizes within thresholds. + // Avoinding calling of bli_dgemm_small_At from gemm_front + // and directing to native implementation. + return BLIS_NOT_YET_IMPLEMENTED; +#else + return bli_dgemm_small_At(alpha, a, b, beta, c, cntx, cntl); +#endif + } + + if (bli_obj_has_notrans( b )) { if (dt == BLIS_FLOAT) { return bli_sgemm_small_atbn(alpha, a, b, beta, c, cntx, cntl); } - else if (dt == BLIS_DOUBLE) - { - return bli_dgemm_small_atbn(alpha, a, b, beta, c, cntx, cntl); - } } return BLIS_NOT_YET_IMPLEMENTED; @@ -153,7 +162,14 @@ err_t bli_gemm_small if (dt == BLIS_DOUBLE) { - return bli_dgemm_small(alpha, a, b, beta, c, cntx, cntl); +#ifndef BLIS_ENABLE_MULTITHREADING + // bli_dgemm_small is called directly from BLAS interface for sizes within thresholds. + // Avoiding calling bli_dgemm_small from gemm_front and directing to + // native implementation. + return BLIS_NOT_YET_IMPLEMENTED; +#else + return bli_dgemm_small(alpha, a, b, beta, c, cntx, cntl); +#endif } if (dt == BLIS_FLOAT) @@ -4239,5 +4255,1465 @@ static err_t bli_dgemm_small_atbn return BLIS_NONCONFORMAL_DIMENSIONS; } } + +static err_t bli_dgemm_small_At + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ) +{ + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); + + gint_t M = bli_obj_length( c ); // number of rows of Matrix C + gint_t N = bli_obj_width( c ); // number of columns of Matrix C + gint_t K = bli_obj_width_after_trans( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) . + + + if (N<3) //Implemenation assumes that N is atleast 3. + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "N < 3, cannot be processed by small gemm" + ); + return BLIS_NOT_YET_IMPLEMENTED; + } + +/* #ifdef BLIS_ENABLE_SMALL_MATRIX_ROME + * if( (L && K) && ((K < D_BLIS_SMALL_MATRIX_K_THRES_ROME) || ((N < BLIS_SMALL_MATRIX_THRES_ROME) && (K < BLIS_SMALL_MATRIX_THRES_ROME)))) + * #else + * if ((((L) < (D_BLIS_SMALL_MATRIX_THRES * D_BLIS_SMALL_MATRIX_THRES)) + * || ((M < D_BLIS_SMALL_M_RECT_MATRIX_THRES) && (K < D_BLIS_SMALL_K_RECT_MATRIX_THRES))) && ((L!=0) && (K!=0))) + * #endif + */ + if( M && N && K ) + { + guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled. + guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled. + guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C + guint_t row_idx, col_idx, k; + double *A = bli_obj_buffer_at_off(a); // pointer to elements of Matrix A + double *B = bli_obj_buffer_at_off(b); // pointer to elements of Matrix B + double *C = bli_obj_buffer_at_off(c); // pointer to elements of Matrix C + + double *tA = A, *tB = B, *tC = C;//, *tA_pack; + double *tA_packed; // temprorary pointer to hold packed A memory pointer + guint_t row_idx_packed; //packed A memory row index + guint_t lda_packed; //lda of packed A + dim_t tb_inc_row = 1; // row stride of matrix B + dim_t tb_inc_col = ldb; // column stride of matrix B + + double *alpha_cast, *beta_cast; // alpha, beta multiples + alpha_cast = bli_obj_buffer_for_1x1(BLIS_DOUBLE, alpha); + beta_cast = bli_obj_buffer_for_1x1(BLIS_DOUBLE, beta); + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s; + double *D_A_pack = NULL; + rntm_t rntm; + + if( bli_obj_has_trans( b ) ) + { + tb_inc_col = 1; // switch row and column strides + tb_inc_row = ldb; + } + + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm0, ymm1, ymm2, ymm3; + + double result; + double scratch[8] = {0.0}; + + gint_t n_remainder; // If the N is non multiple of 3.(N%3) + gint_t m_remainder; // If the M is non multiple of 16.(M%16) + + //checking whether beta value is zero. + //if true, we should perform C=alpha * A*B operation + //instead of C = beta * C + alpha * (A * B) + bool is_beta_non_zero = 0; + if(!bli_obj_equals(beta, &BLIS_ZERO)) + is_beta_non_zero = 1; + + /* + * This function was using global array to pack part of A input when needed. + * However, using this global array make the function non-reentrant. + * Instead of using a global array we should allocate buffer for each invocation. + * Since the buffer size is too big or stack and doing malloc every time will be too expensive, + * better approach is to get the buffer from the pre-allocated pool and return + * it the pool once we are doing. + * + * In order to get the buffer from pool, we need access to memory broker, + * currently this function is not invoked in such a way that it can receive + * the memory broker (via rntm). Following hack will get the global memory + * broker that can be use it to access the pool. + * + * Note there will be memory allocation at least on first innovation + * as there will not be any pool created for this size. + * Subsequent invocations will just reuse the buffer from the pool. + */ + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + // Get the current size of the buffer pool for A block packing. + // We will use the same size to avoid pool re-initliazaton + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + // + // This kernel assumes that "A" will be unpackged if N <= 3. + // Usually this range (N <= 3) is handled by SUP, however, + // if SUP is disabled or for any other condition if we do + // enter this kernel with N <= 3, we want to make sure that + // "A" remains unpacked. + // + // If this check is removed it will result in the crash as + // reported in CPUPL-587. + // + + if ((N < 3) || ((D_MR * K) << 3) > buffer_size) + { + required_packing_A = 0; + return BLIS_NOT_YET_IMPLEMENTED; + } + + if (required_packing_A == 1) + { +#ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_dgemm_small: Requesting mem pool block of size %lu\n", buffer_size); +#endif + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + } + + /* + * The computation loop runs for D_MRxN columns of C matrix, thus + * accessing the D_MRxK A matrix data and KxNR B matrix data. + * The computation is organized as inner loops of dimension D_MRxNR. + */ + // Process D_MR rows of C matrix at a time. + for (row_idx = 0; (row_idx + (D_MR - 1)) < M; row_idx += D_MR) + { + + tA = A + row_idx * lda; + tA_packed = D_A_pack; + lda_packed = D_MR; + + // Pack 16xk of matrix A into buffer + // continuous access for A and strided stores to B + for(inc_t x = 0; (x) < 4; x += 1) + { + double* tA_temp = tA; + + for(k = 0; (k+3) < K; k += 4) + { + ymm0 = _mm256_loadu_pd(tA_temp + 0 * lda); + ymm1 = _mm256_loadu_pd(tA_temp + 1 * lda); + ymm2 = _mm256_loadu_pd(tA_temp + 2 * lda); + ymm3 = _mm256_loadu_pd(tA_temp + 3 * lda); + + ymm10 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm11 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm12 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm13 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm0 = _mm256_permute2f128_pd(ymm10, ymm12, 0x20); + ymm1 = _mm256_permute2f128_pd(ymm11, ymm13, 0x20); + + ymm2 = _mm256_permute2f128_pd(ymm10, ymm12, 0x31); + ymm3 = _mm256_permute2f128_pd(ymm11, ymm13, 0x31); + + _mm256_storeu_pd(tA_packed + 0 * lda_packed, ymm0); + _mm256_storeu_pd(tA_packed + 1 * lda_packed, ymm1); + _mm256_storeu_pd(tA_packed + 2 * lda_packed, ymm2); + _mm256_storeu_pd(tA_packed + 3 * lda_packed, ymm3); + + tA_temp += 4; + tA_packed += 4 * lda_packed; + } + + for(; k < K; k += 1) + { + tA_packed[0] = tA_temp[0 * lda]; + tA_packed[1] = tA_temp[1 * lda]; + tA_packed[2] = tA_temp[2 * lda]; + tA_packed[3] = tA_temp[3 * lda]; + + tA_temp += 1; + tA_packed += lda_packed; + } + + tA += 4 * lda; + tA_packed = D_A_pack +(x +1) * 4; + } + + tA_packed = D_A_pack; + row_idx_packed = 0; + lda_packed = D_MR; + + // Process NR columns of C matrix at a time. + for (col_idx = 0; (col_idx + (NR - 1)) < N; col_idx += NR) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + +#ifdef BLIS_ENABLE_PREFETCH + _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); +#endif + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + // ymm4 += ymm0 * ymm3; + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + // ymm8 += ymm1 * ymm3; + ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8); + // ymm12 += ymm2 * ymm3; + ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12); + + ymm3 = _mm256_loadu_pd(tA + 4); + // ymm5 += ymm0 * ymm3; + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + // ymm9 += ymm1 * ymm3; + ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9); + // ymm13 += ymm2 * ymm3; + ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13); + + ymm3 = _mm256_loadu_pd(tA + 8); + // ymm6 += ymm0 * ymm3; + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + // ymm10 += ymm1 * ymm3; + ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10); + // ymm14 += ymm2 * ymm3; + ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14); + + ymm3 = _mm256_loadu_pd(tA + 12); + // ymm7 += ymm0 * ymm3; + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + // ymm11 += ymm1 * ymm3; + ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11); + // ymm15 += ymm2 * ymm3; + ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15); + + tA += lda_packed; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm4 = _mm256_mul_pd(ymm4, ymm0); + ymm5 = _mm256_mul_pd(ymm5, ymm0); + ymm6 = _mm256_mul_pd(ymm6, ymm0); + ymm7 = _mm256_mul_pd(ymm7, ymm0); + ymm8 = _mm256_mul_pd(ymm8, ymm0); + ymm9 = _mm256_mul_pd(ymm9, ymm0); + ymm10 = _mm256_mul_pd(ymm10, ymm0); + ymm11 = _mm256_mul_pd(ymm11, ymm0); + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm14, ymm0); + ymm15 = _mm256_mul_pd(ymm15, ymm0); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + ymm2 = _mm256_loadu_pd(tC + 12); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + + // multiply C by beta and accumulate, col 2. + double* ttC = tC + ldc; + ymm2 = _mm256_loadu_pd(ttC); + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_pd(ttC + 8); + ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); + ymm2 = _mm256_loadu_pd(ttC + 12); + ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11); + + // multiply C by beta and accumulate, col 3. + ttC += ldc; + ymm2 = _mm256_loadu_pd(ttC); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_pd(ttC + 8); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + ymm2 = _mm256_loadu_pd(ttC + 12); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + } + _mm256_storeu_pd(tC, ymm4); + _mm256_storeu_pd(tC + 4, ymm5); + _mm256_storeu_pd(tC + 8, ymm6); + _mm256_storeu_pd(tC + 12, ymm7); + + tC += ldc; + + _mm256_storeu_pd(tC, ymm8); + _mm256_storeu_pd(tC + 4, ymm9); + _mm256_storeu_pd(tC + 8, ymm10); + _mm256_storeu_pd(tC + 12, ymm11); + + tC += ldc; + + _mm256_storeu_pd(tC, ymm12); + _mm256_storeu_pd(tC + 4, ymm13); + _mm256_storeu_pd(tC + 8, ymm14); + _mm256_storeu_pd(tC + 12, ymm15); + + } + n_remainder = N - col_idx; + + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm8 = _mm256_fmadd_pd(ymm0, ymm3, ymm8); + ymm12 = _mm256_fmadd_pd(ymm1, ymm3, ymm12); + + ymm3 = _mm256_loadu_pd(tA + 4); + ymm9 = _mm256_fmadd_pd(ymm0, ymm3, ymm9); + ymm13 = _mm256_fmadd_pd(ymm1, ymm3, ymm13); + + ymm3 = _mm256_loadu_pd(tA + 8); + ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10); + ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14); + + ymm3 = _mm256_loadu_pd(tA + 12); + ymm11 = _mm256_fmadd_pd(ymm0, ymm3, ymm11); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tA += lda_packed; + + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm8 = _mm256_mul_pd(ymm8, ymm0); + ymm9 = _mm256_mul_pd(ymm9, ymm0); + ymm10 = _mm256_mul_pd(ymm10, ymm0); + ymm11 = _mm256_mul_pd(ymm11, ymm0); + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm14, ymm0); + ymm15 = _mm256_mul_pd(ymm15, ymm0); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate, col 1. + ymm2 = _mm256_loadu_pd(tC + 0); + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); + ymm2 = _mm256_loadu_pd(tC + 12); + ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11); + + // multiply C by beta and accumulate, col 2. + double *ttC = tC + ldc; + + ymm2 = _mm256_loadu_pd(ttC); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_pd(ttC + 8); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + ymm2 = _mm256_loadu_pd(ttC + 12); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + } + + _mm256_storeu_pd(tC + 0, ymm8); + _mm256_storeu_pd(tC + 4, ymm9); + _mm256_storeu_pd(tC + 8, ymm10); + _mm256_storeu_pd(tC + 12, ymm11); + + tC += ldc; + + _mm256_storeu_pd(tC, ymm12); + _mm256_storeu_pd(tC + 4, ymm13); + _mm256_storeu_pd(tC + 8, ymm14); + _mm256_storeu_pd(tC + 12, ymm15); + col_idx += 2; + } + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm12 = _mm256_fmadd_pd(ymm0, ymm3, ymm12); + + ymm3 = _mm256_loadu_pd(tA + 4); + ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13); + + ymm3 = _mm256_loadu_pd(tA + 8); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + ymm3 = _mm256_loadu_pd(tA + 12); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + tA += lda_packed; + + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm14, ymm0); + ymm15 = _mm256_mul_pd(ymm15, ymm0); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC + 0); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + ymm2 = _mm256_loadu_pd(tC + 12); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + } + + _mm256_storeu_pd(tC + 0, ymm12); + _mm256_storeu_pd(tC + 4, ymm13); + _mm256_storeu_pd(tC + 8, ymm14); + _mm256_storeu_pd(tC + 12, ymm15); + } + } + + m_remainder = M - row_idx; + + if (m_remainder >= 12) + { + m_remainder -= 12; + + tA = A + row_idx * lda; + tA_packed = D_A_pack; + lda_packed = 12; + + // Pack 12xk of matrix A into buffer + // continuous access for A and strided stores to B + for(inc_t x = 0; (x) < 3; x += 1) + { + double* tA_temp = tA; + + for(k = 0; (k+3) < K; k += 4) + { + ymm0 = _mm256_loadu_pd(tA_temp + 0 * lda); + ymm1 = _mm256_loadu_pd(tA_temp + 1 * lda); + ymm2 = _mm256_loadu_pd(tA_temp + 2 * lda); + ymm3 = _mm256_loadu_pd(tA_temp + 3 * lda); + + ymm10 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm11 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm12 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm13 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm0 = _mm256_permute2f128_pd(ymm10, ymm12, 0x20); + ymm1 = _mm256_permute2f128_pd(ymm11, ymm13, 0x20); + + ymm2 = _mm256_permute2f128_pd(ymm10, ymm12, 0x31); + ymm3 = _mm256_permute2f128_pd(ymm11, ymm13, 0x31); + + _mm256_storeu_pd(tA_packed + 0 * lda_packed, ymm0); + _mm256_storeu_pd(tA_packed + 1 * lda_packed, ymm1); + _mm256_storeu_pd(tA_packed + 2 * lda_packed, ymm2); + _mm256_storeu_pd(tA_packed + 3 * lda_packed, ymm3); + + tA_temp += 4; + tA_packed += 4 * lda_packed; + } + + for(; k < K; k += 1) + { + tA_packed[0] = tA_temp[0 * lda]; + tA_packed[1] = tA_temp[1 * lda]; + tA_packed[2] = tA_temp[2 * lda]; + tA_packed[3] = tA_temp[3 * lda]; + + tA_temp += 1; + tA_packed += lda_packed; + } + + tA += 4 * lda; + tA_packed = D_A_pack +(x +1) * 4; + } + + tA_packed = D_A_pack; + row_idx_packed = 0; + lda_packed = 12; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + // ymm4 += ymm0 * ymm3; + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + // ymm8 += ymm1 * ymm3; + ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8); + // ymm12 += ymm2 * ymm3; + ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12); + + ymm3 = _mm256_loadu_pd(tA + 4); + // ymm5 += ymm0 * ymm3; + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + // ymm9 += ymm1 * ymm3; + ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9); + // ymm13 += ymm2 * ymm3; + ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13); + + ymm3 = _mm256_loadu_pd(tA + 8); + // ymm6 += ymm0 * ymm3; + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + // ymm10 += ymm1 * ymm3; + ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10); + // ymm14 += ymm2 * ymm3; + ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14); + + tA += lda_packed; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm4 = _mm256_mul_pd(ymm4, ymm0); + ymm5 = _mm256_mul_pd(ymm5, ymm0); + ymm6 = _mm256_mul_pd(ymm6, ymm0); + ymm8 = _mm256_mul_pd(ymm8, ymm0); + ymm9 = _mm256_mul_pd(ymm9, ymm0); + ymm10 = _mm256_mul_pd(ymm10, ymm0); + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + + // multiply C by beta and accumulate. + double *ttC = tC +ldc; + ymm2 = _mm256_loadu_pd(ttC); + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_pd(ttC + 8); + ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); + + // multiply C by beta and accumulate. + ttC += ldc; + ymm2 = _mm256_loadu_pd(ttC); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_pd(ttC + 8); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + } + _mm256_storeu_pd(tC, ymm4); + _mm256_storeu_pd(tC + 4, ymm5); + _mm256_storeu_pd(tC + 8, ymm6); + + tC += ldc; + + _mm256_storeu_pd(tC, ymm8); + _mm256_storeu_pd(tC + 4, ymm9); + _mm256_storeu_pd(tC + 8, ymm10); + + tC += ldc; + + _mm256_storeu_pd(tC, ymm12); + _mm256_storeu_pd(tC + 4, ymm13); + _mm256_storeu_pd(tC + 8, ymm14); + } + n_remainder = N - col_idx; + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm8 = _mm256_fmadd_pd(ymm0, ymm3, ymm8); + ymm12 = _mm256_fmadd_pd(ymm1, ymm3, ymm12); + + ymm3 = _mm256_loadu_pd(tA + 4); + ymm9 = _mm256_fmadd_pd(ymm0, ymm3, ymm9); + ymm13 = _mm256_fmadd_pd(ymm1, ymm3, ymm13); + + ymm3 = _mm256_loadu_pd(tA + 8); + ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10); + ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14); + + tA += lda_packed; + + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm8 = _mm256_mul_pd(ymm8, ymm0); + ymm9 = _mm256_mul_pd(ymm9, ymm0); + ymm10 = _mm256_mul_pd(ymm10, ymm0); + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC + 0); + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); + + double *ttC = tC + ldc; + + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(ttC); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_pd(ttC + 8); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + } + _mm256_storeu_pd(tC + 0, ymm8); + _mm256_storeu_pd(tC + 4, ymm9); + _mm256_storeu_pd(tC + 8, ymm10); + + tC += ldc; + + _mm256_storeu_pd(tC, ymm12); + _mm256_storeu_pd(tC + 4, ymm13); + _mm256_storeu_pd(tC + 8, ymm14); + + col_idx += 2; + } + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm12 = _mm256_fmadd_pd(ymm0, ymm3, ymm12); + + ymm3 = _mm256_loadu_pd(tA + 4); + ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13); + + ymm3 = _mm256_loadu_pd(tA + 8); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tA += lda_packed; + + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC + 0); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + } + _mm256_storeu_pd(tC + 0, ymm12); + _mm256_storeu_pd(tC + 4, ymm13); + _mm256_storeu_pd(tC + 8, ymm14); + } + + row_idx += 12; + } + + if (m_remainder >= 8) + { + m_remainder -= 8; + + tA = A + row_idx * lda; + tA_packed = D_A_pack; + lda_packed = 8; + + // Pack 8xk of matrix A into buffer + // continuous access for A and strided stores to B + for(inc_t x = 0; (x) < 2; x += 1) + { + double* tA_temp = tA; + + for(k = 0; (k+3) < K; k += 4) + { + ymm0 = _mm256_loadu_pd(tA_temp + 0 * lda); + ymm1 = _mm256_loadu_pd(tA_temp + 1 * lda); + ymm2 = _mm256_loadu_pd(tA_temp + 2 * lda); + ymm3 = _mm256_loadu_pd(tA_temp + 3 * lda); + + ymm10 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm11 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm12 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm13 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm0 = _mm256_permute2f128_pd(ymm10, ymm12, 0x20); + ymm1 = _mm256_permute2f128_pd(ymm11, ymm13, 0x20); + + ymm2 = _mm256_permute2f128_pd(ymm10, ymm12, 0x31); + ymm3 = _mm256_permute2f128_pd(ymm11, ymm13, 0x31); + + _mm256_storeu_pd(tA_packed + 0 * lda_packed, ymm0); + _mm256_storeu_pd(tA_packed + 1 * lda_packed, ymm1); + _mm256_storeu_pd(tA_packed + 2 * lda_packed, ymm2); + _mm256_storeu_pd(tA_packed + 3 * lda_packed, ymm3); + + tA_temp += 4; + tA_packed += 4 * lda_packed; + } + + for(; k < K; k += 1) + { + tA_packed[0] = tA_temp[0 * lda]; + tA_packed[1] = tA_temp[1 * lda]; + tA_packed[2] = tA_temp[2 * lda]; + tA_packed[3] = tA_temp[3 * lda]; + + tA_temp += 1; + tA_packed += lda_packed; + } + + tA += 4 * lda; + tA_packed = D_A_pack +(x +1) * 4; + } + + tA_packed = D_A_pack; + row_idx_packed = 0; + lda_packed = 8; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm6 = _mm256_fmadd_pd(ymm1, ymm3, ymm6); + ymm8 = _mm256_fmadd_pd(ymm2, ymm3, ymm8); + + ymm3 = _mm256_loadu_pd(tA + 4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9); + + tA += lda_packed; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm4 = _mm256_mul_pd(ymm4, ymm0); + ymm5 = _mm256_mul_pd(ymm5, ymm0); + ymm6 = _mm256_mul_pd(ymm6, ymm0); + ymm7 = _mm256_mul_pd(ymm7, ymm0); + ymm8 = _mm256_mul_pd(ymm8, ymm0); + ymm9 = _mm256_mul_pd(ymm9, ymm0); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + + double* ttC = tC + ldc; + + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(ttC); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + + ttC += ldc; + + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(ttC); + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); + } + + _mm256_storeu_pd(tC, ymm4); + _mm256_storeu_pd(tC + 4, ymm5); + + tC += ldc; + _mm256_storeu_pd(tC, ymm6); + _mm256_storeu_pd(tC + 4, ymm7); + + tC += ldc; + _mm256_storeu_pd(tC, ymm8); + _mm256_storeu_pd(tC + 4, ymm9); + + } + n_remainder = N - col_idx; + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm6 = _mm256_fmadd_pd(ymm1, ymm3, ymm6); + + ymm3 = _mm256_loadu_pd(tA + 4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tA += lda_packed; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm4 = _mm256_mul_pd(ymm4, ymm0); + ymm5 = _mm256_mul_pd(ymm5, ymm0); + ymm6 = _mm256_mul_pd(ymm6, ymm0); + ymm7 = _mm256_mul_pd(ymm7, ymm0); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + + double* ttC = tC + ldc; + + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(ttC); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + } + _mm256_storeu_pd(tC, ymm4); + _mm256_storeu_pd(tC + 4, ymm5); + + tC += ldc; + _mm256_storeu_pd(tC, ymm6); + _mm256_storeu_pd(tC + 4, ymm7); + + col_idx += 2; + + } + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm3 = _mm256_loadu_pd(tA + 4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + tA += lda_packed; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); + + ymm4 = _mm256_mul_pd(ymm4, ymm0); + ymm5 = _mm256_mul_pd(ymm5, ymm0); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + } + _mm256_storeu_pd(tC, ymm4); + _mm256_storeu_pd(tC + 4, ymm5); + + } + + row_idx += 8; + } + + if (m_remainder >= 4) + { + //printf("HERE\n"); + m_remainder -= 4; + + tA = A + row_idx * lda; + tA_packed = D_A_pack; + lda_packed = 4; + + // Pack 4xk of matrix A into buffer + // continuous access for A and strided stores to B +// for(inc_t x = 0; (x) < 1; x += 1) + { + double* tA_temp = tA; + + for(k = 0; (k+3) < K; k += 4) + { + ymm0 = _mm256_loadu_pd(tA_temp + 0 * lda); + ymm1 = _mm256_loadu_pd(tA_temp + 1 * lda); + ymm2 = _mm256_loadu_pd(tA_temp + 2 * lda); + ymm3 = _mm256_loadu_pd(tA_temp + 3 * lda); + + ymm10 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm11 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm12 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm13 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm0 = _mm256_permute2f128_pd(ymm10, ymm12, 0x20); + ymm1 = _mm256_permute2f128_pd(ymm11, ymm13, 0x20); + + ymm2 = _mm256_permute2f128_pd(ymm10, ymm12, 0x31); + ymm3 = _mm256_permute2f128_pd(ymm11, ymm13, 0x31); + + _mm256_storeu_pd(tA_packed + 0 * lda_packed, ymm0); + _mm256_storeu_pd(tA_packed + 1 * lda_packed, ymm1); + _mm256_storeu_pd(tA_packed + 2 * lda_packed, ymm2); + _mm256_storeu_pd(tA_packed + 3 * lda_packed, ymm3); + + tA_temp += 4; + tA_packed += 4 * lda_packed; + } + + for(; k < K; k += 1) + { + tA_packed[0] = tA_temp[0 * lda]; + tA_packed[1] = tA_temp[1 * lda]; + tA_packed[2] = tA_temp[2 * lda]; + tA_packed[3] = tA_temp[3 * lda]; + + tA_temp += 1; + tA_packed += lda_packed; + } + + tA += 4 * lda; + tA_packed = D_A_pack + 4; + } + + tA_packed = D_A_pack; + row_idx_packed = 0; + lda_packed = 4; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm2, ymm3, ymm6); + + tA += lda_packed; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm4 = _mm256_mul_pd(ymm4, ymm0); + ymm5 = _mm256_mul_pd(ymm5, ymm0); + ymm6 = _mm256_mul_pd(ymm6, ymm0); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + + double* ttC = tC + ldc; + + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(ttC); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + + ttC += ldc; + + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(ttC); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + } + _mm256_storeu_pd(tC, ymm4); + + tC += ldc; + _mm256_storeu_pd(tC, ymm5); + + tC += ldc; + _mm256_storeu_pd(tC, ymm6); + } + n_remainder = N - col_idx; + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda_packed; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm4 = _mm256_mul_pd(ymm4, ymm0); + ymm5 = _mm256_mul_pd(ymm5, ymm0); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + + double* ttC = tC + ldc; + + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(ttC); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + } + _mm256_storeu_pd(tC, ymm4); + + tC += ldc; + _mm256_storeu_pd(tC, ymm5); + + col_idx += 2; + + } + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + ymm4 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tA += lda_packed; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); + + ymm4 = _mm256_mul_pd(ymm4, ymm0); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + + } + _mm256_storeu_pd(tC, ymm4); + + } + + row_idx += 4; + } + + if (m_remainder) + { + if(bli_obj_has_notrans(b)) + { + for (; row_idx < M; row_idx += 1) + { + for (col_idx = 0; col_idx < N; col_idx += 1) + { + tA = A + row_idx * lda; + tB = B + col_idx * ldb; + tC = C + col_idx * ldc + row_idx; + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + + for (k = 0; (k + 3) < K; k += 4) + { + ymm0 = _mm256_loadu_pd(tB + 0); + ymm3 = _mm256_loadu_pd(tA); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tA += 4; + tB += 4; + } + + // if K is not a multiple of 4, padding is done before load using temproary array. + if (k < K) + { + int iter; + double data_feeder[4] = { 0.0 }; + + for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter]; + ymm0 = _mm256_loadu_pd(data_feeder); + + for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter]; + ymm3 = _mm256_loadu_pd(data_feeder); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + } + + //horizontal addition and storage of the data. + ymm4 = _mm256_hadd_pd(ymm4, ymm4); + _mm256_storeu_pd(scratch, ymm4); + result = scratch[0] + scratch[2]; + result *= (*alpha_cast); + if(is_beta_non_zero) + tC[0] = result + tC[0] * (*beta_cast); + else + tC[0] = result; + } + } + + } + else + { + double result; + for(; row_idx < M; row_idx += 1) + { + for(col_idx = 0; col_idx < N; col_idx += 1) + { + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx * lda; + + result = 0; + for(k = 0; k < K; k++) + { + result += (*tA) * (*tB); + + tA += 1; + tB += tb_inc_row; + } + + result *= (*alpha_cast); + if(is_beta_non_zero) + (*tC) = (*tC) * (*beta_cast) + result; + else + (*tC) = result; + } + } + } + } + + // Return the buffer to pool + if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) { +#ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_dgemm_small_At(): releasing mem pool block\n" ); +#endif + bli_membrk_release(&rntm, + &local_mem_buf_A_s); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return BLIS_SUCCESS; + } + else + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "Invalid dimesions for dgemm_small_At." + ); + return BLIS_NONCONFORMAL_DIMENSIONS; + } +}; #endif diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 161bcef1aa..cf649f8bf0 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -199,3 +199,18 @@ GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_2x4n ) GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_1x4n ) GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_3x2 ) GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_3x1 ) + + +// gemm square matrix size friendly implementation +err_t bli_gemm_sqp + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ); + + From c597fa677bb04dd786a17a262ccaa915d778177b Mon Sep 17 00:00:00 2001 From: Meghana Vankadari Date: Tue, 16 Feb 2021 10:29:13 +0530 Subject: [PATCH 3/5] Disabled calling of bli_dgemm_small from gemm_front Details: - Decision logic to choose small_gemm has been moved to blas interface. - Redirecting all the calls to small_gemm from gemm_front to native implementation. AMD-Internal: [CPUPL-1376] Change-Id: I6490f67113e9f7c272269f441c86f2a0b3c89a53 --- kernels/zen/3/bli_gemm_small.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kernels/zen/3/bli_gemm_small.c b/kernels/zen/3/bli_gemm_small.c index 4ba9d08b1b..5d9bb62016 100644 --- a/kernels/zen/3/bli_gemm_small.c +++ b/kernels/zen/3/bli_gemm_small.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2017 - 2020, Advanced Micro Devices, Inc. + Copyright (C) 2017-2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -162,7 +162,7 @@ err_t bli_gemm_small if (dt == BLIS_DOUBLE) { -#ifndef BLIS_ENABLE_MULTITHREADING +#ifndef BLIS_ENABLE_MULTITHREADING // bli_dgemm_small is called directly from BLAS interface for sizes within thresholds. // Avoiding calling bli_dgemm_small from gemm_front and directing to // native implementation. From ac2a50fc4e062e7353b0a735c20eae4d081a7dcb Mon Sep 17 00:00:00 2001 From: Meghana Vankadari Date: Thu, 6 May 2021 13:23:18 +0530 Subject: [PATCH 4/5] Fixed blastest failure for haswell configuration Details: - Placed optimized version of BLAS DGEMM, ZGEMM definitions under BLIS_CONFIG_EPYC as they use gemm small which are defined only for zen family configurations. - Added code to query and set cntx in gemv and trsv framework before cntx is referred for any function pointers to avoid querying from NULL pointer. AMD-Internal: [CPUPL-1562] Change-Id: I977d028ec4ddb57dcdc70e443e7708f36c01cca9 --- frame/2/gemv/bli_gemv_unf_var2.c | 4 + frame/2/trsv/bli_trsv_unf_var2.c | 350 ++++++++++++++++--------------- 2 files changed, 181 insertions(+), 173 deletions(-) diff --git a/frame/2/gemv/bli_gemv_unf_var2.c b/frame/2/gemv/bli_gemv_unf_var2.c index fe7702e4c3..49a1315824 100644 --- a/frame/2/gemv/bli_gemv_unf_var2.c +++ b/frame/2/gemv/bli_gemv_unf_var2.c @@ -52,6 +52,10 @@ void PASTEMAC(ch,varname) \ ) \ { \ const num_t dt = PASTEMAC(ch,type); \ +\ + bli_init_once(); \ +\ + if(cntx == NULL) cntx = bli_gks_query_cntx(); \ \ ctype* zero = PASTEMAC(ch,0); \ ctype* A1; \ diff --git a/frame/2/trsv/bli_trsv_unf_var2.c b/frame/2/trsv/bli_trsv_unf_var2.c index 10741d2918..1860e8d6b7 100644 --- a/frame/2/trsv/bli_trsv_unf_var2.c +++ b/frame/2/trsv/bli_trsv_unf_var2.c @@ -49,179 +49,183 @@ void PASTEMAC(ch,varname) \ cntx_t* cntx \ ) \ { \ - const num_t dt = PASTEMAC(ch,type); \ -\ - ctype* minus_one = PASTEMAC(ch,m1); \ - ctype* A01; \ - ctype* A11; \ - ctype* A21; \ - ctype* a01; \ - ctype* alpha11; \ - ctype* a21; \ - ctype* x0; \ - ctype* x1; \ - ctype* x2; \ - ctype* x01; \ - ctype* chi11; \ - ctype* x21; \ - ctype alpha11_conj; \ - ctype minus_chi11; \ - dim_t iter, i, k, j, l; \ - dim_t b_fuse, f; \ - dim_t n_ahead, f_ahead; \ - inc_t rs_at, cs_at; \ - uplo_t uploa_trans; \ - conj_t conja; \ -\ - /* x = alpha * x; */ \ - PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - m, \ - alpha, \ - x, incx, \ - cntx, \ - NULL \ - ); \ -\ - if ( bli_does_notrans( transa ) ) \ - { \ - rs_at = rs_a; \ - cs_at = cs_a; \ - uploa_trans = uploa; \ - } \ - else /* if ( bli_does_trans( transa ) ) */ \ - { \ - rs_at = cs_a; \ - cs_at = rs_a; \ - uploa_trans = bli_uplo_toggled( uploa ); \ - } \ -\ - conja = bli_extract_conj( transa ); \ -\ - PASTECH(ch,axpyf_ker_ft) kfp_af; \ -\ - /* Query the context for the kernel function pointer and fusing factor. */ \ - kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); \ - b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); \ -\ - /* We reduce all of the possible cases down to just lower/upper. */ \ - if ( bli_is_upper( uploa_trans ) ) \ - { \ - for ( iter = 0; iter < m; iter += f ) \ - { \ - f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); \ - i = m - iter - f; \ - n_ahead = i; \ - A11 = a + (i )*rs_at + (i )*cs_at; \ - A01 = a + (0 )*rs_at + (i )*cs_at; \ - x1 = x + (i )*incx; \ - x0 = x + (0 )*incx; \ -\ - /* x1 = x1 / triu( A11 ); */ \ - for ( k = 0; k < f; ++k ) \ - { \ - l = f - k - 1; \ - f_ahead = l; \ - alpha11 = A11 + (l )*rs_at + (l )*cs_at; \ - a01 = A11 + (0 )*rs_at + (l )*cs_at; \ - chi11 = x1 + (l )*incx; \ - x01 = x1 + (0 )*incx; \ -\ - /* chi11 = chi11 / alpha11; */ \ - if ( bli_is_nonunit_diag( diaga ) ) \ - { \ - PASTEMAC(ch,copycjs)( conja, *alpha11, alpha11_conj ); \ - PASTEMAC(ch,invscals)( alpha11_conj, *chi11 ); \ - } \ -\ - /* x01 = x01 - chi11 * a01; */ \ - PASTEMAC(ch,neg2s)( *chi11, minus_chi11 ); \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( j = 0; j < f_ahead; ++j ) \ - PASTEMAC(ch,axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); \ - } \ - else \ - { \ - for ( j = 0; j < f_ahead; ++j ) \ - PASTEMAC(ch,axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); \ - } \ - } \ -\ - /* x0 = x0 - A01 * x1; */ \ - kfp_af \ - ( \ - conja, \ - BLIS_NO_CONJUGATE, \ - n_ahead, \ - f, \ - minus_one, \ - A01, rs_at, cs_at, \ - x1, incx, \ - x0, incx, \ - cntx \ - ); \ - } \ - } \ - else /* if ( bli_is_lower( uploa_trans ) ) */ \ - { \ - for ( iter = 0; iter < m; iter += f ) \ - { \ - f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); \ - i = iter; \ - n_ahead = m - iter - f; \ - A11 = a + (i )*rs_at + (i )*cs_at; \ - A21 = a + (i+f)*rs_at + (i )*cs_at; \ - x1 = x + (i )*incx; \ - x2 = x + (i+f)*incx; \ -\ - /* x1 = x1 / tril( A11 ); */ \ - for ( k = 0; k < f; ++k ) \ - { \ - l = k; \ - f_ahead = f - k - 1; \ - alpha11 = A11 + (l )*rs_at + (l )*cs_at; \ - a21 = A11 + (l+1)*rs_at + (l )*cs_at; \ - chi11 = x1 + (l )*incx; \ - x21 = x1 + (l+1)*incx; \ -\ - /* chi11 = chi11 / alpha11; */ \ - if ( bli_is_nonunit_diag( diaga ) ) \ - { \ - PASTEMAC(ch,copycjs)( conja, *alpha11, alpha11_conj ); \ - PASTEMAC(ch,invscals)( alpha11_conj, *chi11 ); \ - } \ -\ - /* x21 = x21 - chi11 * a21; */ \ - PASTEMAC(ch,neg2s)( *chi11, minus_chi11 ); \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( j = 0; j < f_ahead; ++j ) \ - PASTEMAC(ch,axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); \ - } \ - else \ - { \ - for ( j = 0; j < f_ahead; ++j ) \ - PASTEMAC(ch,axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); \ - } \ - } \ -\ - /* x2 = x2 - A21 * x1; */ \ - kfp_af \ - ( \ - conja, \ - BLIS_NO_CONJUGATE, \ - n_ahead, \ - f, \ - minus_one, \ - A21, rs_at, cs_at, \ - x1, incx, \ - x2, incx, \ - cntx \ - ); \ - } \ - } \ + const num_t dt = PASTEMAC(ch,type); \ +\ + bli_init_once(); \ +\ + if( cntx == NULL ) cntx = bli_gks_query_cntx(); \ +\ + ctype* minus_one = PASTEMAC(ch,m1); \ + ctype* A01; \ + ctype* A11; \ + ctype* A21; \ + ctype* a01; \ + ctype* alpha11; \ + ctype* a21; \ + ctype* x0; \ + ctype* x1; \ + ctype* x2; \ + ctype* x01; \ + ctype* chi11; \ + ctype* x21; \ + ctype alpha11_conj; \ + ctype minus_chi11; \ + dim_t iter, i, k, j, l; \ + dim_t b_fuse, f; \ + dim_t n_ahead, f_ahead; \ + inc_t rs_at, cs_at; \ + uplo_t uploa_trans; \ + conj_t conja; \ +\ + /* x = alpha * x; */ \ + PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + m, \ + alpha, \ + x, incx, \ + cntx, \ + NULL \ + ); \ +\ + if ( bli_does_notrans( transa ) ) \ + { \ + rs_at = rs_a; \ + cs_at = cs_a; \ + uploa_trans = uploa; \ + } \ + else /* if ( bli_does_trans( transa ) ) */ \ + { \ + rs_at = cs_a; \ + cs_at = rs_a; \ + uploa_trans = bli_uplo_toggled( uploa ); \ + } \ +\ + conja = bli_extract_conj( transa ); \ +\ + PASTECH(ch,axpyf_ker_ft) kfp_af; \ +\ + /* Query the context for the kernel function pointer and fusing factor. */ \ + kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); \ + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); \ +\ + /* We reduce all of the possible cases down to just lower/upper. */ \ + if ( bli_is_upper( uploa_trans ) ) \ + { \ + for ( iter = 0; iter < m; iter += f ) \ + { \ + f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); \ + i = m - iter - f; \ + n_ahead = i; \ + A11 = a + (i )*rs_at + (i )*cs_at; \ + A01 = a + (0 )*rs_at + (i )*cs_at; \ + x1 = x + (i )*incx; \ + x0 = x + (0 )*incx; \ +\ + /* x1 = x1 / triu( A11 ); */ \ + for ( k = 0; k < f; ++k ) \ + { \ + l = f - k - 1; \ + f_ahead = l; \ + alpha11 = A11 + (l )*rs_at + (l )*cs_at; \ + a01 = A11 + (0 )*rs_at + (l )*cs_at; \ + chi11 = x1 + (l )*incx; \ + x01 = x1 + (0 )*incx; \ +\ + /* chi11 = chi11 / alpha11; */ \ + if ( bli_is_nonunit_diag( diaga ) ) \ + { \ + PASTEMAC(ch,copycjs)( conja, *alpha11, alpha11_conj ); \ + PASTEMAC(ch,invscals)( alpha11_conj, *chi11 ); \ + } \ +\ + /* x01 = x01 - chi11 * a01; */ \ + PASTEMAC(ch,neg2s)( *chi11, minus_chi11 ); \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( j = 0; j < f_ahead; ++j ) \ + PASTEMAC(ch,axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); \ + } \ + else \ + { \ + for ( j = 0; j < f_ahead; ++j ) \ + PASTEMAC(ch,axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); \ + } \ + } \ +\ + /* x0 = x0 - A01 * x1; */ \ + kfp_af \ + ( \ + conja, \ + BLIS_NO_CONJUGATE, \ + n_ahead, \ + f, \ + minus_one, \ + A01, rs_at, cs_at, \ + x1, incx, \ + x0, incx, \ + cntx \ + ); \ + } \ + } \ + else /* if ( bli_is_lower( uploa_trans ) ) */ \ + { \ + for ( iter = 0; iter < m; iter += f ) \ + { \ + f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); \ + i = iter; \ + n_ahead = m - iter - f; \ + A11 = a + (i )*rs_at + (i )*cs_at; \ + A21 = a + (i+f)*rs_at + (i )*cs_at; \ + x1 = x + (i )*incx; \ + x2 = x + (i+f)*incx; \ +\ + /* x1 = x1 / tril( A11 ); */ \ + for ( k = 0; k < f; ++k ) \ + { \ + l = k; \ + f_ahead = f - k - 1; \ + alpha11 = A11 + (l )*rs_at + (l )*cs_at; \ + a21 = A11 + (l+1)*rs_at + (l )*cs_at; \ + chi11 = x1 + (l )*incx; \ + x21 = x1 + (l+1)*incx; \ +\ + /* chi11 = chi11 / alpha11; */ \ + if ( bli_is_nonunit_diag( diaga ) ) \ + { \ + PASTEMAC(ch,copycjs)( conja, *alpha11, alpha11_conj ); \ + PASTEMAC(ch,invscals)( alpha11_conj, *chi11 ); \ + } \ +\ + /* x21 = x21 - chi11 * a21; */ \ + PASTEMAC(ch,neg2s)( *chi11, minus_chi11 ); \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( j = 0; j < f_ahead; ++j ) \ + PASTEMAC(ch,axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); \ + } \ + else \ + { \ + for ( j = 0; j < f_ahead; ++j ) \ + PASTEMAC(ch,axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); \ + } \ + } \ +\ + /* x2 = x2 - A21 * x1; */ \ + kfp_af \ + ( \ + conja, \ + BLIS_NO_CONJUGATE, \ + n_ahead, \ + f, \ + minus_one, \ + A21, rs_at, cs_at, \ + x1, incx, \ + x2, incx, \ + cntx \ + ); \ + } \ + } \ } INSERT_GENTFUNC_BASIC0( trsv_unf_var2 ) From faf55400610eafd1b715047498c87c49cd541e09 Mon Sep 17 00:00:00 2001 From: Meghana Vankadari Date: Fri, 26 Nov 2021 13:58:05 +0530 Subject: [PATCH 5/5] Removed unwanted function declarations from kernel.h file --- kernels/zen/bli_kernels_zen.h | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index cf649f8bf0..2c24931365 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -201,16 +201,3 @@ GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_3x2 ) GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_3x1 ) -// gemm square matrix size friendly implementation -err_t bli_gemm_sqp - ( - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - cntl_t* cntl - ); - -