Per-block 量化混合精度卷积实现与 GEMM、GEMV 优化

问题定义

输入两个 Tensor:\(input[Batch, Ci, Ih, Iw]\) (经典 NCHW 布局)、\(kernel[Co, Ci, Kh, Kw]\)(为适应 Coalesced Access 可重排为 \([Co, Kh, Kw, Ci]\)),其中前者为 FP32/FP16 类型,后者为 INT4/INT8 类型。输出卷积结果。

可以发现,对于 \(Iw = Ih = Ow = Oh = Kw = Kh = Dw = Dh = Sw = Sh = 1\)(输入、输出、卷积核尺寸、扩张率 Dilute、步长 Stride 均为 1x1),\(Pw = Ph = 0\)(不填充)的情况,卷积退化为标准卷积乘法 \([Batch, Ci] \times [Ci, Co]\),以下简记为 \(A[N, K] \times B[K, M]\)。(当然,即使不是如此,也可以通过 Im2Col 或 Implicit Conv 等方式将其转换成可直接使用加速库的矩阵乘法)。

对于 \(B\)(在 LLM 推理中通常为量化后权重矩阵),采用 Per-block 量化,即在 \(K\) 维度上,每 \(block\_size\) 个元素共享一套量化参数(\(scale\)\(zero\_point\)),\(M\) 维度不共享。因此可以理解为,将 \(B\) 的每列按 \(block\_size\) 划分为多个竖条,总块数为 \(K / block\_size \times M\)

初探

CONV_FpAInt4B

首先,卷积可以按照原始定义实现,我们对 \(B\) 在线反量化(离线反量化就失去量化本身的意义了),并要求一个量化块在同一个线程内处理从而共享量化参数。因此,我们让每个线程负责一个输出点,实现它的计算过程即可。

理论上有较大优化空间,但并非本文重点。INT4 需要加一个简单的解包,不再赘述。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
template<typename T>
__global__ void CONV_FpAInt4B(const T* input,
const uint8_t* kernel,
const T* scale, const T* offset, const T* bias,
T *output,
const float maxV, const float minV,
const int ic, const int ic_p, const int iw, const int ih,
const int c, const int c_p, const int ow, const int oh,
const int kw, const int kh, const int dw, const int dh,
const int sw, const int sh, const int pw, const int ph,
const int total, const int quanC,
DivModFast d_oc, DivModFast d_ow, DivModFast d_oh
) {

for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < total; index += blockDim.x * gridDim.x) {
int oz_2, tmp2, oy, ox, tmp1, ob;
d_oc.divmod(index, tmp1, oz_2);
d_ow.divmod(tmp1, tmp2, ox);
d_oh.divmod(tmp2, ob, oy);

int oz = oz_2;
int ix = ox * sw - pw;
int iy = oy * sh - ph;
float color0 = bias[oz];

const int num_quan_groups_per_channel = (c > 0 && quanC > 0) ? (quanC / c) : 1;
const int ic_per_group = (num_quan_groups_per_channel > 0) ? (ic / num_quan_groups_per_channel) : ic;
const int quan_param_index_base = oz * num_quan_groups_per_channel;

int fxSta = max(0, (UP_DIV(-ix, dw)));
int fySta = max(0, (UP_DIV(-iy, dh)));
int fxEnd = min(kw, UP_DIV(iw - ix, dw));
int fyEnd = min(kh, UP_DIV(ih - iy, dh));
int fx, fy, fz;
for (fy=fySta; fy<fyEnd; ++fy) {
int sy = fy*dh + iy;
for (fx=fxSta; fx<fxEnd; ++fx) {
int sx = fx*dw + ix;
for (int group_idx = 0; group_idx < num_quan_groups_per_channel; ++group_idx) {
const int quan_param_index = quan_param_index_base + group_idx;
const float x_scale = scale[quan_param_index];
const float x_offset = offset[quan_param_index];

const int sz_start = group_idx * ic_per_group / 2;
const int sz_end = sz_start + ic_per_group / 2;

for (int sz = sz_start; sz < sz_end && sz * 2 < ic_p; ++ sz) {
int src_offset = ((ob * ih + sy) * iw + sx) * ic_p + 2 * sz;
float inp0 = input[src_offset];
float inp1 = input[src_offset+1];

//[Cop, KhKw, Cip]
uint8_t ker = kernel[((oz * kh + fy) * kw + fx) * ic_p / 2 + sz];
int8_t ker0 = (ker >> 4) - 8;
int8_t ker1 = (ker & 15) - 8;
color0 = color0 + inp0 * ((float)ker0 * x_scale + x_offset);
color0 = color0 + inp1 * ((float)ker1 * x_scale + x_offset);
}
}
}
}
color0 = max(color0, minV);
color0 = min(color0, maxV);

int dst_offset = ((ob * oh + oy) * ow + ox) * c_p + oz;

output[dst_offset] = color0;
}
}

ConvInt8CutlassExecution

接下来考察 GEMM 特殊情况,考虑先对 \(A\) 做量化,再对两个矩阵进行 INT*INT 的矩阵乘法,最后再将得到的 INT32 用一套新的量化参数量化。

先思考 \(A\)\(B\) 均为 Per-tensor 量化的特殊情况,即两个输入矩阵分别只有一套量化参数,输出有一套。此外还有一个 \(Bias\) 参数需要加上。

将量化参数记为 \(s\)\(z\),简单推导得:

\(c_{fp, ij} = s_A s_B \left[ \sum_{k=1}^{K} a_{q, ik} b_{q, kj} - z_B \sum_{k=1}^{K} a_{q, ik} - z_A \sum_{k=1}^{K} b_{q, kj} + K z_A z_B \right]\)

因此,以上算法是完全可行的,问题只有如何做好反量化和最终结果的量化。

考虑一个更简单的情况,假设 \(B\) 为对称量化(即 \(Z_B = 0\))。下面将两个矩阵分别记为 \(I\)\(W\)(Input 和 Weight)。本质上我们的操作流程如下:

  1. 首先反量化,将 INT8 的输入和权重变回浮点数:\(I_{fp} = S_I \cdot (I_q - Z_I)\)\(W_{fp} = S_W \cdot W_q\)
  2. 执行标准的浮点卷积和偏置加法:\(O_{fp} = \text{GEMM}(I_{fp}, W_{fp}) + B_{fp}\)
  3. 重新量化,将浮点结果,变回 INT8 输出:\(O_q = O_{fp} / S_O + Z_O\)

