|
@@ -4,6 +4,7 @@ from jiwer import wer
|
|
|
import os
|
|
|
from typing import Tuple, Iterable, Dict, Any
|
|
|
import logging
|
|
|
+from whisper.normalizers import EnglishTextNormalizer
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
@@ -12,15 +13,20 @@ from seamless_communication.inference import Translator
|
|
|
|
|
|
log = logging.getLogger("l")
|
|
|
|
|
|
-TOKEN = "<YOU HF TOKEN HERE>"
|
|
|
+TOKEN = "dummy"
|
|
|
MAX_SAMPLES = 100
|
|
|
CHCK_PATH = os.path.expanduser("~/tune_chck/chck.pt")
|
|
|
|
|
|
+norm = EnglishTextNormalizer()
|
|
|
|
|
|
-def _iterate_test_ds() -> Iterable[Tuple[torch.Tensor, str]]:
|
|
|
+
|
|
|
+DATASET = [] # type:ignore
|
|
|
+
|
|
|
+
|
|
|
+def __iterate_test_ds() -> Iterable[Tuple[torch.Tensor, str]]:
|
|
|
ds = load_dataset(
|
|
|
"speechcolab/gigaspeech",
|
|
|
- "xs",
|
|
|
+ "s",
|
|
|
token=os.environ.get("HF_TOKEN", TOKEN),
|
|
|
split="test",
|
|
|
streaming=True,
|
|
@@ -33,11 +39,21 @@ def _iterate_test_ds() -> Iterable[Tuple[torch.Tensor, str]]:
|
|
|
yield (torch.from_numpy(item["audio"]["array"]), item["text"])
|
|
|
|
|
|
|
|
|
+def _iterate_test_ds() -> Iterable[Tuple[torch.Tensor, str]]:
|
|
|
+ global DATASET
|
|
|
+ if not DATASET:
|
|
|
+ DATASET = list(__iterate_test_ds())
|
|
|
+ yield from DATASET
|
|
|
+
|
|
|
+
|
|
|
def _eval(translator: Translator) -> float:
|
|
|
references = []
|
|
|
predictions = []
|
|
|
for idx, (wav, text) in enumerate(_iterate_test_ds()):
|
|
|
- references.append(text)
|
|
|
+ reference = norm(text)
|
|
|
+ if not reference:
|
|
|
+ reference = "."
|
|
|
+ references.append(reference)
|
|
|
prediction = str(
|
|
|
translator.predict(
|
|
|
input=wav,
|
|
@@ -46,12 +62,15 @@ def _eval(translator: Translator) -> float:
|
|
|
src_lang="eng",
|
|
|
)[0][0]
|
|
|
)
|
|
|
+ prediction = norm(prediction)
|
|
|
+ if not prediction:
|
|
|
+ prediction = "."
|
|
|
log.info(idx)
|
|
|
- log.info(f"REF: {text}")
|
|
|
+ log.info(f"REF: {reference}")
|
|
|
log.info(f"PRE: {prediction}")
|
|
|
log.info("----")
|
|
|
predictions.append(prediction)
|
|
|
- return wer(reference=references, hypothesis=predictions)
|
|
|
+ return wer(reference=references, hypothesis=predictions) # type:ignore
|
|
|
|
|
|
|
|
|
def _select_keys(state_dict: Dict[str, Any], prefix: str) -> Dict[str, Any]:
|
|
@@ -82,5 +101,6 @@ def main() -> None:
|
|
|
log.info(f"WER non-tuned: {non_tuned_wer:.3f}")
|
|
|
log.info(f"WER tuned: {tuned_wer:.3f}")
|
|
|
|
|
|
+
|
|
|
if __name__ == "__main__":
|
|
|
main()
|