ソースを参照

Update quantization docs and scripts

Sengxian 3 年 前
コミット
28e449b79f

+ 33 - 8
README.md

@@ -7,24 +7,42 @@
 
 # GLM-130B: An Open Bilingual Pre-Trained Model
 
-GLM-130B is an open bilingual (English & Chinese) bidirectional dense model with 130 billion parameters, pre-trained using the algorithm of [General Language Model (GLM)](https://aclanthology.org/2022.acl-long.26). It is designed to support inference tasks with the 130B parameters on **a single A100 (40G * 8)** or **V100 (32G * 8) server**. As of July 3rd, 2022, GLM-130B has been trained on over 400 billion text tokens (200B each for Chinese and English) and it has the following unique features:
+GLM-130B is an open bilingual (English & Chinese) bidirectional dense model with 130 billion parameters, pre-trained using the algorithm of [General Language Model (GLM)](https://aclanthology.org/2022.acl-long.26). It is designed to support inference tasks with the 130B parameters on **a single A100 (40G * 8)** or **V100 (32G * 8) server**. With INT4 quantization, the  hardware requirements can further be reduced to **a single server with 4 * RTX 3090 (24G)** with **almost no performance degradation**. As of July 3rd, 2022, GLM-130B has been trained on over 400 billion text tokens (200B each for Chinese and English) and it has the following unique features:
  
 - **Bilingual:** supports both English and Chinese. 
 - **Performance (EN):** better than GPT-3 175B (+4.0%), OPT-175B (+5.5%), and BLOOM-176B (+13.0%) on LAMBADA and slightly better than GPT-3 175B (+0.9%) on MMLU.
 - **Performance (CN):** significantly better than ERNIE TITAN 3.0 260B on 7 zero-shot CLUE datasets (+24.26%) and 5 zero-shot FewCLUE datasets (+12.75%). 
 - **Fast Inference:** supports fast inference on both [SAT](https://github.com/THUDM/SwissArmyTransformer) and [FasterTransformer](https://github.com/NVIDIA/FasterTransformer) (up to 2.5X faster) with a single A100 server.
 - **Reproducibility:** all results (30+ tasks) can be easily reproduced with open-sourced code and model checkpoints.
-- **Cross-Platform:** supports training and inference on NVIDIA, Hygon DCU, Ascend 910, and Sunway (Will be released soon).
+- **Cross-Platform:** supports training and inference on NVIDIA, Hygon DCU, Ascend 910, and Sunway (Will be released soon).
+
+## News
+
+- **2022.08.24:** We are proud to publish the quantized version for GLM-130B.  While preserving the activation precision as FP16, the model weights can be quantized to as low as **INT4 with almost no degradation of performance**, further reducing the hardware requirements of the GLM-130B to **a single server with 4 * RTX 3090 (24G)**! See [Quantization of GLM-130B](docs/quantization.md) for details.
 
 ## Getting Started
 
-### Environment Setup
+### Environment Setup
+
+#### Hardware
+
+| **Hardware**    | **GPU Memory** | **Quantization** | **Weight Offload** |
+| --------------- | -------------- | ---------------- | ------------------ |
+| 8 * A100        | 40 GB          | No               | No                 |
+| 8 * V100        | 32 GB          | No               | Yes (BMInf)        |
+| 8 * V100        | 32 GB          | INT8             | No                 |
+| 8 * RTX 3090    | 24 GB          | INT8             | No                 |
+| 4 * RTX 3090    | 24 GB          | INT4             | No                 |
+| 8 * RTX 2080 Ti | 11 GB          | INT4             | Yes (BMInf)        |
+
+It is recommended to use the an A100 (40G * 8) server, as all GLM-130B evaluation results (~30 tasks) reported can be easily reproduced with a single A100 server in about half a day. With INT8/INT4 quantization, efficient inference on **a single server with 4 * RTX 3090 (24G)** is possible, see [Quantization of GLM-130B](docs/quantization.md) for details. Combining quantization and weight offloading techniques, GLM-130B can also be inferenced on servers with even more smaller GPU memory, e.g. 8 * RTX 2080 Ti, see [Low-Resource Inference](docs/low-resource-inference.md) for details.
+
+#### Software
 
 The GLM-130B code is built on the top of [SAT](https://github.com/THUDM/SwissArmyTransformer). We recommend using [Miniconda](https://docs.conda.io/en/latest/miniconda.html) to manage your environment and installing additional dependencies via `pip install -r requirements.txt`. Here are the recommended environment configurations:
 
-- Python 3.9+ / CUDA 11+ / PyTorch 1.10+ / DeepSpeed 0.6+ / Apex (**installation with CUDA and C++ extensions is required, see [here](https://github.com/NVIDIA/apex/#linux)**)
-    
-It is recommended to use the an A100 (40G * 8) server, as all GLM-130B evaluation results (~30 tasks) reported can be easily reproduced with a single A100 server in about half a day. GLM-130B can also be inferenced on servers with smaller GPU memory, such as a V100 (32G * 8) server. See [Low-Resource Inference](docs/low-resource-inference.md) for details.
+- Python 3.9+ / CUDA 11+ / PyTorch 1.10+ / DeepSpeed 0.6+ / Apex (**installation with CUDA and C++ extensions is required, see [here](https://github.com/NVIDIA/apex/#linux)**)
+- SwissArmyTransformer>=0.2.11 is required for quantization
 
 Download the GLM-130B’s model checkpoint from [here](https://docs.google.com/forms/d/e/1FAIpQLSehr5Dh_i3TwACmFFi8QEgIVNYGmSPwV0GueIcsUev0NEfUug/viewform?usp=sf_link), make sure all 60 chunks are downloaded completely, then use the following command to merge them into a single archive file and extract it:
 
@@ -33,7 +51,14 @@ cat glm-130b-sat.tar.part_* > glm-130b-sat.tar
 tar xvf glm-130b-sat.tar
 ```
 
-Set `CHECKPOINT_PATH` in `configs/model_glm_130b.sh` to the path of the extracted folder. Since the checkpoint file is up to 260G, it is recommended to use the SSD or RAM disk to reduce the checkpoint loading time.
+Set `CHECKPOINT_PATH` in `configs/model_glm_130b.sh` to the path of the extracted folder. Since the checkpoint file is up to 260G, it is recommended to use the SSD or RAM disk to reduce the checkpoint loading time. Since the checkpoint we distribute is in 8-way tensor parallel, a conversion scripts is also provided if you need to change the tensor parallel dimension.
+
+```bash
+python tools/convert_tp.py \
+    --input-folder <SRC_CKPT_PATH>  \
+    --output-folder <DST_CKPT_PATH> \
+    --target-tp <TARGET_TP>
+```
 
 ### Left-To-Right Generation / Blank Filling
 
@@ -130,7 +155,7 @@ See [Evaluate Your Own Tasks](docs/evaluate-your-own-tasks.md) for details on ho
 
 ### 2.5X faster Inference using FasterTransformer
 
-- By adapting the GLM-130B model to [FasterTransfomer](https://github.com/NVIDIA/FasterTransformer), a highly optimized transformer model library by NVIDIA, we can reach up to 2.5X speedup on generation, see [Inference with FasterTransformer](docs/inference-with-fastertransformer.md) for details.
+By adapting the GLM-130B model to [FasterTransfomer](https://github.com/NVIDIA/FasterTransformer), a highly optimized transformer model library by NVIDIA, we can reach up to 2.5X speedup on generation, see [Inference with FasterTransformer](docs/inference-with-fastertransformer.md) for details.
 
 ## What is GLM-130B
 

+ 18 - 0
configs/model_glm_130b_2080ti.sh

@@ -0,0 +1,18 @@
+MODEL_TYPE="glm-130b"
+CHECKPOINT_PATH="<your checkpoint path>"
+MP_SIZE=8
+MODEL_ARGS="--model-parallel-size ${MP_SIZE} \
+            --num-layers 70 \
+            --hidden-size 12288 \
+            --inner-hidden-size 32768 \
+            --vocab-size 150528 \
+            --num-attention-heads 96 \
+            --max-sequence-length 2048 \
+            --tokenizer-type icetk-glm-130B \
+            --layernorm-order post \
+            --quantization-bit-width 4 \
+            --load ${CHECKPOINT_PATH} \
+            --skip-init \
+            --fp16 \
+            --bminf \
+            --bminf-memory-limit 6"

BIN
docs/media/16613396005977.jpg


+ 54 - 0
docs/quantization.md

@@ -0,0 +1,54 @@
+# Quantization of GLM-130B
+
+## Usage
+
+> Please note that SwissArmyTransformer>=0.2.11 is required for quantization
+
+Set `CHECKPOINT_PATH` in `configs/model_glm_130b_{int4/int8}.sh` to your local checkpoint folder. The model will first be initialized from the FP16 checkpoint on the CPU memory, then dynamically quantized and transferred to the GPU memory. So please make sure you have enough CPU memory (>260GB) to store the FP16 model weights.
+
+You need to pay attention to the tensor parallel dimension of the model checkpoint, we only provide the checkpoint in 8-way tensor parallel, i.e. 8 GPUs store a whole model. If you need to do inference on a small number of GPUs, e.g. 4 * RTX 3090 GPUs with INT4 precision, you first need to convert the checkpoint to 4-way tensor parallel using the following command and modify `MP_SIZE` in corresponding model config file.
+
+```bash
+python tools/convert_tp.py \
+    --input-folder <SRC_CKPT_PATH>  \
+    --output-folder <DST_CKPT_PATH> \
+    --target-tp 4
+```
+
+Finally, change the model config file from `configs/model_glm_130b.sh` to `configs/model_glm_130b_{int4/int8}.sh` in your scripts (e.g. `scripts/generate.sh`), then run your scripts just as normal.
+
+## Evaluation Results
+
+|   | **MMLU(Accuracy↑)** | **LAMBADA(Accuracy↑  )** | **WikiText-2(PPL↓)** | **WikiText-103(PPL↓)** | **PTB(PPL↓)** |
+| ---- | -------- | ----------- | ------------------- | --------------------- | ------------ |
+| FP16 | 44.751   | 80.206      | 10.901              | 10.759                | 18.964       |
+| INT8 | 44.709   | 80.206      | 10.904              | 10.763                | 18.994       |
+| INT4 | 44.801   | 79.468      | 11.167              | 11.046                | 19.535       |
+
+## Benchmark
+
+> TODO: More benchmark to add (8 * V100, 8 * 3090, 4 * A100)
+
+| **Hardware** | **GPU Memory** | **Precison** | **512** | **1024** | **2048** |
+| ------------ | -------------- | ------------ | ------- | -------- | -------- |
+| 8 * A100     | 40 GB          | FP16         | 45.21 s  | 89.00 s  | 179.22 s |
+| 4 * RTX 3090 | 24 GB          | INT4         | 138.66 s | 292.69 s | 649.64 s |
+
+The above results in the table is tests with SAT. Using FasterTransformer can speed up more than 2X, as detailed in [Inference with FasterTransformer](../docs/inference-with-fastertransformer.md).
+
+
+## Details
+
+Typical methods quantize both model weights and activations to INT8, enabling the INT8 matrix multiplication kernel for efficiency. However, we found that there are outliers in GLM-130B's activations, making it hard to reduce the precision of activations. 
+
+Almost at the same time, researchers from [Meta AI](https://arxiv.org/abs/2208.07339) also found the emergent outliers issue in large-scale transformers (>6.8B), which is consistent with our observations on GLM-130B. They conducted an in-depth analysis and found that the outliers make up only about 0.1% of all feature dimensions, so it's possible to make a decomposition for matrix multiplication that focuses on high precision multiplication for these particular dimensions.
+
+| ![](media/16613396005977.jpg) | 
+|:--:| 
+| *Distribution of outliers (the white ones) in GLM-130B's activation* |
+
+Unfortunately, the outliers in GLM-130B can sometimes make up at most 30% of the feature dimension, possibly because we used GLU as a variant of FFN. Therefore, a mixed-precision decomposition for matmul can be much less efficient than a single FP16 matmul. After a few weeks of trial, we finally decided to keep the precision of activations to FP16 and only consider the quantization of model weights. In that case, the quantized model parameters are dynamically converted to FP16 precision at runtime, introducing a small computational overhead but greatly reducing GPU memory requirements for storing model weights.
+
+We quantized all linear layers as they take up most of the model parameters. All model weights, excluding input/output embedding, layernorm and bias terms are quantized using vector-wise symmetric quantization. At the quantization precision of INT4, two INT4 weights are compressed into one INT8 weight for saving GPU memory usage, so that only 70GB of GPU memory approximately is required for INT4 model weights.
+
+

+ 7 - 4
initialize.py

@@ -47,7 +47,7 @@ def initialize_model_and_tokenizer(args):
     model = GLM130B(args).half()
 
     if args.from_quantized_checkpoint:
-        assert not args.bminf and args.quantization_bit_width is not None
+        assert args.quantization_bit_width is not None
         # Quantize model before moving to GPU
         model = quantize(model, args.quantization_bit_width)
 
@@ -59,15 +59,18 @@ def initialize_model_and_tokenizer(args):
     if torch.distributed.get_rank() == 0:
         print(f"> Checkpoint loaded in {time.time() - start:.1f}s")
 
+    if args.quantization_bit_width is not None and not args.from_quantized_checkpoint:
+        # Quantize model before moving to GPU
+        model = quantize(model, args.quantization_bit_width)
+
     if args.bminf:
         import bminf
 
+        if torch.distributed.get_rank() == 0:
+            print(f"> BMInf activated, memory limit: {args.bminf_memory_limit} GB")
         with torch.cuda.device(args.device):
             model = bminf.wrapper(model, quantization=False, memory_limit=args.bminf_memory_limit << 30)
     else:
-        if args.quantization_bit_width is not None and not args.from_quantized_checkpoint:
-            # Quantize model before moving to GPU
-            model = quantize(model, args.quantization_bit_width)
         model = model.to(args.device)
 
     torch.cuda.empty_cache()

+ 4 - 4
quantization/layers.py

@@ -30,8 +30,8 @@ class QuantizedColumnParallelLinear(ColumnParallelLinear):
             if weight_bit_width == 4:
                 self.weight = compress_int4_weight(self.weight)
 
-        self.weight = Parameter(self.weight, requires_grad=False)
-        self.weight_scale = Parameter(self.weight_scale, requires_grad=False)
+        self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
+        self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False)
 
     def forward(self, input_):
         # Set up backprop all-reduce.
@@ -67,8 +67,8 @@ class QuantizedRowParallelLinear(RowParallelLinear):
             if weight_bit_width == 4:
                 self.weight = compress_int4_weight(self.weight)
 
-        self.weight = Parameter(self.weight, requires_grad=False)
-        self.weight_scale = Parameter(self.weight_scale, requires_grad=False)
+        self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
+        self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False)
 
     def forward(self, input_):
         # Set up backprop all-reduce.

+ 1 - 1
requirements.txt

@@ -1,4 +1,4 @@
-SwissArmyTransformer>=0.2.11
+SwissArmyTransformer>=0.2.12
 icetk
 apex
 scipy