我们的目标是跳过中间的浮点步骤。整合以上三个步骤得到:

\(O_q = \frac{ \left( \sum S_I(I_q-Z_I) \cdot S_W W_q \right) + B_{fp} }{S_O} + Z_O\)

分离出纯整数部分 \(\text{Accum}_{32} = \sum I_q \cdot W_q\)(32位整数累加器): \(O_q = \frac{S_I S_W \left( \sum I_q W_q - Z_I \sum W_q \right) + B_{fp}}{S_O} + Z_O\)

提出总缩放因子 \(M\),合并所有偏移项为 \(\text{FusedBias}\)

\(M = \frac{S_I \cdot S_W}{S_O}, \quad \text{FusedBias} = -Z_I \cdot \sum W_q + \frac{B_{fp}}{S_I S_W} + \frac{Z_O}{M}\)

得到最后结果:\(O_q = M \cdot (\text{Accum}_{32} + \text{FusedBias})\)

代码实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
void ConvInt8CutlassExecution::Resource::updateInputOutputScale(std::vector<float> inputQuantInfo, std::vector<float> outputQuantInfo) {
if(mUseConvQuan) {
return;
}
// new scales and zero points
float inputScale = inputQuantInfo[0];
float outputScale = outputQuantInfo[0];
float inputZeroPoint = inputQuantInfo[1];
float outputZeroPoint = outputQuantInfo[1];
mClampMin = int8_t(outputQuantInfo[2]);
mClampMax = int8_t(outputQuantInfo[3]);

if (inputScale == 0.f || outputScale == 0.f) return;

mInputScale = inputScale;
mOutputScale = outputScale;
mInputZeroPoint = int8_t(inputZeroPoint);
mOutputZeroPoint = int8_t(outputZeroPoint);
const int kernelNum = static_cast<int>(mInt8WeightKernelSum.size());

auto alphaScale = inputScale / outputScale;
auto alphaData = mScaleFloatVec;
auto biasData = (float *)mBiasInt32Vec;

for (int i = 0; i < kernelNum; i++) {
auto alphaValue = alphaData[i];
if (fabs(alphaValue) < 1e-6) alphaValue = 1e-6;
mScaleFloatVec[i] = alphaValue * alphaScale;
// compute outputZeroPointFused in asymmetric quant
int outputZeroPointFused = static_cast<int32_t>(outputZeroPoint / mScaleFloatVec[i]);
mBiasInt32Vec[i] = static_cast<int32_t>(biasData[i] / (alphaScale * alphaValue)) - mInt8WeightKernelSum[i] * inputZeroPoint + outputZeroPointFused;
}

}

\(M = \frac{S_I \cdot S_W}{S_O}\)mScaleFloatVec[i] = alphaValue * alphaScale

alphaValue 对应权重尺度 \(S_W\)alphaScale 对应\(S_I / S_O\)

\(\text{FusedBias} = -Z_I \sum W_q + \frac{B_{fp}}{S_I S_W} + \frac{Z_O}{M}\)mBiasInt32Vec[i] = term1 + term2 + term3;

term1 = biasData[i] / (alphaScale * alphaValue) 对应 \(B_{fp} / M\)。为了让它和公式中的 \(\frac{B_{fp}}{S_I S_W}\) 匹配,biasData 必须是预先处理过的 \(B_{fp}/S_O\)。这是一个常见的实现技巧。

term2 = - mInt8WeightKernelSum[i] * inputZeroPoint 对应 \(-Z_I \sum W_q\)

term3 = outputZeroPointFused (即 outputZeroPoint / mScaleFloatVec[i]) 对应公式中的 \(Z_O / M\)

llama.cpp 实现

混合精度 GEMM 与 GEMV 代码重点在 ggml/src/ggml-cuda/ 下的 ggml-cuda.cu、mmq.cu、mmq.cuh、mmvq.cu。

量化相关代码重点在 quantize.cu 中;底层向量矩阵乘法实现在 vecdotq.cu 中。

Ilama.cpp 对于 FP16/FP32* INT 混合精度矩阵乘法:

  1. cuBLAS 路径:全反量化成 FP16/FP32(代码在 ggml-cuda.cu 中的 ggml_cuda_mul_mat_batched_cublasggml_cuda_op_mul_mat_cublas);
  2. 手写内核(MMQ)路径:
    1. 对 FP16INT,使用cuBLAS 路径,执行 FP16 FP16,再将 FP16 结果反量化成 FP32;
    2. 对 FP32 * INT,全矩阵量化成 INT,执行 INT *INT,再将 INT32 结果反量化为 FP32;
      1. 分muL_mat_q 和 mul_mat_vec_q 两个版本;还有对 MoE 特化版本;
      2. 没有能够直接调用的在线反量化并做矩阵乘的 cuBLAS 接口;没用 Cutlass;
      3. 大量模板元实现的编译期分支,用于确定核函数常数和调用的函数指针;

以下是几个与 MMQ 路径相关的关键函数:

  • ggml_cuda_op_mul_mat: 一个通用的矩阵乘法执行引擎,它负责处理多GPU数据切分与同步,并能通过函数指针调用任何具体计算实现(cuBLAS或自定义量化核函数)。
  • ggml_cuda_mul_mat: 矩阵乘法操作的顶层“决策者”,它通过分析输入张量的类型、形状和硬件特性,智能地分发任务给最优的后端实现(如自定义量化内核或多种cuBLAS路径)。 在需要 GPU 切分等时调用 ggml_cuda_op_mul_mat,否则直接调用 ggml_cuda_mul_mat_q 等;
  • ggml_cuda_mul_mat_q: 通用“矩阵-矩阵”量化乘法(MMQ)的逻辑主入口,负责动态量化FP32输入并处理标准及混合专家(MoE)两种计算模式。 针对 MoE 的整体解决方案。
  • ggml_cuda_op_mul_mat_q: 作为通用执行引擎调用的底层计算接口,它接收已准备好的量化输入,将其打包为内核参数并启动实际的“矩阵-矩阵”量化计算核函数。 直接得到上层 ggml_cuda_op_mul_mat 量化处理之后的参数。

mul_mat_q_case 是真正的执行函数。

