-
-
Notifications
You must be signed in to change notification settings - Fork 6.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC]: Int8 Activation Quantization #3975
Comments
Note: this will also be extremely important for chunked prefill regime |
#7271 has landed, so closing now! |
Summary
Motivation and Scope
The high-level goal of this RFC is to speed up Prefill by increasing the rate of computation by using int8 tensor cores. We don't anticipate improving decode performance except for very large batch sizes, as inference time in that case is dominated by loading the weights and is already well-served by weight-only quantization.
Int4 activation quantization is out of scope for this RFC, but we are interested in support for it. Successful int4 activation quantization (namely QuaRot) requires more work and more extensive modifications to the model definitions than int8 activation quantization, so it's natural to do this after int8 quantization.
For this RFC, we are focusing on support for Nvidia GPUs, and leaving other systems as out of scope.
Quantization Schemes and Zero Points
We are considering quantization of the form:
$$\widehat X = \lfloor \frac{X}{s_x} \rceil + z_x$$ $X$ is floating point, and $\widehat X$ will be its int8 quantized representation. $s_x$ is the scale or tensor of scales, and $z_x$ is a zero point.
In this case,
There are several cases to consider, with performance and accuracy tradeoffs in each case.
In light of these considerations, this RFC proposes initially supporting the following cases.
For the weights:
For the activations:
Other cases left as future work, out of scope for this RFC: asymmetric w8a8 weights and asymmetric per-token activations, can be handled by additional$\mathcal O(n^2)$ terms that are be computed during inference.
For asymmetric quantized weights where the activation is stored in a higher precision, such as w4a8, the zero points may be handled via a shift after the weights are up-converted to the activation's precision for computation.
Zero Point Correction Terms
This section is a zoom-in on the linear algebra for the zero point correction terms, to further motivate some of the decisions made above on support for asymmetric vs symmetric and per-token vs per-tensor cases.
Suppose we want to compute a quantized GEMM operation$C = AB$ , where $A$ is $m \times k$ , $B$ is $k \times n$ , and $C$ is $m \times n$ . In this setting, $A$ is the input activation matrix and $B$ is the weight matrix, known offline. We quantize we quantize the matrices as $C = s_C (\widehat C - z_C J_C)$ , $B = s_B (\widehat B - z_B J_B)$ , $A = s_A (\widehat A - z_A J_A)$ .$s_X$ is the scale of matrix $X$ , $z_X$ is the zero point of $X$ , and $J_X$ is the conformal matrix of all ones. Here we are ignoring any rounding for quantization for simplicity. Let's furthermore assume that $z_C = 0$ and $s_A, s_B, s_C = 1$ just to get them out of the way -- the scales of all matrices and the output's zero point are pretty easy to deal with.
This is per-tensor quantization where
Let's substitute the above equations into$C = AB$ to see how to compute $\widehat C$ .
$C = AB$
$\widehat C = (\widehat A - z_A J_A) (\widehat B - z_B J_B)$
$\widehat C = \widehat A \widehat B - z_A J_A \widehat B - z_B \widehat A J_B + z_A z_B J_A J_B$
A brief remark on each term:
The text was updated successfully, but these errors were encountered: