Ver Fonte

Reformat using black

Sengxian há 2 anos atrás
pai
commit
c713dae241

+ 6 - 1
evaluation/dataset.py

@@ -228,7 +228,12 @@ class MultiChoiceTaskDataset(EvaluationDataset):
 
 
     @staticmethod
     @staticmethod
     def build_multiple_choice_sample(
     def build_multiple_choice_sample(
-        text, choices, is_single_token, unified_multitask_encoding=False, unidirectional=False, use_task_mask=False
+        text,
+        choices,
+        is_single_token,
+        unified_multitask_encoding=False,
+        unidirectional=False,
+        use_task_mask=False,
     ):
     ):
         tokenizer = get_tokenizer()
         tokenizer = get_tokenizer()
 
 

+ 1 - 2
evaluation/metrics.py

@@ -3,10 +3,9 @@ import math
 import string
 import string
 import functools
 import functools
 
 
-import torch
 import numpy as np
 import numpy as np
 
 
-from typing import Tuple, List
+from typing import List
 from collections import Counter
 from collections import Counter
 from collections import defaultdict
 from collections import defaultdict
 from SwissArmyTransformer import get_tokenizer
 from SwissArmyTransformer import get_tokenizer

+ 44 - 38
evaluation/model.py

@@ -7,22 +7,22 @@ from SwissArmyTransformer.mpu import vocab_parallel_cross_entropy
 
 
 
 
 def batch_filling_sequence(
 def batch_filling_sequence(
-        model,
-        seqs,
-        context_lengths,
-        strategy,
-        max_memory_length=100000,
-        get_masks_and_position_ids=get_masks_and_position_ids_default,
-        mems=None,
-        **kw_args
-        ):
-    '''
-        seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
-        mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
-            cache, should be first mems.shape[1] parts of context_tokens.
-            mems are the first-level citizens here, but we don't assume what is memorized.
-            input mems are used when multi-phase generation.
-    '''
+    model,
+    seqs,
+    context_lengths,
+    strategy,
+    max_memory_length=100000,
+    get_masks_and_position_ids=get_masks_and_position_ids_default,
+    mems=None,
+    **kw_args
+):
+    """
+    seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
+    mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
+        cache, should be first mems.shape[1] parts of context_tokens.
+        mems are the first-level citizens here, but we don't assume what is memorized.
+        input mems are used when multi-phase generation.
+    """
     assert len(seqs.shape) == 2
     assert len(seqs.shape) == 2
 
 
     # building the initial tokens, attention_mask, and position_ids
     # building the initial tokens, attention_mask, and position_ids