方案一:在线反量化

现在,我们考虑上一节的在线反量化卷积的矩阵乘法与向量-矩阵乘法特化版本。

根据 CUDA 编程加速技巧,我们对 GEMM 做如下优化:

  1. 将矩阵分割为 16x16 的 Tile,每个 thread-block 负责结果的一块(涉及 A 的一”块行“ 与 B 的一”块列“);每个线程负责其中 A 的一行与 B 的一列, k_tile 负责枚举 K 维度上的块;
  2. 代码中矩阵乘法的过程访问的是 B_tile_fp[k][tx]A_tile[ty][k] 。相邻线程 ty 相等, tx 相邻。因此同一个 warp 内,B 访问的是同一行数据(合并访问),而 A 访问的是同一个数(直接触发广播)。均不会发生 32-way bank conflict,因此对 B_tile 列维度 +1 的 Padding 是没有必要的;
  3. GEMM 循环内第一阶段,通过合并访存将 A 和 B 载入到 Shared Memory 中;第二阶段并行进行反量化与转置;第三阶段做矩阵乘法;

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
const int TILE_DIM = 16;

template<typename T>
__global__ void GEMM_FpAInt8B(
const T* input,
const int8_t* kernel,
const T* scale, const T* offset, const T* bias,
T* output,
const float maxV, const float minV,
const int ic, const int ic_p, const int oc, const int oc_p,
const int batch, const int quanC
) {
__shared__ T A_tile[TILE_DIM][TILE_DIM]; // [batch, ic]
__shared__ int8_t B_tile_s8[TILE_DIM][TILE_DIM]; // [ic, oc] kernel 本身是 B^T [oc, ic]
__shared__ T B_tile_fp[TILE_DIM][TILE_DIM];

const int tx = threadIdx.x;
const int ty = threadIdx.y;

const int block_row = blockIdx.y; // batch
const int block_col = blockIdx.x; // output channel

// 每个线程负责计算输出Tile中的一个元素
const int out_row = block_row * TILE_DIM + ty; // M
const int out_col = block_col * TILE_DIM + tx; // N

float acc = 0.0f;
// 沿K维度(输入通道ic)分块循环
const int num_k_tiles = UP_DIV(ic, TILE_DIM);

for (int k_tile = 0; k_tile < num_k_tiles; ++k_tile) {
const int k_tile_base = k_tile * TILE_DIM;

// 合并访问加载 A_tile (input),线程 (ty, tx) 加载 A_tile[ty][tx]
int a_col_idx = k_tile_base + tx;
A_tile[ty][tx] = (out_row < batch && a_col_idx < ic) ? input[out_row * ic_p + a_col_idx] : (T)0.0f;

// 合并访问加载 B_tile (kernel),线程 (ty, tx) 加载 B_tile[ty][tx]
// kernel 布局为 [oc, ic],需要 B(k,n),即 kernel(n,k)
int b_load_row = block_col * TILE_DIM + ty;
int b_col_idx = k_tile_base + tx;
B_tile_s8[ty][tx] = (b_load_row < oc && b_col_idx < ic) ? kernel[b_load_row * ic_p + b_col_idx] : 0;

__syncthreads();

// 反量化 + 转置
const int K = ty;
const int N = tx;

const int global_k = k_tile_base + K;
const int global_n = block_col * TILE_DIM + N;

if (global_n < oc && global_k < ic) {
const int num_quan_groups_per_channel = (quanC > 0) ? (quanC / oc) : 1;
const int ic_per_group = (num_quan_groups_per_channel > 0) ? (ic / num_quan_groups_per_channel) : ic;
const int group_idx = global_k / ic_per_group;
const int quan_param_index = global_n * num_quan_groups_per_channel + group_idx;

const float x_scale = (float)scale[quan_param_index];
const float x_offset = (float)offset[quan_param_index];

// B_tile_s8(n,k), thread(n,k) -> (tx, ty)
// So we need to read from B_tile_s8[n_dim_in_tile][k_dim_in_tile] -> B_tile_s8[tx][ty]
const float b_quant = (float)B_tile_s8[N][K];

// B_tile_fp[k][n] -> B_tile_fp[ty][tx]
B_tile_fp[K][N] = (T)(b_quant * x_scale + x_offset);
}

__syncthreads();

// 在 SMEM 中进行子矩阵乘法
if (out_col < oc) {
#pragma unroll
for (int k = 0; k < TILE_DIM; ++k) {
acc += (float)A_tile[ty][k] * (float)B_tile_fp[k][tx];
}
}

__syncthreads();
}

// 写回全局内存
if (out_row < batch && out_col < oc) {
acc += (float)bias[out_col];
acc = max(acc, minV);
acc = min(acc, maxV);
output[out_row * oc_p + out_col] = (T)acc;
}
}

// 调用
dim3 threads(TILE_DIM, TILE_DIM);
dim3 blocks(UP_DIV(ocp, TILE_DIM), UP_DIV(batch, TILE_DIM));

对 GEMV 做如下优化:

  1. 每个线程块负责计算一个输出位置,每个线程负责一个 %64 剩余系的位置(合并访存+计算);
  2. 由于向量-矩阵乘的结果是一个向量,因此直接做并行规约(蝶式交换)即可,最终只需要 thread0 写回 Global Memory;
  3. 使用动态 Shared Memory,直接在第一阶段通过合并访存从 Global Memory 中加载;
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
template<typename T>
__global__ void GEMV_FpAInt4B(
const T* input,
const uint8_t* kernel, // kernel 是打包的 int4
const T* scale, const T* offset, const T* bias,
T* output,
const float maxV, const float minV,
const int ic, const int ic_p,
const int oc, const int quanC
) {
extern __shared__ uint8_t smem_buffer[];
T* smem_input = reinterpret_cast<T*>(smem_buffer);
float* partial_sums = reinterpret_cast<float*>(smem_buffer + ic_p * sizeof(T));

const int oz = blockIdx.x; // 当前线程块负责计算的输出通道索引
const int tid = threadIdx.x;

// 加载 input 到共享内存
for (int i = tid; i < ic; i += blockDim.x) smem_input[i] = input[i];
for (int i = ic + tid; i < ic_p; i += blockDim.x) smem_input[i] = (T)0.0f;
__syncthreads();

const int num_quan_groups_per_channel = (quanC > 0) ? (quanC / oc) : 1;
const int ic_per_group = (num_quan_groups_per_channel > 0) ? (ic / num_quan_groups_per_channel) : ic;
const int quan_param_index_base = oz * num_quan_groups_per_channel;

float my_sum = 0.0f;
for (int k = tid * 2; k < ic; k += blockDim.x * 2) {
// 加载打包的权重并解包
const uint8_t ker_packed = kernel[oz * (ic_p / 2) + k / 2];
const int8_t ker0_s8 = (ker_packed >> 4) - 8;
const int8_t ker1_s8 = (ker_packed & 0x0F) - 8;

const int group_idx0 = k / ic_per_group;
const int quan_param_index0 = quan_param_index_base + group_idx0;
const float x_scale0 = scale[quan_param_index0];
const float x_offset0 = offset[quan_param_index0];
my_sum += (float)smem_input[k] * ((float)ker0_s8 * x_scale0 + x_offset0);
my_sum += (float)smem_input[k + 1] * ((float)ker1_s8 * x_scale0 + x_offset0);
}

partial_sums[tid] = my_sum;
__syncthreads();

// 并行规约
for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) partial_sums[tid] += partial_sums[tid + s];
__syncthreads();
}

