Pārlūkot izejas kodu

fixing infeernce related scripts and command examples (#222)

Ilia Kulikov 1 gadu atpakaļ
vecāks
revīzija
1779d9d4fe

+ 9 - 9
docs/expressive/README.md

@@ -111,6 +111,7 @@ export SPLIT="dev_mexpresso_eng_spa" # example, change for your split
 export TGT_LANG="spa"
 export SRC_LANG="eng"
 export GENERATED_DIR="path_to_generated_output_for_given_data_split"
+export GENERATED_TSV="generate-${SPLIT}.tsv"
 export STOPES_ROOT="path_to_stopes_code_repo"
 export SC_ROOT="path_to_this_repo"
 ```
@@ -124,7 +125,6 @@ python ${SC_ROOT}/src/seamless_communication/cli/expressivity/evaluate/run_asr_b
     --tgt_lang=${TGT_LANG}
 ```
 * `generate-${SPLIT}.tsv` is an expected output from inference described in pre-requisite
-* `run_asr_bleu.py` creates an additional manifest called `output_manifest.tsv` inside `--generation_dir_path` which includes all relevant columns needed for this evaluation
 
 After completion resulting ASR-BLEU score is written in `${GENERATED_DIR}/s2st_asr_bleu_normalized.json`.
 
@@ -137,10 +137,10 @@ python -m stopes.modules +vocal_style_similarity=base \
     launcher.cluster=local \
     vocal_style_similarity.model_type=valle \
     +vocal_style_similarity.model_path=${SPEECH_ENCODER_MODEL_PATH} \
-    +vocal_style_similarity.input_file=${GENERATED_DIR}/output_manifest.tsv \
+    +vocal_style_similarity.input_file=${GENERATED_DIR}/${GENERATED_TSV} \
     +vocal_style_similarity.output_file=${GENERATED_DIR}/vocal_style_sim_result.txt \
     vocal_style_similarity.named_columns=true \
-    vocal_style_similarity.src_audio_column=audio \
+    vocal_style_similarity.src_audio_column=src_audio \
     vocal_style_similarity.tgt_audio_column=hypo_audio
 ```
 * We report average number from all utterance scores written in `${GENERATED_DIR}/vocal_style_sim_result.txt`.
@@ -150,8 +150,8 @@ python -m stopes.modules +vocal_style_similarity=base \
 ```bash
 python -m stopes.modules +compare_audios=AutoPCP_multilingual_v2 \
     launcher.cluster=local \
-    +compare_audios.input_file=${GENERATED_DIR}/output_manifest.tsv \
-    compare_audios.src_audio_column=audio \
+    +compare_audios.input_file=${GENERATED_DIR}/${GENERATED_TSV} \
+    compare_audios.src_audio_column=src_audio \
     compare_audios.tgt_audio_column=hypo_audio \
     +compare_audios.named_columns=true \
     +compare_audios.output_file=${GENERATED_DIR}/autopcp_result.txt
@@ -165,10 +165,10 @@ This stage includes 3 steps: (1) src lang annotation, (2) tgt lang annotation, (
 ```bash
 # src lang pause&rate annotation
 python ${STOPES_ROOT}/stopes/eval/local_prosody/annotate_utterances.py \
-    +data_path=${GENERATED_DIR}/output_manifest.tsv \
+    +data_path=${GENERATED_DIR}/${GENERATED_TSV} \
     +result_path=${GENERATED_DIR}/${SRC_LANG}_speech_rate_pause_annotation.tsv \
-    +audio_column=audio \
-    +text_column=raw_src_text \
+    +audio_column=src_audio \
+    +text_column=src_text \
     +speech_units=[syllable] \
     +vad=true \
     +net=true \
@@ -177,7 +177,7 @@ python ${STOPES_ROOT}/stopes/eval/local_prosody/annotate_utterances.py \
 
 # tgt lang pause&rate annotation
 python ${STOPES_ROOT}/stopes/eval/local_prosody/annotate_utterances.py \
-    +data_path=${GENERATED_DIR}/output_manifest.tsv \
+    +data_path=${GENERATED_DIR}/${GENERATED_TSV} \
     +result_path=${GENERATED_DIR}/${TGT_LANG}_speech_rate_pause_annotation.tsv \
     +audio_column=hypo_audio \
     +text_column=s2t_out \

+ 11 - 18
src/seamless_communication/cli/expressivity/evaluate/pretssel_inference.py

@@ -61,7 +61,9 @@ def build_data_pipeline(
 ) -> DataPipeline:
     with open(args.data_file, "r") as f:
         header = f.readline().strip("\n").split("\t")
-        assert args.audio_field in header, f"Input file does not contain {args.audio_field} field"
+        assert (
+            args.audio_field in header
+        ), f"Input file does not contain {args.audio_field} field"
 
     n_parallel = 4
 
@@ -73,7 +75,9 @@ def build_data_pipeline(
 
     map_file = FileMapper(root_dir=args.audio_root_dir, cached_fd_count=10)
 
-    pipeline_builder.map(map_file, selector=args.audio_field, num_parallel_calls=n_parallel)
+    pipeline_builder.map(
+        map_file, selector=args.audio_field, num_parallel_calls=n_parallel
+    )
 
     decode_audio = AudioDecoder(dtype=torch.float32, device=device)
 
@@ -150,8 +154,8 @@ def main() -> None:
     parser.add_argument(
         "--duration_factor",
         type=float,
-        help="The duration factor for NAR T2U model. Expressivity model uses 1.1",
-        default=1.1,
+        help="The duration factor for NAR T2U model.",
+        default=1.0,
     )
     parser.add_argument(
         "--output_result_tsv",
@@ -212,6 +216,7 @@ def main() -> None:
 
     hyps = []
     refs = []
+    audio_hyps = []
 
     with contextlib.ExitStack() as stack:
         hyp_file = stack.enter_context(
@@ -286,6 +291,7 @@ def main() -> None:
                     speech_output.audio_wavs[i][0].to(torch.float32).cpu(),
                     sample_rate=speech_output.sample_rate,
                 )
+                audio_hyps.append((waveforms_dir / f"{idx}_pred.wav").as_posix())
 
                 sample_id += 1
                 progress_bar.update(1)
@@ -306,25 +312,12 @@ def main() -> None:
             for line in file:
                 unit_out.append(line.strip())
 
+        output_tsv["hypo_audio"] = audio_hyps
         output_tsv["s2t_out"] = text_out
         output_tsv["orig_unit"] = unit_out
         output_tsv.to_csv(output_tsv_file, quoting=3, sep="\t", index=False)
         logger.info(f"Output results in {output_tsv_file}")
 
-    if len(hyps) == len(refs):
-        logger.info(f"Calculating S2T BLEU using {args.ref_field} column")
-        if args.tgt_lang in ("cmn", "jpn", "lao", "mya", "tha"):
-            tokenizer = "char"
-        else:
-            tokenizer = "13a"
-
-        bleu = BLEU(tokenize=tokenizer)
-        score = bleu.corpus_score(hyps, [refs])
-        bleu_filename = output_path / f"{args.data_file.stem}_text_output_bleu.json"
-        with open(bleu_filename, "w") as f:
-            f.write(score.format(signature=str(bleu.get_signature()), is_json=True))
-        logger.info(score.format(signature=bleu.get_signature()))
-
 
 if __name__ == "__main__":
     main()

+ 2 - 36
src/seamless_communication/cli/expressivity/evaluate/run_asr_bleu.py

@@ -15,47 +15,13 @@ from fairseq2.typing import Device
 from pathlib import Path
 
 
-def create_output_manifest(
-    generation_dir_path: str,
-    generate_tsv_filename: str,
-) -> pd.DataFrame:
-    generate_df = pd.read_csv(
-        f"{generation_dir_path}/{generate_tsv_filename}",
-        sep="\t",
-        quoting=csv.QUOTE_MINIMAL,
-    )
-
-    # fetch waveforms following indices from generate_df
-    waveform_paths = []
-    for idx in generate_df["id"]:
-        waveform_path = f"{generation_dir_path}/waveform/{idx}_pred.wav"
-        assert os.path.exists(waveform_path)
-        waveform_paths.append(waveform_path)
-
-    generate_df["hypo_audio"] = waveform_paths
-
-    generate_df.set_index("id").to_csv(
-        f"{generation_dir_path}/output_manifest.tsv",
-        sep="\t",
-        quoting=csv.QUOTE_MINIMAL,
-    )
-    return generate_df
-
-
 def run_asr_bleu_expressive_model(
     generation_dir_path: str,
     generate_tsv_filename: str,
     tgt_lang: str,
-) -> None:
-    output_manifest_path = Path(generation_dir_path) / "output_manifest.tsv"
-
-    if not output_manifest_path.exists():
-        _ = create_output_manifest(
-            generation_dir_path, generate_tsv_filename
-        ).set_index("id")
-
+):
     compute_quality_metrics(
-        output_manifest_path,
+        f"{generation_dir_path}/{generate_tsv_filename}",
         Path(generation_dir_path),
         tgt_lang,
         "S2ST",