Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC]: Int8 Activation Quantization #3975

Closed
tlrmchlsmth opened this issue Apr 10, 2024 · 3 comments
Closed

[RFC]: Int8 Activation Quantization #3975

tlrmchlsmth opened this issue Apr 10, 2024 · 3 comments

Comments

@tlrmchlsmth
Copy link
Collaborator

tlrmchlsmth commented Apr 10, 2024

Summary

  • We (engineering at @neuralmagic) are working on support for int8 quantized activations.
  • This RFC is proposing an incremental approach to quantization, where the initial support for quantization will make minimal and local changes to the PyTorch model definitions. We propose swapping out Linear and Attention modules with their quantized counterparts without modifying the graphs around them. The upside to this will be quicker support for quantized models. The downside is that we will be quantizing the activations on the fly prior to computation.
  • To reduce the additional data movement from quantizing the activations on the fly, the activations will need to remain quantized throughout the graph, requiring more extensive and nonlocal modifications to the model definitions. We will be working on abstractions for the quantized model definitions to make adding support for new models as easy as possible.
  • Activation quantization will introduce additional elementwise operations to the model. To reduce the additional data movement of the activations from these operations, operator fusion will be needed. Rather than manually writing fused kernels for these, this RFC proposes committing to a torch.compile-based solution, to be explored in a future RFC.

Motivation and Scope

The high-level goal of this RFC is to speed up Prefill by increasing the rate of computation by using int8 tensor cores. We don't anticipate improving decode performance except for very large batch sizes, as inference time in that case is dominated by loading the weights and is already well-served by weight-only quantization.

Int4 activation quantization is out of scope for this RFC, but we are interested in support for it. Successful int4 activation quantization (namely QuaRot) requires more work and more extensive modifications to the model definitions than int8 activation quantization, so it's natural to do this after int8 quantization.

For this RFC, we are focusing on support for Nvidia GPUs, and leaving other systems as out of scope.

Quantization Schemes and Zero Points

We are considering quantization of the form:
$$\widehat X = \lfloor \frac{X}{s_x} \rceil + z_x$$
In this case, $X$ is floating point, and $\widehat X$ will be its int8 quantized representation. $s_x$ is the scale or tensor of scales, and $z_x$ is a zero point.

There are several cases to consider, with performance and accuracy tradeoffs in each case.

  • Static vs dynamic quantization. The scales and zero points may be known ahead of time, or may instead be determined at runtime after inspecting the values of the tensor. Dynamic quantization will provide more accuracy, but requires multiple passes over the activation.
  • Asymmetric vs symmetric quantization. In symmetric quantization, $z_x$ is equal to 0. In asymmetric quantization $z_x$ is nonzero. When upconverting before quantization, $z_x$ can be applied as a shift prior to computation. If there is no upconversion, then an additional term (which this RFC will call a zero point correction term ) can be computed and added to the output. This costs an additional $\mathcal O(n^2)$, either at runtime or computed offline.
  • Per-tensor vs per-token quantized activations. Generally per-token quantization has higher accuracy but requires more data movement. The particular case of per-token and asymmetric is unfavorable as it increases the dimensionality of the zero point correction term.
  • Per-tensor vs per-column vs group quantized weights. Group quantization will require kernel work for the activation quantization case, so is out of scope for this PR. If weight quantization is symmetric symmetric quantization, per-tensor or per-column quantization can be handled by scaling the output tensor of a linear layer, either by a scalar value in the case of per-tensor quantization or by a vector (with tensor expansion) in the case of per-column quantization.

In light of these considerations, this RFC proposes initially supporting the following cases.

For the weights:

  • w8a8 case: Static, symmetric and either per-tensor or per-column.

For the activations:

  • Static, either symmetric or asymmetric, per-tensor quantization.
  • Dynamic, symmetric, per-token quantization.