if (tid == 0) {
float final_val = partial_sums[0] + (float)bias[oz];
final_val = max(final_val, minV);
final_val = min(final_val, maxV);
output[oz] = (T)final_val;
}
}

// 调用
dim3 threads(GEMV_TILE);
dim3 blocks(oc);
size_t input_smem_size = icp * (mFp16Infer ? sizeof(half) : sizeof(float));
size_t reduction_smem_size = GEMV_TILE * sizeof(float);
size_t smem_size = input_smem_size + reduction_smem_size;
GEMV_FpAInt4B<<<blocks, threads, smem_size>>>(...);

方案二:在线量化

方案推导

最后,我们考虑对 \(A\) 做与 \(B\) 相同 \(block\_size\) 的 Per-block 量化(均在 \(K\) 维度切割),再调用 INT*INT 矩阵乘法,最后反量化得到浮点结果。现在,\(A\) 的每行被分割,\(B\) 的每列被分割,其余维度不共享参数。

问题在于,不同 block 中的元素使用的是不同的 量化参数,因此显然无法只通过一次完整的整数矩阵乘法就得到答案。另一方面,逐个 block 地相乘效率又太低,无法享受到 Cutlass 的加速(一个 block 往往只有 64 个元素)。

因此,我们考虑每次拿出 \(A\) 的一”块列“,与 \(B\) 的一”块行“做矩阵乘法,得到一个 \(N \times M\) 的矩阵,然后暴力将矩阵中的每个数(对应一对 block 的乘积)用对应那对 block 的量化参数做反量化,加到答案矩阵上。重复 \(\frac{K}{block\_size}\) 次即可。

更形式化地:

我们将求和维度 \(K\) 分为 \(P\) 个连续的块(block):\(K = G_1 \cup G_2 \cup \dots \cup G_P\)

量化参数 \(s\)\(z\) 不仅依赖于矩阵和行列,更依赖于 \(k\) 所在的块 \(G_i\)

因此,反量化公式应该写成(下标 \(i\) 代表第 \(i\)\(K\) 维度块):

\(A_{fp, mk} = s_{A,m,i} \cdot (A_{q, mk} - z_{A,m,i}) \quad k \in G_i\)

\(B_{fp, kn} = s_{B,n,i} \cdot (B_{q, kn} - z_{B,n,i}) \quad k \in G_i\)

代入矩阵乘法公式( \(g(k)\) 为索引 \(k\) 所属的块的编号):

\(C_{fp, mn} = \sum_{k=1}^{K} \left[ s_{A,m,g(k)} \cdot (A_{q, mk} - z_{A,m,g(k)}) \right] \cdot \left[ s_{B,n,g(k)} \cdot (B_{q, kn} - z_{B,n,g(k)}) \right]\)

由于缩放因子 \(s_A, s_B\) 和零点 \(z_A, z_B\) 的值会随着 \(k\) 的变化而改变。因此不能将它们作为常数从整个求和 \(\sum_{k=1}^{K}\) 中提出来。 改变求和结构,将对 \(k\) 的总求和分解为“对所有块的求和”,内部嵌套“对块内元素的求和”: \(C_{fp, mn} = \sum_{i=1}^{P} \left( \sum_{k \in G_i} \left[ s_{A,m,i} \cdot (A_{q, mk} - z_{A,m,i}) \right] \cdot \left[ s_{B,n,i} \cdot (B_{q, kn} - z_{B,n,i}) \right] \right)\)

对于求和范围 \(\sum_{k \in G_i}\) 内的所有 \(k\),它们都属于同一个块 \(G_i\),因此它们的量化参数不变,可以将这些参数提到内部求和的外面: \(C_{fp, mn} = \sum_{i=1}^{P} s_{A,m,i} \cdot s_{B,n,i} \left( \sum_{k \in G_i} (A_{q, mk} - z_{A,m,i}) \cdot (B_{q, kn} - z_{B,n,i}) \right)\)

以上即逐 block 乘法的数学基础,接下来推导”块行块列“乘法: 我们定义一个\(i\) 块的浮点结果矩阵

\(C_{fp, mn}^{(i)} := s_{A,m,i} \cdot s_{B,n,i} \left( \sum_{k \in G_i} (A_{q, mk} - z_{A,m,i}) \cdot (B_{q, kn} - z_{B,n,i}) \right)\)

那么,最终的输出矩阵 \(C_{fp}\) 就是所有这些局部结果矩阵的简单叠加: \(C_{fp, mn} = \sum_{i=1}^{P} C_{fp, mn}^{(i)}\)

定义与块 \(i\) 相关的子矩阵:

\(A_q^{(i)}\):由矩阵 \(A_q\) 的所有行,以及 只属于块 \(G_i\) 的列 构成的子矩阵。维度为 \(M \times block\_size\)

\(B_q^{(i)}\):由矩阵 \(B_q\) 的所有列,以及 只属于块 \(G_i\) 的行 构成的子矩阵。其维度为 \(block\_size \times N\)

