Writing an AVX2 DGEMM kernel
A double-precision general matrix multiply (DGEMM) computes:
Where:
is an matrix is a matrix is an matrix and are scalar values
It is a common idea to use an inner microkernel for outer product accumulation.
For each element
We process 4 elements simultaneously using 256-bit AVX2 registers and use fused multiply-add instructions for an improved performance margin.
In the microkernel, we compute a 4×4 block of
For each row
- Load
into a register via broadcast (same value replicated 4 times) - Load 4 consecutive elements of
(4 columns) into a 256-bit register - Perform FMA,
We can express this intuitively.
Initialize the accumulators.
void dgemm_micro_kernel(int k, const double *a, int lda, const double *b, int ldb,
double *c, int ldc)
{
__m256d c00 = _mm256_setzero_pd();
__m256d c10 = _mm256_setzero_pd();
__m256d c20 = _mm256_setzero_pd();
__m256d c30 = _mm256_setzero_pd();
}
Enter the main computation loop.
for (int p = 0; p < k; p++) {}
Load a column from
__m256d a0 = _mm256_broadcast_sd(&a[0*lda + p]);
__m256d a1 = _mm256_broadcast_sd(&a[1*lda + p]);
__m256d a2 = _mm256_broadcast_sd(&a[2*lda + p]);
__m256d a3 = _mm256_broadcast_sd(&a[3*lda + p]);
Load a row from
__m256d b0 = _mm256_load_pd(&b[p*ldb]);
Update accumulators with outer products using FMA.
c00 = _mm256_fmadd_pd(a0, b0, c00);
c10 = _mm256_fmadd_pd(a1, b0, c10);
c20 = _mm256_fmadd_pd(a2, b0, c20);
c30 = _mm256_fmadd_pd(a3, b0, c30);
Store results back to
_mm256_store_pd(&c[0*ldc], c00);
_mm256_store_pd(&c[1*ldc], c10);
_mm256_store_pd(&c[2*ldc], c20);
_mm256_store_pd(&c[3*ldc], c30);
Tiling
Matrix
The cache efficiency would be very low with these assumptions so we decompose the matrices into blocks.
Where:
is the block of spanning columns to is the block of spanning rows to
We define macros for cache-based tiling sizes and block size for the microkernel.
#define MC 96
#define KC 256
#define NC 4096
#define MR 4
#define NR 4
Each block multiplication is further decomposed into microkernel operations on 4×4 tiles.
This explains why we used MR and NR.
void pack_a(int mc, int kc, const double *a, int lda, double *a_packed)
{
int mp = (mc + MR - 1) / MR * MR;
for (int i = 0; i < mc; i += MR)
{}
}
Calculate block size with edge case handling.
int ib = (i + MR <= mc) ? MR : mc - i;
Reorder matrix elements for improved stride access.
for (int p = 0; p < kc; p++)
{
for (int ii = 0; ii < ib; ii++)
{
a_packed[p * MR + ii] = a[(i + ii) * lda + p];
}
}
Pad with zeros.
/**/{
for (int ii = ib; ii < MR; ii++) {
a_packed[p*MR + ii] = 0.0;
}
}
a_packed += kc * MR;
Matrix
The primary DGEMM function uses blocks where we apply
Allocate packed buffers aligned for AVX2.
void dgemm_avx2(int m, int n, int k, double alpha, const double *a, int lda,
const double *b, int ldb, double beta, double *c, int ldc)
{
double *a_packed = (double *)_mm_malloc(MC * KC * sizeof(double), 1<<5);
double *b_packed = (double *)_mm_malloc(KC * NC * sizeof(double), 1<<5);
}
Apply beta scaling to
if (beta != 1.0)
{
for (int i = 0; i < m; i++)
{
for (int j = 0; j < n; j++)
{
c[i * ldc + j] *= beta;
}
}
}
Loop over blocks in
for (int j = 0; j < n; j += NC)
{
int jb = (j + NC <= n) ? NC : n - j;
/**/
}
Loop over blocks in
for (int p = 0; p < k; p += KC)
{
int pb = (p + KC <= k) ? KC : k - p;
/**/
}
Pack a block of
pack_b(pb, jb, &b[p*ldb + j], ldb, b_packed);
Loop over blocks in
for (int i = 0; i < m; i += MC)
{
int ib = (i + MC <= m) ? MC : m - i;
/**/
}
Pack a block of A.
pack_a(ib, pb, &a[i*lda + p], lda, a_packed);
Compute on packed data with microkernels.
for (int ii = 0; ii < ib; ii += MR)
{
for (int jj = 0; jj < jb; jj += NR)
{
double *c_block = &c[(i + ii) * ldc + j + jj];
/**/
}
}
Invoke the microkernel.
dgemm_micro_kernel(pb,
&a_packed[(ii/MR)*pb*MR], MR,
&b_packed[(jj/NR)*pb*NR], NR,
c_block, ldc);
Free packed buffers.
_mm_free(a_packed);
_mm_free(b_packed);
Compile with -O3 -mavx2 -mfma -ffast-math. Clang will autovectorize the loops. With GCC, use ftree-vectorize.
There is an interesting result when we log matrix
61.1275 63.3697 66.7324 63.6599 66.6417 63.5247
65.3739 63.7844 63.8110 62.6364 64.2466 65.4696
63.0888 62.6113 66.2171 60.3866 65.0647 64.5883
63.2216 64.7124 61.5693 59.3284 64.4507 66.4562
61.7609 65.7762 63.7569 62.1691 64.8857 64.7059
65.2668 66.8201 62.2359 60.5313 62.4041 66.6678
Check the final source on Godbolt.
- ← Previous
Probability theory reference - Next →
A dirty way to measure GFLOPS