Quellcode durchsuchen

Update benchmark

Sengxian vor 2 Jahren
Ursprung
Commit
9096aadd94
4 geänderte Dateien mit 63 neuen und 7 gelöschten Zeilen
  1. 20 0
      benchmark.py
  2. 10 3
      docs/quantization.md
  3. 13 4
      initialize.py
  4. 20 0
      scripts/benchmark.sh

+ 20 - 0
benchmark.py

@@ -0,0 +1,20 @@
+import torch
+import time
+from initialize import initialize, initialize_model_and_tokenizer
+
+if __name__ == "__main__":
+    args = initialize(extra_args_provider=lambda parser: None)
+    model, tokenizer = initialize_model_and_tokenizer(args)
+
+    for seq_len in [512, 1024, 2048]:
+        torch.distributed.barrier()
+        start = time.time()
+        with torch.no_grad():
+            _, *_ = model(
+                torch.ones(1, seq_len, device=torch.cuda.current_device(), dtype=torch.int64),
+                torch.arange(seq_len, device=torch.cuda.current_device(), dtype=torch.int64).view(1, -1),
+                torch.randn(1, 1, seq_len, seq_len, device=torch.cuda.current_device()) < 0.5,
+            )
+        torch.distributed.barrier()
+        if torch.distributed.get_rank() == 0:
+            print(f"Encode {seq_len}: {(time.time() - start) * 1000:.2f} ms")

+ 10 - 3
docs/quantization.md

@@ -27,8 +27,6 @@ Finally, change the model config file from `configs/model_glm_130b.sh` to `confi
 
 ## Space and Speed Benchmark
 
-> TODO: More benchmark to add
-
 | **Hardware** | **GPU Memory** | **Precison** | **512**  | **1024** | **2048** |
 | ------------ | -------------- | ------------ | -------- | -------- | -------- |
 | 8 * A100     | 40 GB          | FP16         | 45.21 s  | 89.00 s  | 179.22 s |
@@ -37,8 +35,17 @@ Finally, change the model config file from `configs/model_glm_130b.sh` to `confi
 | 8 * RTX 2080 Ti | 11 GB | INT4 | 117.39 s | 240.96 s | 528.66 s |
 
 
-The above results in the table is tests with SAT. Using FasterTransformer can speed up more than 2X, as detailed in [Inference with FasterTransformer](../docs/inference-with-fastertransformer.md).
+The above results in the table is tests with SAT. Using FasterTransformer can speed up more than 2X, as shown in the table below, and the detailed usage is shown in [Inference with FasterTransformer](../docs/inference-with-fastertransformer.md).
 
+| **Hardware**    | **GPU Memory** | **Precison** | **128** Encode / Decode | **512** Encode / Decode | **1024** Encode / Decode | **2048** Encode / Decode |
+| --------------- | -------------- | ------------ | ----------------------- | ----------------------- | ------------------------ | ------------------------ |
+| 8 * A100        | 40 GB          | INT4         | 145 ms / 4.29 s         | 183 ms / 17.7 s         | 313 ms / 37.8 s          | 495 ms / 86.0 s          |
+| 4 * A100        | 80 GB          | INT4         | 174 ms / 6.62 s         | 272 ms / 27.1 s         | 439 ms / 56.2 s          | 810 ms / 123 s           |
+| 8 * V100        | 32 GB          | INT4         | 309 ms / 6.97 s         | 666 ms / 28.1 s         | 1208 ms / 58.4 s         | 2304 ms / 125 s          |
+| 4 * V100        | 32 GB          | INT4         | 448 ms / 11.4 s         | 843 ms / 45.87 s        | 1488 ms / 93.5 s         | 2803 ms / 196 s          |
+| 8 * RTX 3090    | 24 GB          | INT4         | 283 ms / 5.07 s         | 915 ms / 20.5 s         | 1793 ms / 42.7 s         | 3477 ms / 90.3 s         |
+| 4 * RTX 3090    | 24 GB          | INT4         | 374 ms / 8.16 s         | 1300 ms / 32.3 s        | OOM / 66.5 s             | OOM / 150 s              |
+| 8 * RTX 2080 Ti | 11 GB          | INT4         | 392 ms / 6.77 s         | 1044 ms / 27.29 s       | OOM / 56.02 s            | OOM / OOM                |
 
 ## Details
 

+ 13 - 4
initialize.py

@@ -77,13 +77,22 @@ def initialize_model_and_tokenizer(args):
     model.eval()
 
     # generate rotary embedding cache
+    original_parallel_output = model.transformer.parallel_output
+    model.transformer.parallel_output = True
     with torch.no_grad():
         _, *_ = model(
-            torch.ones(1, 1, device=torch.cuda.current_device(), dtype=torch.int64),
-            torch.ones(1, 1, device=torch.cuda.current_device(), dtype=torch.int64) * args.max_sequence_length,
-            torch.ones(1, 1, 1, 1, device=torch.cuda.current_device(), dtype=torch.bool),
+            torch.ones(1, args.max_sequence_length, device=torch.cuda.current_device(), dtype=torch.int64),
+            torch.arange(args.max_sequence_length, device=torch.cuda.current_device(), dtype=torch.int64).view(1, -1),
+            torch.randn(
+                1,
+                1,
+                args.max_sequence_length,
+                args.max_sequence_length,
+                device=torch.cuda.current_device(),
+            )
+            < 0.5,
         )
-
+    model.transformer.parallel_output = original_parallel_output
     torch.distributed.barrier()
 
     return model, tokenizer

+ 20 - 0
scripts/benchmark.sh

@@ -0,0 +1,20 @@
+#!/bin/bash
+
+script_path=$(realpath $0)
+script_dir=$(dirname $script_path)
+main_dir=$(dirname $script_dir)
+
+source "${main_dir}/configs/model_glm_130b.sh"
+
+ARGS="${main_dir}/benchmark.py \
+       --mode inference \
+       $MODEL_ARGS"
+
+TIMESTAMP=$(date +'%Y.%m.%d-%H:%M:%S')
+EXP_NAME=${TIMESTAMP}
+
+mkdir -p logs
+
+run_cmd="torchrun --nproc_per_node $MP_SIZE ${ARGS}"
+echo $run_cmd
+eval ${run_cmd} 2>&1 | tee logs/${EXP_NAME}.log