这两个矩阵的乘积结果定义为 \(i\) 块的整数结果矩阵 \(C_{q}^{(i)}\)\(C_{q, mn}^{(i)} := \sum_{k \in G_i} A_{q, mk} B_{q, kn}\)

这个运算是一个维度为 \((M \times block\_size) \times (block\_size \times N) \rightarrow (M \times N)\) 的通用矩阵乘法 (GEMM)。其结果 \(C_{q}^{(i)}\) 是一个完整的 \(M \times N\) 的 INT32 矩阵。

再次展开 \(C_{fp, mn}^{(i)}\) 的定义并将 \(C_{q, mn}^{(i)}\) 代入,我们得到了最终的反量化公式,它描述了如何将一个完整的中间整数矩阵逐元素地转换为浮点矩阵:

\(C_{fp, mn}^{(i)} = s_{A,m,i} s_{B,n,i} \cdot \\ \left( \underbrace{\sum_{k \in G_i} A_{q, mk}B_{q, kn}}_{C_{q, mn}^{(i)}} - z_{B,n,i}\sum_{k \in G_i} A_{q, mk} - z_{A,m,i}\sum_{k \in G_i} B_{q, kn} + block\_size \cdot z_{A,m,i}z_{B,n,i} \right)\)

算法流程如下:

for i = 1 to P(遍历所有 \(K\) 维度的块):

  1. 整数子矩阵乘法
    1. 提取子矩阵 \(A_q^{(i)}\)\(B_q^{(i)}\)
    2. 调用高效 int8*int8 -> int32 GEMM 库函数,计算出中间结果矩阵 \(C_{q_i} = \text{matmul}(A_{q_i}, B_{q_i})\)
  2. 逐元素反量化与累加
    1. 启动一个 CUDA Kernel,每个线程处理 \(C_{q_i}\) 的一个或多个元素。
    2. 在 Kernel 内部,对于每个元素 \((m, n)\)
      1. 读取 \(C_{q_i,mn}\),根据块 \(i\) 的量化参数 \(s_{A,m,i}, z_{A,m,i}, s_{B,n,i}, z_{B,n,i}\),计算出局部浮点值 \(c_{fp_i}\)
      2. 以原子方式或直接(如果输出空间不冲突)将 \(c_{fp_i}\) 加到最终结果矩阵的对应位置。

以下代码采用 offset 而非 zero-point 形式:

\(C_{fp, mn}^{(i)} = \sum_{k \in G_i} \left[ \text{scale}_{A,m,i} A_{q,mk} + \text{offset}_{A,m,i} \right] \cdot \left[ \text{scale}_{B,n,i} B_{q,kn} + \text{offset}_{B,n,i} \right] \\ = \left(\text{scale}_{A,m,i}\text{scale}_{B,n,i}\right) C_{q,mn}^{(i)} + \left(\text{scale}_{A,m,i}\text{offset}_{B,n,i}\right) A_{q,m,i} \\ + \left(\text{offset}_{A,m,i}\text{scale}_{B,n,i}\right) B_{q,n,i} + K_i \cdot \text{offset}_{A,m,i}\text{offset}_{B,n,i}\)

代码

对于核函数部分,做了以下优化:

  1. 使用 union 复用共享内存,减少总占用量;
  2. 对于并行规约,先试用 __shfl_down_sync 做 Warp 内规约(原子操作),再在 Warp 0 做一次 Warp 间规约(实际上仍是 __shfl_down_sync 做 Warp 内规约);
  3. 查看 cuda_fp16.hpp(CUDA Toolkit 12.4)可知,__half atomicAdd(__half *const address, const __half val) 需要在 SM_70 及以后才支持,因此使用预编译期条件(这种方式同样可以应对胖二进制(Fat Binary)编译方式的情况) #if (defined(__CUDACC__) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700))) || defined(_NVHPC_CUDA) 和编译期条件 if constexpr (std::is_same_v<T, half>) 进行分支判断;
  4. Cutlass 可以直接调用,也可以分成准备+执行两阶段(在下文 Cutlass 相关代码中);
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
// 动态量化 A、计算 sum_A
template<typename T>
__global__ void QuantA(
const T* A_sub_fp, int8_t* A_sub_q,
T* scale_A_out, T* offset_A_out, int32_t* sum_A_q_out,
const int M, const int K_i, const int lda
) {
// 使用 union 复用共享内存,减少总占用量
__shared__ union {
float min_max_vals[2][BLOCK_SIZE / WARP_SIZE]; // 用于 min/max 的 warp 间规约
int32_t sum_vals[BLOCK_SIZE / WARP_SIZE]; // 用于 sum 的 warp 间规约
float scale_offset[2]; // 用于向块内所有线程广播 scale 和 offset
} smem;

const int m = blockIdx.x;
if (m >= M) return;

const int tid = threadIdx.x;
const int warp_id = tid / WARP_SIZE;
const int lane_id = tid % WARP_SIZE;

float my_min = FLT_MAX;
float my_max = -FLT_MAX;

for (int k = tid; k < K_i; k += BLOCK_SIZE) {
float val = (float)A_sub_fp[m * lda + k];
my_min = min(my_min, val);
my_max = max(my_max, val);
}

// Warp 内规约:使用 __shfl_down_sync 在 warp 内无锁计算 min/max
for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) {
my_min = min(my_min, __shfl_down_sync(0xFFFFFFFF, my_min, offset));
my_max = max(my_max, __shfl_down_sync(0xFFFFFFFF, my_max, offset));
}

// Warp 间规约:每个 warp 的 0 号线程将 warp 结果写入共享内存
if (lane_id == 0) {
smem.min_max_vals[0][warp_id] = my_min;
smem.min_max_vals[1][warp_id] = my_max;
}
__syncthreads();

if (warp_id == 0) {
// 从共享内存加载各 warp 的结果
if (lane_id < (BLOCK_SIZE / WARP_SIZE)) {
my_min = smem.min_max_vals[0][lane_id];
my_max = smem.min_max_vals[1][lane_id];
} else {
my_min = FLT_MAX;
my_max = -FLT_MAX;
}

// 在第一个 warp 内部完成最终规约
for (int offset = (BLOCK_SIZE / WARP_SIZE) / 2; offset > 0; offset >>= 1) {
my_min = min(my_min, __shfl_down_sync(0xFFFFFFFF, my_min, offset));
my_max = max(my_max, __shfl_down_sync(0xFFFFFFFF, my_max, offset));
}
}
// 线程 0 计算 scale/offset 并写入共享内存进行广播
if (tid == 0) {
float scale = (my_max - my_min) / 255.0f;
float offset;
if (abs(scale) > 1e-5) {
offset = my_max - scale * 127.0f; // my_max 映射到 127, my_min 映射到 -128
} else {
scale = 1.0f;
offset = my_max;
}

scale_A_out[m] = (T)scale;
offset_A_out[m] = (T)offset;

smem.scale_offset[0] = scale;
smem.scale_offset[1] = offset;
}
__syncthreads();

