Introducing dmx.compressor

Quantization plays a key role in reducing memory usage, speeding up inference, and lowering energy consumption at inference time. As large language models (LLMs) continue to grow exponentially in size — some now exceeding a trillion parameters — quantization becomes crucial to running these large models efficiently on specialized hardware, such as, d-Matrix’s Corsair hardware (https://www.d-matrix.ai/product/).

In this article, we would like to introduce dmx.compressor, a user-friendly model compression toolkit for functional fake quantization of large language models (LLMs) and built on top of torch.fx. This toolkit simplifies the process involved in Post-Training Quantization (PTQ) and Quantization Aware Training (QAT), making them more accessible to users.

The key features of dmx.compressor include high customizability, supporting custom numerical formats, operations, and quantization configurations, which could be super powerful for research. The backward function is also available for fine-tuning with straight through estimators used for the quantization operations. The quantized model retains compatibility with the original model’s interface, so that it can be easily integrated with the PyTorch ecosystem and Hugging Face pipelines. In addition, we offer an API that enables users to apply other optimizations like quantization calibration, post-training quantization, and Quantization Aware Training. These features make dmx.compressor an ideal tool for researchers who want to experiment with new numerical formats, quantization algorithms, and chip design concepts. The fake-quantized models can be executed on any PyTorch-supported backend, including CPUs and GPUs.

Torch.fx is a popular framework for fake quantization, it allows users to create a quantized GraphModule by symbolically tracing the original model and insert quantization operations or replace original operations with their quantized versions.

Quantization using Torch Fx

However, quantizing large language models using FX presents the following challenges:
1. The quantized model generated by torch.fx is a fx.GraphModule, which is incompatible with the original model’s API. For example, if a user obtains an LLM from Hugging Face, it will have specific functionalities like the generate method. Unfortunately, the fx.GraphModule loses all attributes and class methods of the LLM, making it incapable of tasks like text generation.
2. Torch.fx only supports static graphs, while LLMs have highly dynamic computation graphs. Consider the following code snippet as an example.

from transformers import pipeline
pipe = pipeline(task="text-generation", model="meta-llama/Meta-Llama-3-8B")
pipe("Once upon a time, ")  # this inference API invokes the LLM multiple times

The pipeline call triggers the generate function of the Llama model, which iteratively calls the model’s forward function. As shown in the table below, the input signature differs between the first and second calls.

Input arguments to llama forward calls

Diving deeper into the Llama forward function, we see that some input arguments are conditional, determining which blocks of code are executed. For instance, when past_key_values is not None, additional handling of the KV cache occurs, resulting in different computation graphs.

if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

Double Graph Representation

To address the compatibility issue of quantized models, we introduce a double graph representation behind the scenes (Please play with the interactive image below to get more details). The process starts with an instance of a model, such as a Hugging Face Pretrained model or a standard Torch module. We then define tracing rules to instruct the FX tracer on which submodules to avoid tracing and preserve as module operations. After obtaining a GraphModule from tracing, we substitute the original operation nodes with new nodes that contain quantization. Substitution rules are specified to map original operations to corresponding DmxModules, which are modules wrapped in quantization operations. Once tracing and transformation are complete, we get a GraphModule containing DmxModules, which is then attached to the original model and packed into a DmxModel. This double graph representation is key to maintaining the original model’s structure while ensuring the forward call triggers the GraphModule’s forward function, encompassing all the quantization operations. Once we’ve inserted our compression-related operations, we configure them into the appropriate format. Users can define custom formats subclassing from Format class, which can be powerful for research. Alternatively, they can select from predefined formats, which includes the emerging Microscaling formats. After that, configuration rules define which formats apply to which operations, down to the granularity of a specific operation, leading to a fully quantized DmxModel.

If you are curious about how a DmxModule look like, let’s peek into one of them, for example, a DmxLinear module. We can see that it preserves the features of Torch Linear while adding quantization ops. After configuration, the format of the CastTos will change. More specifically, following the quantization config defined in the previous figure (click on the Quantization config block!), the format in input_cast and weight_cast has been updated. For memory efficiency, the parameters of DmxModules are tied to the original module.

DmxLinear

JIT Tracing

Next, we tackle the issue of fx.Tracer only generating static graphs. Our solution is Just-In-Time (JIT) tracing, which defers model tracing and transformation to runtime, allowing the computation graph to be based on actual input arguments. If subsequent forward calls detect changes in the input signature that would lead to a new computation graph, retracing is automatically triggered. It’s important to note that the quantization configuration is decoupled from the GraphModule itself, meaning that even when retracing occurs and a new GraphModule is generated, the existing quantization configuration is seamlessly applied to the new graph.

Let’s dive deeper into the process, checkout the animation below for illustration. First, we instantiate a DmxModel. At this stage, no GraphModules have been created yet. The DmxModel closely resembles the original model, except that its forward function is refactored to handle JIT tracing. The next step is defining the quantization configuration. When an input is passed through the DmxModel, since no GraphModule exists yet, tracing and substitution are triggered, creating the GraphModule, and the quantization configuration is then applied. The input tensor proceeds through the GraphModule’s forward function, producing the output. For subsequent forward passes, the DmxModel checks whether the new input will result in a different computation graph. If no changes are detected, the input is passed directly to the existing GraphModule. However, if the input signature changes significantly, retracing is triggered, creating a new GraphModule with the same quantization configuration applied. The output is then produced using this updated graph module.

JIT transformation

Note that the DmxConfig can be modified at any point of time and the future forward passes will use the updated DmxConfig to apply on the GraphModule.

Toy Example

Let’s summarize the flow of Dmx-Compressor with an example. Imagine a custom network containing a linear layer with two optional arguments in the forward function. The computation graph of this model varies depending on the input arguments, making it impossible to determine a static graph beforehand.

class CustomNet(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layer = torch.nn.Linear(DIM, DIM)
        self.reset_params()

    def forward(
            self,
            x: torch.Tensor,
            old_x: Optional[torch.Tensor] = None,
            use_gelu: bool = False,
        ):
            x = self.layer(x)
            if use_gelu:
                x = torch.nn.functional.gelu(x)
            else:
                x = torch.nn.functional.relu(x)
            if old_x is not None:
                x += old_x
            return x
Dynamic computation graph for toy example

When the dmx.compressor API is evoked, the process begins by instantiating a CustomNet instance. Wrapping this instance into a DmxModel initially does not trigger FX tracing and transformation. The configuration rules are registered to the model by calling the configure function.

model = CustomNet()
model = DmxModel.from_torch(model)
x = torch.rand(1, DIM)
rules = (
  DmxConfigRule(
    module_types=(Linear,),
    module_config=dict(
      input_formats=[format.MXINT8_K64],
      weight_format=format.MXINT4_K64,
    ),
  ),
)
model.configure(None, *rules)

When the model’s forward function is called, the input arguments are inspected, triggering tracing, and the static GraphModule is determined.

> model(x)
tensor([[0.0000, 1.1791, …

To make things clearer, we provide a tool to visualize the GraphModule.

model.visualize_graph()
GraphModule visualization given input x

If subsequent inputs do not necessitate changes, retracing is not triggered, the same GraphModule is reused.

> model(-x)
tensor([[1.1308, 0.5489, …

However, passing inputs that results in a different computation graph, such as passing use_gelu=True would trigger retracing.

> model(x, use_gelu=True)
tensor([[-1.6742e-01, 1.0386e+00, …

Passing old_x as a tensor instead of None would also trigger retracing.

> model(x, old_x=x, use_gelu=True)
tensor([[1.3736, 0.7451, …

Please visit our GitHub repo at https://github.com/d-matrix-ai/dmx-compressor. For questions and feedback, please open a GitHub issue. Community contributions are also greatly welcomed. For more information regarding d-Matrix, please visit https://www.d-matrix.ai/