2
0
Эх сурвалжийг харах

Merge pull request #22 from THUDM/quantization

Add INT8 and INT4 quantization
Aohan Zeng 3 жил өмнө
parent
commit
0bdb6d2a92

+ 30 - 5
README.md

@@ -7,7 +7,7 @@
 
 # GLM-130B: An Open Bilingual Pre-Trained Model
 
-GLM-130B is an open bilingual (English & Chinese) bidirectional dense model with 130 billion parameters, pre-trained using the algorithm of [General Language Model (GLM)](https://aclanthology.org/2022.acl-long.26). It is designed to support inference tasks with the 130B parameters on **a single A100 (40G * 8)** or **V100 (32G * 8) server**. As of July 3rd, 2022, GLM-130B has been trained on over 400 billion text tokens (200B each for Chinese and English) and it has the following unique features:
+GLM-130B is an open bilingual (English & Chinese) bidirectional dense model with 130 billion parameters, pre-trained using the algorithm of [General Language Model (GLM)](https://aclanthology.org/2022.acl-long.26). It is designed to support inference tasks with the 130B parameters on **a single A100 (40G * 8)** or **V100 (32G * 8) server**. With INT4 quantization, the  hardware requirements can further be reduced to **a single server with 4 * RTX 3090 (24G)** with **almost no performance degradation**. As of July 3rd, 2022, GLM-130B has been trained on over 400 billion text tokens (200B each for Chinese and English) and it has the following unique features:
  
 - **Bilingual:** supports both English and Chinese. 
 - **Performance (EN):** better than GPT-3 175B (+4.0%), OPT-175B (+5.5%), and BLOOM-176B (+13.0%) on LAMBADA and slightly better than GPT-3 175B (+0.9%) on MMLU.
@@ -16,17 +16,35 @@ GLM-130B is an open bilingual (English & Chinese) bidirectional dense model with
 - **Reproducibility:** all results (30+ tasks) can be easily reproduced with open-sourced code and model checkpoints.
 - **Cross-Platform:** supports training and inference on NVIDIA, Hygon DCU, Ascend 910, and Sunway (Will be released soon).
 
+## News
+
+- 🌟 **[2022.08.24]** We are proud to publish the quantized version for GLM-130B.  While preserving the activation precision as FP16, the model weights can be quantized to as low as **INT4 with almost no degradation of performance**, further reducing the hardware requirements of the GLM-130B to **a single server with 4 * RTX 3090 (24G)**! See [Quantization of GLM-130B](docs/quantization.md) for details.
+
 For smaller models, please find [monolingual GLMs](https://github.com/THUDM/GLM) (English: 10B/2B/515M/410M/335M/110M, Chinese: 10B/335M) and an [1B multilingual GLM](https://github.com/THUDM/Multilingual-GLM) (104 languages).
 
 ## Getting Started
 
 ### Environment Setup
 
+#### Hardware
+
+| **Hardware**    | **GPU Memory** | **Quantization** | **Weight Offload** |
+| --------------- | -------------- | ---------------- | ------------------ |
+| 8 * A100        | 40 GB          | No               | No                 |
+| 8 * V100        | 32 GB          | No               | Yes (BMInf)        |
+| 8 * V100        | 32 GB          | INT8             | No                 |
+| 8 * RTX 3090    | 24 GB          | INT8             | No                 |
+| 4 * RTX 3090    | 24 GB          | INT4             | No                 |
+| 8 * RTX 2080 Ti | 11 GB          | INT4             | Yes (BMInf)        |
+
+It is recommended to use the an A100 (40G * 8) server, as all GLM-130B evaluation results (~30 tasks) reported can be easily reproduced with a single A100 server in about half a day. With INT8/INT4 quantization, efficient inference on **a single server with 4 * RTX 3090 (24G)** is possible, see [Quantization of GLM-130B](docs/quantization.md) for details. Combining quantization and weight offloading techniques, GLM-130B can also be inferenced on servers with even smaller GPU memory, e.g. 8 * RTX 2080 Ti (11G), see [Low-Resource Inference](docs/low-resource-inference.md) for details.
+
+#### Software
+
 The GLM-130B code is built on the top of [SAT](https://github.com/THUDM/SwissArmyTransformer). We recommend using [Miniconda](https://docs.conda.io/en/latest/miniconda.html) to manage your environment and installing additional dependencies via `pip install -r requirements.txt`. Here are the recommended environment configurations:
 
 - Python 3.9+ / CUDA 11+ / PyTorch 1.10+ / DeepSpeed 0.6+ / Apex (**installation with CUDA and C++ extensions is required, see [here](https://github.com/NVIDIA/apex/#linux)**)
-    
-It is recommended to use the an A100 (40G * 8) server, as all GLM-130B evaluation results (~30 tasks) reported can be easily reproduced with a single A100 server in about half a day. GLM-130B can also be inferenced on servers with smaller GPU memory, such as a V100 (32G * 8) server. See [Low-Resource Inference](docs/low-resource-inference.md) for details.
+- SwissArmyTransformer>=0.2.11 is required for quantization
 
 Download the GLM-130B’s model checkpoint from [here](https://docs.google.com/forms/d/e/1FAIpQLSehr5Dh_i3TwACmFFi8QEgIVNYGmSPwV0GueIcsUev0NEfUug/viewform?usp=sf_link), make sure all 60 chunks are downloaded completely, then use the following command to merge them into a single archive file and extract it:
 
@@ -35,7 +53,14 @@ cat glm-130b-sat.tar.part_* > glm-130b-sat.tar
 tar xvf glm-130b-sat.tar
 ```
 
-Set `CHECKPOINT_PATH` in `configs/model_glm_130b.sh` to the path of the extracted folder. Since the checkpoint file is up to 260G, it is recommended to use the SSD or RAM disk to reduce the checkpoint loading time.
+Set `CHECKPOINT_PATH` in `configs/model_glm_130b.sh` to the path of the extracted folder. Since the checkpoint file is up to 260G, it is recommended to use the SSD or RAM disk to reduce the checkpoint loading time. Since the checkpoint we distribute is in 8-way tensor parallel, a conversion scripts is also provided if you need to change the tensor parallel dimension.
+
+```bash
+python tools/convert_tp.py \
+    --input-folder <SRC_CKPT_PATH>  \
+    --output-folder <DST_CKPT_PATH> \
+    --target-tp <TARGET_TP>
+```
 
 ### Left-To-Right Generation / Blank Filling
 
@@ -132,7 +157,7 @@ See [Evaluate Your Own Tasks](docs/evaluate-your-own-tasks.md) for details on ho
 
 ### 2.5X faster Inference using FasterTransformer
 
-- By adapting the GLM-130B model to [FasterTransfomer](https://github.com/NVIDIA/FasterTransformer), a highly optimized transformer model library by NVIDIA, we can reach up to 2.5X speedup on generation, see [Inference with FasterTransformer](docs/inference-with-fastertransformer.md) for details.
+By adapting the GLM-130B model to [FasterTransfomer](https://github.com/NVIDIA/FasterTransformer), a highly optimized transformer model library by NVIDIA, we can reach up to 2.5X speedup on generation, see [Inference with FasterTransformer](docs/inference-with-fastertransformer.md) for details.
 
 ## What is GLM-130B
 

+ 1 - 1
configs/model_glm_130b.sh

@@ -1,5 +1,5 @@
 MODEL_TYPE="glm-130b"
-CHECKPOINT_PATH="/thudm/workspace/hanyu/SwissArmyTransformer/data/ckpt/iter_0049300"
+CHECKPOINT_PATH="<your checkpoint path>"
 MP_SIZE=8
 MODEL_ARGS="--model-parallel-size ${MP_SIZE} \
             --num-layers 70 \

+ 18 - 0
configs/model_glm_130b_2080ti.sh

@@ -0,0 +1,18 @@
+MODEL_TYPE="glm-130b"
+CHECKPOINT_PATH="<your checkpoint path>"
+MP_SIZE=8
+MODEL_ARGS="--model-parallel-size ${MP_SIZE} \
+            --num-layers 70 \
+            --hidden-size 12288 \
+            --inner-hidden-size 32768 \
+            --vocab-size 150528 \
+            --num-attention-heads 96 \
+            --max-sequence-length 2048 \
+            --tokenizer-type icetk-glm-130B \
+            --layernorm-order post \
+            --quantization-bit-width 4 \
+            --load ${CHECKPOINT_PATH} \
+            --skip-init \
+            --fp16 \
+            --bminf \
+            --bminf-memory-limit 6"

+ 16 - 0
configs/model_glm_130b_int4.sh

@@ -0,0 +1,16 @@
+MODEL_TYPE="glm-130b"
+CHECKPOINT_PATH="<your checkpoint path>"
+MP_SIZE=4
+MODEL_ARGS="--model-parallel-size ${MP_SIZE} \
+            --num-layers 70 \
+            --hidden-size 12288 \
+            --inner-hidden-size 32768 \
+            --vocab-size 150528 \
+            --num-attention-heads 96 \
+            --max-sequence-length 2048 \
+            --tokenizer-type icetk-glm-130B \
+            --layernorm-order post \
+            --quantization-bit-width 4 \
+            --load ${CHECKPOINT_PATH} \
+            --skip-init \
+            --fp16"

+ 16 - 0
configs/model_glm_130b_int8.sh

@@ -0,0 +1,16 @@
+MODEL_TYPE="glm-130b"
+CHECKPOINT_PATH="<your checkpoint path>"
+MP_SIZE=8
+MODEL_ARGS="--model-parallel-size ${MP_SIZE} \
+            --num-layers 70 \
+            --hidden-size 12288 \
+            --inner-hidden-size 32768 \
+            --vocab-size 150528 \
+            --num-attention-heads 96 \
+            --max-sequence-length 2048 \
+            --tokenizer-type icetk-glm-130B \
+            --layernorm-order post \
+            --quantization-bit-width 8 \
+            --load ${CHECKPOINT_PATH} \
+            --skip-init \
+            --fp16"

+ 22 - 0
cuda/Makefile

@@ -0,0 +1,22 @@
+NVCC=nvcc
+OPTIONS=-gencode arch=compute_61,code=sm_61 \
+		-gencode arch=compute_62,code=sm_62 \
+		-gencode arch=compute_70,code=sm_70 \
+		-gencode arch=compute_72,code=sm_72 \
+		-gencode arch=compute_75,code=sm_75 \
+		-gencode arch=compute_80,code=sm_80 \
+		-gencode arch=compute_86,code=sm_86
+
+TARGETS=$(patsubst %.cu, %.fatbin, $(wildcard *.cu))
+
+all: $(TARGETS)
+
+%.fatbin: %.cu
+	$(NVCC) -fatbin $^ $(OPTIONS) -o $@
+
+.PHONY : clean, copy
+clean:
+	rm $(TARGETS)
+
+copy:
+	cp $(TARGETS) ../kernels/

+ 81 - 0
cuda/quantization.cu

@@ -0,0 +1,81 @@
+#include <cuda_fp16.h>
+
+template<typename T>
+__device__ void
+int4WeightExtractionDevice(const int8_t* weight,
+                                const T* scale_list,
+                                T* output,
+                                const int n,
+                                const int k)
+{
+    for(int i = blockIdx.x * k + threadIdx.x; i < blockIdx.x * k + k; i += blockDim.x){
+        int8_t original = weight[i];
+        int8_t high = original >> 4;
+        int8_t low = original << 4; low = low >> 4;
+        output[i * 2] = T(high) * scale_list[blockIdx.x];
+        output[i * 2 + 1] = T(low) * scale_list[blockIdx.x];
+    }
+}
+
+__device__ void
+int4WeightCompressionDevice(const int8_t* input,
+                                int8_t* output,
+                                const int n,
+                                const int k)
+{
+    for(int i = blockIdx.x * k + threadIdx.x; i < blockIdx.x * k + k; i += blockDim.x){
+        output[i] = (input[i * 2] << 4) | (input[i * 2 + 1] & 0b00001111);
+    }
+}
+
+template<typename T>
+__device__ void
+int8WeightExtractionDevice(const int8_t* weight,
+                                const T* scale_list,
+                                T* output,
+                                const int n,
+                                const int k)
+{
+    for(int i = blockIdx.x * k + threadIdx.x; i < blockIdx.x * k + k; i += blockDim.x){
+        output[i] = T(weight[i]) * scale_list[blockIdx.x];
+    }
+}
+
+extern "C" __global__ void int4WeightExtractionHalf(const int8_t* weight,
+                                const half* scale_list,
+                                half* output,
+                                const int n,
+                                const int k){
+                                    int4WeightExtractionDevice<half>(weight, scale_list, output, n, k);
+                                }
+
+extern "C" __global__ void int4WeightExtractionFloat(const int8_t* weight,
+                                const float* scale_list,
+                                float* output,
+                                const int n,
+                                const int k){
+                                    int4WeightExtractionDevice<float>(weight, scale_list, output, n, k);
+                                }
+
+extern "C" __global__ void int8WeightExtractionHalf(const int8_t* weight,
+                                const half* scale_list,
+                                half* output,
+                                const int n,
+                                const int k){
+                                    int8WeightExtractionDevice<half>(weight, scale_list, output, n, k);
+                                }
+
+extern "C" __global__ void int8WeightExtractionFloat(const int8_t* weight,
+                                const float* scale_list,
+                                float* output,
+                                const int n,
+                                const int k){
+                                    int8WeightExtractionDevice<float>(weight, scale_list, output, n, k);
+                                }
+
+extern "C" __global__ void int4WeightCompression(const int8_t* input,
+                                int8_t* output,
+                                const int n,
+                                const int k){
+                                    int4WeightCompressionDevice(input, output, n, k);
+                                }

BIN
docs/media/16613396005977.jpg


+ 54 - 0
docs/quantization.md

@@ -0,0 +1,54 @@
+# Quantization of GLM-130B
+
+## Usage
+
+> Please note that SwissArmyTransformer>=0.2.11 is required for quantization
+
+Set `CHECKPOINT_PATH` in `configs/model_glm_130b_{int4/int8}.sh` to your local checkpoint folder. The model will be first initialized from the FP16 checkpoint on the CPU memory, then dynamically quantized and transferred to the GPU memory. So please make sure you have enough CPU memory (>260GB) to store the FP16 model weights.
+
+You need to pay attention to the tensor parallel dimension of the model checkpoint, we only provide the checkpoint in 8-way tensor parallel, i.e. 8 GPUs store a whole model. If you need to do inference on a small number of GPUs, e.g. 4 * RTX 3090 GPUs with INT4 precision, you first need to convert the checkpoint to 4-way tensor parallel using the following command and modify `MP_SIZE` in corresponding model config file.
+
+```bash
+python tools/convert_tp.py \
+    --input-folder <SRC_CKPT_PATH>  \
+    --output-folder <DST_CKPT_PATH> \
+    --target-tp 4
+```
+
+Finally, change the model config file from `configs/model_glm_130b.sh` to `configs/model_glm_130b_{int4/int8}.sh` in your scripts (e.g. `scripts/generate.sh`), then run your scripts just as normal.
+
+## Evaluation Results
+
+|   | **MMLU(Accuracy↑)** | **LAMBADA(Accuracy↑  )** | **WikiText-2(PPL↓)** | **WikiText-103(PPL↓)** | **PTB(PPL↓)** |
+| ---- | -------- | ----------- | ------------------- | --------------------- | ------------ |
+| FP16 | 44.751   | 80.206      | 10.901              | 10.759                | 18.964       |
+| INT8 | 44.709   | 80.206      | 10.904              | 10.763                | 18.994       |
+| INT4 | 44.801   | 79.468      | 11.167              | 11.046                | 19.535       |
+
+## Space and Speed Benchmark
+
+> TODO: More benchmark to add (8 * V100, 8 * 3090, 4 * A100)
+
+| **Hardware** | **GPU Memory** | **Precison** | **512** | **1024** | **2048** |
+| ------------ | -------------- | ------------ | ------- | -------- | -------- |
+| 8 * A100     | 40 GB          | FP16         | 45.21 s  | 89.00 s  | 179.22 s |
+| 4 * RTX 3090 | 24 GB          | INT4         | 138.66 s | 292.69 s | 649.64 s |
+
+The above results in the table is tests with SAT. Using FasterTransformer can speed up more than 2X, as detailed in [Inference with FasterTransformer](../docs/inference-with-fastertransformer.md).
+
+
+## Details
+
+Typical methods quantize both model weights and activations to INT8, enabling the INT8 matrix multiplication kernel for efficiency. However, we found that there are outliers in GLM-130B's activations, making it hard to reduce the precision of activations. 
+
+Concurrently, researchers from [Meta AI](https://arxiv.org/abs/2208.07339) also found the emergent outliers issue in large-scale transformers (>6.8B), which is consistent with our observations on GLM-130B. They conducted an in-depth analysis and found that the outliers make up only about 0.1% of all feature dimensions, so it's possible to make a decomposition for matrix multiplication that focuses on high precision multiplication for these particular dimensions.
+
+| ![](media/16613396005977.jpg) | 
+|:--:| 
+| *Distribution of outliers (the white ones) in GLM-130B's activation* |
+
+Unfortunately, the outliers in GLM-130B can sometimes make up at most 30% of the feature dimension, possibly because we used GLU as a variant of FFN. Therefore, a mixed-precision decomposition for matmul can be much less efficient than a single FP16 matmul. After a few weeks of trial, we finally decided to keep the precision of activations to FP16 and only consider the quantization of model weights. In that case, the quantized model parameters are dynamically converted to FP16 precision at runtime, introducing a small computational overhead but greatly reducing GPU memory requirements for storing model weights.
+
+We quantized all linear layers as they take up most of the model parameters. All model weights, excluding input/output embedding, layernorm and bias terms are quantized using vector-wise symmetric quantization. At the quantization precision of INT4, two INT4 weights are compressed into one INT8 weight for saving GPU memory usage, so that only 70GB of GPU memory approximately is required for INT4 model weights.
+
+

+ 6 - 2
evaluation/__init__.py

@@ -1,7 +1,11 @@
 from .configs import *
 from .model import ModelForEvaluation
-from .tasks import BaseTask, GenerationTask, MultiChoiceTask
+from .tasks import BaseTask, GenerationTask, MultiChoiceTask, LanguageModelTask
 from .metrics import qa_evaluate
 from .utils import print_rank_0
 
-DEFAULT_CLASS = {TaskType.GENERATION: GenerationTask, TaskType.MULTICHOICE: MultiChoiceTask}
+DEFAULT_CLASS = {
+    TaskType.GENERATION: GenerationTask,
+    TaskType.MULTICHOICE: MultiChoiceTask,
+    TaskType.LANGUAGE_MODEL: LanguageModelTask,
+}

+ 9 - 0
evaluation/configs.py

@@ -8,6 +8,7 @@ from typing import Optional, List, Dict
 class TaskType(Enum):
     MULTICHOICE = "mul"
     GENERATION = "gen"
+    LANGUAGE_MODEL = "lm"
     OTHER = "other"
 
 
@@ -51,3 +52,11 @@ class GenerationTaskConfig(BaseConfig):
 
     def __post_init__(self):
         assert self.micro_batch_size == 1, "Only support micro batch size = 1 for generation task"
+
+
+@dataclass
+class LanguageModelTaskConfig(BaseConfig):
+    module = "evaluation.LanguageModelTask"
+    metrics: List[str] = field(default_factory=lambda: ["PPL"])
+
+    generation_length: int = 256  # Generated length in each window

+ 72 - 10
evaluation/dataset.py

@@ -1,15 +1,19 @@
 import os
+import math
 import json
 
 import numpy as np
 import torch
 
+from typing import List, Union
 from abc import ABC, abstractmethod
 from scipy.linalg import block_diag
+from itertools import accumulate
+from bisect import bisect_right
 
 from SwissArmyTransformer import get_tokenizer
 
-from .configs import BaseConfig, MultiChoiceTaskConfig, GenerationTaskConfig
+from .configs import BaseConfig, MultiChoiceTaskConfig, GenerationTaskConfig, LanguageModelTaskConfig
 from .utils import get_tokenized_input
 
 
@@ -35,21 +39,19 @@ class EvaluationDataset(torch.utils.data.Dataset, ABC):
     If [MASK] not in context, will append [MASK] after text
     """
 
-    def __init__(self, path, config: BaseConfig):
-        self.path = path
+    def __init__(self, path: Union[str, List[str]], config: BaseConfig):
+        self.path = path if isinstance(path, list) else [path]
         self.config = config
         self.max_seq_length = self.config.max_seq_length
         self.dtype = np.int64
 
-        tokenizer = get_tokenizer(tokenizer_type="icetk-glm-130B")
-        self.mask_id = tokenizer.get_command("[MASK]")
-        self.gmask_id = tokenizer.get_command("[gMASK]")
+        self.tokenizer = get_tokenizer()
+        self.mask_id = self.tokenizer.get_command("[MASK]")
+        self.gmask_id = self.tokenizer.get_command("[gMASK]")
 
         self.data = []
-        with open(os.path.join(path), "r", encoding="utf-8") as file:
-            for line in file:
-                item = json.loads(line)
-                self.data.append(self.process_single_item(item))
+        for p in self.path:
+            self.process_single_file(p)
 
     @property
     def has_collate_fn(self) -> bool:
@@ -58,6 +60,12 @@ class EvaluationDataset(torch.utils.data.Dataset, ABC):
     def collate_fn(self, samples):
         return None
 
+    def process_single_file(self, path):
+        with open(os.path.join(path), "r", encoding="utf-8") as file:
+            for line in file:
+                item = json.loads(line)
+                self.data.append(self.process_single_item(item))
+
     @abstractmethod
     def process_single_item(self, item) -> dict:
         pass
@@ -257,3 +265,57 @@ class MultiChoiceTaskDataset(EvaluationDataset):
         )
         sample["label"] = item["label"]
         return sample
+
+
+class LanguageModelTaskDataset(EvaluationDataset):
+    config: LanguageModelTaskConfig
+
+    def process_single_file(self, path):
+        with open(os.path.join(path), "r", encoding="utf-8") as file:
+            raw_text = file.read()
+            tokens = self.tokenizer.tokenize(raw_text)
+            self.data.append(
+                {
+                    "raw_text": tokens,
+                    "num_original_tokens": len(raw_text.strip().split(" ")),
+                    "num_sequences": max(
+                        math.ceil(
+                            max(len(tokens) - (self.config.max_seq_length - 1), 0) / self.config.generation_length
+                        )
+                        + 1,
+                        1,
+                    ),
+                }
+            )
+
+    def process_single_item(self, item):
+        pass
+
+    def __len__(self):
+        return self.data[0]["num_sequences"]
+
+    def __getitem__(self, idx):
+        start_idx = idx * self.config.generation_length
+        end_idx = start_idx + self.config.max_seq_length - 1  # for additional [gMASK]
+        tokens = self.data[0]["raw_text"][start_idx:end_idx]
+
+        mask_id = self.gmask_id if self.config.use_task_mask else self.mask_id
+        sop_id = self.tokenizer.get_command("sop")
+
+        if idx == 0 or self.config.unidirectional:
+            prompt, text = tokens[:1], tokens[1:]
+        else:
+            prompt_length = self.config.max_seq_length - 1 - self.config.generation_length
+            prompt, text = tokens[:prompt_length], tokens[prompt_length:]
+
+        seq_length = len(prompt) + len(text) + 1
+        attention_mask = np.tril(np.ones((seq_length, seq_length), dtype=np.int64))
+        attention_mask[: len(prompt) + 1, : len(prompt) + 1] = 1
+
+        return {
+            "tokens": np.array(prompt + [mask_id, sop_id] + text[:-1], dtype=np.int64),
+            "targets": np.array(prompt + [mask_id] + text, dtype=np.int64),
+            "position_ids": np.arange(0, seq_length, dtype=np.int64),
+            "attention_mask": attention_mask < 0.5,
+            "loss_masks": np.array([0] * (len(prompt) + 1) + [1] * len(text), dtype=np.int64),
+        }

+ 11 - 2
evaluation/metrics.py

@@ -1,7 +1,11 @@
-import string
 import re
+import math
+import string
 import functools
 
+import numpy as np
+
+from typing import Tuple, List
 from collections import Counter
 
 from SwissArmyTransformer import get_tokenizer
@@ -79,4 +83,9 @@ def qa_evaluate(predictions, examples, metric):
 qa_exact_match = functools.partial(qa_evaluate, metric=exact_match_score)
 qa_f1 = functools.partial(qa_evaluate, metric=f1_score)
 
-DEFAULT_METRICS = {"EM": qa_exact_match, "F1": qa_f1, "Accuracy": accuracy_metric}
+
+def calculate_perplexity(loss: List[float], data):
+    return math.exp(min(20, np.sum(loss) / data[0]["num_original_tokens"]))
+
+
+DEFAULT_METRICS = {"EM": qa_exact_match, "F1": qa_f1, "Accuracy": accuracy_metric, "PPL": calculate_perplexity}

+ 22 - 0
evaluation/model.py

@@ -3,6 +3,7 @@ import torch
 from typing import List, Union
 
 from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence
+from SwissArmyTransformer.mpu import vocab_parallel_cross_entropy
 
 
 class ModelForEvaluation(torch.nn.Module):
@@ -86,3 +87,24 @@ class ModelForEvaluation(torch.nn.Module):
             output_targets.append(line)
 
         return output_targets if return_all_beams else output_targets[0]
+
+    def calculate_loss(self, batch) -> List[float]:
+        tokens, position_ids, attention_mask = self.process_data(batch)
+        targets, loss_masks = (
+            batch["targets"].to(device=torch.cuda.current_device()).long(),
+            batch["loss_masks"].to(device=torch.cuda.current_device()).long(),
+        )
+
+        original_parallel_output = self.model.transformer.parallel_output
+        self.model.transformer.parallel_output = True
+        self.model.eval()
+
+        with torch.no_grad():
+            logits, *output_per_layers = self.model(tokens, position_ids, attention_mask, log_attention_weights=None)
+            losses = vocab_parallel_cross_entropy(logits.contiguous().float(), targets)
+            loss = torch.sum(losses * loss_masks, dim=-1)
+
+        self.model.transformer.parallel_output = original_parallel_output
+
+        # return list(zip(loss.tolist(), loss_masks.sum(dim=-1).tolist()))
+        return loss.tolist()

+ 16 - 2
evaluation/tasks.py

@@ -13,9 +13,9 @@ from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
 from SwissArmyTransformer.tokenization.icetk_glm_130B.ice_tokenizer import _IceTokenizer
 
 from generation import BeamSearchStrategy
-from .configs import BaseConfig, GenerationTaskConfig, MultiChoiceTaskConfig
+from .configs import BaseConfig, GenerationTaskConfig, MultiChoiceTaskConfig, LanguageModelTaskConfig
 from .model import ModelForEvaluation
-from .dataset import EvaluationDataset, GenerationTaskDataset, MultiChoiceTaskDataset
+from .dataset import EvaluationDataset, GenerationTaskDataset, MultiChoiceTaskDataset, LanguageModelTaskDataset
 from .utils import build_data_loader, gather_result, print_rank_0
 from .metrics import DEFAULT_METRICS
 
@@ -205,3 +205,17 @@ class MultiChoiceTask(BaseTask, ABC):
     def predict_single_batch(self, batch) -> List[int]:
         log_probs = self.model.cond_log_prob(batch)
         return [np.argmax(log_probs_single).item() for log_probs_single in log_probs]
+
+
+class LanguageModelTask(BaseTask, ABC):
+    config: LanguageModelTaskConfig
+
+    @classmethod
+    def config_class(cls):
+        return LanguageModelTaskConfig
+
+    def build_dataset(self, relative_path):
+        return LanguageModelTaskDataset(join(self.config.path, relative_path), self.config)
+
+    def predict_single_batch(self, batch) -> List[float]:
+        return self.model.calculate_loss(batch)

+ 32 - 6
initialize.py

@@ -2,6 +2,8 @@ import argparse
 import torch
 import time
 
+from quantization import quantize
+
 from SwissArmyTransformer import get_args, get_tokenizer
 from SwissArmyTransformer.arguments import initialize_distributed
 from SwissArmyTransformer.training import load_checkpoint
@@ -17,9 +19,17 @@ def add_bminf_args(parser):
     return parser
 
 
+def add_quantization_args(parser):
+    group = parser.add_argument_group("Quantization")
+
+    group.add_argument("--quantization-bit-width", type=int, default=None)
+    group.add_argument("--from-quantized-checkpoint", action="store_true", help="Loading from a quantized checkpoint")
+
+
 def initialize(extra_args_provider):
     parser = argparse.ArgumentParser(add_help=False)
     add_bminf_args(parser)
+    add_quantization_args(parser)
     GLM130B.add_model_specific_args(parser)
     extra_args_provider(parser)
     known, args_list = parser.parse_known_args()
@@ -33,21 +43,37 @@ def initialize(extra_args_provider):
 def initialize_model_and_tokenizer(args):
     tokenizer = get_tokenizer(args)
 
+    # Initialize model
     model = GLM130B(args).half()
-    if args.bminf:
-        import bminf
 
-        with torch.cuda.device(args.device):
-            model = bminf.wrapper(model, quantization=False, memory_limit=args.bminf_memory_limit << 30)
-    else:
-        model = model.to(args.device)
+    if args.from_quantized_checkpoint:
+        assert args.quantization_bit_width is not None
+        # Quantize model before moving to GPU
+        model = quantize(model, args.quantization_bit_width)
 
+    # Load checkpoint
     torch.distributed.barrier()
     start = time.time()
     load_checkpoint(model, args)
     torch.distributed.barrier()
     if torch.distributed.get_rank() == 0:
         print(f"> Checkpoint loaded in {time.time() - start:.1f}s")
+
+    if args.quantization_bit_width is not None and not args.from_quantized_checkpoint:
+        # Quantize model before moving to GPU
+        model = quantize(model, args.quantization_bit_width)
+
+    if args.bminf:
+        import bminf
+
+        if torch.distributed.get_rank() == 0:
+            print(f"> BMInf activated, memory limit: {args.bminf_memory_limit} GB")
+        with torch.cuda.device(args.device):
+            model = bminf.wrapper(model, quantization=False, memory_limit=args.bminf_memory_limit << 30)
+    else:
+        model = model.to(args.device)
+
+    torch.cuda.empty_cache()
     model.eval()
 
     # generate rotary embedding cache

+ 99 - 0
kernels/__init__.py

@@ -0,0 +1,99 @@
+import pkg_resources
+import torch
+import ctypes
+
+from typing import List
+from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
+
+RESOURCE_PACKAGE_NAME = __name__
+
+
+class Kernel:
+    def __init__(self, filename: str, function_names: List[str]):
+        filename = filename + ".fatbin"
+        if not pkg_resources.resource_exists(RESOURCE_PACKAGE_NAME, filename):
+            raise RuntimeError("File `%s` not found in `%s`" % (filename, RESOURCE_PACKAGE_NAME))
+        self.filename = filename
+        self.code = pkg_resources.resource_string(RESOURCE_PACKAGE_NAME, filename)
+        self._function_names = function_names
+        self._cmodule = LazyKernelCModule(self.code)
+
+        for name in self._function_names:
+            setattr(self, name, KernelFunction(self._cmodule, name))
+
+
+kernels = Kernel(
+    "quantization",
+    [
+        "int4WeightCompression",
+        "int4WeightExtractionFloat",
+        "int4WeightExtractionHalf",
+        "int8WeightExtractionFloat",
+        "int8WeightExtractionHalf",
+    ],
+)
+
+
+def compress_int4_weight(weight: torch.Tensor):  # (n, m)
+    with torch.cuda.device(weight.device):
+        n, m = weight.size(0), weight.size(1)
+        assert m % 2 == 0
+        m = m // 2
+        out = torch.empty(n, m, dtype=torch.int8, device="cuda")
+        stream = torch.cuda.current_stream()
+
+        gridDim = (n, 1, 1)
+        blockDim = (min(round_up(m, 32), 1024), 1, 1)
+
+        kernels.int4WeightCompression(
+            gridDim,
+            blockDim,
+            0,
+            stream,
+            [ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)],
+        )
+        return out
+
+
+def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int):
+    if source_bit_width == 8:
+        func = kernels.int8WeightExtractionHalf
+    elif source_bit_width == 4:
+        func = kernels.int4WeightExtractionHalf
+    else:
+        assert False, "Unsupported bit-width"
+
+    with torch.cuda.device(weight.device):
+        n, m = weight.size(0), weight.size(1)
+        out = torch.empty(n, m * (8 // source_bit_width), dtype=torch.half, device="cuda")
+        stream = torch.cuda.current_stream()
+
+        gridDim = (n, 1, 1)
+        blockDim = (min(round_up(m, 32), 1024), 1, 1)
+
+        func(
+            gridDim,
+            blockDim,
+            0,
+            stream,
+            [
+                ctypes.c_void_p(weight.data_ptr()),
+                ctypes.c_void_p(scale_list.data_ptr()),
+                ctypes.c_void_p(out.data_ptr()),
+                ctypes.c_int32(n),
+                ctypes.c_int32(m),
+            ],
+        )
+        return out
+
+
+if __name__ == "__main__":
+    weight = torch.randn(4, 32).to(torch.int8).cuda()
+    scale = torch.ones(weight.size(0)).to(torch.half).cuda()
+
+    print(weight)
+    b = compress_int4_weight(weight)
+    print(b)
+
+    a = extract_weight_to_half(b, scale, source_bit_width=4)
+    print(a)

BIN
kernels/quantization.fatbin


+ 63 - 0
quantization/__init__.py

@@ -0,0 +1,63 @@
+import torch
+
+from .layers import QuantizedColumnParallelLinear
+from .layers import QuantizedRowParallelLinear
+
+
+def quantize(model, weight_bit_width):
+    """Replace fp16 linear with quantized linear"""
+
+    if torch.distributed.get_rank() == 0:
+        print(f"> Quantizing model weight to {weight_bit_width} bits")
+
+    for layer in model.transformer.layers:
+        layer.attention.query_key_value = QuantizedColumnParallelLinear(
+            weight_bit_width=weight_bit_width,
+            weight=layer.attention.query_key_value.weight.to(torch.cuda.current_device()),
+            input_size=layer.attention.query_key_value.input_size,
+            output_size=layer.attention.query_key_value.output_size,
+            bias=True,
+            gather_output=False,
+            params_dtype=torch.half,
+            name="query_key_value",
+            skip_init=True,
+            device=layer.attention.query_key_value.weight.device,
+        )
+        layer.attention.dense = QuantizedRowParallelLinear(
+            weight_bit_width=weight_bit_width,
+            weight=layer.attention.dense.weight.to(torch.cuda.current_device()),
+            input_size=layer.attention.dense.input_size,
+            output_size=layer.attention.dense.output_size,
+            bias=True,
+            input_is_parallel=True,
+            params_dtype=torch.half,
+            name="dense",
+            skip_init=True,
+            device=layer.attention.dense.weight.device,
+        )
+        layer.mlp.dense_h_to_4h = QuantizedColumnParallelLinear(
+            weight_bit_width=weight_bit_width,
+            weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()),
+            input_size=layer.mlp.dense_h_to_4h.input_size,
+            output_size=layer.mlp.dense_h_to_4h.output_size,
+            bias=True,
+            gather_output=False,
+            params_dtype=torch.half,
+            name="dense_h_to_4h",
+            skip_init=True,
+            device=layer.mlp.dense_h_to_4h.weight.device,
+        )
+        layer.mlp.dense_4h_to_h = QuantizedRowParallelLinear(
+            weight_bit_width=weight_bit_width,
+            weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()),
+            input_size=layer.mlp.dense_4h_to_h.input_size,
+            output_size=layer.mlp.dense_4h_to_h.output_size,
+            bias=True,
+            input_is_parallel=True,
+            params_dtype=torch.half,
+            name="dense_h_to_4h",
+            skip_init=True,
+            device=layer.mlp.dense_4h_to_h.weight.device,
+        )
+
+    return model

+ 26 - 0
quantization/functional.py

@@ -0,0 +1,26 @@
+import torch
+
+from kernels import extract_weight_to_half
+
+
+class W8A16Linear(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
+        ctx.inp_shape = inp.size()
+        ctx.weight_shape = quant_w.size()
+        ctx.weight_bit_width = weight_bit_width
+        out_features = quant_w.size(0)
+        inp = inp.contiguous().view(-1, inp.size(-1))
+        weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)
+        output = inp.mm(weight.t())
+        ctx.save_for_backward(inp, quant_w, scale_w)
+        return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
+
+    @staticmethod
+    def backward(ctx, grad_output: torch.Tensor):
+        inp, quant_w, scale_w = ctx.saved_tensors
+        weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width)
+        grad_output = grad_output.contiguous().view(-1, weight.size(0))
+        grad_input = grad_output.mm(weight)
+        grad_weight = grad_output.t().mm(inp)
+        return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None

+ 87 - 0
quantization/layers.py

@@ -0,0 +1,87 @@
+import torch
+from torch.nn.parameter import Parameter
+
+from SwissArmyTransformer.mpu import copy_to_model_parallel_region
+from SwissArmyTransformer.mpu import gather_from_model_parallel_region
+from SwissArmyTransformer.mpu import reduce_from_model_parallel_region
+from SwissArmyTransformer.mpu import scatter_to_model_parallel_region
+from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
+
+from .functional import W8A16Linear
+from kernels import compress_int4_weight
+
+
+class QuantizedColumnParallelLinear(ColumnParallelLinear):
+    def __init__(self, weight_bit_width: int, weight=None, *args, **kwargs):
+        super(QuantizedColumnParallelLinear, self).__init__(*args, **kwargs)
+        self.weight_bit_width = weight_bit_width
+
+        shape = self.weight.shape
+        del self.weight
+
+        if weight is None:
+            self.weight = torch.empty(
+                shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"]
+            )
+            self.weight_scale = torch.empty(shape[0], dtype=kwargs["params_dtype"], device=kwargs["device"])
+        else:
+            self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
+            self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
+            if weight_bit_width == 4:
+                self.weight = compress_int4_weight(self.weight)
+
+        self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
+        self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False)
+
+    def forward(self, input_):
+        # Set up backprop all-reduce.
+        input_parallel = copy_to_model_parallel_region(input_)
+        # Matrix multiply.
+        output_parallel = W8A16Linear.apply(input_parallel, self.weight, self.weight_scale, self.weight_bit_width)
+        if self.bias is not None:
+            output_parallel = output_parallel + self.bias
+        if self.gather_output:
+            # All-gather across the partitions.
+            output = gather_from_model_parallel_region(output_parallel)
+        else:
+            output = output_parallel
+        return output
+
+
+class QuantizedRowParallelLinear(RowParallelLinear):
+    def __init__(self, weight_bit_width: int, weight=None, *args, **kwargs):
+        super(QuantizedRowParallelLinear, self).__init__(*args, **kwargs)
+        self.weight_bit_width = weight_bit_width
+
+        shape = self.weight.shape
+        del self.weight
+
+        if weight is None:
+            self.weight = torch.empty(
+                shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"]
+            )
+            self.weight_scale = torch.empty(shape[0], dtype=kwargs["params_dtype"], device=kwargs["device"])
+        else:
+            self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
+            self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
+            if weight_bit_width == 4:
+                self.weight = compress_int4_weight(self.weight)
+
+        self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
+        self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False)
+
+    def forward(self, input_):
+        # Set up backprop all-reduce.
+        if self.input_is_parallel:
+            input_parallel = input_
+        else:
+            input_parallel = scatter_to_model_parallel_region(input_)
+        # Matrix multiply.
+        output_parallel = W8A16Linear.apply(input_parallel, self.weight, self.weight_scale, self.weight_bit_width)
+        # All-reduce across all the partitions.
+        output_ = reduce_from_model_parallel_region(output_parallel)
+        if self.bias is not None:
+            output = output_ + self.bias
+        else:
+            output = output_
+        return output

+ 3 - 2
requirements.txt

@@ -1,5 +1,6 @@
-SwissArmyTransformer>=0.2.11
+SwissArmyTransformer>=0.2.12
 icetk
 apex
 scipy
-dataclass_wizard
+dataclass_wizard
+cpm_kernels

+ 8 - 0
tasks/language-modeling/ptb.yaml

@@ -0,0 +1,8 @@
+name: "Penn Treebank"
+type: "lm"
+path: "ptbdataset"
+file-pattern:
+  test: "**/ptb.test.txt"
+
+generation-length: 256
+use_task_mask: true

+ 8 - 0
tasks/language-modeling/wikitext-103.yaml

@@ -0,0 +1,8 @@
+name: "WikiText-103"
+type: "lm"
+path: "wikitext-103"
+file-pattern:
+  test: "**/wiki.test.tokens"
+
+generation-length: 256
+use_task_mask: true

+ 8 - 0
tasks/language-modeling/wikitext-2.yaml

@@ -0,0 +1,8 @@
+name: "WikiText-2"
+type: "lm"
+path: "wikitext-2"
+file-pattern:
+  test: "**/wiki.test.tokens"
+
+generation-length: 256
+use_task_mask: true

+ 118 - 0
tools/convert_tp.py

@@ -0,0 +1,118 @@
+import os
+import torch
+import argparse
+import glob
+
+SEQUENTIAL_LAYERS = [
+    "input_layernorm.weight",
+    "input_layernorm.bias",
+    "attention.dense.bias",
+    "post_attention_layernorm.weight",
+    "post_attention_layernorm.bias",
+    "mlp.dense_4h_to_h.bias",
+    "attention.rotary_emb.inv_freq",
+    "final_layernorm.weight",
+    "final_layernorm.bias",
+]
+
+GLU_LAYERS = [
+    "mlp.dense_h_to_4h.weight",
+    "mlp.dense_h_to_4h.bias",
+]
+
+QUANTIZED_LAYERS = [
+    "attention.dense.weight",
+    "attention.query_key_value.weight",
+    "mlp.dense_h_to_4h.weight",
+    "mlp.dense_4h_to_h.weight",
+]
+
+LAYER_CONCAT_DIM = {"attention.dense.weight": 1, "mlp.dense_4h_to_h.weight": 1}
+
+
+def parse_arguments():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--input-folder", default=None, type=str, help="Input SAT checkpoint folder")
+    parser.add_argument("--output-folder", default=None, type=str, help="Output SAT checkpoint folder")
+    parser.add_argument("--target-tp", default=4, type=int, help="Target TP degree")
+    return parser.parse_args()
+
+
+def merge_weights(key, sd_list, tp_index, original_tp, target_tp, cat_dim, is_glu):
+    if original_tp >= target_tp:
+        if is_glu:
+            if original_tp > target_tp:
+                num_part = original_tp // target_tp
+                assert len(sd_list) == num_part
+                part1, part2 = [], []
+                for i in range(len(sd_list)):
+                    chunks = torch.chunk(sd_list[i][key], 2, dim=cat_dim)
+                    part1.append(chunks[0])
+                    part2.append(chunks[1])
+                merged_sd = torch.cat(part1 + part2, dim=cat_dim)
+            else:
+                merged_sd = sd_list[0][key]
+        else:
+            merged_sd = torch.cat([sd[key] for sd in sd_list], dim=cat_dim)
+    else:
+        assert len(sd_list) == 1
+        num_part = target_tp // original_tp
+        if is_glu:
+            offset = tp_index % num_part
+            chunks = torch.chunk(sd_list[0][key], num_part * 2, dim=cat_dim)
+            merged_sd = torch.cat([chunks[offset], chunks[num_part + offset]], dim=cat_dim)
+        else:
+            # without clone, torch will save entire tensor
+            merged_sd = torch.chunk(sd_list[0][key], num_part, dim=cat_dim)[tp_index % num_part].clone()
+    return merged_sd
+
+
+def create_checkpoint(sd_list, tp_index, original_tp, target_tp):
+    new_sd = {}
+    for key in sd_list[0].keys():
+        name = ".".join(key.split(".")[3 if key.startswith("transformer.layers") else 1 :])
+        if name in SEQUENTIAL_LAYERS:
+            new_sd[key] = sd_list[0][key]
+        else:
+            new_sd[key] = merge_weights(
+                key,
+                sd_list,
+                tp_index=tp_index,
+                original_tp=original_tp,
+                target_tp=target_tp,
+                cat_dim=LAYER_CONCAT_DIM.get(name, 0),
+                is_glu=name in GLU_LAYERS,
+            )
+    new_sd = {"module": new_sd}
+    return new_sd
+
+
+def main(args):
+    iteration = open(os.path.join(args.input_folder, "latest"), "r").read()
+    original_tp = len(glob.glob(os.path.join(args.input_folder, iteration, "mp_rank_*_model_states.pt")))
+    print(f"Iteration {iteration} from {args.input_folder} to {args.output_folder}")
+    os.makedirs(args.output_folder, exist_ok=True)
+    with open(os.path.join(args.output_folder, "latest"), "w") as file:
+        file.write(str(iteration))
+    os.makedirs(os.path.join(args.output_folder, iteration), exist_ok=True)
+
+    for i in range(0, args.target_tp):
+        save_path = os.path.join(args.output_folder, iteration, f"mp_rank_{i:02}_model_states.pt")
+        print(f"Processing {save_path}")
+        num_parts = original_tp // args.target_tp
+        sd_list = [
+            torch.load(
+                os.path.join(args.input_folder, iteration, f"mp_rank_{j:02}_model_states.pt"), map_location="cpu"
+            )["module"]
+            for j in (
+                range(i * num_parts, (i + 1) * num_parts)
+                if args.target_tp <= original_tp
+                else [i // (args.target_tp // original_tp)]
+            )
+        ]
+        torch.save(create_checkpoint(sd_list, i, original_tp, args.target_tp), save_path)
+
+
+if __name__ == "__main__":
+    args = parse_arguments()
+    main(args)