Other cases left as future work, out of scope for this RFC: asymmetric w8a8 weights and asymmetric per-token activations, can be handled by additional $\mathcal O(n^2)$ terms that are be computed during inference.
For asymmetric quantized weights where the activation is stored in a higher precision, such as w4a8, the zero points may be handled via a shift after the weights are up-converted to the activation's precision for computation.

Zero Point Correction Terms

This section is a zoom-in on the linear algebra for the zero point correction terms, to further motivate some of the decisions made above on support for asymmetric vs symmetric and per-token vs per-tensor cases.

Suppose we want to compute a quantized GEMM operation $C = AB$, where $A$ is $m \times k$, $B$ is $k \times n$, and $C$ is $m \times n$. In this setting, $A$ is the input activation matrix and $B$ is the weight matrix, known offline. We quantize we quantize the matrices as $C = s_C (\widehat C - z_C J_C)$, $B = s_B (\widehat B - z_B J_B)$, $A = s_A (\widehat A - z_A J_A)$.
This is per-tensor quantization where $s_X$ is the scale of matrix $X$, $z_X$ is the zero point of $X$, and $J_X$ is the conformal matrix of all ones. Here we are ignoring any rounding for quantization for simplicity. Let's furthermore assume that $z_C = 0$ and $s_A, s_B, s_C = 1$ just to get them out of the way -- the scales of all matrices and the output's zero point are pretty easy to deal with.

Let's substitute the above equations into $C = AB$ to see how to compute $\widehat C$.
$C = AB$
$\widehat C = (\widehat A - z_A J_A) (\widehat B - z_B J_B)$
$\widehat C = \widehat A \widehat B - z_A J_A \widehat B - z_B \widehat A J_B + z_A z_B J_A J_B$

A brief remark on each term:

  • $\widehat A \widehat B$: will be computed by our quantized GEMM kernel.

  • $z_A z_B J_A J_B$: If per-tensor quantization is used, every value of $z_A z_B J_A J_B$, is the same and depends only on $k$ and the zero points of $A$ and $B$.

  • $z_A J_A \widehat B$: A few remarks on this one.

    • This term can be computed offline, since $\widehat B$ is known ahead of time.
    • Each row of $J_A \widehat B$ is the same and is equal to $z_A \mathbf 1 \widehat B$, where $\mathbf 1$ is the vector of all ones. This can be computed via a ReduceSum operation or a GEMV operation with a vector of ones.
    • If per-tensor quantization is used, then $z_A \mathbf 1 \widehat B$ can be computed and subtracted from the output via tensor expansion. If we further have static quantization and know $z_A$ in the Linear module's constructor, we can fully compute this term and possibly fold it into the bias if it exists. In that case, asymmetric activation quantization can be implemented at zero cost as compared to the symmetric case.
    • If we are using per-token quantization, this term becomes $z_A \circ (J_A \widehat B)$ where $\circ$ is the Hadamard product with tensor expansion, and $z_A$ is a column-vector. This is equivalent to the outer product of $z_A$ with $\mathbf 1 \widehat B$. This is more expensive to handle than the per-tensor case but can be applied with a rank-1 update to avoid materializing $z_A \circ (J_A \widehat B)$, which is the size of the output matrix.
  • $z_B \widehat A J_B$: This term depends on the activation matrix, so must be computed at runtime if asymmetric weight quantization is used.

@robertgshaw2-redhat
Copy link
Collaborator

Motivation and Scope

The high-level goal of this RFC is to speed up Prefill by increasing the rate of computation by using int8 tensor cores. We don't anticipate improving decode performance except for very large batch sizes, as inference time in that case is dominated by loading the weights and is already well-served by weight-only quantization.

Note: this will also be extremely important for chunked prefill regime

@tlrmchlsmth
Copy link
Collaborator Author

tlrmchlsmth commented Sep 3, 2024

Will close this after #7270 and #7271 land

@tlrmchlsmth
Copy link
Collaborator Author

#7271 has landed, so closing now!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants
@tlrmchlsmth @robertgshaw2-redhat and others