问题定义
输入两个 Tensor:\(input[Batch, Ci, Ih,
Iw]\) (经典 NCHW 布局)、\(kernel[Co,
Ci, Kh, Kw]\) (为适应 Coalesced Access 可重排为 \([Co, Kh, Kw, Ci]\) ),其中前者为 FP32/FP16
类型,后者为 INT4/INT8 类型。输出卷积结果。
可以发现,对于 \(Iw = Ih = Ow = Oh = Kw =
Kh = Dw = Dh = Sw = Sh = 1\) (输入、输出、卷积核尺寸、扩张率
Dilute、步长 Stride 均为 1x1),\(Pw = Ph =
0\) (不填充)的情况,卷积退化为标准卷积乘法 \([Batch, Ci] \times [Ci, Co]\) ,以下简记为
\(A[N, K] \times B[K,
M]\) 。(当然,即使不是如此,也可以通过 Im2Col 或 Implicit Conv
等方式将其转换成可直接使用加速库的矩阵乘法)。
对于 \(B\) (在 LLM
推理中通常为量化后权重矩阵),采用 Per-block 量化,即在 \(K\) 维度上,每 \(block\_size\) 个元素共享一套量化参数(\(scale\) 和 \(zero\_point\) ),\(M\) 维度不共享。因此可以理解为,将 \(B\) 的每列按 \(block\_size\) 划分为多个竖条,总块数为
\(K / block\_size \times M\) 。
初探
CONV_FpAInt4B
首先,卷积可以按照原始定义实现,我们对 \(B\)
在线反量化(离线反量化就失去量化本身的意义了),并要求一个量化块在同一个线程内处理从而共享量化参数。因此,我们让每个线程负责一个输出点,实现它的计算过程即可。
理论上有较大优化空间,但并非本文重点。INT4
需要加一个简单的解包,不再赘述。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 template <typename T>__global__ void CONV_FpAInt4B (const T* input, const uint8_t * kernel, const T* scale, const T* offset, const T* bias, T *output, const float maxV, const float minV, const int ic, const int ic_p, const int iw, const int ih, const int c, const int c_p, const int ow, const int oh, const int kw, const int kh, const int dw, const int dh, const int sw, const int sh, const int pw, const int ph, const int total, const int quanC, DivModFast d_oc, DivModFast d_ow, DivModFast d_oh ) { for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < total; index += blockDim.x * gridDim.x) { int oz_2, tmp2, oy, ox, tmp1, ob; d_oc.divmod (index, tmp1, oz_2); d_ow.divmod (tmp1, tmp2, ox); d_oh.divmod (tmp2, ob, oy); int oz = oz_2; int ix = ox * sw - pw; int iy = oy * sh - ph; float color0 = bias[oz]; const int num_quan_groups_per_channel = (c > 0 && quanC > 0 ) ? (quanC / c) : 1 ; const int ic_per_group = (num_quan_groups_per_channel > 0 ) ? (ic / num_quan_groups_per_channel) : ic; const int quan_param_index_base = oz * num_quan_groups_per_channel; int fxSta = max (0 , (UP_DIV (-ix, dw))); int fySta = max (0 , (UP_DIV (-iy, dh))); int fxEnd = min (kw, UP_DIV (iw - ix, dw)); int fyEnd = min (kh, UP_DIV (ih - iy, dh)); int fx, fy, fz; for (fy=fySta; fy<fyEnd; ++fy) { int sy = fy*dh + iy; for (fx=fxSta; fx<fxEnd; ++fx) { int sx = fx*dw + ix; for (int group_idx = 0 ; group_idx < num_quan_groups_per_channel; ++group_idx) { const int quan_param_index = quan_param_index_base + group_idx; const float x_scale = scale[quan_param_index]; const float x_offset = offset[quan_param_index]; const int sz_start = group_idx * ic_per_group / 2 ; const int sz_end = sz_start + ic_per_group / 2 ; for (int sz = sz_start; sz < sz_end && sz * 2 < ic_p; ++ sz) { int src_offset = ((ob * ih + sy) * iw + sx) * ic_p + 2 * sz; float inp0 = input[src_offset]; float inp1 = input[src_offset+1 ]; uint8_t ker = kernel[((oz * kh + fy) * kw + fx) * ic_p / 2 + sz]; int8_t ker0 = (ker >> 4 ) - 8 ; int8_t ker1 = (ker & 15 ) - 8 ; color0 = color0 + inp0 * ((float )ker0 * x_scale + x_offset); color0 = color0 + inp1 * ((float )ker1 * x_scale + x_offset); } } } } color0 = max (color0, minV); color0 = min (color0, maxV); int dst_offset = ((ob * oh + oy) * ow + ox) * c_p + oz; output[dst_offset] = color0; } }
ConvInt8CutlassExecution
接下来考察 GEMM 特殊情况,考虑先对 \(A\) 做量化,再对两个矩阵进行 INT*INT
的矩阵乘法,最后再将得到的 INT32 用一套新的量化参数量化。
先思考 \(A\) 和 \(B\) 均为 Per-tensor
量化的特殊情况,即两个输入矩阵分别只有一套量化参数,输出有一套。此外还有一个
\(Bias\) 参数需要加上。
将量化参数记为 \(s\) 和 \(z\) ,简单推导得:
\(c_{fp, ij} = s_A s_B \left[
\sum_{k=1}^{K} a_{q, ik} b_{q, kj} - z_B \sum_{k=1}^{K} a_{q, ik} - z_A
\sum_{k=1}^{K} b_{q, kj} + K z_A z_B \right]\)
因此,以上算法是完全可行的,问题只有如何做好反量化和最终结果的量化。
考虑一个更简单的情况,假设 \(B\)
为对称量化(即 \(Z_B =
0\) )。下面将两个矩阵分别记为 \(I\) 和 \(W\) (Input 和
Weight)。本质上我们的操作流程如下:
首先反量化,将 INT8 的输入和权重变回浮点数:\(I_{fp} = S_I \cdot (I_q - Z_I)\) ,\(W_{fp} = S_W \cdot W_q\) ;
执行标准的浮点卷积和偏置加法:\(O_{fp} =
\text{GEMM}(I_{fp}, W_{fp}) + B_{fp}\) ;
重新量化,将浮点结果,变回 INT8 输出:\(O_q = O_{fp} / S_O + Z_O\) ;
我们的目标是跳过中间的浮点步骤。整合以上三个步骤得到:
\(O_q = \frac{ \left( \sum S_I(I_q-Z_I)
\cdot S_W W_q \right) + B_{fp} }{S_O} + Z_O\)
分离出纯整数部分 \(\text{Accum}_{32} = \sum
I_q \cdot W_q\) (32位整数累加器): \(O_q = \frac{S_I S_W \left( \sum I_q W_q - Z_I \sum
W_q \right) + B_{fp}}{S_O} + Z_O\)
提出总缩放因子 \(M\) ,合并所有偏移项为 \(\text{FusedBias}\) :
\(M = \frac{S_I \cdot S_W}{S_O}, \quad
\text{FusedBias} = -Z_I \cdot \sum W_q + \frac{B_{fp}}{S_I S_W} +
\frac{Z_O}{M}\)
得到最后结果:\(O_q = M \cdot
(\text{Accum}_{32} + \text{FusedBias})\)
代码实现:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 void ConvInt8CutlassExecution::Resource::updateInputOutputScale (std::vector<float > inputQuantInfo, std::vector<float > outputQuantInfo) { if (mUseConvQuan) { return ; } float inputScale = inputQuantInfo[0 ]; float outputScale = outputQuantInfo[0 ]; float inputZeroPoint = inputQuantInfo[1 ]; float outputZeroPoint = outputQuantInfo[1 ]; mClampMin = int8_t (outputQuantInfo[2 ]); mClampMax = int8_t (outputQuantInfo[3 ]); if (inputScale == 0.f || outputScale == 0.f ) return ; mInputScale = inputScale; mOutputScale = outputScale; mInputZeroPoint = int8_t (inputZeroPoint); mOutputZeroPoint = int8_t (outputZeroPoint); const int kernelNum = static_cast <int >(mInt8WeightKernelSum.size ()); auto alphaScale = inputScale / outputScale; auto alphaData = mScaleFloatVec; auto biasData = (float *)mBiasInt32Vec; for (int i = 0 ; i < kernelNum; i++) { auto alphaValue = alphaData[i]; if (fabs (alphaValue) < 1e-6 ) alphaValue = 1e-6 ; mScaleFloatVec[i] = alphaValue * alphaScale; int outputZeroPointFused = static_cast <int32_t >(outputZeroPoint / mScaleFloatVec[i]); mBiasInt32Vec[i] = static_cast <int32_t >(biasData[i] / (alphaScale * alphaValue)) - mInt8WeightKernelSum[i] * inputZeroPoint + outputZeroPointFused; } }
\(M = \frac{S_I \cdot S_W}{S_O}\) :
mScaleFloatVec[i] = alphaValue * alphaScale
;
alphaValue
对应权重尺度 \(S_W\) , alphaScale
对应\(S_I / S_O\) 。
\(\text{FusedBias} = -Z_I \sum W_q +
\frac{B_{fp}}{S_I S_W} +
\frac{Z_O}{M}\) :mBiasInt32Vec[i] = term1 + term2 + term3
;
term1 = biasData[i] / (alphaScale * alphaValue)
对应
\(B_{fp} / M\) 。为了让它和公式中的
\(\frac{B_{fp}}{S_I S_W}\)
匹配,biasData 必须是预先处理过的 \(B_{fp}/S_O\) 。这是一个常见的实现技巧。
term2 = - mInt8WeightKernelSum[i] * inputZeroPoint
对应
\(-Z_I \sum W_q\) 。
term3 = outputZeroPointFused
(即 outputZeroPoint /
mScaleFloatVec[i]) 对应公式中的 \(Z_O /
M\) 。
llama.cpp 实现
混合精度 GEMM 与 GEMV 代码重点在 ggml/src/ggml-cuda/ 下的
ggml-cuda.cu、mmq.cu、mmq.cuh、mmvq.cu。
量化相关代码重点在 quantize.cu 中;底层向量矩阵乘法实现在 vecdotq.cu
中。
Ilama.cpp 对于 FP16/FP32* INT 混合精度矩阵乘法:
cuBLAS 路径:全反量化成 FP16/FP32(代码在 ggml-cuda.cu 中的
ggml_cuda_mul_mat_batched_cublas
、
ggml_cuda_op_mul_mat_cublas
);
手写内核(MMQ )路径:
对 FP16INT,使用cuBLAS 路径,执行 FP16 FP16,再将 FP16
结果反量化成 FP32;
对 FP32 * INT,全矩阵量化成 INT,执行 INT *INT,再将 INT32
结果反量化为 FP32;
分muL_mat_q 和 mul_mat_vec_q 两个版本;还有对 MoE 特化版本;
没有能够直接调用的在线反量化并做矩阵乘的 cuBLAS 接口;没用
Cutlass;
大量模板元实现的编译期分支,用于确定核函数常数和调用的函数指针;
以下是几个与 MMQ 路径相关的关键函数:
ggml_cuda_op_mul_mat
:
一个通用的矩阵乘法执行引擎,它负责处理多GPU数据切分与同步,并能通过函数指针调用任何具体计算实现(cuBLAS或自定义量化核函数)。
ggml_cuda_mul_mat
:
矩阵乘法操作的顶层“决策者”,它通过分析输入张量的类型、形状和硬件特性,智能地分发任务给最优的后端实现(如自定义量化内核或多种cuBLAS路径)。
在需要 GPU 切分等时调用 ggml_cuda_op_mul_mat,否则直接调用
ggml_cuda_mul_mat_q 等;
ggml_cuda_mul_mat_q
:
通用“矩阵-矩阵”量化乘法(MMQ)的逻辑主入口,负责动态量化FP32输入并处理标准及混合专家(MoE)两种计算模式。
针对 MoE 的整体解决方案。
ggml_cuda_op_mul_mat_q
:
作为通用执行引擎调用的底层计算接口,它接收已准备好的量化输入,将其打包为内核参数并启动实际的“矩阵-矩阵”量化计算核函数。
直接得到上层 ggml_cuda_op_mul_mat 量化处理之后的参数。
mul_mat_q_case
是真正的执行函数。
方案一:在线反量化
现在,我们考虑上一节的在线反量化卷积的矩阵乘法与向量-矩阵乘法特化版本。
根据 CUDA 编程加速技巧,我们对 GEMM 做如下优化:
将矩阵分割为 16x16 的 Tile,每个 thread-block 负责结果的一块(涉及 A
的一”块行“ 与 B 的一”块列“);每个线程负责其中 A 的一行与 B 的一列,
k_tile
负责枚举 K 维度上的块;
代码中矩阵乘法的过程访问的是 B_tile_fp[k][tx]
和
A_tile[ty][k]
。相邻线程 ty 相等, tx 相邻。因此同一个 warp
内,B 访问的是同一行数据(合并访问),而 A
访问的是同一个数(直接触发广播)。均不会发生 32-way bank
conflict,因此对 B_tile 列维度 +1 的 Padding 是没有必要的;
GEMM 循环内第一阶段,通过合并访存将 A 和 B 载入到 Shared Memory
中;第二阶段并行进行反量化与转置;第三阶段做矩阵乘法;
代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 const int TILE_DIM = 16 ;template <typename T>__global__ void GEMM_FpAInt8B ( const T* input, const int8_t * kernel, const T* scale, const T* offset, const T* bias, T* output, const float maxV, const float minV, const int ic, const int ic_p, const int oc, const int oc_p, const int batch, const int quanC ) { __shared__ T A_tile[TILE_DIM][TILE_DIM]; __shared__ int8_t B_tile_s8[TILE_DIM][TILE_DIM]; __shared__ T B_tile_fp[TILE_DIM][TILE_DIM]; const int tx = threadIdx.x; const int ty = threadIdx.y; const int block_row = blockIdx.y; const int block_col = blockIdx.x; const int out_row = block_row * TILE_DIM + ty; const int out_col = block_col * TILE_DIM + tx; float acc = 0.0f ; const int num_k_tiles = UP_DIV (ic, TILE_DIM); for (int k_tile = 0 ; k_tile < num_k_tiles; ++k_tile) { const int k_tile_base = k_tile * TILE_DIM; int a_col_idx = k_tile_base + tx; A_tile[ty][tx] = (out_row < batch && a_col_idx < ic) ? input[out_row * ic_p + a_col_idx] : (T)0.0f ; int b_load_row = block_col * TILE_DIM + ty; int b_col_idx = k_tile_base + tx; B_tile_s8[ty][tx] = (b_load_row < oc && b_col_idx < ic) ? kernel[b_load_row * ic_p + b_col_idx] : 0 ; __syncthreads(); const int K = ty; const int N = tx; const int global_k = k_tile_base + K; const int global_n = block_col * TILE_DIM + N; if (global_n < oc && global_k < ic) { const int num_quan_groups_per_channel = (quanC > 0 ) ? (quanC / oc) : 1 ; const int ic_per_group = (num_quan_groups_per_channel > 0 ) ? (ic / num_quan_groups_per_channel) : ic; const int group_idx = global_k / ic_per_group; const int quan_param_index = global_n * num_quan_groups_per_channel + group_idx; const float x_scale = (float )scale[quan_param_index]; const float x_offset = (float )offset[quan_param_index]; const float b_quant = (float )B_tile_s8[N][K]; B_tile_fp[K][N] = (T)(b_quant * x_scale + x_offset); } __syncthreads(); if (out_col < oc) { #pragma unroll for (int k = 0 ; k < TILE_DIM; ++k) { acc += (float )A_tile[ty][k] * (float )B_tile_fp[k][tx]; } } __syncthreads(); } if (out_row < batch && out_col < oc) { acc += (float )bias[out_col]; acc = max (acc, minV); acc = min (acc, maxV); output[out_row * oc_p + out_col] = (T)acc; } } dim3 threads (TILE_DIM, TILE_DIM) ;dim3 blocks (UP_DIV(ocp, TILE_DIM), UP_DIV(batch, TILE_DIM)) ;
对 GEMV 做如下优化:
每个线程块负责计算一个输出位置,每个线程负责一个 %64
剩余系的位置(合并访存+计算);
由于向量-矩阵乘的结果是一个向量,因此直接做并行规约(蝶式交换)即可,最终只需要
thread0 写回 Global Memory;
使用动态 Shared Memory,直接在第一阶段通过合并访存从 Global Memory
中加载;
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 template <typename T>__global__ void GEMV_FpAInt4B ( const T* input, const uint8_t * kernel, const T* scale, const T* offset, const T* bias, T* output, const float maxV, const float minV, const int ic, const int ic_p, const int oc, const int quanC ) { extern __shared__ uint8_t smem_buffer[]; T* smem_input = reinterpret_cast <T*>(smem_buffer); float * partial_sums = reinterpret_cast <float *>(smem_buffer + ic_p * sizeof (T)); const int oz = blockIdx.x; const int tid = threadIdx.x; for (int i = tid; i < ic; i += blockDim.x) smem_input[i] = input[i]; for (int i = ic + tid; i < ic_p; i += blockDim.x) smem_input[i] = (T)0.0f ; __syncthreads(); const int num_quan_groups_per_channel = (quanC > 0 ) ? (quanC / oc) : 1 ; const int ic_per_group = (num_quan_groups_per_channel > 0 ) ? (ic / num_quan_groups_per_channel) : ic; const int quan_param_index_base = oz * num_quan_groups_per_channel; float my_sum = 0.0f ; for (int k = tid * 2 ; k < ic; k += blockDim.x * 2 ) { const uint8_t ker_packed = kernel[oz * (ic_p / 2 ) + k / 2 ]; const int8_t ker0_s8 = (ker_packed >> 4 ) - 8 ; const int8_t ker1_s8 = (ker_packed & 0x0F ) - 8 ; const int group_idx0 = k / ic_per_group; const int quan_param_index0 = quan_param_index_base + group_idx0; const float x_scale0 = scale[quan_param_index0]; const float x_offset0 = offset[quan_param_index0]; my_sum += (float )smem_input[k] * ((float )ker0_s8 * x_scale0 + x_offset0); my_sum += (float )smem_input[k + 1 ] * ((float )ker1_s8 * x_scale0 + x_offset0); } partial_sums[tid] = my_sum; __syncthreads(); for (unsigned int s = blockDim.x / 2 ; s > 0 ; s >>= 1 ) { if (tid < s) partial_sums[tid] += partial_sums[tid + s]; __syncthreads(); } if (tid == 0 ) { float final_val = partial_sums[0 ] + (float )bias[oz]; final_val = max (final_val, minV); final_val = min (final_val, maxV); output[oz] = (T)final_val; } } dim3 threads (GEMV_TILE) ;dim3 blocks (oc) ;size_t input_smem_size = icp * (mFp16Infer ? sizeof (half) : sizeof (float ));size_t reduction_smem_size = GEMV_TILE * sizeof (float );size_t smem_size = input_smem_size + reduction_smem_size;GEMV_FpAInt4B<<<blocks, threads, smem_size>>>(...);
方案二:在线量化
方案推导
最后,我们考虑对 \(A\) 做与 \(B\) 相同 \(block\_size\) 的 Per-block 量化(均在 \(K\) 维度切割),再调用 INT*INT
矩阵乘法,最后反量化得到浮点结果。现在,\(A\) 的每行被分割,\(B\) 的每列被分割,其余维度不共享参数。
问题在于,不同 block 中的元素使用的是不同的
量化参数,因此显然无法只通过一次完整的整数矩阵乘法就得到答案。另一方面,逐个
block 地相乘效率又太低,无法享受到 Cutlass 的加速(一个 block 往往只有
64 个元素)。
因此,我们考虑每次拿出 \(A\)
的一”块列“,与 \(B\)
的一”块行“做矩阵乘法,得到一个 \(N \times
M\) 的矩阵,然后暴力将矩阵中的每个数(对应一对 block
的乘积)用对应那对 block 的量化参数做反量化,加到答案矩阵上。重复 \(\frac{K}{block\_size}\) 次即可。
更形式化地:
我们将求和维度 \(K\) 分为 \(P\) 个连续的块(block):\(K = G_1 \cup G_2 \cup \dots \cup
G_P\) 。
量化参数 \(s\) 和 \(z\) 不仅依赖于矩阵和行列,更依赖于 \(k\) 所在的块 \(G_i\) 。
因此,反量化公式应该写成(下标 \(i\)
代表第 \(i\) 个 \(K\) 维度块):
\(A_{fp, mk} = s_{A,m,i} \cdot (A_{q, mk} -
z_{A,m,i}) \quad k \in G_i\)
\(B_{fp, kn} = s_{B,n,i} \cdot (B_{q, kn} -
z_{B,n,i}) \quad k \in G_i\)
代入矩阵乘法公式( \(g(k)\) 为索引
\(k\) 所属的块的编号):
\(C_{fp, mn} = \sum_{k=1}^{K} \left[
s_{A,m,g(k)} \cdot (A_{q, mk} - z_{A,m,g(k)}) \right] \cdot \left[
s_{B,n,g(k)} \cdot (B_{q, kn} - z_{B,n,g(k)}) \right]\)
由于缩放因子 \(s_A, s_B\) 和零点
\(z_A, z_B\) 的值会随着 \(k\)
的变化而改变。因此不能将它们作为常数从整个求和 \(\sum_{k=1}^{K}\) 中提出来。
改变求和结构,将对 \(k\)
的总求和分解为“对所有块的求和”,内部嵌套“对块内元素的求和”: \(C_{fp, mn} = \sum_{i=1}^{P} \left( \sum_{k \in
G_i} \left[ s_{A,m,i} \cdot (A_{q, mk} - z_{A,m,i}) \right] \cdot \left[
s_{B,n,i} \cdot (B_{q, kn} - z_{B,n,i}) \right] \right)\)
对于求和范围 \(\sum_{k \in G_i}\)
内的所有 \(k\) ,它们都属于同一个块
\(G_i\) ,因此它们的量化参数不变,可以将这些参数提到内部求和的外面:
\(C_{fp, mn} = \sum_{i=1}^{P} s_{A,m,i} \cdot
s_{B,n,i} \left( \sum_{k \in G_i} (A_{q, mk} - z_{A,m,i}) \cdot (B_{q,
kn} - z_{B,n,i}) \right)\)
以上即逐 block 乘法的数学基础,接下来推导”块行块列“乘法:
我们定义一个第 \(i\)
块的浮点结果矩阵 :
\(C_{fp, mn}^{(i)} := s_{A,m,i} \cdot
s_{B,n,i} \left( \sum_{k \in G_i} (A_{q, mk} - z_{A,m,i}) \cdot (B_{q,
kn} - z_{B,n,i}) \right)\)
那么,最终的输出矩阵 \(C_{fp}\)
就是所有这些局部结果矩阵的简单叠加: \(C_{fp,
mn} = \sum_{i=1}^{P} C_{fp, mn}^{(i)}\)
定义与块 \(i\) 相关的子矩阵:
\(A_q^{(i)}\) :由矩阵 \(A_q\) 的所有行,以及 只属于块 \(G_i\) 的列 构成的子矩阵。维度为
\(M \times block\_size\) ;
\(B_q^{(i)}\) :由矩阵 \(B_q\) 的所有列,以及 只属于块 \(G_i\) 的行 构成的子矩阵。其维度为
\(block\_size \times N\) ;
这两个矩阵的乘积结果定义为 第 \(i\) 块的整数结果矩阵 \(C_{q}^{(i)}\) : \(C_{q, mn}^{(i)} := \sum_{k \in G_i} A_{q, mk}
B_{q, kn}\)
这个运算是一个维度为 \((M \times
block\_size) \times (block\_size \times N) \rightarrow (M \times
N)\) 的通用矩阵乘法 (GEMM)。其结果 \(C_{q}^{(i)}\) 是一个完整的 \(M \times N\) 的 INT32 矩阵。
再次展开 \(C_{fp, mn}^{(i)}\)
的定义并将 \(C_{q, mn}^{(i)}\)
代入,我们得到了最终的反量化公式,它描述了如何将一个完整的中间整数矩阵逐元素地转换为浮点矩阵:
\(C_{fp, mn}^{(i)} = s_{A,m,i} s_{B,n,i}
\cdot \\ \left( \underbrace{\sum_{k \in G_i} A_{q, mk}B_{q, kn}}_{C_{q,
mn}^{(i)}} - z_{B,n,i}\sum_{k \in G_i} A_{q, mk} - z_{A,m,i}\sum_{k \in
G_i} B_{q, kn} + block\_size \cdot z_{A,m,i}z_{B,n,i}
\right)\)
算法流程如下:
for i = 1 to P(遍历所有 \(K\)
维度的块):
整数子矩阵乘法
提取子矩阵 \(A_q^{(i)}\) 和 \(B_q^{(i)}\) 。
调用高效 int8*int8 -> int32 GEMM 库函数,计算出中间结果矩阵 \(C_{q_i} = \text{matmul}(A_{q_i},
B_{q_i})\) 。
逐元素反量化与累加
启动一个 CUDA Kernel,每个线程处理 \(C_{q_i}\) 的一个或多个元素。
在 Kernel 内部,对于每个元素 \((m,
n)\) :
读取 \(C_{q_i,mn}\) ,根据块 \(i\) 的量化参数 \(s_{A,m,i}, z_{A,m,i}, s_{B,n,i},
z_{B,n,i}\) ,计算出局部浮点值 \(c_{fp_i}\) 。
以原子方式或直接(如果输出空间不冲突)将 \(c_{fp_i}\)
加到最终结果矩阵的对应位置。
以下代码采用 offset 而非 zero-point 形式:
\(C_{fp, mn}^{(i)} = \sum_{k \in G_i}
\left[ \text{scale}_{A,m,i} A_{q,mk} + \text{offset}_{A,m,i} \right]
\cdot \left[ \text{scale}_{B,n,i} B_{q,kn} + \text{offset}_{B,n,i}
\right] \\ = \left(\text{scale}_{A,m,i}\text{scale}_{B,n,i}\right)
C_{q,mn}^{(i)} + \left(\text{scale}_{A,m,i}\text{offset}_{B,n,i}\right)
A_{q,m,i} \\ + \left(\text{offset}_{A,m,i}\text{scale}_{B,n,i}\right)
B_{q,n,i} + K_i \cdot
\text{offset}_{A,m,i}\text{offset}_{B,n,i}\)
代码
对于核函数部分,做了以下优化:
使用 union 复用共享内存,减少总占用量;
对于并行规约,先试用 __shfl_down_sync
做 Warp
内规约(原子操作),再在 Warp 0 做一次 Warp 间规约(实际上仍是
__shfl_down_sync
做 Warp 内规约);
查看 cuda_fp16.hpp(CUDA Toolkit
12.4)可知,__half atomicAdd(__half *const address, const __half val)
需要在 SM_70
及以后才支持,因此使用预编译期条件(这种方式同样可以应对胖二进制(Fat
Binary)编译方式的情况)
#if (defined(__CUDACC__) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700))) || defined(_NVHPC_CUDA)
和编译期条件 if constexpr (std::is_same_v<T, half>)
进行分支判断;
Cutlass 可以直接调用,也可以分成准备+执行两阶段(在下文 Cutlass
相关代码中);
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 template <typename T>__global__ void QuantA ( const T* A_sub_fp, int8_t * A_sub_q, T* scale_A_out, T* offset_A_out, int32_t * sum_A_q_out, const int M, const int K_i, const int lda ) { __shared__ union { float min_max_vals[2 ][BLOCK_SIZE / WARP_SIZE]; int32_t sum_vals[BLOCK_SIZE / WARP_SIZE]; float scale_offset[2 ]; } smem; const int m = blockIdx.x; if (m >= M) return ; const int tid = threadIdx.x; const int warp_id = tid / WARP_SIZE; const int lane_id = tid % WARP_SIZE; float my_min = FLT_MAX; float my_max = -FLT_MAX; for (int k = tid; k < K_i; k += BLOCK_SIZE) { float val = (float )A_sub_fp[m * lda + k]; my_min = min (my_min, val); my_max = max (my_max, val); } for (int offset = WARP_SIZE / 2 ; offset > 0 ; offset >>= 1 ) { my_min = min (my_min, __shfl_down_sync(0xFFFFFFFF , my_min, offset)); my_max = max (my_max, __shfl_down_sync(0xFFFFFFFF , my_max, offset)); } if (lane_id == 0 ) { smem.min_max_vals[0 ][warp_id] = my_min; smem.min_max_vals[1 ][warp_id] = my_max; } __syncthreads(); if (warp_id == 0 ) { if (lane_id < (BLOCK_SIZE / WARP_SIZE)) { my_min = smem.min_max_vals[0 ][lane_id]; my_max = smem.min_max_vals[1 ][lane_id]; } else { my_min = FLT_MAX; my_max = -FLT_MAX; } for (int offset = (BLOCK_SIZE / WARP_SIZE) / 2 ; offset > 0 ; offset >>= 1 ) { my_min = min (my_min, __shfl_down_sync(0xFFFFFFFF , my_min, offset)); my_max = max (my_max, __shfl_down_sync(0xFFFFFFFF , my_max, offset)); } } if (tid == 0 ) { float scale = (my_max - my_min) / 255.0f ; float offset; if (abs (scale) > 1e-5 ) { offset = my_max - scale * 127.0f ; } else { scale = 1.0f ; offset = my_max; } scale_A_out[m] = (T)scale; offset_A_out[m] = (T)offset; smem.scale_offset[0 ] = scale; smem.scale_offset[1 ] = offset; } __syncthreads(); const float s_scale_A = smem.scale_offset[0 ]; const float s_offset_A = smem.scale_offset[1 ]; int32_t my_sum_q = 0 ; for (int k = tid; k < K_i; k += BLOCK_SIZE) { float val_fp = (float )A_sub_fp[m * lda + k]; int32_t val_q = roundf ((val_fp - s_offset_A) / s_scale_A); int8_t a_q = (int8_t )max (-128 , min (127 , val_q)); A_sub_q[m * K_i + k] = a_q; my_sum_q += a_q; } for (int offset = WARP_SIZE / 2 ; offset > 0 ; offset >>= 1 ) { my_sum_q += __shfl_down_sync(0xFFFFFFFF , my_sum_q, offset); } if (lane_id == 0 ) smem.sum_vals[warp_id] = my_sum_q; __syncthreads(); if (warp_id == 0 ) { my_sum_q = (lane_id < (BLOCK_SIZE / WARP_SIZE)) ? smem.sum_vals[lane_id] : 0 ; for (int offset = (BLOCK_SIZE / WARP_SIZE) / 2 ; offset > 0 ; offset >>= 1 ) { my_sum_q += __shfl_down_sync(0xFFFFFFFF , my_sum_q, offset); } } if (tid == 0 ) sum_A_q_out[m] = my_sum_q; } __global__ void GEMM_Int8 ( const int8_t * A_q, const int8_t * B_q, int32_t * C_q, const int M, const int N, const int K_i, const int lda_q, const int ldb, const int ldc ) { __shared__ int8_t A_tile_s8[SUB_GEMM_TILE_DIM][SUB_GEMM_TILE_DIM]; __shared__ int8_t B_tile_s8[SUB_GEMM_TILE_DIM][SUB_GEMM_TILE_DIM]; const int block_row = blockIdx.y; const int block_col = blockIdx.x; const int ty = threadIdx.y; const int tx = threadIdx.x; const int row = block_row * SUB_GEMM_TILE_DIM + ty; const int col = block_col * SUB_GEMM_TILE_DIM + tx; int32_t acc = 0 ; const int num_k_tiles = UP_DIV (K_i, SUB_GEMM_TILE_DIM); for (int k_tile = 0 ; k_tile < num_k_tiles; ++k_tile) { const int k_tile_base = k_tile * SUB_GEMM_TILE_DIM; A_tile_s8[ty][tx] = (row < M && (k_tile_base + tx) < K_i) ? A_q[row * lda_q + k_tile_base + tx] : 0 ; B_tile_s8[ty][tx] = (col < N && (k_tile_base + ty) < K_i) ? B_q[col * ldb + k_tile_base + ty] : 0 ; __syncthreads(); #pragma unroll for (int k = 0 ; k < SUB_GEMM_TILE_DIM; ++k) { acc += (int32_t )A_tile_s8[ty][k] * (int32_t )B_tile_s8[k][tx]; } __syncthreads(); } if (row < M && col < N) C_q[row * ldc + col] = acc; } template <typename T>__global__ void DequantAndAcc ( const int32_t * C_q, T* C_fp_final, const T* scale_A_in, const T* offset_A_in, const T* base_scale_B, const T* base_offset_B, const int32_t * base_sum_B_q, const int group_idx, const int num_oc_groups, const int32_t * sum_A_q_in, const int M, const int N, const int K_i, const int ldc ) { __shared__ float smem_scale_A[DEQUANT_TILE_DIM]; __shared__ float smem_offset_A[DEQUANT_TILE_DIM]; __shared__ int32_t smem_sum_A_q[DEQUANT_TILE_DIM]; __shared__ float smem_scale_B[DEQUANT_TILE_DIM]; __shared__ float smem_offset_B[DEQUANT_TILE_DIM]; __shared__ int32_t smem_sum_B_q[DEQUANT_TILE_DIM]; const int tx = threadIdx.x; const int ty = threadIdx.y; const int block_row_start = blockIdx.y * DEQUANT_TILE_DIM; const int block_col_start = blockIdx.x * DEQUANT_TILE_DIM; int m_load_idx = block_row_start + ty; if (tx == 0 && m_load_idx < M) { smem_scale_A[ty] = (float )scale_A_in[m_load_idx]; smem_offset_A[ty] = (float )offset_A_in[m_load_idx]; smem_sum_A_q[ty] = sum_A_q_in[m_load_idx]; } int n_load_idx = block_col_start + tx; if (ty == 0 && n_load_idx < N) { const size_t b_param_idx = n_load_idx * num_oc_groups + group_idx; smem_scale_B[tx] = (float )base_scale_B[b_param_idx]; smem_offset_B[tx] = (float )base_offset_B[b_param_idx]; smem_sum_B_q[tx] = base_sum_B_q[b_param_idx]; } __syncthreads(); const int m = block_row_start + ty; const int n = block_col_start + tx; if (m < M && n < N) { const float scale_A = smem_scale_A[ty]; const float offset_A = smem_offset_A[ty]; const float sum_A_q = (float )smem_sum_A_q[ty]; const float scale_B = smem_scale_B[tx]; const float offset_B = smem_offset_B[tx]; const float sum_B_q = (float )smem_sum_B_q[tx]; const float c_q_val = (float )C_q[m * ldc + n]; float term1 = scale_A * (c_q_val * scale_B + sum_A_q * offset_B); float term2 = offset_A * (sum_B_q * scale_B + K_i * offset_B); float final_val = term1 + term2; #if (defined(__CUDACC__) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700))) || defined(_NVHPC_CUDA) if constexpr (std::is_same_v<T, float >) { atomicAdd (&C_fp_final[m * ldc + n], final_val); } else if constexpr (std::is_same_v<T, half>) { atomicAdd (&C_fp_final[m * ldc + n], __float2half(final_val)); } #else C_fp_final[m * ldc + n] += final_val; #endif } } __global__ void Precompute_SumBq ( const int8_t * B_q, int32_t * sum_B_q_out, const int num_groups, const int ic_per_group, const int oc, const int ic_p ) { for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < (size_t )oc * num_groups; index += blockDim.x * gridDim.x) { const int i = index % num_groups; const int n = index / num_groups; const int k_start = i * ic_per_group; int32_t sum = 0 ; for (int k_offset = 0 ; k_offset < ic_per_group; ++k_offset) { sum += (int32_t )B_q[n * ic_p + k_start + k_offset]; } sum_B_q_out[index] = sum; } } template <typename T>__global__ void BiasAndActivation ( T* data, const T* bias, const float minV, const float maxV, const int M, const int N, const int ldc ) { for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < (size_t )M * N; index += blockDim.x * gridDim.x) { const int m = index / N; const int n = index % N; const size_t buffer_idx = m * ldc + n; float val = (float )data[buffer_idx]; val += (float )bias[n]; val = max (val, minV); val = min (val, maxV); data[buffer_idx] = (T)val; } }
这里我们只对 INT8 进行优化,因为 Cutlass 对 INT4 GEMM
的支持与硬件计算能力绑定。
查阅 PTX
ISA 和 Turing
Guide 可知,SM_75 架构(Turing)虽然 Tensor Core 理论支持 INT4
GEMM,也提供了诸如
mma.sync.aligned.m8n8k16.row.col.s32.s4.s4.s32
的 PTX
ISA,但 Cutlass 并没有很好地集成相应能力。而 s4.s4 的 wmma 和 Cutlass
支持则要到 Ampere(SM_80 以上)才能支持。
Cutlass 代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 #ifndef CutlassGemmIntParam_hpp #define CutlassGemmIntParam_hpp #include "../CutlassGemmParam.hpp" namespace MNN {namespace CUDA {using EpilogueGemmInt = cutlass::epilogue::thread::LinearCombination< int32_t , 1 , int32_t , int32_t >; using CutlassGemmInt = cutlass::gemm::device::Gemm< int8_t , cutlass::layout::RowMajor, int8_t , cutlass::layout::ColumnMajor, int32_t , cutlass::layout::RowMajor, int32_t , cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, cutlass::gemm::GemmShape<128 , 128 , 32 >, cutlass::gemm::GemmShape<64 , 64 , 32 >, cutlass::gemm::GemmShape<8 , 8 , 16 >, EpilogueGemmInt, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; } } #endif mGemmArguments = { {M, N, ic_per_group}, {nullptr , ic_per_group}, {nullptr , UP_DIV (K, 8 ) * 8 }, {nullptr , UP_DIV (N, 8 ) * 8 }, {nullptr , UP_DIV (N, 8 ) * 8 }, {1 , 0 }, 1 }; cutlass::Status status = mCutlassGemmInt.can_implement (mGemmArguments); if (status != cutlass::Status::kSuccess) { MNN_ERROR ("CUTLASS GEMM cannot implement this problem\n" ); return NOT_SUPPORT; } size_t workspace_size = mCutlassGemmInt.get_workspace_size (mGemmArguments);if (workspace_size > 0 ) { mWorkspaceTensor.reset (Tensor::createDevice <int8_t >({(int )workspace_size})); backend ()->onAcquireBuffer (mWorkspaceTensor.get (), Backend::STATIC); mWorkspacePtr = (void *)mWorkspaceTensor->buffer ().device; } status = mCutlassGemmInt.initialize (mGemmArguments, mWorkspacePtr); cutlass_check (status);
主函数:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 const int M = input->batch ();const int K = input->channel ();const int N = output->channel ();const int ic_p = UP_DIV (K, 8 ) * 8 ;const int oc_p = UP_DIV (N, 8 ) * 8 ;const int num_groups = (mResource->mQuanC > 0 ) ? (mResource->mQuanC / N) : 1 ;const int ic_per_group = (num_groups > 0 ) ? (K / num_groups) : K;auto Cq_tensor = std::make_shared <Tensor>(Tensor::createDevice <int32_t >({M, oc_p}));backend ()->onAcquireBuffer (Cq_tensor.get (), Backend::STATIC);int32_t * C_q_buffer = (int32_t *)Cq_tensor->deviceId ();const int type_size = mFp16Infer ? sizeof (half) : sizeof (float );auto scaleA_tensor = std::make_shared <Tensor>(Tensor::createDevice <uint8_t >({M * type_size}));backend ()->onAcquireBuffer (scaleA_tensor.get (), Backend::STATIC);void * scale_A_buffer = (void *)scaleA_tensor->deviceId ();auto offsetA_tensor = std::make_shared <Tensor>(Tensor::createDevice <uint8_t >({M * type_size}));backend ()->onAcquireBuffer (offsetA_tensor.get (), Backend::STATIC);void * offset_A_buffer = (void *)offsetA_tensor->deviceId ();auto sumA_tensor = std::make_shared <Tensor>(Tensor::createDevice <int32_t >({M}));backend ()->onAcquireBuffer (sumA_tensor.get (), Backend::STATIC);int32_t * sum_A_q_buffer = (int32_t *)sumA_tensor->deviceId ();auto Aq_tensor = std::make_shared <Tensor>(Tensor::createDevice <int8_t >({M, ic_per_group}));backend ()->onAcquireBuffer (Aq_tensor.get (), Backend::STATIC);int8_t * A_q_buffer = (int8_t *)Aq_tensor->deviceId ();runtime->memset (output_addr, 0 , M * N * type_size); for (int i = 0 ; i < num_groups; ++i) { const int k_start = i * ic_per_group; dim3 blocks_quant (M) ; dim3 threads_quant (BLOCK_SIZE) ; dim3 blocks_gemm (UP_DIV(oc_p, SUB_GEMM_TILE_DIM), UP_DIV(M, SUB_GEMM_TILE_DIM)) ; dim3 threads_gemm (SUB_GEMM_TILE_DIM, SUB_GEMM_TILE_DIM) ; dim3 blocks_dequant (UP_DIV(N, DEQUANT_TILE_DIM), UP_DIV(M, DEQUANT_TILE_DIM)) ; dim3 threads_dequant (DEQUANT_TILE_DIM, DEQUANT_TILE_DIM) ; if (mFp16Infer) { const half* A_sub_ptr = (const half*)input_addr + k_start; const int8_t * B_sub_ptr = (const int8_t *)mResource->mFilter + k_start; const half* base_scale_B_ptr = (const half*)mResource->mScale; const half* base_offset_B_ptr = (const half*)mResource->mOffset; const int32_t * base_sum_B_q_ptr = (const int32_t *)mResource->mSumBQ; QuantA<half><<<blocks_quant, threads_quant>>>( A_sub_ptr, A_q_buffer, (half*)scale_A_buffer, (half*)offset_A_buffer, sum_A_q_buffer, M, ic_per_group, ic_p ); mGemmArguments.ref_A.reset (A_q_buffer); mGemmArguments.ref_B.reset (B_sub_ptr); mGemmArguments.ref_C.reset (C_q_buffer); mGemmArguments.ref_D.reset (C_q_buffer); cutlass::Status status = mCutlassGemmInt (mGemmArguments, mWorkspacePtr, 0 ); cutlass_check (status); DequantAndAcc<half><<<blocks_dequant, threads_dequant>>>( C_q_buffer, (half*)output_addr, (const half*)scale_A_buffer, (const half*)offset_A_buffer, base_scale_B_ptr, base_offset_B_ptr, base_sum_B_q_ptr, i, num_groups, sum_A_q_buffer, M, N, ic_per_group, oc_p ); } else { const float * A_sub_ptr = (const float *)input_addr + k_start; const int8_t * B_sub_ptr = (const int8_t *)mResource->mFilter + k_start; const float * base_scale_B_ptr = (const float *)mResource->mScale; const float * base_offset_B_ptr = (const float *)mResource->mOffset; const int32_t * base_sum_B_q_ptr = (const int32_t *)mResource->mSumBQ; QuantA<float ><<<blocks_quant, threads_quant>>>( A_sub_ptr, A_q_buffer, (float *)scale_A_buffer, (float *)offset_A_buffer, sum_A_q_buffer, M, ic_per_group, ic_p ); mGemmArguments.ref_A.reset (A_q_buffer); mGemmArguments.ref_B.reset (B_sub_ptr); mGemmArguments.ref_C.reset (C_q_buffer); mGemmArguments.ref_D.reset (C_q_buffer); cutlass::Status status = mCutlassGemmInt (mGemmArguments, mWorkspacePtr, 0 ); cutlass_check (status); DequantAndAcc<float ><<<blocks_dequant, threads_dequant>>>( C_q_buffer, (float *)output_addr, (const float *)scale_A_buffer, (const float *)offset_A_buffer, base_scale_B_ptr, base_offset_B_ptr, base_sum_B_q_ptr, i, num_groups, sum_A_q_buffer, M, N, ic_per_group, oc_p ); } } float maxV = FLT_MAX, minV = -FLT_MAX;if (mActivationType == 1 ) minV = 0.0f ;if (mActivationType == 2 ) { minV = 0.0f ; maxV = 6.0f ; }const int total_threads_act = M * oc_p;const int block_size_act = BLOCK_SIZE;const int num_blocks_act = (total_threads_act + block_size_act - 1 ) / block_size_act;if (mFp16Infer) { BiasAndActivation<half><<<num_blocks_act, block_size_act>>>((half*)output_addr, (const half*)mResource->mBias, minV, maxV, M, N, oc_p); } else { BiasAndActivation<float ><<<num_blocks_act, block_size_act>>>((float *)output_addr, (const float *)mResource->mBias, minV, maxV, M, N, oc_p); } backend ()->onReleaseBuffer (Cq_tensor.get (), Backend::STATIC);backend ()->onReleaseBuffer (scaleA_tensor.get (), Backend::STATIC);backend ()->onReleaseBuffer (offsetA_tensor.get (), Backend::STATIC);backend ()->onReleaseBuffer (sumA_tensor.get (), Backend::STATIC);backend ()->onReleaseBuffer (Aq_tensor.get (), Backend::STATIC);return NO_ERROR;