// 所有线程从共享内存读取 scale/offset, 并进行量化和求和
const float s_scale_A = smem.scale_offset[0];
const float s_offset_A = smem.scale_offset[1];

int32_t my_sum_q = 0;
for (int k = tid; k < K_i; k += BLOCK_SIZE) {
float val_fp = (float)A_sub_fp[m * lda + k];
int32_t val_q = roundf((val_fp - s_offset_A) / s_scale_A);
int8_t a_q = (int8_t)max(-128, min(127, val_q));
A_sub_q[m * K_i + k] = a_q;
my_sum_q += a_q;
}

for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) {
my_sum_q += __shfl_down_sync(0xFFFFFFFF, my_sum_q, offset);
}

if (lane_id == 0) smem.sum_vals[warp_id] = my_sum_q;
__syncthreads();

if (warp_id == 0) {
my_sum_q = (lane_id < (BLOCK_SIZE / WARP_SIZE)) ? smem.sum_vals[lane_id] : 0;
for (int offset = (BLOCK_SIZE / WARP_SIZE) / 2; offset > 0; offset >>= 1) {
my_sum_q += __shfl_down_sync(0xFFFFFFFF, my_sum_q, offset);
}
}

// 线程 0 将最终的和写入全局内存
if (tid == 0) sum_A_q_out[m] = my_sum_q;
}

__global__ void GEMM_Int8(
const int8_t* A_q, const int8_t* B_q, int32_t* C_q,
const int M, const int N, const int K_i,
const int lda_q, const int ldb, const int ldc
) {
__shared__ int8_t A_tile_s8[SUB_GEMM_TILE_DIM][SUB_GEMM_TILE_DIM];
__shared__ int8_t B_tile_s8[SUB_GEMM_TILE_DIM][SUB_GEMM_TILE_DIM];

const int block_row = blockIdx.y;
const int block_col = blockIdx.x;
const int ty = threadIdx.y;
const int tx = threadIdx.x;

const int row = block_row * SUB_GEMM_TILE_DIM + ty;
const int col = block_col * SUB_GEMM_TILE_DIM + tx;

int32_t acc = 0;
const int num_k_tiles = UP_DIV(K_i, SUB_GEMM_TILE_DIM);

for (int k_tile = 0; k_tile < num_k_tiles; ++k_tile) {
const int k_tile_base = k_tile * SUB_GEMM_TILE_DIM;

A_tile_s8[ty][tx] = (row < M && (k_tile_base + tx) < K_i) ? A_q[row * lda_q + k_tile_base + tx] : 0;
B_tile_s8[ty][tx] = (col < N && (k_tile_base + ty) < K_i) ? B_q[col * ldb + k_tile_base + ty] : 0;
__syncthreads();

#pragma unroll
for (int k = 0; k < SUB_GEMM_TILE_DIM; ++k) {
acc += (int32_t)A_tile_s8[ty][k] * (int32_t)B_tile_s8[k][tx];
}
__syncthreads();
}

if (row < M && col < N) C_q[row * ldc + col] = acc;
}

template<typename T>
__global__ void DequantAndAcc(
const int32_t* C_q, T* C_fp_final,
const T* scale_A_in, const T* offset_A_in,
const T* base_scale_B, const T* base_offset_B,
const int32_t* base_sum_B_q,
const int group_idx, const int num_oc_groups,
const int32_t* sum_A_q_in,
const int M, const int N, const int K_i, const int ldc
) {
// 使用共享内存缓存行和列的量化参数
__shared__ float smem_scale_A[DEQUANT_TILE_DIM];
__shared__ float smem_offset_A[DEQUANT_TILE_DIM];
__shared__ int32_t smem_sum_A_q[DEQUANT_TILE_DIM];

__shared__ float smem_scale_B[DEQUANT_TILE_DIM];
__shared__ float smem_offset_B[DEQUANT_TILE_DIM];
__shared__ int32_t smem_sum_B_q[DEQUANT_TILE_DIM];

const int tx = threadIdx.x;
const int ty = threadIdx.y;

const int block_row_start = blockIdx.y * DEQUANT_TILE_DIM;
const int block_col_start = blockIdx.x * DEQUANT_TILE_DIM;

// 每个线程加载一个 A 的参数
int m_load_idx = block_row_start + ty;
if (tx == 0 && m_load_idx < M) {
smem_scale_A[ty] = (float)scale_A_in[m_load_idx];
smem_offset_A[ty] = (float)offset_A_in[m_load_idx];
smem_sum_A_q[ty] = sum_A_q_in[m_load_idx];
}

// 每个线程加载一个 B 的参数
int n_load_idx = block_col_start + tx;
if (ty == 0 && n_load_idx < N) {
const size_t b_param_idx = n_load_idx * num_oc_groups + group_idx;
smem_scale_B[tx] = (float)base_scale_B[b_param_idx];
smem_offset_B[tx] = (float)base_offset_B[b_param_idx];
smem_sum_B_q[tx] = base_sum_B_q[b_param_idx];
}

__syncthreads();

// 计算每个线程负责的输出点
const int m = block_row_start + ty;
const int n = block_col_start + tx;

if (m < M && n < N) {
// 从共享内存中读取参数
const float scale_A = smem_scale_A[ty];
const float offset_A = smem_offset_A[ty];
const float sum_A_q = (float)smem_sum_A_q[ty];

const float scale_B = smem_scale_B[tx];
const float offset_B = smem_offset_B[tx];
const float sum_B_q = (float)smem_sum_B_q[tx];

const float c_q_val = (float)C_q[m * ldc + n];

float term1 = scale_A * (c_q_val * scale_B + sum_A_q * offset_B);
float term2 = offset_A * (sum_B_q * scale_B + K_i * offset_B);
float final_val = term1 + term2;

#if (defined(__CUDACC__) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700))) || defined(_NVHPC_CUDA)
if constexpr (std::is_same_v<T, float>) {
atomicAdd(&C_fp_final[m * ldc + n], final_val);
} else if constexpr (std::is_same_v<T, half>) {
// printf("[final_val]%.4f %.4f\n", __half2float(C_fp_final[m * ldc + n]), final_val);
atomicAdd(&C_fp_final[m * ldc + n], __float2half(final_val));
// printf("[C_fp_final]%.4f\n", __half2float(C_fp_final[m * ldc + n]));
}
#else
C_fp_final[m * ldc + n] += final_val;
#endif
}
}