@@ -30,10 +30,10 @@ def batch_filling_sequence(
     seqs, attention_mask, position_ids = get_masks_and_position_ids(seqs)
     seqs, attention_mask, position_ids = get_masks_and_position_ids(seqs)
     tokens = seqs[..., :context_length]
     tokens = seqs[..., :context_length]
     if attention_mask.dtype != torch.bool:
     if attention_mask.dtype != torch.bool:
-        attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16
+        attention_mask = attention_mask.type_as(next(model.parameters()))  # if fp16
     # initialize generation
     # initialize generation
-    counter = context_length - 1 # Last fixed index is ``counter''
-    index = 0 if mems is None else mems.shape[2] # Next forward starting index, also the length of cache.
+    counter = context_length - 1  # Last fixed index is ``counter''
+    index = 0 if mems is None else mems.shape[2]  # Next forward starting index, also the length of cache.
     num_beams = 1
     num_beams = 1
     # step-by-step generation
     # step-by-step generation
     while counter < seqs.shape[1] - 1:
     while counter < seqs.shape[1] - 1:
@@ -41,15 +41,19 @@ def batch_filling_sequence(
         # token[:, index: counter+1] needs forwarding.
         # token[:, index: counter+1] needs forwarding.
         # forward
         # forward
         tokens = tokens.reshape(batch_size * num_beams, -1)
         tokens = tokens.reshape(batch_size * num_beams, -1)
-        mems = mems.reshape(mems.shape[0], batch_size * num_beams, mems.shape[-2], mems.shape[-1]) if mems is not None else None
+        mems = (
+            mems.reshape(mems.shape[0], batch_size * num_beams, mems.shape[-2], mems.shape[-1])
+            if mems is not None
+            else None
+        )
         logits, *output_per_layers = model(
         logits, *output_per_layers = model(
             tokens[:, index:],
             tokens[:, index:],
-            position_ids[..., index: counter+1],
-            attention_mask[..., index: counter+1, :counter+1], # TODO memlen
+            position_ids[..., index : counter + 1],
+            attention_mask[..., index : counter + 1, : counter + 1],  # TODO memlen
             mems=mems,
             mems=mems,
             **kw_args
             **kw_args
         )
         )
-        mem_kv = [o['mem_kv'] for o in output_per_layers]
+        mem_kv = [o["mem_kv"] for o in output_per_layers]
         mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
         mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
         if counter == context_length - 1:
         if counter == context_length - 1:
             logits = logits[torch.arange(batch_size), context_lengths - 1]
             logits = logits[torch.arange(batch_size), context_lengths - 1]
@@ -66,10 +70,15 @@ def batch_filling_sequence(
         tokens, mems = strategy.forward(logits, tokens, mems)
         tokens, mems = strategy.forward(logits, tokens, mems)
         if len(tokens.shape) == 3 and num_beams == 1:
         if len(tokens.shape) == 3 and num_beams == 1:
             num_beams = tokens.shape[1]
             num_beams = tokens.shape[1]
-            position_ids = position_ids.unsqueeze(1).expand(batch_size, num_beams, -1).reshape(batch_size * num_beams, -1)
+            position_ids = (
+                position_ids.unsqueeze(1).expand(batch_size, num_beams, -1).reshape(batch_size * num_beams, -1)
+            )
             attention_mask_shape = attention_mask.shape[-3:]
             attention_mask_shape = attention_mask.shape[-3:]
-            attention_mask = attention_mask.unsqueeze(1).expand(batch_size, num_beams, -1, -1, -1).reshape(
-                batch_size * num_beams, *attention_mask_shape)
+            attention_mask = (
+                attention_mask.unsqueeze(1)
+                .expand(batch_size, num_beams, -1, -1, -1)
+                .reshape(batch_size * num_beams, *attention_mask_shape)
+            )
         if strategy.is_done:
         if strategy.is_done:
             break
             break
     return strategy.finalize(tokens, mems)
     return strategy.finalize(tokens, mems)
@@ -118,8 +127,7 @@ class ModelForEvaluation(torch.nn.Module):
                 log_probs.append(log_probs_single)
                 log_probs.append(log_probs_single)
         return log_probs
         return log_probs
 
 
-    def generate_text(self, sample, strategy, return_all_beams=False) -> Union[
-        List[List[int]], List[List[List[int]]]]:
+    def generate_text(self, sample, strategy, return_all_beams=False) -> Union[List[List[int]], List[List[List[int]]]]:
         """
         """
         @return: A list of text model generated, sorted by score in descending order
         @return: A list of text model generated, sorted by score in descending order
         """
         """
@@ -129,18 +137,17 @@ class ModelForEvaluation(torch.nn.Module):
 
 
         def get_masks_and_position_ids(seq):
         def get_masks_and_position_ids(seq):
             batch_size = seq.shape[0]
             batch_size = seq.shape[0]
-            max_gen_length = sample['target_position_ids'].shape[-1]
-            tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode='constant', value=-1)
-            position_ids = torch.cat((sample['position_ids'], sample['target_position_ids']), dim=-1)
+            max_gen_length = sample["target_position_ids"].shape[-1]
+            tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode="constant", value=-1)
+            position_ids = torch.cat((sample["position_ids"], sample["target_position_ids"]), dim=-1)
             position_ids = position_ids.to(device=self.device).long()
             position_ids = position_ids.to(device=self.device).long()
             attention_mask = sample["attention_mask"].to(device=self.device)
             attention_mask = sample["attention_mask"].to(device=self.device)
-            context_mask = attention_mask[torch.arange(batch_size), context_lengths - 1].unsqueeze(1).repeat(1,
-                                                                                                             max_gen_length,
-                                                                                                             1)
+            context_mask = (
+                attention_mask[torch.arange(batch_size), context_lengths - 1].unsqueeze(1).repeat(1, max_gen_length, 1)
+            )
             causal_mask = torch.tril(context_mask.new_ones((batch_size, max_gen_length, max_gen_length))) < 0.5
             causal_mask = torch.tril(context_mask.new_ones((batch_size, max_gen_length, max_gen_length))) < 0.5
-            generation_mask = torch.cat(
-                (context_mask, causal_mask), dim=-1)
-            attention_mask = torch.nn.functional.pad(attention_mask, (0, max_gen_length), mode='constant', value=1)
+            generation_mask = torch.cat((context_mask, causal_mask), dim=-1)
+            attention_mask = torch.nn.functional.pad(attention_mask, (0, max_gen_length), mode="constant", value=1)
             attention_mask = torch.cat((attention_mask, generation_mask), dim=1)
             attention_mask = torch.cat((attention_mask, generation_mask), dim=1)
             attention_mask = attention_mask.bool().unsqueeze(1)
             attention_mask = attention_mask.bool().unsqueeze(1)
             return tokens, attention_mask, position_ids
             return tokens, attention_mask, position_ids
@@ -177,7 +184,6 @@ class ModelForEvaluation(torch.nn.Module):
                 output_targets.append(output_target)
                 output_targets.append(output_target)
         return output_targets
         return output_targets
 
 
-
     def calculate_loss(self, batch) -> List[float]:
     def calculate_loss(self, batch) -> List[float]:
         tokens, position_ids, attention_mask = self.process_data(batch, self.device)
         tokens, position_ids, attention_mask = self.process_data(batch, self.device)
         targets, loss_masks = (
         targets, loss_masks = (

+ 12 - 4
evaluation/tasks.py

@@ -1,5 +1,6 @@
-import torch
+import os
 import time
 import time
+import torch
 import numpy as np
 import numpy as np
 import torch.distributed as dist
 import torch.distributed as dist
 
 
@@ -88,7 +89,6 @@ class BaseTask(ABC):
                     for _, batch in enumerate(dataloader):
                     for _, batch in enumerate(dataloader):
                         prediction.append(self.predict_single_batch(batch))
                         prediction.append(self.predict_single_batch(batch))
 
 
-
                 prediction = gather_result(prediction, len(dataset), self.config.micro_batch_size)
                 prediction = gather_result(prediction, len(dataset), self.config.micro_batch_size)
                 result_dict = {key: metric(prediction, dataset.data) for key, metric in self.metrics.items()}
                 result_dict = {key: metric(prediction, dataset.data) for key, metric in self.metrics.items()}
                 result_dict_group[file] = (result_dict, len(dataset))
                 result_dict_group[file] = (result_dict, len(dataset))
@@ -172,6 +172,13 @@ class GenerationTask(BaseTask, ABC):
     def build_dataset(self, relative_path):
     def build_dataset(self, relative_path):
         return GenerationTaskDataset(join(self.config.path, relative_path), self.config)
         return GenerationTaskDataset(join(self.config.path, relative_path), self.config)
 
 
+    def save_prediction_to_file(self, file, prediction, data):
+        filename = os.path.join("outputs", self.config.name, f"{file}.predict")
+        os.makedirs(os.path.dirname(filename), exist_ok=True)
+        with open(filename, "w") as file:
+            for item in prediction:
+                file.write(self.tokenizer.detokenize(item) + "\n")
+
     def __init__(self, model: ModelForEvaluation, tokenizer: _IceTokenizer, config: GenerationTaskConfig):
     def __init__(self, model: ModelForEvaluation, tokenizer: _IceTokenizer, config: GenerationTaskConfig):
         super(GenerationTask, self).__init__(model, tokenizer, config)
         super(GenerationTask, self).__init__(model, tokenizer, config)
 
 
@@ -181,8 +188,9 @@ class GenerationTask(BaseTask, ABC):
                 end_tokens.append(self.tokenizer.tokenize(token)[-1])
                 end_tokens.append(self.tokenizer.tokenize(token)[-1])
             print_rank_0(f"End tokens {end_tokens}")
             print_rank_0(f"End tokens {end_tokens}")
         if self.config.sampling_strategy == "BaseStrategy":
         if self.config.sampling_strategy == "BaseStrategy":
-            self.strategy = BaseStrategy(batch_size=self.config.micro_batch_size, temperature=1.0, top_k=1,
-                                         end_tokens=end_tokens)
+            self.strategy = BaseStrategy(
+                batch_size=self.config.micro_batch_size, temperature=1.0, top_k=1, end_tokens=end_tokens
+            )
         elif self.config.sampling_strategy == "BeamSearchStrategy":
         elif self.config.sampling_strategy == "BeamSearchStrategy":
             self.strategy = BeamSearchStrategy(
             self.strategy = BeamSearchStrategy(
                 self.config.micro_batch_size,
                 self.config.micro_batch_size,

+ 5 - 2
generation/strategies.py

@@ -3,8 +3,9 @@ import torch
 import torch.nn.functional as F
 import torch.nn.functional as F
 from SwissArmyTransformer.generation.sampling_strategies.base_strategy import top_k_logits
 from SwissArmyTransformer.generation.sampling_strategies.base_strategy import top_k_logits
 
 
+
 class BaseStrategy:
 class BaseStrategy:
-    def __init__(self, batch_size, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None):
+    def __init__(self, batch_size, invalid_slices=[], temperature=1.0, top_k=200, eps=1e-4, top_p=0.0, end_tokens=None):
         self.batch_size = batch_size
         self.batch_size = batch_size
         self.invalid_slices = invalid_slices
         self.invalid_slices = invalid_slices
         self.temperature = temperature
         self.temperature = temperature
@@ -153,7 +154,9 @@ class BeamSearchStrategy:
                     if self.ngram > 0:
                     if self.ngram > 0:
                         bans = self.cached_beam_ngram_bans[batch_idx][next_indices[batch_idx, i]].copy()
                         bans = self.cached_beam_ngram_bans[batch_idx][next_indices[batch_idx, i]].copy()
                         # TODO ngram=1
                         # TODO ngram=1
-                        ngram_prefix = tuple(tokens[batch_idx, next_indices[batch_idx, i], -(self.ngram - 1):].tolist())
+                        ngram_prefix = tuple(
+                            tokens[batch_idx, next_indices[batch_idx, i], -(self.ngram - 1) :].tolist()
+                        )
                         bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (next_tokens[batch_idx, i],)
                         bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (next_tokens[batch_idx, i],)
                         bans_continue.append(bans)
                         bans_continue.append(bans)
                 else:
                 else:

+ 129 - 42
server.py

@@ -161,10 +161,9 @@ def fill_blanks(raw_text: str, model, tokenizer, strategy) -> Tuple[List[str], L
     return answers, answers_with_style, blanks
     return answers, answers_with_style, blanks
 
 
 
 
-
 def generate_continually(func, raw_text):
 def generate_continually(func, raw_text):
     if not raw_text:
     if not raw_text:
-        return 'Input should not be empty!'
+        return "Input should not be empty!"
     try:
     try:
         start_time = time.time()
         start_time = time.time()
         answer = func(raw_text)
         answer = func(raw_text)
@@ -173,10 +172,12 @@ def generate_continually(func, raw_text):
         return answer
         return answer
     except (ValueError, FileNotFoundError) as e:
     except (ValueError, FileNotFoundError) as e:
         print(e)
         print(e)
-        return 'Error!'
+        return "Error!"
+
 
 
 strategy = None
 strategy = None
 
 
+
 def main(args):
 def main(args):
     model, tokenizer = initialize_model_and_tokenizer(args)
     model, tokenizer = initialize_model_and_tokenizer(args)
 
 
@@ -195,16 +196,55 @@ def main(args):
 
 
         return answers[0]
         return answers[0]
 
 
-    
-    def predict(text, seed=1234, out_seq_length=200, min_gen_length=20, sampling_strategy='BaseStrategy', 
-    num_beams=4, length_penalty=0.9, no_repeat_ngram_size=3, 
-    temperature=1, topk=1, topp=1):
+    def predict(
+        text,
+        seed=1234,
+        out_seq_length=200,
+        min_gen_length=20,
+        sampling_strategy="BaseStrategy",
+        num_beams=4,
+        length_penalty=0.9,
+        no_repeat_ngram_size=3,
+        temperature=1,
+        topk=1,
+        topp=1,
+    ):
 
 
         global strategy
         global strategy
 
 
         if torch.distributed.get_rank() == 0:
         if torch.distributed.get_rank() == 0:
-            print('info', [text, seed, out_seq_length, min_gen_length, sampling_strategy, num_beams, length_penalty, no_repeat_ngram_size, temperature, topk, topp])
-            dist.broadcast_object_list([text, seed, out_seq_length, min_gen_length, sampling_strategy, num_beams, length_penalty, no_repeat_ngram_size, temperature, topk, topp], src=0)
+            print(
+                "info",
+                [
+                    text,
+                    seed,
+                    out_seq_length,
+                    min_gen_length,
+                    sampling_strategy,
+                    num_beams,
+                    length_penalty,
+                    no_repeat_ngram_size,
+                    temperature,
+                    topk,
+                    topp,
+                ],
+            )
+            dist.broadcast_object_list(
+                [
+                    text,
+                    seed,
+                    out_seq_length,
+                    min_gen_length,
+                    sampling_strategy,
+                    num_beams,
+                    length_penalty,
+                    no_repeat_ngram_size,
+                    temperature,
+                    topk,
+                    topp,
+                ],
+                src=0,
+            )
 
 
         args.seed = seed
         args.seed = seed
         args.out_seq_length = out_seq_length
         args.out_seq_length = out_seq_length
@@ -237,11 +277,11 @@ def main(args):
         return generate_continually(process, text)
         return generate_continually(process, text)
 
 
     if torch.distributed.get_rank() == 0:
     if torch.distributed.get_rank() == 0:
-        en_fil = ['The Starry Night is an oil-on-canvas painting by [MASK] in June 1889.']
-        en_gen = ['Eight planets in solar system are [gMASK]']
-        ch_fil = ['凯旋门位于意大利米兰市古城堡旁。1807年为纪念[MASK]而建,门高25米,顶上矗立两武士青铜古兵车铸像。']
-        ch_gen = ['三亚位于海南岛的最南端,是中国最南部的热带滨海旅游城市 [gMASK]']
-        en_to_ch = ['Pencil in Chinese is [MASK].']
+        en_fil = ["The Starry Night is an oil-on-canvas painting by [MASK] in June 1889."]
+        en_gen = ["Eight planets in solar system are [gMASK]"]
+        ch_fil = ["凯旋门位于意大利米兰市古城堡旁。1807年为纪念[MASK]而建,门高25米,顶上矗立两武士青铜古兵车铸像。"]
+        ch_gen = ["三亚位于海南岛的最南端,是中国最南部的热带滨海旅游城市 [gMASK]"]
+        en_to_ch = ["Pencil in Chinese is [MASK]."]
         ch_to_en = ['"我思故我在"的英文是"[MASK]"。']
         ch_to_en = ['"我思故我在"的英文是"[MASK]"。']
 
 
         examples = [en_fil, en_gen, ch_fil, ch_gen, en_to_ch, ch_to_en]
         examples = [en_fil, en_gen, ch_fil, ch_gen, en_to_ch, ch_to_en]
@@ -253,28 +293,36 @@ def main(args):
                 GLM-130B uses two different mask tokens: `[MASK]` for short blank filling and `[gMASK]` for left-to-right long text generation. When the input does not contain any MASK token, `[gMASK]` will be automatically appended to the end of the text. We recommend that you use `[MASK]` to try text fill-in-the-blank to reduce wait time (ideally within seconds without queuing).
                 GLM-130B uses two different mask tokens: `[MASK]` for short blank filling and `[gMASK]` for left-to-right long text generation. When the input does not contain any MASK token, `[gMASK]` will be automatically appended to the end of the text. We recommend that you use `[MASK]` to try text fill-in-the-blank to reduce wait time (ideally within seconds without queuing).
                 
                 
                 Note: We suspect that there is a bug in the current FasterTransformer INT4 implementation that leads to gaps in generations compared to the FP16 model (e.g. more repititions), which we are troubleshooting, and the current model output is **for reference only**
                 Note: We suspect that there is a bug in the current FasterTransformer INT4 implementation that leads to gaps in generations compared to the FP16 model (e.g. more repititions), which we are troubleshooting, and the current model output is **for reference only**
-                """)
+                """
+            )
 
 
             with gr.Row():
             with gr.Row():
                 with gr.Column():
                 with gr.Column():
-                    model_input = gr.Textbox(lines=7, placeholder='Input something in English or Chinese', label='Input')
+                    model_input = gr.Textbox(
+                        lines=7, placeholder="Input something in English or Chinese", label="Input"
+                    )
                     with gr.Row():
                     with gr.Row():
                         gen = gr.Button("Generate")
                         gen = gr.Button("Generate")
                         clr = gr.Button("Clear")
                         clr = gr.Button("Clear")
-                       
-                outputs = gr.Textbox(lines=7, label='Output')
-                    
+
+                outputs = gr.Textbox(lines=7, label="Output")
+
             gr.Markdown(
             gr.Markdown(
                 """
                 """
                 Generation Parameter
                 Generation Parameter
-                """)
+                """
+            )
             with gr.Row():
             with gr.Row():
                 with gr.Column():
                 with gr.Column():
-                    seed = gr.Slider(maximum=100000, value=1234, step=1, label='Seed')
-                    out_seq_length = gr.Slider(maximum=512, value=128, minimum=32, step=1, label='Output Sequence Length')
+                    seed = gr.Slider(maximum=100000, value=1234, step=1, label="Seed")
+                    out_seq_length = gr.Slider(
+                        maximum=512, value=128, minimum=32, step=1, label="Output Sequence Length"
+                    )
                 with gr.Column():
                 with gr.Column():
-                    min_gen_length = gr.Slider(maximum=64, value=0, step=1, label='Min Generate Length')
-                    sampling_strategy = gr.Radio(choices=['BeamSearchStrategy', 'BaseStrategy'], value='BaseStrategy', label='Search Strategy')
+                    min_gen_length = gr.Slider(maximum=64, value=0, step=1, label="Min Generate Length")
+                    sampling_strategy = gr.Radio(
+                        choices=["BeamSearchStrategy", "BaseStrategy"], value="BaseStrategy", label="Search Strategy"
+                    )
 
 
             with gr.Row():
             with gr.Row():
                 with gr.Column():
                 with gr.Column():
@@ -282,32 +330,49 @@ def main(args):
                     gr.Markdown(
                     gr.Markdown(
                         """
                         """
                         BeamSearchStrategy
                         BeamSearchStrategy
-                        """)
-                    num_beams = gr.Slider(maximum=4, value=2, minimum=1, step=1, label='Number of Beams')
-                    length_penalty = gr.Slider(maximum=1, value=1, minimum=0, label='Length Penalty')
-                    no_repeat_ngram_size = gr.Slider(maximum=5, value=3, minimum=1, step=1, label='No Repeat Ngram Size')
+                        """
+                    )
+                    num_beams = gr.Slider(maximum=4, value=2, minimum=1, step=1, label="Number of Beams")
+                    length_penalty = gr.Slider(maximum=1, value=1, minimum=0, label="Length Penalty")
+                    no_repeat_ngram_size = gr.Slider(
+                        maximum=5, value=3, minimum=1, step=1, label="No Repeat Ngram Size"
+                    )
                 with gr.Column():
                 with gr.Column():
                     # base search
                     # base search
                     gr.Markdown(
                     gr.Markdown(
                         """
                         """
                         BaseStrategy
                         BaseStrategy
-                        """)
-                    temperature = gr.Slider(maximum=1, value=0.7, minimum=0, label='Temperature')
-                    topk = gr.Slider(maximum=40, value=0, minimum=0, step=1, label='Top K')
-                    topp = gr.Slider(maximum=1, value=0.7, minimum=0, label='Top P')
-                
-            inputs = [model_input, seed, out_seq_length, min_gen_length, sampling_strategy, num_beams, length_penalty, no_repeat_ngram_size, temperature, topk, topp]
+                        """
+                    )
+                    temperature = gr.Slider(maximum=1, value=0.7, minimum=0, label="Temperature")
+                    topk = gr.Slider(maximum=40, value=0, minimum=0, step=1, label="Top K")
+                    topp = gr.Slider(maximum=1, value=0.7, minimum=0, label="Top P")
+
+            inputs = [
+                model_input,
+                seed,
+                out_seq_length,
+                min_gen_length,
+                sampling_strategy,
+                num_beams,
+                length_penalty,
+                no_repeat_ngram_size,
+                temperature,
+                topk,
+                topp,
+            ]
             gen.click(fn=predict, inputs=inputs, outputs=outputs)
             gen.click(fn=predict, inputs=inputs, outputs=outputs)
             clr.click(fn=lambda value: gr.update(value=""), inputs=clr, outputs=model_input)
             clr.click(fn=lambda value: gr.update(value=""), inputs=clr, outputs=model_input)
-            
+
             gr_examples = gr.Examples(examples=examples, inputs=model_input)
             gr_examples = gr.Examples(examples=examples, inputs=model_input)
-            
+
             gr.Markdown(
             gr.Markdown(
                 """
                 """
                 Disclaimer inspired from [BLOOM](https://huggingface.co/spaces/bigscience/bloom-book)
                 Disclaimer inspired from [BLOOM](https://huggingface.co/spaces/bigscience/bloom-book)
                 
                 
                 GLM-130B was trained on web-crawled data, so it's hard to predict how GLM-130B will respond to particular prompts; harmful or otherwise offensive content may occur without warning. We prohibit users from knowingly generating or allowing others to knowingly generate harmful content, including Hateful, Harassment, Violence, Adult, Political, Deception, etc. 
                 GLM-130B was trained on web-crawled data, so it's hard to predict how GLM-130B will respond to particular prompts; harmful or otherwise offensive content may occur without warning. We prohibit users from knowingly generating or allowing others to knowingly generate harmful content, including Hateful, Harassment, Violence, Adult, Political, Deception, etc. 
-                """)
+                """
+            )
 
 
         demo.launch(share=True)
         demo.launch(share=True)
     else:
     else:
@@ -315,11 +380,33 @@ def main(args):
             info = [None, None, None, None, None, None, None, None, None, None, None]
             info = [None, None, None, None, None, None, None, None, None, None, None]
             dist.broadcast_object_list(info, src=0)
             dist.broadcast_object_list(info, src=0)
 
 
-            text, seed, out_seq_length, min_gen_length, sampling_strategy, num_beams, length_penalty, no_repeat_ngram_size, temperature, topk, topp = info
-
-            predict(text, seed, out_seq_length, min_gen_length, sampling_strategy, 
-                num_beams, length_penalty, no_repeat_ngram_size, 
-                temperature, topk, topp)
+            (
+                text,
+                seed,
+                out_seq_length,
+                min_gen_length,
+                sampling_strategy,
+                num_beams,
+                length_penalty,
+                no_repeat_ngram_size,
+                temperature,
+                topk,
+                topp,
+            ) = info
+
+            predict(
+                text,
+                seed,
+                out_seq_length,
+                min_gen_length,
+                sampling_strategy,
+                num_beams,
+                length_penalty,
+                no_repeat_ngram_size,
+                temperature,
+                topk,
+                topp,
+            )
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":

+ 1 - 3
tasks/clean-e2e-nlg/task.py

@@ -38,9 +38,7 @@ class E2E(GenerationTask):
             target_de = self.tokenizer.detokenize(target["targets"][0])
             target_de = self.tokenizer.detokenize(target["targets"][0])
 
 
             scores_rouge = scorer_rouge.score(text_de, target_de)
             scores_rouge = scorer_rouge.score(text_de, target_de)
-            scores_bleurt = scorer_bleurt.score(
-                references=[target_de], candidates=[text_de]
-            )
+            scores_bleurt = scorer_bleurt.score(references=[target_de], candidates=[text_de])
             rouge2_precision = scores_rouge["rouge2"].precision
             rouge2_precision = scores_rouge["rouge2"].precision
             rouge2_recall = scores_rouge["rouge2"].recall
             rouge2_recall = scores_rouge["rouge2"].recall
             rouge2_fmeasure = scores_rouge["rouge2"].fmeasure
             rouge2_fmeasure = scores_rouge["rouge2"].fmeasure

+ 1 - 1
tasks/ethnic/crows-pair/tasks.py

@@ -62,7 +62,7 @@ class CrowsPairTask(MultiChoiceTask, ABC):
             for value1 in result.items():
             for value1 in result.items():
                 value1 = value1[1]
                 value1 = value1[1]
                 for key, value in value1.items():
                 for key, value in value1.items():
-                    print_rank_0("category:{cat}        score:{score}".format(cat=key, score=round(value * 100,2)))
+                    print_rank_0("category:{cat}        score:{score}".format(cat=key, score=round(value * 100, 2)))
 
 
 
 
 class CrowsPairDataset(MultiChoiceTaskDataset):
 class CrowsPairDataset(MultiChoiceTaskDataset):

+ 1 - 3
tasks/web-nlg/task.py

@@ -38,9 +38,7 @@ class WEB(GenerationTask):
             target_de = self.tokenizer.detokenize(target["targets"][0])
             target_de = self.tokenizer.detokenize(target["targets"][0])
 
 
             scores_rouge = scorer_rouge.score(text_de, target_de)
             scores_rouge = scorer_rouge.score(text_de, target_de)
-            scores_bleurt = scorer_bleurt.score(
-                references=[target_de], candidates=[text_de]
-            )
+            scores_bleurt = scorer_bleurt.score(references=[target_de], candidates=[text_de])
             rouge2_precision = scores_rouge["rouge2"].precision
             rouge2_precision = scores_rouge["rouge2"].precision
             rouge2_recall = scores_rouge["rouge2"].recall
             rouge2_recall = scores_rouge["rouge2"].recall
             rouge2_fmeasure = scores_rouge["rouge2"].fmeasure
             rouge2_fmeasure = scores_rouge["rouge2"].fmeasure

+ 1 - 3
tasks/wiki-lingua/task.py

@@ -39,9 +39,7 @@ class WIKI(GenerationTask):
             target_de = self.tokenizer.detokenize(target["targets"][0])
             target_de = self.tokenizer.detokenize(target["targets"][0])
 
 
             scores_rouge = scorer_rouge.score(text_de, target_de)
             scores_rouge = scorer_rouge.score(text_de, target_de)
-            scores_bleurt = scorer_bleurt.score(
-                references=[target_de], candidates=[text_de]
-            )
+            scores_bleurt = scorer_bleurt.score(references=[target_de], candidates=[text_de])
             rouge2_precision = scores_rouge["rouge2"].precision
             rouge2_precision = scores_rouge["rouge2"].precision
             rouge2_recall = scores_rouge["rouge2"].recall
             rouge2_recall = scores_rouge["rouge2"].recall
             rouge2_fmeasure = scores_rouge["rouge2"].fmeasure
             rouge2_fmeasure = scores_rouge["rouge2"].fmeasure