Sengxian 3 lat temu
commit
737be7c740
68 zmienionych plików z 2574 dodań i 0 usunięć
  1. 5 0
      .gitignore
  2. 201 0
      LICENSE
  3. 33 0
      MODEL_LICENSE
  4. 53 0
      README.md
  5. 55 0
      README_zh.md
  6. 15 0
      configs/model_glm_130b.sh
  7. 17 0
      configs/model_glm_130b_v100.sh
  8. 86 0
      docs/evaluate-your-own-tasks.md
  9. 103 0
      docs/inference-with-fastertransformer.md
  10. 28 0
      docs/low-resource-inference.md
  11. 67 0
      evaluate.py
  12. 7 0
      evaluation/__init__.py
  13. 53 0
      evaluation/configs.py
  14. 259 0
      evaluation/dataset.py
  15. 82 0
      evaluation/metrics.py
  16. 88 0
      evaluation/model.py
  17. 207 0
      evaluation/tasks.py
  18. 67 0
      evaluation/utils.py
  19. 210 0
      generate.py
  20. 1 0
      generation/__init__.py
  21. 132 0
      generation/strategies.py
  22. 63 0
      initialize.py
  23. 5 0
      logs/README.md
  24. 251 0
      logs/main-log.md
  25. 5 0
      requirements.txt
  26. BIN
      resources/03DF31017FE184DB45D41DFFC6F80EF0.png
  27. BIN
      resources/33872E48D3539EA132B74BCF5EFF458F.png
  28. BIN
      resources/49BF334CB352BAA19F7D55460B1DBCA9.gif
  29. BIN
      resources/7CB441707D1035B2890AA2164C5B6EAC.png
  30. BIN
      resources/7D6433A42D189E2E6FBC62BE066BCE91.png
  31. BIN
      resources/849024E93FA85347F7F6443932911922.png
  32. BIN
      resources/AE18F14396E2D22BC0BC8DD77EFD3414.png
  33. BIN
      resources/E42321373D22DE198231279B5856BB42.png
  34. BIN
      resources/F48B69263360688CCA21E915F4B1A98B.png
  35. 70 0
      resources/multitask_list.txt
  36. 23 0
      scripts/evaluate.sh
  37. 28 0
      scripts/evaluate_multiple_node.sh
  38. 38 0
      scripts/generate.sh
  39. 6 0
      tasks/bloom/glue_cola.yaml
  40. 7 0
      tasks/bloom/glue_mnli.yaml
  41. 6 0
      tasks/bloom/glue_qnli.yaml
  42. 6 0
      tasks/bloom/glue_wnli.yaml
  43. 7 0
      tasks/bloom/math_qa.yaml
  44. 6 0
      tasks/bloom/mc_taco.yaml
  45. 7 0
      tasks/bloom/openbook_qa.yaml
  46. 6 0
      tasks/bloom/pubmed_qa.yaml
  47. 6 0
      tasks/bloom/superglue_axb.yaml
  48. 6 0
      tasks/bloom/superglue_axg.yaml
  49. 4 0
      tasks/chinese/clue/afqmc.yaml
  50. 4 0
      tasks/chinese/clue/c3.yaml
  51. 4 0
      tasks/chinese/clue/cluewsc.yaml
  52. 4 0
      tasks/chinese/clue/cmnli.yaml
  53. 3 0
      tasks/chinese/clue/cmrc2018.yaml
  54. 4 0
      tasks/chinese/clue/csl.yaml
  55. 3 0
      tasks/chinese/clue/drcd.yaml
  56. 4 0
      tasks/chinese/clue/ocnli.yaml
  57. 7 0
      tasks/chinese/fewclue/bustm.yaml
  58. 7 0
      tasks/chinese/fewclue/chidf.yaml
  59. 7 0
      tasks/chinese/fewclue/cluewscf.yaml
  60. 7 0
      tasks/chinese/fewclue/cslf.yaml
  61. 7 0
      tasks/chinese/fewclue/eprstmt.yaml
  62. 7 0
      tasks/chinese/fewclue/ocnlif.yaml
  63. 13 0
      tasks/lambada/lambada-unidirectional.yaml
  64. 12 0
      tasks/lambada/lambada.yaml
  65. 20 0
      tasks/lambada/strategy.py
  66. 54 0
      tasks/lambada/task.py
  67. 10 0
      tasks/mmlu/mmlu.yaml
  68. 78 0
      tasks/mmlu/task.py

+ 5 - 0
.gitignore

@@ -0,0 +1,5 @@
+data
+__pycache__
+samples
+.DS_Store
+.idea

+ 201 - 0
LICENSE

@@ -0,0 +1,201 @@
+                                 Apache License
+                           Version 2.0, January 2004
+                        http://www.apache.org/licenses/
+
+   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+   1. Definitions.
+
+      "License" shall mean the terms and conditions for use, reproduction,
+      and distribution as defined by Sections 1 through 9 of this document.
+
+      "Licensor" shall mean the copyright owner or entity authorized by
+      the copyright owner that is granting the License.
+
+      "Legal Entity" shall mean the union of the acting entity and all
+      other entities that control, are controlled by, or are under common
+      control with that entity. For the purposes of this definition,
+      "control" means (i) the power, direct or indirect, to cause the
+      direction or management of such entity, whether by contract or
+      otherwise, or (ii) ownership of fifty percent (50%) or more of the
+      outstanding shares, or (iii) beneficial ownership of such entity.
+
+      "You" (or "Your") shall mean an individual or Legal Entity
+      exercising permissions granted by this License.
+
+      "Source" form shall mean the preferred form for making modifications,
+      including but not limited to software source code, documentation
+      source, and configuration files.
+
+      "Object" form shall mean any form resulting from mechanical
+      transformation or translation of a Source form, including but
+      not limited to compiled object code, generated documentation,
+      and conversions to other media types.
+
+      "Work" shall mean the work of authorship, whether in Source or
+      Object form, made available under the License, as indicated by a
+      copyright notice that is included in or attached to the work
+      (an example is provided in the Appendix below).
+
+      "Derivative Works" shall mean any work, whether in Source or Object
+      form, that is based on (or derived from) the Work and for which the
+      editorial revisions, annotations, elaborations, or other modifications
+      represent, as a whole, an original work of authorship. For the purposes
+      of this License, Derivative Works shall not include works that remain
+      separable from, or merely link (or bind by name) to the interfaces of,
+      the Work and Derivative Works thereof.
+
+      "Contribution" shall mean any work of authorship, including
+      the original version of the Work and any modifications or additions
+      to that Work or Derivative Works thereof, that is intentionally
+      submitted to Licensor for inclusion in the Work by the copyright owner
+      or by an individual or Legal Entity authorized to submit on behalf of
+      the copyright owner. For the purposes of this definition, "submitted"
+      means any form of electronic, verbal, or written communication sent
+      to the Licensor or its representatives, including but not limited to
+      communication on electronic mailing lists, source code control systems,
+      and issue tracking systems that are managed by, or on behalf of, the
+      Licensor for the purpose of discussing and improving the Work, but
+      excluding communication that is conspicuously marked or otherwise
+      designated in writing by the copyright owner as "Not a Contribution."
+
+      "Contributor" shall mean Licensor and any individual or Legal Entity
+      on behalf of whom a Contribution has been received by Licensor and
+      subsequently incorporated within the Work.
+
+   2. Grant of Copyright License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      copyright license to reproduce, prepare Derivative Works of,
+      publicly display, publicly perform, sublicense, and distribute the
+      Work and such Derivative Works in Source or Object form.
+
+   3. Grant of Patent License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      (except as stated in this section) patent license to make, have made,
+      use, offer to sell, sell, import, and otherwise transfer the Work,
+      where such license applies only to those patent claims licensable
+      by such Contributor that are necessarily infringed by their
+      Contribution(s) alone or by combination of their Contribution(s)
+      with the Work to which such Contribution(s) was submitted. If You
+      institute patent litigation against any entity (including a
+      cross-claim or counterclaim in a lawsuit) alleging that the Work
+      or a Contribution incorporated within the Work constitutes direct
+      or contributory patent infringement, then any patent licenses
+      granted to You under this License for that Work shall terminate
+      as of the date such litigation is filed.
+
+   4. Redistribution. You may reproduce and distribute copies of the
+      Work or Derivative Works thereof in any medium, with or without
+      modifications, and in Source or Object form, provided that You
+      meet the following conditions:
+
+      (a) You must give any other recipients of the Work or
+          Derivative Works a copy of this License; and
+
+      (b) You must cause any modified files to carry prominent notices
+          stating that You changed the files; and
+
+      (c) You must retain, in the Source form of any Derivative Works
+          that You distribute, all copyright, patent, trademark, and
+          attribution notices from the Source form of the Work,
+          excluding those notices that do not pertain to any part of
+          the Derivative Works; and
+
+      (d) If the Work includes a "NOTICE" text file as part of its
+          distribution, then any Derivative Works that You distribute must
+          include a readable copy of the attribution notices contained
+          within such NOTICE file, excluding those notices that do not
+          pertain to any part of the Derivative Works, in at least one
+          of the following places: within a NOTICE text file distributed
+          as part of the Derivative Works; within the Source form or
+          documentation, if provided along with the Derivative Works; or,
+          within a display generated by the Derivative Works, if and
+          wherever such third-party notices normally appear. The contents
+          of the NOTICE file are for informational purposes only and
+          do not modify the License. You may add Your own attribution
+          notices within Derivative Works that You distribute, alongside
+          or as an addendum to the NOTICE text from the Work, provided
+          that such additional attribution notices cannot be construed
+          as modifying the License.
+
+      You may add Your own copyright statement to Your modifications and
+      may provide additional or different license terms and conditions
+      for use, reproduction, or distribution of Your modifications, or
+      for any such Derivative Works as a whole, provided Your use,
+      reproduction, and distribution of the Work otherwise complies with
+      the conditions stated in this License.
+
+   5. Submission of Contributions. Unless You explicitly state otherwise,
+      any Contribution intentionally submitted for inclusion in the Work
+      by You to the Licensor shall be under the terms and conditions of
+      this License, without any additional terms or conditions.
+      Notwithstanding the above, nothing herein shall supersede or modify
+      the terms of any separate license agreement you may have executed
+      with Licensor regarding such Contributions.
+
+   6. Trademarks. This License does not grant permission to use the trade
+      names, trademarks, service marks, or product names of the Licensor,
+      except as required for reasonable and customary use in describing the
+      origin of the Work and reproducing the content of the NOTICE file.
+
+   7. Disclaimer of Warranty. Unless required by applicable law or
+      agreed to in writing, Licensor provides the Work (and each
+      Contributor provides its Contributions) on an "AS IS" BASIS,
+      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+      implied, including, without limitation, any warranties or conditions
+      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+      PARTICULAR PURPOSE. You are solely responsible for determining the
+      appropriateness of using or redistributing the Work and assume any
+      risks associated with Your exercise of permissions under this License.
+
+   8. Limitation of Liability. In no event and under no legal theory,
+      whether in tort (including negligence), contract, or otherwise,
+      unless required by applicable law (such as deliberate and grossly
+      negligent acts) or agreed to in writing, shall any Contributor be
+      liable to You for damages, including any direct, indirect, special,
+      incidental, or consequential damages of any character arising as a
+      result of this License or out of the use or inability to use the
+      Work (including but not limited to damages for loss of goodwill,
+      work stoppage, computer failure or malfunction, or any and all
+      other commercial damages or losses), even if such Contributor
+      has been advised of the possibility of such damages.
+
+   9. Accepting Warranty or Additional Liability. While redistributing
+      the Work or Derivative Works thereof, You may choose to offer,
+      and charge a fee for, acceptance of support, warranty, indemnity,
+      or other liability obligations and/or rights consistent with this
+      License. However, in accepting such obligations, You may act only
+      on Your own behalf and on Your sole responsibility, not on behalf
+      of any other Contributor, and only if You agree to indemnify,
+      defend, and hold each Contributor harmless for any liability
+      incurred by, or claims asserted against, such Contributor by reason
+      of your accepting any such warranty or additional liability.
+
+   END OF TERMS AND CONDITIONS
+
+   APPENDIX: How to apply the Apache License to your work.
+
+      To apply the Apache License to your work, attach the following
+      boilerplate notice, with the fields enclosed by brackets "[]"
+      replaced with your own identifying information. (Don't include
+      the brackets!)  The text should be enclosed in the appropriate
+      comment syntax for the file format. We also recommend that a
+      file or class name and description of purpose be included on the
+      same "printed page" as the copyright notice for easier
+      identification within third-party archives.
+
+   Copyright Aohan Zeng
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.

+ 33 - 0
MODEL_LICENSE

@@ -0,0 +1,33 @@
+The GLM-130B License
+
+1. Definitions
+
+“Licensor” means the GLM-130B Model Team that distributes its Software.
+
+“Software” means the GLM-130B model parameters made available under this license.
+
+2. License Grant
+
+Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software solely for your non-commercial research purposes.
+
+The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
+
+3. Restriction
+
+You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any commercial, military, or illegal purposes.
+
+You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings.
+
+4. Disclaimer
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+5. Limitation of Liability
+
+EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
+
+6. Dispute Resolution
+
+This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing.
+
+Note that the license is subject to update to a more comprehensive version.  For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com.

Plik diff jest za duży
+ 53 - 0
README.md


Plik diff jest za duży
+ 55 - 0
README_zh.md


+ 15 - 0
configs/model_glm_130b.sh