// 预计算权重 B 的修正项 (sum_B_q)
__global__ void Precompute_SumBq(
const int8_t* B_q, int32_t* sum_B_q_out,
const int num_groups, const int ic_per_group,
const int oc, const int ic_p
) {
// 每个线程处理一个 (output_channel, group) 的修正项
for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < (size_t)oc * num_groups; index += blockDim.x * gridDim.x) {

const int i = index % num_groups; // group_idx
const int n = index / num_groups; // output_channel

const int k_start = i * ic_per_group;

int32_t sum = 0;
for (int k_offset = 0; k_offset < ic_per_group; ++k_offset) {
// 访问 B 矩阵的布局是 [oc][ic_p]
sum += (int32_t)B_q[n * ic_p + k_start + k_offset];
}
sum_B_q_out[index] = sum;
}
}

template<typename T>
__global__ void BiasAndActivation(
T* data, const T* bias,
const float minV, const float maxV,
const int M, const int N, const int ldc
) {
for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < (size_t)M * N; index += blockDim.x * gridDim.x) {
const int m = index / N;
const int n = index % N;

const size_t buffer_idx = m * ldc + n;
float val = (float)data[buffer_idx];
val += (float)bias[n];
val = max(val, minV);
val = min(val, maxV);
data[buffer_idx] = (T)val;
}
}

这里我们只对 INT8 进行优化,因为 Cutlass 对 INT4 GEMM 的支持与硬件计算能力绑定。

查阅 PTX ISATuring Guide 可知,SM_75 架构(Turing)虽然 Tensor Core 理论支持 INT4 GEMM,也提供了诸如 mma.sync.aligned.m8n8k16.row.col.s32.s4.s4.s32 的 PTX ISA,但 Cutlass 并没有很好地集成相应能力。而 s4.s4 的 wmma 和 Cutlass 支持则要到 Ampere(SM_80 以上)才能支持。

Cutlass 代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#ifndef CutlassGemmIntParam_hpp
#define CutlassGemmIntParam_hpp

#include "../CutlassGemmParam.hpp"

namespace MNN {
namespace CUDA {

// int8 * int8 => int32 GEMM.
using EpilogueGemmInt = cutlass::epilogue::thread::LinearCombination<
int32_t, // ElementC
1, // Elements per access.
int32_t, // ElementAccumulator
int32_t // ElementCompute, not used for this epilogue
>;

using CutlassGemmInt = cutlass::gemm::device::Gemm<
int8_t, // ElementA
cutlass::layout::RowMajor, // LayoutA
int8_t, // ElementB
cutlass::layout::ColumnMajor, // LayoutB (Using the same trick as before)
int32_t, // ElementC
cutlass::layout::RowMajor, // LayoutC
int32_t, // ElementAccumulator
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75, // Target GPU Architecture
cutlass::gemm::GemmShape<128, 128, 32>, // ThreadblockShape
cutlass::gemm::GemmShape<64, 64, 32>, // WarpShape
cutlass::gemm::GemmShape<8, 8, 16>, // InstructionShape
EpilogueGemmInt,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzle
2 // Stages
>;

} // namespace CUDA
} // namespace MNN

#endif // CutlassGemmIntParam_hpp

// 准备阶段
mGemmArguments = {
{M, N, ic_per_group},
{nullptr, ic_per_group}, // ptr_A and lda_A
{nullptr, UP_DIV(K, 8) * 8}, // ptr_B and ldb_B
{nullptr, UP_DIV(N, 8) * 8}, // ptr_C and ldc_C
{nullptr, UP_DIV(N, 8) * 8}, // ptr_D and ldd_D
{1, 0}, // Epilogue: D = 1 * A*B + 0 * C
1
};

// Check if CUTLASS can support this problem
cutlass::Status status = mCutlassGemmInt.can_implement(mGemmArguments);
if (status != cutlass::Status::kSuccess) {
MNN_ERROR("CUTLASS GEMM cannot implement this problem\n");
return NOT_SUPPORT;
}

// Allocate workspace if needed
size_t workspace_size = mCutlassGemmInt.get_workspace_size(mGemmArguments);
if (workspace_size > 0) {
mWorkspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
backend()->onAcquireBuffer(mWorkspaceTensor.get(), Backend::STATIC);
mWorkspacePtr = (void*)mWorkspaceTensor->buffer().device;
}

// Initialize the GEMM kernel
status = mCutlassGemmInt.initialize(mGemmArguments, mWorkspacePtr);
cutlass_check(status);

主函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
const int M = input->batch();
const int K = input->channel();
const int N = output->channel();

const int ic_p = UP_DIV(K, 8) * 8;
const int oc_p = UP_DIV(N, 8) * 8;

const int num_groups = (mResource->mQuanC > 0) ? (mResource->mQuanC / N) : 1;
const int ic_per_group = (num_groups > 0) ? (K / num_groups) : K;

auto Cq_tensor = std::make_shared<Tensor>(Tensor::createDevice<int32_t>({M, oc_p}));
backend()->onAcquireBuffer(Cq_tensor.get(), Backend::STATIC);
int32_t* C_q_buffer = (int32_t*)Cq_tensor->deviceId();

const int type_size = mFp16Infer ? sizeof(half) : sizeof(float);
auto scaleA_tensor = std::make_shared<Tensor>(Tensor::createDevice<uint8_t>({M * type_size}));
backend()->onAcquireBuffer(scaleA_tensor.get(), Backend::STATIC);
void* scale_A_buffer = (void*)scaleA_tensor->deviceId();

