Forráskód Böngészése

minor adjusts to mini_eval

Ruslan Mavlyutov 1 éve
szülő
commit
d381da1354
1 módosított fájl, 26 hozzáadás és 6 törlés
  1. 26 6
      src/seamless_communication/cli/m4t/finetune/mini_eval.py

+ 26 - 6
src/seamless_communication/cli/m4t/finetune/mini_eval.py

@@ -4,6 +4,7 @@ from jiwer import wer
 import os
 from typing import Tuple, Iterable, Dict, Any
 import logging
+from whisper.normalizers import EnglishTextNormalizer
 
 logging.basicConfig(level=logging.INFO)
 
@@ -12,15 +13,20 @@ from seamless_communication.inference import Translator
 
 log = logging.getLogger("l")
 
-TOKEN = "<YOU HF TOKEN HERE>"
+TOKEN = "dummy"
 MAX_SAMPLES = 100
 CHCK_PATH = os.path.expanduser("~/tune_chck/chck.pt")
 
+norm = EnglishTextNormalizer()
 
-def _iterate_test_ds() -> Iterable[Tuple[torch.Tensor, str]]:
+
+DATASET = []  # type:ignore
+
+
+def __iterate_test_ds() -> Iterable[Tuple[torch.Tensor, str]]:
     ds = load_dataset(
         "speechcolab/gigaspeech",
-        "xs",
+        "s",
         token=os.environ.get("HF_TOKEN", TOKEN),
         split="test",
         streaming=True,
@@ -33,11 +39,21 @@ def _iterate_test_ds() -> Iterable[Tuple[torch.Tensor, str]]:
         yield (torch.from_numpy(item["audio"]["array"]), item["text"])
 
 
+def _iterate_test_ds() -> Iterable[Tuple[torch.Tensor, str]]:
+    global DATASET
+    if not DATASET:
+        DATASET = list(__iterate_test_ds())
+    yield from DATASET
+
+
 def _eval(translator: Translator) -> float:
     references = []
     predictions = []
     for idx, (wav, text) in enumerate(_iterate_test_ds()):
-        references.append(text)
+        reference = norm(text)
+        if not reference:
+            reference = "."
+        references.append(reference)
         prediction = str(
             translator.predict(
                 input=wav,
@@ -46,12 +62,15 @@ def _eval(translator: Translator) -> float:
                 src_lang="eng",
             )[0][0]
         )
+        prediction = norm(prediction)
+        if not prediction:
+            prediction = "."
         log.info(idx)
-        log.info(f"REF: {text}")
+        log.info(f"REF: {reference}")
         log.info(f"PRE: {prediction}")
         log.info("----")
         predictions.append(prediction)
-    return wer(reference=references, hypothesis=predictions)
+    return wer(reference=references, hypothesis=predictions)  # type:ignore
 
 
 def _select_keys(state_dict: Dict[str, Any], prefix: str) -> Dict[str, Any]:
@@ -82,5 +101,6 @@ def main() -> None:
     log.info(f"WER non-tuned: {non_tuned_wer:.3f}")
     log.info(f"WER tuned: {tuned_wer:.3f}")
 
+
 if __name__ == "__main__":
     main()