@@ -0,0 +1,15 @@
+MODEL_TYPE="glm-130b"
+CHECKPOINT_PATH="/thudm/workspace/hanyu/SwissArmyTransformer/data/ckpt/iter_0049300"
+MP_SIZE=8
+MODEL_ARGS="--model-parallel-size ${MP_SIZE} \
+            --num-layers 70 \
+            --hidden-size 12288 \
+            --inner-hidden-size 32768 \
+            --vocab-size 150528 \
+            --num-attention-heads 96 \
+            --max-sequence-length 2048 \
+            --tokenizer-type icetk-glm-130B \
+            --layernorm-order post \
+            --load ${CHECKPOINT_PATH} \
+            --skip-init \
+            --fp16"

+ 17 - 0
configs/model_glm_130b_v100.sh

@@ -0,0 +1,17 @@
+MODEL_TYPE="glm-130b"
+CHECKPOINT_PATH="<your checkpoint path>"
+MP_SIZE=8
+MODEL_ARGS="--model-parallel-size ${MP_SIZE} \
+            --num-layers 70 \
+            --hidden-size 12288 \
+            --inner-hidden-size 32768 \
+            --vocab-size 150528 \
+            --num-attention-heads 96 \
+            --max-sequence-length 2048 \
+            --tokenizer-type icetk-glm-130B \
+            --layernorm-order post \
+            --load ${CHECKPOINT_PATH} \
+            --skip-init \
+            --fp16 \
+            --bminf \
+            --bminf-memory-limit 25"

+ 86 - 0
docs/evaluate-your-own-tasks.md