auto offsetA_tensor = std::make_shared<Tensor>(Tensor::createDevice<uint8_t>({M * type_size}));
backend()->onAcquireBuffer(offsetA_tensor.get(), Backend::STATIC);
void* offset_A_buffer = (void*)offsetA_tensor->deviceId();

auto sumA_tensor = std::make_shared<Tensor>(Tensor::createDevice<int32_t>({M}));
backend()->onAcquireBuffer(sumA_tensor.get(), Backend::STATIC);
int32_t* sum_A_q_buffer = (int32_t*)sumA_tensor->deviceId();

auto Aq_tensor = std::make_shared<Tensor>(Tensor::createDevice<int8_t>({M, ic_per_group}));
backend()->onAcquireBuffer(Aq_tensor.get(), Backend::STATIC);
int8_t* A_q_buffer = (int8_t*)Aq_tensor->deviceId();

runtime->memset(output_addr, 0, M * N * type_size);

// 按 K 维度块循环执行
for (int i = 0; i < num_groups; ++i) {
const int k_start = i * ic_per_group;

// QuantA
dim3 blocks_quant(M);
dim3 threads_quant(BLOCK_SIZE);

// GEMM_Int8
dim3 blocks_gemm(UP_DIV(oc_p, SUB_GEMM_TILE_DIM), UP_DIV(M, SUB_GEMM_TILE_DIM));
dim3 threads_gemm(SUB_GEMM_TILE_DIM, SUB_GEMM_TILE_DIM);

// DequantAndAcc
dim3 blocks_dequant(UP_DIV(N, DEQUANT_TILE_DIM), UP_DIV(M, DEQUANT_TILE_DIM));
dim3 threads_dequant(DEQUANT_TILE_DIM, DEQUANT_TILE_DIM);

if (mFp16Infer) {
const half* A_sub_ptr = (const half*)input_addr + k_start;
const int8_t* B_sub_ptr = (const int8_t*)mResource->mFilter + k_start;

const half* base_scale_B_ptr = (const half*)mResource->mScale;
const half* base_offset_B_ptr = (const half*)mResource->mOffset;
const int32_t* base_sum_B_q_ptr = (const int32_t*)mResource->mSumBQ;

QuantA<half><<<blocks_quant, threads_quant>>>(
A_sub_ptr, A_q_buffer,
(half*)scale_A_buffer, (half*)offset_A_buffer, sum_A_q_buffer,
M, ic_per_group, ic_p
);

// GEMM_Int8<<<blocks_gemm, threads_gemm>>>(
// A_q_buffer, B_sub_ptr, C_q_buffer,
// M, N, ic_per_group, ic_per_group, ic_p, oc_p
// );

mGemmArguments.ref_A.reset(A_q_buffer);
mGemmArguments.ref_B.reset(B_sub_ptr);
mGemmArguments.ref_C.reset(C_q_buffer);
mGemmArguments.ref_D.reset(C_q_buffer);

cutlass::Status status = mCutlassGemmInt(mGemmArguments, mWorkspacePtr, 0);
cutlass_check(status);

DequantAndAcc<half><<<blocks_dequant, threads_dequant>>>(
C_q_buffer, (half*)output_addr,
(const half*)scale_A_buffer, (const half*)offset_A_buffer,
base_scale_B_ptr, base_offset_B_ptr,
base_sum_B_q_ptr,
i, num_groups,
sum_A_q_buffer,
M, N, ic_per_group, oc_p
);
} else { // FP32
const float* A_sub_ptr = (const float*)input_addr + k_start; // A 第 k 块的起始位置
const int8_t* B_sub_ptr = (const int8_t*)mResource->mFilter + k_start; // B 第 k 块的起始位置

const float* base_scale_B_ptr = (const float*)mResource->mScale;
const float* base_offset_B_ptr = (const float*)mResource->mOffset;
const int32_t* base_sum_B_q_ptr = (const int32_t*)mResource->mSumBQ;

QuantA<float><<<blocks_quant, threads_quant>>>(
A_sub_ptr, A_q_buffer,
(float*)scale_A_buffer, (float*)offset_A_buffer, sum_A_q_buffer,
M, ic_per_group, ic_p
);

// GEMM_Int8<<<blocks_gemm, threads_gemm>>>(
// A_q_buffer, B_sub_ptr, C_q_buffer,
// M, N, ic_per_group, ic_per_group, ic_p, oc_p
// );

mGemmArguments.ref_A.reset(A_q_buffer);
mGemmArguments.ref_B.reset(B_sub_ptr);
mGemmArguments.ref_C.reset(C_q_buffer);
mGemmArguments.ref_D.reset(C_q_buffer);

// Run the pre-initialized GEMM operation
cutlass::Status status = mCutlassGemmInt(mGemmArguments, mWorkspacePtr, 0);
cutlass_check(status);

DequantAndAcc<float><<<blocks_dequant, threads_dequant>>>(
C_q_buffer, (float*)output_addr,
(const float*)scale_A_buffer, (const float*)offset_A_buffer,
base_scale_B_ptr, base_offset_B_ptr,
base_sum_B_q_ptr,
i, num_groups,
sum_A_q_buffer,
M, N, ic_per_group, oc_p
);
}
}

float maxV = FLT_MAX, minV = -FLT_MAX;
if (mActivationType == 1) minV = 0.0f;
if (mActivationType == 2) { minV = 0.0f; maxV = 6.0f; }

const int total_threads_act = M * oc_p;
const int block_size_act = BLOCK_SIZE;
const int num_blocks_act = (total_threads_act + block_size_act - 1) / block_size_act;

if (mFp16Infer) {
BiasAndActivation<half><<<num_blocks_act, block_size_act>>>((half*)output_addr, (const half*)mResource->mBias, minV, maxV, M, N, oc_p);
} else {
BiasAndActivation<float><<<num_blocks_act, block_size_act>>>((float*)output_addr, (const float*)mResource->mBias, minV, maxV, M, N, oc_p);
}

backend()->onReleaseBuffer(Cq_tensor.get(), Backend::STATIC);
backend()->onReleaseBuffer(scaleA_tensor.get(), Backend::STATIC);
backend()->onReleaseBuffer(offsetA_tensor.get(), Backend::STATIC);
backend()->onReleaseBuffer(sumA_tensor.get(), Backend::STATIC);
backend()->onReleaseBuffer(Aq_tensor.get(), Backend::STATIC);

return NO_ERROR;