@@ -0,0 +1,86 @@
+# Evaluate Your Own Tasks
+
+## YAML file for tasks
+
+We use the YAML file to define tasks, this allows us to easily evaluate multiple tasks at a single run and configure them independently. Specifically, you can add multiple tasks or folders  at a time for evaluation, and the script will automatically collect all YAML files under those folders recursively.
+
+```
+# Single node
+bash scripts/evaluate.sh task1.yaml task2.yaml dir1 dir2 ...
+# Multi node
+bash scripts/evaluate_multiple_node.sh task1.yaml task2.yaml dir1 dir2 ...
+```
+
+We support two types of evaluation tasks: multi-choice and generation. The YAML config options for both tasks are defined in `evaluation/configs.py`. Basically, all types of tasks share common configs defining task information:
+
+```yaml
+name: 'glue_cola'  # Task Name
+type: 'mul'  # Task type, 'gen' (generate) or 'mul' (multiple choice)
+path: 'bloom/glue_cola'  # task data path relative to DATA_PATH in 'evaluate.sh'
+use_task_mask: False # Whether use [gMASK] for evaluation
+unidirectional: False # Whether use unidirectional attention
+max_seq_length: 2048  # Max sequence length
+file-pattern: # Organize jsonl file in groups
+  validation: "**/validation.jsonl" # Will search for all file named 'validation.jsonl' in `DATA_PATH/bloom/glue_cola` using glob.glob()
+micro-batch-size: 30 # 'gen' task only support mbs = 1 for now
+```
+
+See configuration details for multi-choice and generation tasks in `evaluation/configs.py`.
+
+## Data format for tasks
+
+We recommend organizing the task data in the following structure and setup up two groups named "validation" and "test" in the `file-pattern` config so that it becomes very easy to evaluate different prompts on both validation and test sets independently.
+
+```bash
+DATA_PATH
+└── task_name
+    ├── prompt_1
+    │   ├── test.jsonl
+    │   └── val.jsonl
+    ├── prompt_2
+    │   ├── test.jsonl
+    │   └── val.jsonl
+    └── prompt_3
+        ├── test.jsonl
+        └── val.jsonl
+```
+
+The evaluation data for each prompt are organized into jsonline format. For multi-choice tasks, the format of each line of JSON should be
+
+```json
+{
+    "inputs_pretokenized": "Context and question here",
+    "choices_pretokenized": ["Choice 1", "Choice 2", "Choice 3"],
+    "label": int
+}
+```
+
+The default metric for the multi-choice task is Accuracy.
+
+For the generation task, the format of each line of JSON should be
+
+```json
+{
+    "inputs_pretokenized": "Context and question here",
+    "targets_pretokenized": ["Target 1", "Target 2", "Target 3"],
+    "label": int
+}
+```
+
+The default metrics for the generation task are EM(Exact-Match) and F1. Given inputs, the sequence generated by the model will be metricized separately from all targets and the highest value will be taken.
+
+
+## Implement Your Metrics
+
+You can customize your evaluation metrics function and add it to `DEFAULT_METRICS` in `generation/metrics.py`, and then you can specify `metric: ['Your metric name']` in the task YAML file.
+
+## Fully customize the evaluation process
+
+By default, we implement classes named `MultiChoiceTask` and `GenerationTask` in `evaluation/tasks.py` for multi-choice tasks and generation tasks, respectively. 
+
+You can implement a new task class and inherit from one of these two classes, and implement the `process_single_batch` function to define how to process a batch of inputs and get the predictions. Following [Big-Bench](https://github.com/google/BIG-bench/#creating-the-task), we implemented two methods you can use for your evaluation:
+
+- `model.cond_log_prob()`: Compute the probabilities of provided model outputs for given inputs.
+- `model.generate_text()`: Generate text for given inputs.
+
+Once you have created the new task class, you need to specify the relative path to import the class in the `module` field of the task YAML file.  See `tasks/lambada/tasks.py` and `tasks/lambada/lambada.yaml` for how we customize the beam search generation strategy for LAMBADA tasks and configure the YAML file.

+ 103 - 0
docs/inference-with-fastertransformer.md

@@ -0,0 +1,103 @@
+# Inference with FasterTransformer
+
+[FasterTransformer](https://github.com/NVIDIA/FasterTransformer) provides a script and recipe to run the highly optimized transformer-based encoder and decoder component, and it is tested and maintained by NVIDIA.
+
+We adapted the GLM-130B based on Fastertransformer for fast inference, with details in [benchmark](#benchmark) section.
+
+## Setup
+
+### Requirements
+
+- CMake >= 3.13 for PyTorch
+- CUDA 11.0 or newer version
+- NCCL 2.10 or newer version
+- Python 3 is recommended because some features are not supported in python 2
+- PyTorch: Verify on 1.11.0, >= 1.8.0 should work.
+
+All the packages can be installed using conda.
+
+```bash
+conda install -y cmake numpy pybind11 pytorch torchvision cudatoolkit-dev cudnn
+cp -r $CONDA_PREFIX/lib/libcudnn* /usr/local/cuda/lib64/
+cp -r $CONDA_PREFIX/include/cudnn*.h /usr/local/cuda/include/
+```
+
+GLM-130B is trained with FP16 precision, a total of 260G of GPU memory is required to store model weights. The model is tested with 8 * 40G A100s.
+
+### Build
+
+Get the code and install all dependencies:
+
+```bash
+git clone https://github.com/THUDM/FasterTransformer.git
+mkdir -p FasterTransformer/build
+cd FasterTransformer/build
+git submodule init && git submodule update
+pip3 install fire jax jaxlib icetk
+```
+
+Note: the `xx` of `-DSM=xx` in following scripts means the compute capability of your GPU. For example, 60 (P40) or 61 (P4) or 70 (V100) or 75(T4) or 80 (A100).  Default setting is including 70, 75, 80 and 86.
+
+```bash
+cmake -DSM=80 -DCMAKE_BUILD_TYPE=Release -DBUILD_PYT=ON -DBUILD_MULTI_GPU=ON ..
+make -j
+```
+
+### Download the Model
+
+See [Get Model](/README.md#environment-setup).
+
+The original checkpoint compatible with [SAT](https://github.com/THUDM/SwissArmyTransformer), but each time the model is initialized it needs to be extracted, which costs time. So we provide a script `FasterTransformer/examples/pytorch/glm/utils/glm_ckpt_convert.py` to extract the downloaded checkpoint.
+
+For example:
+
+```bash
+# convert SAT checkpoint to FT checkpoint
+python3 ../examples/pytorch/glm/utils/glm_ckpt_convert.py -i global_step20000/iter_0020000 -o ft_output -i_g 8
+```
+
+### Run GLM-130B
+
+Generate the `gemm_config.in` file.
+
+```bash
+# ./bin/gpt_gemm <batch_size> <beam_width> <max_input_len> <head_number> <size_per_head> <inter_size> <vocab_size> <data_type> <tensor_para_size>
+./bin/gpt_gemm 1 1 128 96 128 49152 150528 1 8
+```
+
+Running GLM_130B in Pytorch.
+
+```bash
+bash ../examples/pytorch/glm/benchmark-generation.sh
+```
+
+You need to check and edit this file to set arguments such as the checkpoint's load path.
+
+When running GLM_130B, pay special attention to the following arguments:
+
+1. `--sat-ckpt-dir` is the path to the original downloaded checkpoint, compatible with SwissArmyTransformer.
+2. `--ft-ckpt-dir` is the path to the extracted checkpoint. It is faster to load, but you have to run `examples/pytorch/glm/utils/glm_ckpt_convert.py` to convert the downloaded checkpoint.
+3. `--n-inference-gpus` number of GPUs used for inference, defaults to 8. The binary model parameters are saved to `${output-dir}/${n-inference-gpus}-gpu/`
+4. `--sample-input-file` everyline is a batch, you can set `max_batch_size` to get multiple generations at one time, however, you need to ensure that all inputs are of the same length after being converted to tokens, otherwise only the longest sentence will get a right output.
+
+## Optimization methods
+
+Optimization in GLM_130B are similar to optimization in GPT and GPT-J, describing in the [FasterTransformer/gpt_guide.md](https://github.com/NVIDIA/FasterTransformer/blob/main/docs/gpt_guide.md). Meanwhile, some of the operators are differ from GPT, such as the implementation of RotaryEmbedding, and the use of GeGLU, so we add them additionally into FasterTransformer.
+
+## Benchmark
+
+- Hardware: DGX-A100(8 * 40G)
+
+## Encode
+
+| **Sequence Len**   | 512    | 1024   | 2048   |
+| ---------- | ------ | ------ | ------ |
+| Megatron   | 145 ms | 250 ms | 453 ms |
+| FasterTransformer | 120 ms | 220 ms | OOM  |
+
+## Decode
+
+| **Sequence Len**  | 512     | 1024    | 2048     |
+| ---------- | ------- | ------- | -------- |
+| Megatron   | 45.21 s | 89.00 s | 179.22 s |
+| FasterTransformer | 18.77 s | 39.81 s | 89.88 s  |

+ 28 - 0
docs/low-resource-inference.md

@@ -0,0 +1,28 @@
+# Low-resource Inference with BMInf
+
+GLM-130B is trained with 4-way tensor parallel and 8-way pipeline parallel for efficiency. Then the checkpoint is converted into a 8-way tensor parallel one in order to inference the model in a single node. GLM-130B has 130 billion parameters in FP16 precision, a total of 260G of GPU memory is required to store model weights. The DGX-A100 server has 8 A100s and provides an amount of 320G of GPU memory (640G for 80G A100 version)  so it suits GLM-130B well. 
+
+However, a server with 8 * 32G V100 only provides an amount of 256G of GPU memory, which indicates that the full loading of model weights is not possible. Fortunately, with the swap-in-and-out feature between CPU and GPU memory provided by the [BMInf](https://github.com/OpenBMB/BMInf) library, GLM-130B can still run on servers with a smaller amount of GPU memory. After joint debugging with the BMInf team, we achieved a resonable evaluation efficiency on DGX-1 servers with 8 * 32G V100 by carefully overlapping computation and communication, see the [benchmark section](#benchmark) for details.
+
+We have integrated BMInf into our codebase, just install BMInf via `pip install bminf`, and change the model configuration file from `configs/model_glm_130b.sh` to `configs/model_glm_130b_v100.sh` in your launch shell script. The default BMInf config is for V100 servers, you can also adjust the maximum memory the model weights can occupy on one GPU by setting `--bminf-memory-limit` according to your GPU memory in the model config file.
+
+## Benchmark
+
+### Evaluation
+
+- CoLA task on the validation set
+- Micro Batch Size = 30
+- BMInf: 25GB model weights in GPU memory limit by: `--bminf-memory-limit 25`
+
+|                | Peak GPU Memory | Time   |
+| -------------- | ---------- | ------ |
+| A100-SAT       | 40.3 G     | 74.6 s |
+| V100-SAT       | OOM        | OOM    |
+| V100-SAT-BMInf | 32.3 G     | 196.0 s |
+
+The `micro-batch-size` config in task YAML files is configured according to the maximum utilization of the DGX-A100 server. If you encounter an OOM error on the V100 server, please adjust the `micro-batch-size` appropriately.
+
+### Text generation
+
+In text generation, due to the small amount of calculation per model forward (usually <10 tokens/forward using beam search strategy), the communication between the CPU and GPU memory becomes the bottleneck. With the help of the BMInf team, we did an in-depth profile on our V100 server. Given a 25GB model weight limit per GPU, a total of 13 layers need to be copied from CPU to GPU for a single forward, each layer will take about 75ms on IO, indicating that the real IO speed between CPU and GPU is `260GB / 70 / 8 / 75ms = 6.19GB/s`. Our V100 server uses PCI-E 3.0 and two V100s share a switch, so the theoretical bandwidth for each GPU is 8GB/s, close to our profiling results. A server with PCI-E 4.0 will greatly reduce the IO time. Even that, long text generation tokens can still take several minutes so **we do not recommend using V100 servers in text generation scenario**. For this, we are working on INT8 quantization so that GLM-130B can even fit a single RTX-3090 server (24G * 8).
+

+ 67 - 0
evaluate.py

@@ -0,0 +1,67 @@
+import time
+import importlib
+
+from os.path import join, isdir, isfile, relpath
+from glob import glob
+
+from evaluation import BaseConfig, ModelForEvaluation, DEFAULT_CLASS, print_rank_0
+from initialize import initialize, initialize_model_and_tokenizer
+
+
+def add_evaluation_specific_args(parser):
+    """Arguments for evaluation"""
+    group = parser.add_argument_group("evaluation", "Evaluation configurations")
+
+    # Task
+    group.add_argument("--task", nargs="+", default=[], help="All task config to evaluation")
+    group.add_argument("--data-path", type=str, required=True, help="Data dir path for all tasks")
+    return parser
+
+
+def find_all_tasks(all_task_config_path):
+    tasks = []
+    for task in all_task_config_path:
+        if isdir(task):
+            tasks += [relpath(path, ".") for path in glob(join(task, "**/*.yaml"), recursive=True)]
+        elif isfile(task):
+            tasks.append(task)
+    return tasks
+
+
+def evaluate_all_tasks(data_path, model, tokenizer, all_task_config_path, task_classes):
+    for config_path, task_class in zip(all_task_config_path, task_classes):
+        config = task_class.config_class().from_yaml_file(config_path)
+        config.path = join(data_path, config.path)
+        task = task_class(model, tokenizer, config)
+        task.evaluate()
+
+
+def main():
+    args = initialize(extra_args_provider=add_evaluation_specific_args)
+    args.task = find_all_tasks(args.task)
+
+    task_classes = []
+    print_rank_0("> Loading task configs")
+    for task_config_path in args.task:
+        config = BaseConfig.from_yaml_file(task_config_path)
+        if config.module:
+            path = ".".join(config.module.split(".")[:-1])
+            module = importlib.import_module(path)
+            class_name = config.module.split(".")[-1]
+            task_class = getattr(module, class_name)
+            task_classes.append(task_class)
+        else:
+            task_classes.append(DEFAULT_CLASS[config.type])
+        print_rank_0(f"    Task {config.name} loaded from config {task_config_path}")
+    print_rank_0(f"> Successfully load {len(task_classes)} task{'s' if len(task_classes) > 1 else ''}")
+
+    model, tokenizer = initialize_model_and_tokenizer(args)
+    model = ModelForEvaluation(model)
+
+    start = time.time()
+    evaluate_all_tasks(args.data_path, model, tokenizer, args.task, task_classes)
+    print_rank_0(f"Finish {len(task_classes)} task{'s' if len(task_classes) > 1 else ''} in {time.time() - start:.1f}s")
+
+
+if __name__ == "__main__":
+    main()

+ 7 - 0
evaluation/__init__.py

@@ -0,0 +1,7 @@
+from .configs import *
+from .model import ModelForEvaluation
+from .tasks import BaseTask, GenerationTask, MultiChoiceTask
+from .metrics import qa_evaluate
+from .utils import print_rank_0
+
+DEFAULT_CLASS = {TaskType.GENERATION: GenerationTask, TaskType.MULTICHOICE: MultiChoiceTask}

+ 53 - 0
evaluation/configs.py

@@ -0,0 +1,53 @@
+from __future__ import annotations
+from dataclass_wizard import YAMLWizard
+from dataclasses import dataclass, field
+from enum import Enum
+from typing import Optional, List, Dict
+
+
+class TaskType(Enum):
+    MULTICHOICE = "mul"
+    GENERATION = "gen"
+    OTHER = "other"
+
+
+@dataclass
+class BaseConfig(YAMLWizard):
+    name: str  # Task name
+    type: TaskType  # Task type
+    path: str  # task data path relative to DATA_PATH
+
+    module: Optional[str] = None  # Custom task module file, optional
+    metrics: List[str] = field(default_factory=list)  # Evaluation metrics
+
+    use_task_mask: bool = False  # Whether to use [gMASK] for evaluation
+    use_multitask_encoding: bool = False  # Not supported now
+    unidirectional: bool = False  # Whether to use unidirectional attention
+    max_seq_length: int = 2048  # Max sequence length
+    file_pattern: str | Dict[str, str] = "**/*.json*"  # Organize data file in groups
+
+    micro_batch_size: int = 1  # 'gen' task only support mbs = 1 for now
+
+    def __post_init__(self):
+        assert self.use_task_mask or not self.unidirectional, "[MASK] doesn't support unidirectional attention"
+
+
+@dataclass
+class MultiChoiceTaskConfig(BaseConfig):
+    module = "evaluation.MultiChoiceTask"
+    metrics: List[str] = field(default_factory=lambda: ["Accuracy"])
+
+
+@dataclass
+class GenerationTaskConfig(BaseConfig):
+    module = "evaluation.GenerationTask"
+    metrics: List[str] = field(default_factory=lambda: ["EM", "F1"])
+    sampling_strategy: str = "BaseStrategy"
+    num_beams: int = 4
+    length_penalty: float = 1.0
+    no_repeat_ngram_size: int = 3
+    min_gen_length: int = 0
+    max_gen_length: int = 128
+
+    def __post_init__(self):
+        assert self.micro_batch_size == 1, "Only support micro batch size = 1 for generation task"

+ 259 - 0
evaluation/dataset.py

@@ -0,0 +1,259 @@
+import os
+import json
+
+import numpy as np
+import torch
+
+from abc import ABC, abstractmethod
+from scipy.linalg import block_diag
+
+from SwissArmyTransformer import get_tokenizer
+
+from .configs import BaseConfig, MultiChoiceTaskConfig, GenerationTaskConfig
+from .utils import get_tokenized_input
+
+
+def pad_batch(tokens, position_ids, attention_mask, max_seq_length):
+    attention_mask = np.pad(
+        attention_mask,
+        pad_width=((0, max_seq_length - len(tokens)),),
+        mode="constant",
+        constant_values=0,
+    )
+    tokens = np.concatenate((tokens, np.zeros(max_seq_length - len(tokens), dtype=np.int64)))
+    position_ids = np.concatenate((position_ids, np.zeros(max_seq_length - len(position_ids), dtype=np.int64)))
+    return tokens, position_ids, attention_mask
+
+
+class EvaluationDataset(torch.utils.data.Dataset, ABC):
+    """
+    Jsonlines of {
+        "text": context
+        "choices": [choice_id1,...], if not None, len(target) == 1
+        "label": If generation task -1, else [0, len(choices))
+    }
+    If [MASK] not in context, will append [MASK] after text
+    """
+
+    def __init__(self, path, config: BaseConfig):
+        self.path = path
+        self.config = config
+        self.max_seq_length = self.config.max_seq_length
+        self.dtype = np.int64
+
+        tokenizer = get_tokenizer(tokenizer_type="icetk-glm-130B")
+        self.mask_id = tokenizer.get_command("[MASK]")
+        self.gmask_id = tokenizer.get_command("[gMASK]")
+
+        self.data = []
+        with open(os.path.join(path), "r", encoding="utf-8") as file:
+            for line in file:
+                item = json.loads(line)
+                self.data.append(self.process_single_item(item))
+
+    @property
+    def has_collate_fn(self) -> bool:
+        return False
+
+    def collate_fn(self, samples):
+        return None
+
+    @abstractmethod
+    def process_single_item(self, item) -> dict:
+        pass
+
+    def __len__(self):
+        return len(self.data)
+
+
+class GenerationTaskDataset(EvaluationDataset):
+    config: GenerationTaskConfig
+
+    def process_single_item(self, item):
+        text, targets = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "targets")
+        if len(text) + self.config.max_gen_length + 2 > self.config.max_seq_length:
+            text_length = self.config.max_seq_length - self.config.max_gen_length - 2
+            text = text[len(text) - text_length : len(text)]
+        return {"text": text, "targets": targets}
+
+    @staticmethod
+    def build_generation_sample(text, max_gen_length, use_task_mask, unidirectional=True):
+        tokenizer = get_tokenizer()
+
+        sop_id = tokenizer.get_command("sop")
+        mask_id = tokenizer.get_command("[gMASK]") if use_task_mask else tokenizer.get_command("[MASK]")
+
+        token = np.array(text, dtype=np.int64)
+
+        blank_filling = mask_id in text
+        if blank_filling:
+            assert not unidirectional, "Unidirectional attention doesn't support blank filling"
+            assert not use_task_mask, "Unidirectional attention doesn't support task mask"
+            mask_position = text.index(mask_id)
+            token = np.concatenate((token, [sop_id]))
+        else:
+            mask_position = len(token)
+            if unidirectional:
+                token = np.concatenate(([mask_id, sop_id], token))
+            else:
+                token = np.concatenate((token, [mask_id, sop_id]))
+        context_length = len(token)
+        max_seq_length = context_length + max_gen_length
+
+        position_id = np.arange(0, max_seq_length, dtype=np.int64)
+        if not use_task_mask:
+            position_id[context_length - 1 :] = mask_position
+
+        attention_mask = np.tril(np.ones((max_seq_length, max_seq_length), dtype=np.int64))
+        if not unidirectional:
+            attention_mask[: context_length - 1, : context_length - 1] = 1
+
+        item = {
+            "tokens": np.concatenate((token, np.zeros(max_seq_length - len(token), dtype=np.int64))),
+            "position_ids": position_id,
+            "attention_mask": attention_mask < 0.5,
+            "context_length": context_length,
+        }
+        return item
+
+    def __getitem__(self, idx):
+        item = self.data[idx]
+        sample = self.build_generation_sample(
+            item["text"],
+            max_gen_length=self.config.max_gen_length,
+            use_task_mask=self.config.use_task_mask,
+            unidirectional=self.config.unidirectional,
+        )
+        sample["targets"] = [np.array(target, dtype=self.dtype) for target in item["targets"]]
+        return sample
+
+
+class MultiChoiceTaskDataset(EvaluationDataset):
+    config: MultiChoiceTaskConfig
+
+    def __init__(self, path, config: MultiChoiceTaskConfig):
+        self.is_single_token = True  # set to False later in process_single_item func
+        super().__init__(path, config)
+
+    @property
+    def has_collate_fn(self) -> bool:
+        return True
+
+    def collate_fn(self, samples):
+        TILE = 32
+        length_to_pad = (max(map(lambda spl: len(spl["token"]), samples)) + TILE - 1) // TILE * TILE
+
+        token_batch, position_id_batch, attention_mask_batch = [], [], []
+        choices_batch, choice_target_ids_batch = [], []
+
+        for sample in samples:
+            token, position_id, attention_mask = pad_batch(
+                sample["token"], sample["position_id"], sample["attention_mask"], length_to_pad
+            )
+            token_batch.append(token)
+            position_id_batch.append(position_id)
+            attention_mask_batch.append(attention_mask)
+            choices_batch.append(sample["choices"])
+            choice_target_ids_batch.append(sample["choice_target_ids"])
+
+        return {
+            "tokens": torch.tensor(np.array(token_batch), dtype=torch.int64),
+            "position_ids": torch.tensor(np.array(position_id_batch), dtype=torch.int64),
+            "attention_mask": torch.tensor(np.array(attention_mask_batch), dtype=torch.int64) < 0.5,
+            "choices": choices_batch,
+            "choice_target_ids": choice_target_ids_batch,
+            "is_single_token": self.is_single_token,
+        }
+
+    def process_single_item(self, item):
+        text, choices, label = get_tokenized_input(item, "inputs"), get_tokenized_input(item, "choices"), item["label"]
+
+        tgt_seq_length = sum([len(choice) for choice in choices])
+        if tgt_seq_length == len(choices):
+            # For single token, we only insert one [sop]
+            tgt_seq_length = 1
+
+        assert tgt_seq_length < self.config.max_seq_length
+        if len(text) + tgt_seq_length + 2 > self.config.max_seq_length:
+            text_length = self.config.max_seq_length - tgt_seq_length - 2
+            text = text[len(text) - text_length : len(text)]
+
+        assert not (
+            self.mask_id in text and self.config.use_multitask_encoding
+        ), "Unified multitask encoding don't support blank filling"
+
+        if tgt_seq_length != 1:
+            self.is_single_token = False
+
+        return {
+            "text": text,
+            "choices": choices,
+            "label": label,
+        }
+
+    @staticmethod
+    def build_multiple_choice_sample(text, choices, is_single_token, unified_multitask_encoding=False):
+        tokenizer = get_tokenizer()
+
+        sop_id = tokenizer.get_command("sop")
+        mask_id = tokenizer.get_command("[MASK]")
+
+        token = np.array(text, dtype=np.int64)
+        target = np.array(text, dtype=np.int64)
+        position_id = np.arange(len(text), dtype=np.int64)
+        choice_target_id = []
+
+        blank_filling = mask_id in text
+        if not blank_filling:
+            mask_position = len(token)
+            token = np.concatenate((token, [mask_id]))
+            target = np.concatenate((target, [mask_id]))
+            position_id = np.concatenate((position_id, [mask_position]))
+        else:
+            mask_position = text.index(mask_id)
+
+        division = len(token)
+        attention_mask = [np.ones((len(token), len(token)), dtype=np.int64)]
+
+        for choice in choices:
+            position_id = np.concatenate(
+                (
+                    position_id,
+                    [mask_position] * len(choice)
+                    if blank_filling or not unified_multitask_encoding
+                    else np.arange(mask_position, mask_position + len(choice), dtype=np.int64),
+                )
+            )
+            choice_target_id.append(np.arange(len(token), len(token) + len(choice), dtype=np.int64))
+            attention_mask.append(np.tril(np.ones((len(choice), len(choice)), dtype=np.int64)))
+            token = np.concatenate((token, [sop_id], choice[:-1]))
+            target = np.concatenate((target, choice))
+
+            if is_single_token:
+                break
+
+        attention_mask = block_diag(*attention_mask)
+        attention_mask[: len(token), :division] = 1
+
+        if is_single_token:
+            choices = np.array(choices, dtype=np.int64).squeeze().tolist()
+
+        item = {
+            "token": token,
+            "position_id": position_id,
+            "attention_mask": attention_mask,
+            "choices": choices,
+            "choice_target_ids": choice_target_id[0] if is_single_token else choice_target_id,
+        }
+        return item
+
+    def __getitem__(self, idx):
+        item = self.data[idx]
+        sample = self.build_multiple_choice_sample(
+            item["text"],
+            item["choices"],
+            is_single_token=self.is_single_token,
+            unified_multitask_encoding=self.config.use_multitask_encoding,
+        )
+        sample["label"] = item["label"]
+        return sample

+ 82 - 0
evaluation/metrics.py

@@ -0,0 +1,82 @@
+import string
+import re
+import functools
+
+from collections import Counter
+
+from SwissArmyTransformer import get_tokenizer
+
+
+def accuracy_metric(predictions, examples):
+    count = 0
+    num_predictions = max(len(predictions), 1)
+    assert len(predictions) == len(examples)
+    for prediction, example in zip(predictions, examples):
+        count += prediction == example["label"]
+    return count * 100.0 / num_predictions
+
+
+def normalize_answer(s):
+    """Lower text and remove punctuation, articles and extra whitespace."""
+
+    def remove_articles(text):
+        return re.sub(r"\b(a|an|the)\b", " ", text)
+
+    def white_space_fix(text):
+        return " ".join(text.split())
+
+    def remove_punc(text):
+        exclude = set(string.punctuation)
+        return "".join(ch for ch in text if ch not in exclude)
+
+    def lower(text):
+        return text.lower()
+
+    return white_space_fix(remove_articles(remove_punc(lower(s))))
+
+
+def f1_score(prediction, ground_truth):
+    prediction_tokens = normalize_answer(prediction).split()
+    ground_truth_tokens = normalize_answer(ground_truth).split()
+    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
+    num_same = sum(common.values())
+    if num_same == 0:
+        return 0
+    precision = 1.0 * num_same / len(prediction_tokens)
+    recall = 1.0 * num_same / len(ground_truth_tokens)
+    f1 = (2 * precision * recall) / (precision + recall)
+    return f1
+
+
+def exact_match_score(prediction, ground_truth):
+    return normalize_answer(prediction) == normalize_answer(ground_truth)
+
+
+def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
+    if not ground_truths:
+        return 0.0
+    scores_for_ground_truths = []
+    for ground_truth in ground_truths:
+        score = metric_fn(prediction, ground_truth)
+        scores_for_ground_truths.append(score)
+    return max(scores_for_ground_truths)
+
+
+def qa_evaluate(predictions, examples, metric):
+    assert len(examples) == len(predictions)
+    tokenizer = get_tokenizer()
+
+    score = 0.0
+    for example, prediction in zip(examples, predictions):
+        ground_truths = [tokenizer.tokenizer.decode(target) for target in example["targets"]]
+        prediction = tokenizer.tokenizer.decode(prediction)
+        if ground_truths:
+            score += metric_max_over_ground_truths(metric, prediction, ground_truths)
+    score = 100.0 * score / len(predictions)
+    return score
+
+
+qa_exact_match = functools.partial(qa_evaluate, metric=exact_match_score)
+qa_f1 = functools.partial(qa_evaluate, metric=f1_score)
+
+DEFAULT_METRICS = {"EM": qa_exact_match, "F1": qa_f1, "Accuracy": accuracy_metric}

+ 88 - 0
evaluation/model.py

@@ -0,0 +1,88 @@
+import torch
+
+from typing import List, Union
+
+from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence
+
+
+class ModelForEvaluation(torch.nn.Module):
+    def __init__(self, model):
+        super().__init__()
+
+        self.model = model
+
+    @staticmethod
+    def process_data(batch):
+        return (
+            batch["tokens"].to(device=torch.cuda.current_device()).long(),
+            batch["position_ids"].to(device=torch.cuda.current_device()).long(),
+            batch["attention_mask"].to(device=torch.cuda.current_device()).bool().unsqueeze(1),
+        )
+
+    def cond_log_prob(self, batch) -> List[List[float]]:
+        """
+        @return: Conditional log probability of each option
+        """
+        tokens, position_ids, attention_mask = self.process_data(batch)
+        choices_batch, choice_target_ids_batch = batch["choices"], batch["choice_target_ids"]
+        is_single_token = batch["is_single_token"]
+
+        self.model.eval()
+        with torch.no_grad():
+            logits, *output_per_layers = self.model(tokens, position_ids, attention_mask, log_attention_weights=None)
+            logits_batch = torch.nn.functional.log_softmax(logits, dim=-1)
+
+        # output: [b, sq, vocab]
+        log_probs = []
+
+        if is_single_token:  # Single token
+            for logits, choices, choice_target_ids in zip(logits_batch, choices_batch, choice_target_ids_batch):
+                log_probs.append(logits[choice_target_ids[0], choices].tolist())
+        else:  # Multi token
+            for output, choices, choice_target_ids in zip(logits_batch, choices_batch, choice_target_ids_batch):
+                log_probs_single = []
+                for choice, choice_target_id in zip(choices, choice_target_ids):
+                    tmp = output[choice_target_id, choice]
+                    log_probs_single.append(tmp.sum().tolist())
+                log_probs.append(log_probs_single)
+        return log_probs
+
+    def generate_text(self, sample, strategy, return_all_beams=False) -> Union[List[int], List[List[int]]]:
+        """
+        @return: A list of text model generated, sorted by score in descending order
+        """
+
+        seq = torch.squeeze(sample["tokens"].to(device=torch.cuda.current_device()).long())
+        context_length = sample["context_length"].to(device=torch.cuda.current_device()).long()
+        seq[context_length:] = -1
+
+        def get_masks_and_position_ids(seq):
+            tokens = seq.unsqueeze(0)
+            attention_mask = sample["attention_mask"].to(device=torch.cuda.current_device()).bool().unsqueeze(1)
+            position_ids = sample["position_ids"].to(device=torch.cuda.current_device()).long()
+            return tokens, attention_mask, position_ids
+
+        self.model.eval()
+        with torch.no_grad():
+            output = filling_sequence(
+                self.model,
+                seq,
+                get_masks_and_position_ids=get_masks_and_position_ids,
+                batch_size=strategy.num_beams if hasattr(strategy, "num_beams") else 1,
+                strategy=strategy,
+            )[0]
+
+        if isinstance(output, torch.Tensor):  # different strategies
+            output = list(output)
+
+        output_targets = []
+
+        for line in output:
+            line = line.tolist()
+            unfinished = line.index(-1) if -1 in line else len(line)
+            if line[unfinished - 1] in strategy.end_tokens:
+                unfinished -= 1
+            line = line[context_length:unfinished]
+            output_targets.append(line)
+
+        return output_targets if return_all_beams else output_targets[0]

+ 207 - 0
evaluation/tasks.py

@@ -0,0 +1,207 @@
+import torch
+import time
+import numpy as np
+import torch.distributed as dist
+
+from typing import Dict, Callable, Type, Tuple, List, Any
+from abc import ABC, abstractmethod
+from glob import glob
+from os.path import join, relpath
+from collections import defaultdict
+
+from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
+from SwissArmyTransformer.tokenization.icetk_glm_130B.ice_tokenizer import _IceTokenizer
+
+from generation import BeamSearchStrategy
+from .configs import BaseConfig, GenerationTaskConfig, MultiChoiceTaskConfig
+from .model import ModelForEvaluation
+from .dataset import EvaluationDataset, GenerationTaskDataset, MultiChoiceTaskDataset
+from .utils import build_data_loader, gather_result, print_rank_0
+from .metrics import DEFAULT_METRICS
+
+
+class BaseTask(ABC):
+    model: ModelForEvaluation
+    tokenizer: _IceTokenizer
+    config: BaseConfig
+    file_groups: Dict[str, List[str]]
+
+    @classmethod
+    def config_class(cls) -> Type[BaseConfig]:
+        return BaseConfig
+
+    @property
+    def metrics(self) -> Dict[str, Callable]:
+        return {metric: DEFAULT_METRICS[metric] for metric in self.config.metrics}
+
+    def __init__(self, model: ModelForEvaluation, tokenizer: _IceTokenizer, config: BaseConfig):
+        self.model = model
+        self.tokenizer = tokenizer
+        self.config = config
+        self.config.metrics = list(self.metrics.keys())
+
+        self.file_groups = self.get_file_groups()
+        self.verbose = dist.get_rank() == 0
+
+    def get_file_groups(self):
+        pattern_group = {}
+        if isinstance(self.config.file_pattern, str):
+            pattern_group["all"] = self.config.file_pattern
+        else:
+            pattern_group = self.config.file_pattern
+        return {
+            name: [
+                relpath(path, start=self.config.path)
+                for path in sorted(glob(join(self.config.path, pattern), recursive=True))
+            ]
+            for name, pattern in pattern_group.items()
+        }
+
+    def evaluate(self):
+        dist.barrier()
+        start = time.time()
+        print_rank_0("\n")
+        print_rank_0(f"{self.config}")
+        print_rank_0(f"Evaluating task {self.config.name}:")
+
+        result_dict_all = {}
+
+        for group_name, filelist in self.file_groups.items():
+            print_rank_0(f"    Evaluating group {group_name}:")
+
+            result_dict_group = {}
+            for file in filelist:
+                dataset = self.build_dataset(file)
+                dataloader = build_data_loader(
+                    dataset,
+                    micro_batch_size=self.config.micro_batch_size,
+                    num_workers=1,
+                    drop_last=False,
+                    collate_fn=dataset.collate_fn if dataset.has_collate_fn else None,
+                )
+
+                prediction = []
+                with torch.no_grad():
+                    for _, batch in enumerate(dataloader):
+                        prediction.append(self.predict_single_batch(batch))
+
+                prediction = gather_result(prediction, len(dataset), self.config.micro_batch_size)
+                result_dict = {key: metric(prediction, dataset.data) for key, metric in self.metrics.items()}
+                result_dict_group[file] = (result_dict, len(dataset))
+
+                if self.verbose:
+                    self.report_single_metrics(file, result_dict)
+
+            result_dict_all[group_name] = result_dict_group
+
+        print_rank_0(f"Evaluation results of task {self.config.name}:")
+
+        if self.verbose:
+            for group_name, result_dict_group in result_dict_all.items():
+                self.report_group_metrics(group_name, result_dict_group)
+            self.report_overall_metrics(
+                {k: v for result_dict_group in result_dict_all.values() for k, v in result_dict_group.items()},
+            )
+
+        print_rank_0(f"Finish task {self.config.name} in {time.time() - start:.1f}s.")
+
+    def report_single_metrics(self, file: str, result_dict: Dict[str, float]):
+        output_str = f"        Finish {file}"
+        for key, value in result_dict.items():
+            output_str += f", {key} = {value:.3f}"
+        print_rank_0(output_str)
+
+    @staticmethod
+    def calc_group_metrics(result_dict_group: Dict[str, Tuple[Dict[str, float], int]]):
+        metrics_dict = defaultdict(lambda: [])
+        weight = []
+        for file, (result_dict, length) in result_dict_group.items():
+            for key, value in result_dict.items():
+                metrics_dict[key].append(value)
+            weight.append(length)
+        return {
+            name: {
+                "max": np.max(value),
+                "median": np.median(value),
+                "average": np.average(value, weights=weight),
+            }
+            for name, value in metrics_dict.items()
+        }
+
+    def report_group_metrics(self, group_name, result_dict_group: Dict[str, Tuple[Dict[str, float], int]], level=1):
+        stats_dict = self.calc_group_metrics(result_dict_group)
+        if len(stats_dict) == 1:
+            name, stats = next(iter(stats_dict.items()))
+            print_rank_0(
+                "    " * level + f"Group {group_name} {name}: max = {stats['max']:.3f}, "
+                f"median = {stats['median']:.3f}, average = {stats['average']:.3f}"
+            )
+        else:
+            print_rank_0("    " * level + f"  Group {group_name}: ")
+            for name, stats in stats_dict.items():
+                print(
+                    "    " * (level + 1) + f"Metric {name}: max = {stats['max']:.3f}, "
+                    f"median = {stats['median']:.3f}, average = {stats['average']:.3f}"
+                )
+
+    def report_overall_metrics(self, result_dict_all: Dict[str, Tuple[Dict[str, float], int]]):
+        pass
+
+    @abstractmethod
+    def predict_single_batch(self, batch) -> List[Any]:
+        pass
+
+    @abstractmethod
+    def build_dataset(self, relative_path: str) -> EvaluationDataset:
+        pass
+
+
+class GenerationTask(BaseTask, ABC):
+    config: GenerationTaskConfig
+
+    @classmethod
+    def config_class(cls):
+        return GenerationTaskConfig
+
+    def build_dataset(self, relative_path):
+        return GenerationTaskDataset(join(self.config.path, relative_path), self.config)
+
+    def __init__(self, model: ModelForEvaluation, tokenizer: _IceTokenizer, config: GenerationTaskConfig):
+        super(GenerationTask, self).__init__(model, tokenizer, config)
+
+        end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
+        if self.config.sampling_strategy == "BaseStrategy":
+            self.strategy = BaseStrategy(temperature=1.0, top_k=1, end_tokens=end_tokens)
+        elif self.config.sampling_strategy == "BeamSearchStrategy":
+            self.strategy = BeamSearchStrategy(
+                self.config.num_beams,
+                length_penalty=self.config.length_penalty,
+                consider_end=True,
+                end_tokens=end_tokens,
+                no_repeat_ngram_size=self.config.no_repeat_ngram_size,
+                min_gen_length=self.config.min_gen_length,
+                deterministic=True,  # For evaluation, we need a determined generation strategy
+            )
+        else:
+            raise ValueError(f"unknown strategy {self.config.sampling_strategy}")
+
+    def predict_single_batch(self, batch) -> List[List[int]]:
+        # micro batch size = 1 for generation task,
+        # but we still need to return a list of predictions for consistency
+        output = self.model.generate_text(batch, self.strategy, return_all_beams=False)
+        return [output]
+
+
+class MultiChoiceTask(BaseTask, ABC):
+    config: MultiChoiceTaskConfig
+
+    @classmethod
+    def config_class(cls):
+        return MultiChoiceTaskConfig
+
+    def build_dataset(self, relative_path):
+        return MultiChoiceTaskDataset(join(self.config.path, relative_path), self.config)
+
+    def predict_single_batch(self, batch) -> List[int]:
+        log_probs = self.model.cond_log_prob(batch)
+        return [np.argmax(log_probs_single).item() for log_probs_single in log_probs]

+ 67 - 0
evaluation/utils.py

@@ -0,0 +1,67 @@
+import torch
+import torch.distributed as dist
+
+from SwissArmyTransformer import mpu, get_tokenizer
+
+
+def print_rank_0(*args, **kwargs):
+    if torch.distributed.get_rank() == 0:
+        print(*args, **kwargs)
+
+
+def build_data_loader(dataset, micro_batch_size, num_workers, drop_last, collate_fn=None):
+    # Sampler.
+    world_size = mpu.get_data_parallel_world_size()
+    rank = mpu.get_data_parallel_rank()
+    sampler = torch.utils.data.distributed.DistributedSampler(
+        dataset, num_replicas=world_size, rank=rank, shuffle=False
+    )
+
+    # Data loader. Note that batch size is the per GPU batch size.
+    data_loader = torch.utils.data.DataLoader(
+        dataset,
+        batch_size=micro_batch_size,
+        sampler=sampler,
+        shuffle=False,
+        num_workers=num_workers,
+        drop_last=drop_last,
+        pin_memory=True,
+        collate_fn=collate_fn,
+    )
+
+    return data_loader
+
+
+def gather_result(prediction, total_length, micro_batch_size):
+    """
+    @param prediction: Local predictions with order defined by distributed sampler
+    @param total_length: Total sample num
+    @return: [sample_0, sample_1, ..., sample_{total_length-1}]
+    """
+    torch.cuda.empty_cache()
+    world_size = mpu.get_data_parallel_world_size()
+    prediction_gathered = [None for _ in range(world_size)]
+    dist.all_gather_object(prediction_gathered, prediction, group=mpu.get_data_parallel_group())
+    prediction = []
+    for i in range(len(prediction_gathered[0])):
+        for j in range(micro_batch_size):
+            for k in range(world_size):
+                if j < len(prediction_gathered[k][i]):
+                    prediction.append(prediction_gathered[k][i][j])
+    prediction = prediction[:total_length]
+    return prediction
+
+
+def get_tokenized_input(item, key):
+    if key in item:
+        return item[key]
+    tokenizer = get_tokenizer()
+    pretokenized_key = key + "_pretokenized"
+    assert pretokenized_key in item
+    if isinstance(item[pretokenized_key], list):
+        result = []
+        for raw in item[pretokenized_key]:
+            result.append(tokenizer.tokenize(raw))
+        return result
+    else:
+        return tokenizer.tokenize(item[pretokenized_key])

+ 210 - 0
generate.py

@@ -0,0 +1,210 @@
+import os
+import torch
+import stat
+import re
+
+from functools import partial
+from typing import List, Tuple
+
+from SwissArmyTransformer import mpu
+from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence
+from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
+from generation import BeamSearchStrategy
+from SwissArmyTransformer.generation.utils import timed_name, generate_continually
+from initialize import initialize, initialize_model_and_tokenizer
+
+
+def add_generation_specific_args(parser):
+    parser.add_argument("--sampling-strategy", type=str, default="BaseStrategy", help="Type of sampling strategy.")
+    parser.add_argument("--min-gen-length", type=int, default=0, help="The minimum length each blank should generate.")
+    parser.add_argument(
+        "--print-all-beams", action="store_true", help="Print all output generated by beam search strategy."
+    )
+
+
+def isEnglish(s):
+    try:
+        s.encode(encoding="utf-8").decode("ascii")
+    except UnicodeDecodeError:
+        return False
+    else:
+        return True
+
+
+def get_masks_and_position_ids(seq, mask_position, context_length, gmask=False):
+    tokens = seq.unsqueeze(0)
+
+    attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device)
+    attention_mask.tril_()
+    attention_mask[..., : context_length - 1] = 1
+    attention_mask.unsqueeze_(1)
+    attention_mask = (attention_mask < 0.5).bool()
+
+    position_ids = torch.arange(len(seq), dtype=torch.long, device=tokens.device)
+    if not gmask:
+        position_ids[context_length - 1 :] = mask_position
+
+    position_ids = position_ids.unsqueeze(0)
+
+    return tokens, attention_mask, position_ids
+
+
+def fill_blanks(raw_text: str, model, tokenizer, strategy) -> Tuple[List[str], List[str], List[List[str]]]:
+    # add MASK
+    generation_mask = "[MASK]" if "[MASK]" in raw_text else "[gMASK]"
+    use_gmask = "[MASK]" not in raw_text
+
+    mask_pattern = r"\[g?MASK\]"
+    text_list = re.split(mask_pattern, raw_text)
+    pattern_list = re.compile(mask_pattern).findall(raw_text)
+    seq = []
+    for i in range(len(pattern_list)):
+        pattern = pattern_list[i]
+        sub_text = text_list[i]
+        seq.extend(tokenizer.tokenize(sub_text))
+        seq.append(tokenizer.get_command(pattern))
+
+    seq.extend(tokenizer.tokenize(text_list[-1]))
+
+    if "MASK]" not in raw_text:
+        seq += [tokenizer.get_command(generation_mask)]
+        raw_text += " " + generation_mask
+    if not raw_text.endswith("MASK]"):
+        seq = seq + [tokenizer.get_command("eos")]
+    if mpu.get_model_parallel_rank() == 0:
+        print("\nInput: {}\n".format(raw_text))
+    if len(seq) > args.max_sequence_length:
+        raise ValueError("text too long.")
+
+    # generation
+    is_english = isEnglish(raw_text)
+    output_list = [seq]
+    num_output = args.num_beams if args.sampling_strategy == "BeamSearchStrategy" else 1
+    last_pos, answers, answers_with_style, blanks = (
+        [0] * num_output,
+        ["" for _ in range(num_output)],
+        ["" for _ in range(num_output)],
+        [[] for _ in range(num_output)],
+    )
+
+    # continually detect the first mark position
+    while True:
+        seq = output_list[0]
+        # detect mask position
+        mask_token = tokenizer.get_command(generation_mask)
+        if mask_token not in seq:
+            break
+        mask_position = seq.index(mask_token)
+
+        output_list = []
+
+        input_seq = torch.cuda.LongTensor(
+            seq + [tokenizer.get_command("sop")] + [-1] * (args.out_seq_length - len(seq) - 1),
+            device=args.device,
+        )
+        output, _ = filling_sequence(
+            model,
+            input_seq,
+            batch_size=num_output,
+            strategy=strategy,
+            log_attention_weights=None,
+            get_masks_and_position_ids=partial(
+                get_masks_and_position_ids,
+                mask_position=mask_position,
+                context_length=len(seq) + 1,
+                gmask=use_gmask,
+            ),
+        )
+        if isinstance(output, torch.Tensor):  # different strategies
+            output = list(output)
+
+        output_list.extend(output)
+
+        # clip -1s and fill back generated things into seq
+        for i in range(len(output_list)):
+            output = output_list[i].tolist()
+            try:
+                unfinished = output.index(-1)
+            except ValueError:
+                unfinished = len(output)
+            if output[unfinished - 1] in strategy.end_tokens:
+                unfinished -= 1
+            bog = output.index(tokenizer.get_command("sop"))
+
+            prefix = tokenizer.detokenize(output[last_pos[i] : mask_position])
+            blank = tokenizer.detokenize(output[bog + 1 : unfinished])
+            answers_with_style[i] += (
+                prefix
+                + (" " if is_english else "")
+                + ("\033[4m" if use_gmask else "\x1b[0;32m\033[4m")
+                + blank
+                + ("\033[0m" if use_gmask else "\033[0m\x1b[0m")
+                + (" " if is_english else "")
+            )
+            blanks[i].append(blank)
+            last_pos[i] = mask_position + unfinished - (bog + 1)
+            output_list[i] = output[:mask_position] + output[bog + 1 : unfinished] + output[mask_position + 1 : bog]
+
+    for i, output in enumerate(output_list):
+        if output[-1] == tokenizer.get_command("eos"):
+            output = output[:-1]
+        answers_with_style[i] += tokenizer.detokenize(output[last_pos[i] :])
+        answers[i] = tokenizer.detokenize(output)
+
+    return answers, answers_with_style, blanks
+
+
+def main(args):
+    model, tokenizer = initialize_model_and_tokenizer(args)
+
+    end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
+
+    if args.sampling_strategy == "BaseStrategy":
+        strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, end_tokens=end_tokens)
+    elif args.sampling_strategy == "BeamSearchStrategy":
+        strategy = BeamSearchStrategy(
+            args.num_beams,
+            length_penalty=args.length_penalty,
+            consider_end=True,
+            end_tokens=end_tokens,
+            no_repeat_ngram_size=args.no_repeat_ngram_size,
+            min_gen_length=args.min_gen_length,
+        )
+    else:
+        raise ValueError(f"unknown strategy {args.sampling_strategy}")
+
+    def process(raw_text):
+        if args.with_id:
+            query_id, raw_text = raw_text.split("\t")
+
+        answers, answers_with_style, blanks = fill_blanks(raw_text, model, tokenizer, strategy)
+
+        # save
+        if args.with_id:
+            full_path = os.path.join(args.output_path, query_id + ".txt")
+        else:
+            prefix = raw_text.replace("/", "")[:20]
+            full_path = timed_name(prefix, ".txt", args.output_path)
+        if mpu.get_model_parallel_rank() == 0:
+            if args.print_all_beams and len(answers) > 1:
+                for idx, answer_with_style in enumerate(answers_with_style):
+                    print(f"Output beam {idx}:", answer_with_style)  # print the first.
+                    if len(answer_with_style) > 120:
+                        print("")
+            else:
+                print(f"Output:", answers_with_style[0])  # print the first.
+            with open(full_path, "w", encoding="utf-8") as fout:
+                for answer in answers:
+                    fout.write(answer + "\n")
+
+            os.chmod(full_path, stat.S_IRWXO + stat.S_IRWXG + stat.S_IRWXU)
+
+    os.makedirs(args.output_path, exist_ok=True)
+    generate_continually(process, args.input_source)
+
+
+if __name__ == "__main__":
+    args = initialize(extra_args_provider=add_generation_specific_args)
+
+    with torch.no_grad():
+        main(args)

+ 1 - 0
generation/__init__.py

@@ -0,0 +1 @@
+from .strategies import BeamSearchStrategy

+ 132 - 0
generation/strategies.py

@@ -0,0 +1,132 @@
+import torch
+import torch.nn.functional as F
+
+
+class BeamSearchStrategy:
+    def __init__(
+        self,
+        num_beams,
+        length_penalty=1.0,
+        consider_end=False,
+        end_tokens=[],
+        invalid_slices=[],
+        no_repeat_ngram_size=0,
+        min_gen_length=0,
+        deterministic=False,
+    ):
+        self.num_beams = num_beams
+        self.length_penalty = length_penalty
+        self.end_tokens = end_tokens
+        self.ngram = no_repeat_ngram_size
+        self.min_gen_length = min_gen_length
+        self.invalid_slices = invalid_slices
+        self.consider_end = consider_end
+        self.deterministic = deterministic
+        self._init_cache()
+
+    def _init_cache(self):
+        self.end_beams = []  # list of LongTensors
+        self.end_beams_penalized_scores = []  # list of LongTensors
+        self.cached_beam_scores = 0  # [batch_size]
+        self.cached_beam_ngram_bans = [{} for i in range(self.num_beams)]
+        self.length_generated = 0
+        self.is_done = False
+
+    def _add_end_beams(self, score, beam):
+        score = score / ((5.0 + len(beam)) / 6) ** self.length_penalty  # Magic number for OpenNMT
+        for i in range(len(self.end_beams), -1, -1):
+            if i == 0 or score < self.end_beams_penalized_scores[i - 1]:
+                break
+        self.end_beams.insert(i, beam)
+        self.end_beams_penalized_scores.insert(i, score)
+
+        self.end_beams = self.end_beams[: self.num_beams]
+        self.end_beams_penalized_scores = self.end_beams_penalized_scores[: self.num_beams]
+
+    def forward(self, logits, tokens, mems):
+        batch_size, vocab_size = logits.shape
+        seq_len = tokens.shape[-1]
+        logits = logits.float()
+        for invalid_slice in self.invalid_slices:
+            logits[..., invalid_slice] = -65504
+        if self.min_gen_length > self.length_generated:
+            for end_token in self.end_tokens:
+                logits[..., end_token] = -65504
+        if self.ngram > 0 and seq_len > self.ngram:
+            for i in range(batch_size):
+                ngram_prefix = tokens[i, -(self.ngram - 1) :].tolist()  # TODO ngram=1
+                for banned_index in self.cached_beam_ngram_bans[i].get(tuple(ngram_prefix), []):
+                    logits[i, banned_index] = -65504
+
+        next_token_scores = F.log_softmax(logits, dim=-1)  # [batch_size, vocab_size]
+        prev_scores = self.cached_beam_scores
+        if isinstance(self.cached_beam_scores, torch.Tensor):
+            prev_scores = prev_scores[:, None].expand_as(next_token_scores)
+        next_token_scores = next_token_scores + prev_scores
+
+        next_token_scores = next_token_scores.view(batch_size * vocab_size)
+
+        probs = F.softmax(next_token_scores, dim=0)
+        if self.deterministic:
+            if mems.shape[1] < batch_size:  # First token
+                probs = probs[:vocab_size]
+            next_tokens = torch.topk(probs, k=(max(1, len(self.end_tokens)) + 1) * self.num_beams).indices  # [2*nb]
+        else:
+            next_tokens = torch.multinomial(
+                probs, num_samples=(max(1, len(self.end_tokens)) + 1) * self.num_beams
+            )  # [2*nb]
+        next_token_scores = next_token_scores[next_tokens]
+        next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=0)
+        next_tokens = next_tokens[_indices]
+
+        next_indices = torch.div(next_tokens, vocab_size, rounding_mode="trunc")
+        next_tokens = next_tokens % vocab_size
+
+        # select out end beams or continue beams
+        if mems.shape[1] < batch_size:
+            mems = mems.expand(-1, batch_size, -1, -1)
+        beam_continue = []
+        scores_continue = []
+        bans_continue = []
+        mems_contiue = []
+        for i in range(len(next_tokens)):
+            beam = torch.cat((tokens[next_indices[i]], next_tokens[i : i + 1]))
+            if int(next_tokens[i]) in self.end_tokens:
+                self._add_end_beams(next_token_scores[i], beam)
+            elif len(beam_continue) < self.num_beams:
+                beam_continue.append(beam)
+                mems_contiue.append(mems[:, next_indices[i]])
+                # update caches
+                scores_continue.append(next_token_scores[i])
+                if self.ngram > 0:
+                    bans = self.cached_beam_ngram_bans[next_indices[i]].copy()
+                    ngram_prefix = tuple(tokens[next_indices[i], -(self.ngram - 1) :].tolist())  # TODO ngram=1
+                    bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (next_tokens[i],)
+                    bans_continue.append(bans)
+            else:
+                break
+        tokens = torch.stack(beam_continue)
+        mems = torch.stack(mems_contiue, dim=1)
+        self.cached_beam_scores = torch.tensor(scores_continue, device=logits.device)
+        self.cached_beam_ngram_bans = bans_continue
+        self.length_generated += 1
+
+        if (
+            len(self.end_beams) == self.num_beams
+            and self.end_beams_penalized_scores[-1]
+            >= self.cached_beam_scores.max() / ((5.0 + (seq_len + 1)) / 6) ** self.length_penalty
+        ):  # We're done if none of current tokens will better than the worst in end_beams
+            self.is_done = True
+
+        return tokens, mems
+
+    def finalize(self, tokens, mems):
+        if self.consider_end:
+            for i in range(tokens.shape[0]):
+                self._add_end_beams(self.cached_beam_scores[i], tokens[i])
+            mems = None
+            ret = self.end_beams
+        else:
+            ret = tokens
+        self._init_cache()
+        return ret, mems

+ 63 - 0
initialize.py

@@ -0,0 +1,63 @@
+import argparse
+import torch
+import time
+
+from SwissArmyTransformer import get_args, get_tokenizer
+from SwissArmyTransformer.arguments import initialize_distributed
+from SwissArmyTransformer.training import load_checkpoint
+from SwissArmyTransformer.model import GLM130B
+
+
+def add_bminf_args(parser):
+    """Arguments for BMInf"""
+    group = parser.add_argument_group("BMInf")
+
+    group.add_argument("--bminf", action="store_true", help="Use BMInf to support low resource evaluation")
+    group.add_argument("--bminf-memory-limit", type=int, default=20, help="Max memory for model per GPU (in GB)")
+    return parser
+
+
+def initialize(extra_args_provider):
+    parser = argparse.ArgumentParser(add_help=False)
+    add_bminf_args(parser)
+    GLM130B.add_model_specific_args(parser)
+    extra_args_provider(parser)
+    known, args_list = parser.parse_known_args()
+    args = get_args(args_list)
+    args = argparse.Namespace(**vars(args), **vars(known))
+    args.do_train = False
+    initialize_distributed(args)
+    return args
+
+
+def initialize_model_and_tokenizer(args):
+    tokenizer = get_tokenizer(args)
+
+    model = GLM130B(args).half()
+    if args.bminf:
+        import bminf
+
+        with torch.cuda.device(args.device):
+            model = bminf.wrapper(model, quantization=False, memory_limit=args.bminf_memory_limit << 30)
+    else:
+        model = model.to(args.device)
+
+    torch.distributed.barrier()
+    start = time.time()
+    load_checkpoint(model, args)
+    torch.distributed.barrier()
+    if torch.distributed.get_rank() == 0:
+        print(f"> Checkpoint loaded in {time.time() - start:.1f}s")
+    model.eval()
+
+    # generate rotary embedding cache
+    with torch.no_grad():
+        _, *_ = model(
+            torch.ones(1, 1, device=torch.cuda.current_device(), dtype=torch.int64),
+            torch.ones(1, 1, device=torch.cuda.current_device(), dtype=torch.int64) * args.max_sequence_length,
+            torch.ones(1, 1, 1, 1, device=torch.cuda.current_device(), dtype=torch.bool),
+        )
+
+    torch.distributed.barrier()
+
+    return model, tokenizer

+ 5 - 0
logs/README.md

@@ -0,0 +1,5 @@
+# Training Logs
+
+`main-log.md` contains detailed information about each restart of training during GLM-130B training.
+
+Tensorboard logs is available at [here](https://cloud.tsinghua.edu.cn/f/614c06645a1743529a76/).

+ 251 - 0
logs/main-log.md

@@ -0,0 +1,251 @@
+# GLM-130B 训练日志
+
+## 模型信息
+
+- 130B:70 layers,12288 hidden size,32768 ffn hidden size, 150000 vocab size
+   - MP = 4, PP = 8
+- GLM + Rotary Positional Embedding + GeGLU + DeepNorm
+- FP32 softmax with QKV scaling(no PB-Relax)
+- Shrink embedding gradient with $\alpha=0.1$
+- Global batch size: 4224
+
+## 环境版本
+
+- PyTorch 1.11 / CUDA 11.3
+- LargeScale@400893da37bb5cbe22c29e41c02a052369cc72ce
+- DeepSpeed 0.6.1
+- apex@master
+
+## 测速
+
+- 96 nodes, BSZ=176 * 24=4224
+   - glm-130B-2022.05.05-19:34:16:134TFLOPS, 88.5s/iter, 48samples/s,
+- 96 nodes, BSZ=256 * 24=6144
+   - glm-130B-2022.05.05-19:43:13:141TFLOPS, 122.5s/iter, 50samples/s
+
+## 2022-05-06 04:00 开始训练
+
+- glm-130B-2022.05.05-19:53:15
+
+## 2022-05-07 20:14 节点故障
+
+坏掉 n30041, n30157 两个点,更改保存间隔为 100step,从 4000 step 开始训练
+
+- glm-130B-2022.05.07-13:44:59
+
+## 2022-05-10 00:00 提升 alpha
+
+加入 `--shrink-embedding-gradient-steps 6000 500` 从 6000 step 开始训练
+
+- glm-130B-2022.05.09-16:02:04
+
+## 2022-05-11 12:13 节点故障
+
+坏掉 n30115 节点,从 7300 step 开始训练
+
+- glm-130B-2022.05.11-05:55:32
+
+## 2022-05-20 00:03 节点故障
+
+坏掉 n30066 节点,从 15400 step 开始训练
+
+- glm-130B-2022.05.19-19:56:19
+
+再换一批节点,从 15600 step 开始训练
+
+- glm-130B-2022.05.20-01:58:57
+
+## 2022-05-21 12:40 换节点
+
+训练效率一直只有 127T 左右,怀疑之前加入的 n30076 存在问题,踢出后从 16600 step 开始训练,似乎不解决问题。
+
+## 2022-05-22 19:27 节点故障
+
+n30126 失联
+
+- glm-130B-2022.05.22-14:15:41
+
+## 2022-05-26 04:30 节点故障
+
+n30039 掉卡
+
+- glm-130B-2022.05.25-22:23:12
+
+
+## 2022-05-28 11:50 更换中英多任务数据(废除)
+
+从 22800 开始训练,换中英多任务数据
+
+- glm-130B-2022.05.28-03:52:26
+- events.out.tfevents.1653709957.9droa42ltcad5-0.1858.0(移除)
+
+## 2022-05-28 16:50 更换英文多任务数据(废除)
+
+换新的多任务数据 22900 左右出现 nan,挂掉训练,检查发现中文多任务数据噪声极大,从 22800 换成平衡后的 t0 原始数据开始训练
+
+- glm-130B-2022.05.28-09:18:12
+- events.out.tfevents.1653729502.9droa42ltcad5-0.5648.0(移除)
+
+## 2022-05-28 20:50 加入 warmup(废除)
+
+![Image.png](https://res.craft.do/user/full/97ed555f-7125-cca2-fd7d-9f1a0585132e/doc/A8DAC1A6-2A03-489A-8A11-BFAFFFEE3905/C850748B-92A4-4F9F-932F-AD22330895D6_2/E8MboG8vrTTb2N51FRhkb6wsB4eyrD77USmM992obQgz/Image.png)
+
+换上平衡后且不泄漏的 t0 原始数据开始训练仍然有问题,推测是平衡后一些任务占比变大,其实等价于加入新任务的情况,加入参数 `--warmup-samples-after-loading 2112000` warmup 500 步从 22800 开始训练
+
+- glm-130B-2022.05.28-12:57:24
+- events.out.tfevents.1653742654.9droa42ltcad5-0.7942.0(移除)
+
+## 2022-05-29 01:30 再次爆炸,换纯文本(废除)
+
+![Image.png](https://res.craft.do/user/full/97ed555f-7125-cca2-fd7d-9f1a0585132e/doc/A8DAC1A6-2A03-489A-8A11-BFAFFFEE3905/028DE014-00FE-4521-BEEB-EF3F61BB8DA1_2/mgYybTR1OLgPkBysqMiUgGYNyIg8OQnf1yXI66grBeMz/Image.png)
+
+- warmup 以后还是炸了,分析可能是 distribution 变动仍然太过剧烈,先换纯文本 + reshuffle 尝试训练,从 22800 加载
+- glm-130B-2022.05.28-18:05:33
+- events.out.tfevents.1653761143.9droa42ltcad5-0.9744.0(废除)
+- global_step23200_text
++ 配置文件
+
+## 2022-05-29 逐渐修改数据分布(废除)
+
+![Image.png](https://res.craft.do/user/full/97ed555f-7125-cca2-fd7d-9f1a0585132e/doc/A8DAC1A6-2A03-489A-8A11-BFAFFFEE3905/E2BC463F-E519-461E-B1B0-99551DA940BE_2/0ZqN22TLyqRTvqOy6JNLeixEy4TarDJEF7DOvdh3saIz/Image.png)
+
+![Image.png](https://res.craft.do/user/full/97ed555f-7125-cca2-fd7d-9f1a0585132e/doc/A8DAC1A6-2A03-489A-8A11-BFAFFFEE3905/9C7AC4B3-59AB-471A-872E-41CCBAE7E90D_2/0rpEmyAOcIkLyDGR2R4RQiBeUwbWIWiaHbHcwosx6yAz/Image.png)
+
+文本似乎能稳定,那么尝试逐渐平滑修改数据分布, 从 22800 开始,逐渐修改数据分布到 t0 平衡数据
+
+- glm-130B-2022.05.29-05:17:06
+- events.out.tfevents.1653801436.9droa42ltcad5-0.13868.0(废除)
+
+## 2022-05-29 22:40 逐渐修改数据分布并全面 warmup
+
+- 又挂了,分析可能是换新分布学习率也需要 warmup
+
+![Image.png](https://res.craft.do/user/full/97ed555f-7125-cca2-fd7d-9f1a0585132e/doc/A8DAC1A6-2A03-489A-8A11-BFAFFFEE3905/F5532A86-3AAC-4CCE-AC9B-A976B7736D7F_2/M4JZx5GYzNPuysPHXrn0R5Oo54rBhDwQxdErkOpFOhEz/Image.png)
+
+- 从 22800 开始训练,数据和 lr 都 warmup 2000 步,shrink embbeding graident 从 0.2 warmup 6000 步到 1
+- glm-130B-2022.05.29-17:35:45
+
+## 2022-05-30 14:00 挂节点
+
+更改了一下参数配置,发现之前 shrink embedding 的步数写错了(26850 步),现在改成 6000 步。升级了一下 lr auto warmup 的逻辑,写成绝对 samples 数量。从 global_step23200 开始
+
+我们发现这次训练卡在了数据加载,排查后发现是 Lustre 文件系统的故障,导致 2.3T 文本数据读不出来,且工程师无法修复;最终重新从移动硬盘拷贝了一次数据
+
+- glm-130B-2022.05.31-02:18:24
+
+## 2022.05.03 20:00 加 DeepStruct 数据
+
+- 维持原有 transform 过程不变,但直接加入 DeepStruct 数据,从 23500 开始
+
+## 2022-06-01 22:22 换清洗数据
+
+之前的多任务数据 t0 和 deepsturct 各有一个任务的 target 异常,重新清洗后更换,从 24500 开始
+
+- glm-130B-2022.06.01-14:24:33
+
+## 2022-06-02 12:00 节点故障
+
+- n30145 CPU 故障,从 25000 重启训练,lr 和 数据集已经 transfromer 完毕,所以配置直接去掉 warmup
+- glm-130B-2022.06.02-04:35:05
+
+## 2022-06-02 09:30 加入 multitask loss 打印
+
+25800steps 开始,加入 multitask loss 打印
+
+- glm-130B-2022.06.03-01:40:12
+
+## 2022-06-02 15:00 降低学习率,加入 gpt/bert loss 打印
+
+loss 降低比较慢,讨论可能是学习率太大了,26000steps 开始,学习率砍半
+
+- glm-130B-2022.06.03-07:26:16
+
+## 2022-06-06 17:00 集群维护
+
+集群从 9 点到 5 点升级驱动,从  开始训练
+
+- glm-130B-2022.06.06-10:00:39
+
+PS:观察到共享文件系统读取速度显著改善,现在加载 ckpt 几乎只需要 1 分钟
+
+## 2022-06-08 08:00 坏点
+
+- glm-130B-2022.06.08-00:00:37
+
+## 2022-06-09 13:30 训练卡住
+
+23100 开始恢复
+
+- glm-130B-2022.06.09-05:27:54
+
+## 2022-06-12 10:00 loss 爆炸
+
+33700 开始 loss 炸了,loss-scale 在 33710 左右突然下跌然后 loss 在 33740 左右爆炸
+
+- tensorboard 记录:glm-130B-33700
+
+![Image.png](https://res.craft.do/user/full/97ed555f-7125-cca2-fd7d-9f1a0585132e/doc/A8DAC1A6-2A03-489A-8A11-BFAFFFEE3905/C46C7CFE-1B79-491C-90FC-5A88AE90E9DF_2/7ICMyH8v6GhAgngz5bVaDKwzYjFPyk99Ax27R5w56wMz/Image.png)
+
+![Image.png](https://res.craft.do/user/full/97ed555f-7125-cca2-fd7d-9f1a0585132e/doc/A8DAC1A6-2A03-489A-8A11-BFAFFFEE3905/E56BCDE0-C798-429F-81E0-1A07CCB9BC0E_2/Ig2rfKnPmLadg39Jc38UEdK90LDxlAxoH0AxmAygxzAz/Image.png)
+
+- 从 33600 开始加载,shrink embedding gradient 1 → 0.5
+- glm-130B-2022.06.12-02:20:49
+
+## 2022-06-14 03:00 loss 爆炸
+
+35250 loss 又炸了,和 33700 的表现几乎一样,都是完全没有征兆突然爆炸
+
+tensorboard 记录:glm-130B-35250
+
+- 从 35200 开始加载,shrink embedding gradient 0.5 → 0.1
+- glm-130B-2022.06.14-02:28:21
+
+## 2022-06-19 00:10 节点故障
+
+n30085 挂了,从 39600 恢复
+
+- glm-130B-2022.06.18-17:49:53
+
+## 2022-06-20 09:10 loss 爆炸
+
+![Image.png](https://res.craft.do/user/full/97ed555f-7125-cca2-fd7d-9f1a0585132e/doc/A8DAC1A6-2A03-489A-8A11-BFAFFFEE3905/CA344108-3B01-469C-9ABE-C41002F76484_2/oEvBST5MP0I7S4qHmQUeE7DoPCsGFSrveAOOSyitSUwz/Image.png)
+
+![Image.png](https://res.craft.do/user/full/97ed555f-7125-cca2-fd7d-9f1a0585132e/doc/A8DAC1A6-2A03-489A-8A11-BFAFFFEE3905/FED0DE40-A710-4259-AE98-26BCB9568C7A_2/kH4FijsPDVJFzkbaxz7BiX0RZrul1Wrye6cE5EV8ZG0z/Image.png)
+
+- tensorboard 记录:glm-130B-40800
+- `--skip-train-iteration-range 40701-40900`
+- 从 40700 开始重新加载并跳过 40701-40900 数据
+- glm-130B-2022.06.20-03:36:13
+
+## 2022-06-22 10:40 梯度 spike
+
+![Image.png](https://res.craft.do/user/full/97ed555f-7125-cca2-fd7d-9f1a0585132e/doc/A8DAC1A6-2A03-489A-8A11-BFAFFFEE3905/0B7E0A0C-4B11-4F52-BF10-E6B11A533BEF_2/yb1zC07di9zux8jbAi15gpqlstGHXZyjyMBEjO0gNKUz/Image.png)
+
+![Image.png](https://res.craft.do/user/full/97ed555f-7125-cca2-fd7d-9f1a0585132e/doc/A8DAC1A6-2A03-489A-8A11-BFAFFFEE3905/1C60424A-0290-4070-9327-DF9DFD135020_2/XyVoPs1yMLIuzUyrDixSYfgjc2Y2Nuor20GCz0nSPkAz/Image.png)
+
+- grad 有点小 spike,看起来后续恢复了,但 loss 似乎遇到了比较大的波动
+- `--skip-train-iteration-range 40701-40900`
+- 从 42400 开始重新加载并跳过 42401-42600 数据
+- glm-130B-2022.06.22-02:38:20
+
+## 2022-06-22 21:00 梯度 spike
+
+![Image.png](https://res.craft.do/user/full/97ed555f-7125-cca2-fd7d-9f1a0585132e/doc/E406CC41-4180-4108-BCCF-5E727CEB8F09/1D7D801C-3226-4CB0-978C-F19B4DA46721_2/nmg9r87OFrdErZvY9xjiDIHvgPVLv39vy8ZVtGkj2H0z/Image.png)
+
+![Image.png](https://res.craft.do/user/full/97ed555f-7125-cca2-fd7d-9f1a0585132e/doc/E406CC41-4180-4108-BCCF-5E727CEB8F09/5F5CA3D6-AF58-4087-9806-1529D3A2EF6C_2/WSQqyBdv1rvzvNloXE6Ssql7GxMDoULU38FAQCv3778z/Image.png)
+
+- grad 又有 spike,但是 loss-scale 没有一降到底,推测应该可以恢复
+- 这几天的反复 spike,我们分析可能是后期 learning rate 降低太慢,将 min-lr 从 8e-6 调整到 4e-6
+- `--min-lr 4e-6`
+- 从 42700 加载开始训练
+- glm-130B-2022.06.22-13:03:53
+
+## 2022.06.26 16:00 节点故障
+
+- 节点 NVLink Error,重启训练
+- glm-130B-2022.06.26-13:13:51
+
+## 2022.06.29 00:00 恢复 position_id
+
+- 48100 从原先配置开始训练
+- glm-130B-2022.06.29-13:53:21

+ 5 - 0
requirements.txt

@@ -0,0 +1,5 @@
+SwissArmyTransformer>=0.2.11
+icetk
+apex
+scipy
+dataclass_wizard

BIN
resources/03DF31017FE184DB45D41DFFC6F80EF0.png


BIN
resources/33872E48D3539EA132B74BCF5EFF458F.png


BIN
resources/49BF334CB352BAA19F7D55460B1DBCA9.gif


BIN
resources/7CB441707D1035B2890AA2164C5B6EAC.png


BIN
resources/7D6433A42D189E2E6FBC62BE066BCE91.png


BIN
resources/849024E93FA85347F7F6443932911922.png


BIN
resources/AE18F14396E2D22BC0BC8DD77EFD3414.png


BIN
resources/E42321373D22DE198231279B5856BB42.png


BIN
resources/F48B69263360688CCA21E915F4B1A98B.png


+ 70 - 0
resources/multitask_list.txt

@@ -0,0 +1,70 @@
+super_glue/wsc.fixed
+winogrande/winogrande_xl
+super_glue/rte
+glue/mrpc
+glue/qqp
+paws/labeled_final
+ai2_arc/ARC_Challenge
+ai2_arc/ARC_Easy
+kilt_tasks/hotpot_qa
+trivia_qa/unfiltered
+web_questions
+wiki_qa
+adversarial_qa/dbidaf
+adversarial_qa/dbert
+adversarial_qa/droberta
+duorc/SelfRC
+duorc/ParaphraseRC
+ropes
+squad_v2
+super_glue/record
+quoref
+tydiqa
+cos_e/v1.11
+cosmos_qa
+dream
+openbookqa/main
+qasc
+quail
+quarel
+quartz
+race/high
+race/middle
+sciq
+social_i_qa
+super_glue/boolq
+super_glue/multirc
+wiki_hop/original
+wiqa
+piqa
+amazon_polarity
+app_reviews
+imdb
+rotten_tomatoes
+yelp_review_full
+super_glue/copa
+hellaswag
+common_gen
+wiki_bio
+cnndailymail/3.0.0
+gigaword
+multi_news
+samsum
+xsum
+ag_news
+dbpedia_14
+trec
+super_glue/wic
+tacred
+conll04 (joint entity relation extraction)
+nyt29 (joint entity relation extraction)
+ace2005 (joint entity relation extraction)
+ade (joint entity relation extraction)
+conll03 (named entity recognition)
+ontonotes (named entity recognition)
+genia (named entity recognition)
+conll05 (semantic role labeling)
+conll12 (semantic role labeling)
+propbank (semantic role labeling)
+ace05 (event extraction)
+multi_woz_2.1 (dialogue state tracking)

+ 23 - 0
scripts/evaluate.sh

@@ -0,0 +1,23 @@
+#!/bin/bash
+
+script_path=$(realpath $0)
+script_dir=$(dirname $script_path)
+main_dir=$(dirname $script_dir)
+
+source "${main_dir}/configs/model_glm_130b.sh"
+
+DATA_PATH="<your evaluation dataset base directory>"
+
+ARGS="${main_dir}/evaluate.py \
+       --mode inference \
+       --data-path $DATA_PATH \
+       --task $* \
+       $MODEL_ARGS"
+
+TIMESTAMP=$(date +'%Y.%m.%d-%H:%M:%S')
+EXP_NAME=${TIMESTAMP}
+
+mkdir -p logs
+
+run_cmd="torchrun --nproc_per_node $MP_SIZE ${ARGS}"
+eval ${run_cmd} 2>&1 | tee logs/${EXP_NAME}.log

+ 28 - 0
scripts/evaluate_multiple_node.sh

@@ -0,0 +1,28 @@
+#!/bin/bash
+
+NUM_WORKERS=16
+NUM_GPUS_PER_WORKER=8
+HOST_FILE_PATH="<your hostfile>"
+OPTIONS_NCCL="NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2 CUDA_LAUNCH_BLOCKING=0"
+
+script_path=$(realpath $0)
+script_dir=$(dirname $script_path)
+main_dir=$(dirname $script_dir)
+
+source "${main_dir}/configs/model_glm_130b.sh"
+
+DATA_PATH="<your evaluation dataset base directory>"
+
+ARGS="${main_dir}/evaluate.py \
+       --mode inference \
+       --data-path $DATA_PATH \
+       --task $* \
+       $MODEL_ARGS"
+
+TIMESTAMP=$(date +'%Y.%m.%d-%H:%M:%S')
+EXP_NAME=${TIMESTAMP}
+
+mkdir -p logs
+
+run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} ${ARGS}"
+eval ${run_cmd} 2>&1 | tee logs/${EXP_NAME}.log

+ 38 - 0
scripts/generate.sh

@@ -0,0 +1,38 @@
+#!/bin/bash
+
+script_path=$(realpath $0)
+script_dir=$(dirname $script_path)
+main_dir=$(dirname $script_dir)
+
+source "${main_dir}/configs/model_glm_130b.sh"
+
+SEED=1234
+MAX_OUTPUT_LENGTH=256
+MIN_GEN_LENGTH=0
+# BeamSearchStrategy args
+NUM_BEAMS=4
+LENGTH_PENALTY=1.0
+NO_REPEAT_NGRAM=3
+# BaseStrategy args
+TEMP=0.9
+TOPK=1
+TOPP=0
+
+ARGS="${main_dir}/generate.py \
+       --seed $SEED \
+       --mode inference \
+       --sampling-strategy BeamSearchStrategy \
+       --out-seq-length $MAX_OUTPUT_LENGTH \
+       --min-gen-length $MIN_GEN_LENGTH \
+       --num-beams $NUM_BEAMS \
+       --length-penalty $LENGTH_PENALTY \
+       --no-repeat-ngram-size $NO_REPEAT_NGRAM \
+       --temperature $TEMP \
+       --top_k $TOPK \
+       --top_p $TOPP \
+       --output-path samples \
+       $MODEL_ARGS \
+       $*"
+
+run_cmd="torchrun --nproc_per_node $MP_SIZE ${ARGS}"
+eval ${run_cmd}

+ 6 - 0
tasks/bloom/glue_cola.yaml

@@ -0,0 +1,6 @@
+name: 'glue_cola'
+type: 'mul'
+path: 'bloom/glue_cola'
+file-pattern:
+  validation: "**/validation.jsonl"
+micro-batch-size: 30

+ 7 - 0
tasks/bloom/glue_mnli.yaml

@@ -0,0 +1,7 @@
+name: 'glue_mnli'
+type: 'mul'
+path: 'bloom/glue_mnli'
+file-pattern:
+  validation-matched: "**/validation_matched.jsonl"
+  validation-mismatched: "**/validation_mismatched.jsonl"
+micro_batch_size: 8

+ 6 - 0
tasks/bloom/glue_qnli.yaml

@@ -0,0 +1,6 @@
+name: 'glue_qnli'
+type: 'mul'
+path: 'bloom/glue_qnli'
+file-pattern:
+  validation: "**/validation.jsonl"
+micro_batch_size: 6

+ 6 - 0
tasks/bloom/glue_wnli.yaml

@@ -0,0 +1,6 @@
+name: 'glue_wnli'
+type: 'mul'
+path: 'bloom/glue_wnli'
+file-pattern:
+  validation: "**/validation.jsonl"
+micro_batch_size: 16

+ 7 - 0
tasks/bloom/math_qa.yaml

@@ -0,0 +1,7 @@
+name: 'math_qa'
+type: 'mul'
+path: 'bloom/math_qa'
+file-pattern:
+  validation: "**/validation.jsonl"
+  test: "**/test.jsonl"
+micro_batch_size: 6

+ 6 - 0
tasks/bloom/mc_taco.yaml

@@ -0,0 +1,6 @@
+name: 'mc_taco'
+type: 'gen'
+path: 'bloom/mc_taco'
+file-pattern:
+  validation: "**/validation_pp.jsonl"
+  test: "**/test_pp.jsonl"

+ 7 - 0
tasks/bloom/openbook_qa.yaml

@@ -0,0 +1,7 @@
+name: 'openbook_qa'
+type: 'mul'
+path: 'bloom/openbookqa_main'
+file-pattern:
+  test: "**/test.jsonl"
+  validation: "**/validation.jsonl"
+micro_batch_size: 18

+ 6 - 0
tasks/bloom/pubmed_qa.yaml

@@ -0,0 +1,6 @@
+name: 'pubmed_qa'
+type: 'mul'
+path: 'bloom/pubmed_qa_pqa_labeled'
+file-pattern:
+  train: "**/train.jsonl"
+micro_batch_size: 2

+ 6 - 0
tasks/bloom/superglue_axb.yaml

@@ -0,0 +1,6 @@
+name: 'superglue_axb'
+type: 'mul'
+path: 'bloom/super_glue_axb'
+file-pattern:
+  test: "**/test.jsonl"
+micro_batch_size: 16

+ 6 - 0
tasks/bloom/superglue_axg.yaml

@@ -0,0 +1,6 @@
+name: 'superglue_axg'
+type: 'mul'
+path: 'bloom/super_glue_axg'
+file-pattern:
+  test: "**/test.jsonl"
+micro_batch_size: 34

+ 4 - 0
tasks/chinese/clue/afqmc.yaml

@@ -0,0 +1,4 @@
+name: 'AFQMC'
+type: 'mul'
+path: 'CLUE/afqmc'
+micro_batch_size: 16

+ 4 - 0
tasks/chinese/clue/c3.yaml

@@ -0,0 +1,4 @@
+name: 'C3'
+type: 'mul'
+path: 'CLUE/c3'
+micro_batch_size: 2

+ 4 - 0
tasks/chinese/clue/cluewsc.yaml

@@ -0,0 +1,4 @@
+name: 'CLUEWSC2020'
+type: 'mul'
+path: 'CLUE/cluewsc'
+micro_batch_size: 18

+ 4 - 0
tasks/chinese/clue/cmnli.yaml

@@ -0,0 +1,4 @@
+name: 'CMNLI'
+type: 'mul'
+path: 'CLUE/cmnli'
+micro_batch_size: 16

+ 3 - 0
tasks/chinese/clue/cmrc2018.yaml

@@ -0,0 +1,3 @@
+name: "CMRC2018"
+type: "gen"
+path: "CLUE/cmrc2018"

+ 4 - 0
tasks/chinese/clue/csl.yaml

@@ -0,0 +1,4 @@
+name: 'CSL'
+type: 'mul'
+path: 'CLUE/csl'
+micro_batch_size: 3

+ 3 - 0
tasks/chinese/clue/drcd.yaml

@@ -0,0 +1,3 @@
+name: "DRCD"
+type: "gen"
+path: "CLUE/drcd"

+ 4 - 0
tasks/chinese/clue/ocnli.yaml

@@ -0,0 +1,4 @@
+name: 'OCNLI_50K'
+type: 'mul'
+path: 'CLUE/ocnli'
+micro_batch_size: 24

+ 7 - 0
tasks/chinese/fewclue/bustm.yaml

@@ -0,0 +1,7 @@
+name: 'BUSTM'
+type: 'mul'
+path: 'CLUE/bustm'
+file-pattern:
+  dev: "**/dev_few_all.jsonl"
+  test: "**/test_public.jsonl"
+micro_batch_size: 56

+ 7 - 0
tasks/chinese/fewclue/chidf.yaml

@@ -0,0 +1,7 @@
+name: 'CHIDF'
+type: 'mul'
+path: 'CLUE/chid-fc'
+file-pattern:
+  dev: "**/dev_few_all.jsonl"
+  test: "**/test_public.jsonl"
+micro_batch_size: 16

+ 7 - 0
tasks/chinese/fewclue/cluewscf.yaml

@@ -0,0 +1,7 @@
+name: 'CLUEWSCF'
+type: 'mul'
+path: 'CLUE/cluewsc-fc'
+file-pattern:
+  dev: "**/dev_few_all.jsonl"
+  test: "**/test_public.jsonl"
+micro_batch_size: 16

+ 7 - 0
tasks/chinese/fewclue/cslf.yaml

@@ -0,0 +1,7 @@
+name: 'CSLF'
+type: 'mul'
+path: 'CLUE/csl-fc'
+file-pattern:
+  dev: "**/dev_few_all.jsonl"
+  test: "**/test_public.jsonl"
+micro_batch_size: 2

+ 7 - 0
tasks/chinese/fewclue/eprstmt.yaml

@@ -0,0 +1,7 @@
+name: 'EPRSTMT'
+type: 'mul'
+path: 'CLUE/eprstmt-fc'
+file-pattern:
+  dev: "**/dev_few_all.jsonl"
+  test: "**/test_public.jsonl"
+micro_batch_size: 6

+ 7 - 0
tasks/chinese/fewclue/ocnlif.yaml

@@ -0,0 +1,7 @@
+name: 'OCNLIF'
+type: 'mul'
+path: 'CLUE/ocnli-fc'
+file-pattern:
+  dev: "**/dev_few_all.jsonl"
+  test: "**/test_public.jsonl"
+micro_batch_size: 24

+ 13 - 0
tasks/lambada/lambada-unidirectional.yaml

@@ -0,0 +1,13 @@
+name: "LAMBADA-unidirectional"
+type: "gen"
+module: "tasks.lambada.task.LAMBADA"
+path: "lambada/lambada"
+file-pattern:
+  test: "**/test.jsonl"
+  validation: "**/validation.jsonl"
+
+sampling_strategy: "BeamSearchStrategy"
+num_beams: 16
+max_gen_length: 5
+use_task_mask: true
+unidirectional: true

+ 12 - 0
tasks/lambada/lambada.yaml

@@ -0,0 +1,12 @@
+name: "LAMBADA"
+type: "gen"
+module: "tasks.lambada.task.LAMBADA"
+path: "lambada/lambada"
+file-pattern:
+  test: "**/test.jsonl"
+  validation: "**/validation.jsonl"
+
+sampling_strategy: "BeamSearchStrategy"
+num_beams: 16
+max_gen_length: 5
+use_task_mask: true

+ 20 - 0
tasks/lambada/strategy.py

@@ -0,0 +1,20 @@
+from generation import BeamSearchStrategy
+
+
+class BeamSearchStrategyForLAMBADA(BeamSearchStrategy):
+    def __init__(self, *args, banned_prefix=[], **kwargs):
+        super().__init__(*args, **kwargs)
+        self.banned_prefix = banned_prefix
+
+    def forward(self, logits, tokens, mems):
+        batch_size, vocab_size = logits.shape
+        logits = logits.float()
+        for prefix in self.banned_prefix:
+            if self.length_generated == len(prefix) - 1:
+                if len(prefix) == 1:
+                    logits[..., prefix[0]] = -65504
+                else:
+                    for i in range(batch_size):
+                        if tokens[i, -(len(prefix) - 1) :].tolist() == prefix[:-1]:
+                            logits[i, prefix[-1]] = -65504
+        return super().forward(logits, tokens, mems)

+ 54 - 0
tasks/lambada/task.py

@@ -0,0 +1,54 @@
+from string import punctuation
+from functools import partial
+from typing import List
+
+from evaluation import qa_evaluate, GenerationTask
+
+from .strategy import BeamSearchStrategyForLAMBADA
+
+
+def exact_match_score(prediction, ground_truth):
+    return prediction.strip() == ground_truth.strip()
+
+
+class LAMBADA(GenerationTask):
+    @property
+    def metrics(self):
+        return {"Accuracy": partial(qa_evaluate, metric=exact_match_score)}
+
+    def __init__(self, model, tokenizer, config_path):
+        super(LAMBADA, self).__init__(model, tokenizer, config_path)
+
+        if self.config.sampling_strategy == "BeamSearchStrategy":
+            banned_prefix = [[46010], [146337]]  # "'" and "``"
+            invalid_slices = [20068, 146010, 146337]
+            for p in punctuation:
+                pp = tokenizer.tokenize(p)
+                if len(pp) == 1:
+                    invalid_slices.append(pp[0])
+                banned_prefix.append(pp)
+            self.strategy = BeamSearchStrategyForLAMBADA(
+                self.config.num_beams,
+                length_penalty=self.config.length_penalty,
+                consider_end=True,
+                end_tokens=self.strategy.end_tokens,
+                invalid_slices=invalid_slices,
+                banned_prefix=banned_prefix,
+                no_repeat_ngram_size=self.config.no_repeat_ngram_size,
+                min_gen_length=self.config.min_gen_length,
+                deterministic=True,
+            )
+
+    def get_first_word_tokens(self, tokens):
+        text = self.tokenizer.tokenizer.decode(tokens).strip()
+        return self.tokenizer.tokenize(text.split(" ")[0])
+
+    def predict_single_batch(self, batch):
+        # micro batch size = 1 here, but we still need to return a list of predictions for consistency
+        outputs: List[List[int]] = self.model.generate_text(batch, self.strategy, return_all_beams=True)
+        for output in outputs:
+            text = self.tokenizer.tokenizer.decode(output).strip()
+            spl = text.split(" ")
+            if len(spl) >= 2 and spl[1] in punctuation:
+                return [self.get_first_word_tokens(output)]
+        return [self.get_first_word_tokens(outputs[0])]

+ 10 - 0
tasks/mmlu/mmlu.yaml

@@ -0,0 +1,10 @@
+name: "MMLU"
+type: "mul"
+module: "tasks.mmlu.task.MMLU"
+path: "MMLU"
+file-pattern:
+  stem: "stem/*.json"
+  social_sciences: "social_sciences/*.json"
+  humanities: "humanities/*.json"
+  other: "other/*.json"
+micro-batch-size: 1

+ 78 - 0
tasks/mmlu/task.py

@@ -0,0 +1,78 @@
+import numpy as np
+
+from typing import Dict, Tuple
+
+from evaluation import MultiChoiceTask
+
+categories = {
+    "STEM": [
+        "Abstract Algebra",
+        "Anatomy",
+        "Astronomy",
+        "College Biology",
+        "College Chemistry",
+        "College Computer Science",
+        "College Mathematics",
+        "College Physics",
+        "Computer Security",
+        "Conceptual Physics",
+        "Electrical Engineering",
+        "Elementary Mathematics",
+        "High School Biology",
+        "High School Chemistry",
+        "High School Computer Science",
+        "High School Mathematics",
+        "High School Physics",
+        "High School Statistics",
+        "Machine Learning",
+    ],
+    "Other": [
+        "Business Ethics",
+        "Clinical Knowledge",
+        "College Medicine",
+        "Global Facts",
+        "Human Aging",
+        "Management",
+        "Marketing",
+        "Medical Genetics",
+        "Miscellaneous",
+        "Nutrition",
+        "Professional Accounting",
+        "Professional Medicine",
+        "Virology",
+    ],
+    "Social Sciences": [
+        "Econometrics",
+        "High School Geography",
+        "High School Government and Politics",
+        "High School Macroeconomics",
+        "High School Microeconomics",
+        "High School Psychology",
+        "Human Sexuality",
+        "Professional Psychology",
+        "Public Relations",
+        "Security Studies",
+        "Sociology",
+        "US Foreign Policy",
+    ],
+    "Humanities": [
+        "Formal Logic",
+        "High School European History",
+        "High School US History",
+        "High School World History",
+        "International Law",
+        "Jurisprudence",
+        "Logical Fallacies",
+        "Moral Disputes",
+        "Moral Scenarios",
+        "Philosophy",
+        "Prehistory",
+        "Professional Law",
+        "World Religions",
+    ],
+}
+
+
+class MMLU(MultiChoiceTask):
+    def report_overall_metrics(self, result_dict_all: Dict[str, Tuple[Dict[str, float], int]]):
+        self.report_group_metrics("Overall", result_dict_all, level=0)

Niektóre pliki nie zostały wyświetlone z powodu dużej ilości zmienionych plików