Explorar o código

address comments

cndn hai 1 ano
pai
achega
096e9276a3
Modificáronse 2 ficheiros con 188 adicións e 2 borrados
  1. 6 2
      ggml/examples/unity/unity.cpp
  2. 182 0
      ggml/mt.py

+ 6 - 2
ggml/examples/unity/unity.cpp

@@ -29,6 +29,7 @@ struct unity_params {
         /*normalize_scores*/ true,
         /*mem_mb*/ 512
     };
+    int32_t max_audio_s = 30;
     bool verbose = false;
 };
 
@@ -48,6 +49,7 @@ void unity_print_usage(int /*argc*/, char ** argv, const unity_params & params)
     fprintf(stderr, "  --text                text output\n");
     fprintf(stderr, "  --beam-size           beam size (default: %d)\n", params.opts.beam_size);
     fprintf(stderr, "  -M, --mem             memory buffer, increase for long inputs (default: %d)\n", params.opts.mem_mb);
+    fprintf(stderr, " --max-audio max duration of audio in seconds (default: %d)\n", params.max_audio_s);
     fprintf(stderr, "\n");
 }
 
@@ -83,6 +85,8 @@ bool unity_params_parse(int argc, char ** argv, unity_params & params) {
             params.verbose = true;
         } else if (arg == "-M" || arg == "--mem") {
             params.opts.mem_mb = std::stoi(get_next_arg(i, argc, argv, arg, params));
+        } else if (arg == "--max-audio") {
+            params.max_audio_s = std::stoi(get_next_arg(i, argc, argv, arg, params));
         } else {
             params.files.push_back(std::string(arg));
         }
@@ -153,8 +157,8 @@ int main(int argc, char ** argv) {
         // Load audio input
         GGML_ASSERT(info.samplerate == 16000);
         GGML_ASSERT(info.channels == 1);
-        // stop at 30s. Ideally we should chunk input audio, but this will prevent most obvious OOM.
-        int n_frames = std::min(info.samplerate * 30, (int)info.frames);
+        // Truncate audio input. Ideally we should chunk it, but this will prevent most obvious OOM.
+        int n_frames = std::min(info.samplerate * params.max_audio_s, (int)info.frames);
         std::vector<float> data(n_frames * info.channels);
         sf_readf_float(sndfile, data.data(), n_frames);
 

+ 182 - 0
ggml/mt.py

@@ -0,0 +1,182 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# MIT_LICENSE file in the root directory of this source tree.
+#
+# This script contains the builder and loader for the MT models. It has some
+# overlaps with fairseq2.models.nllb, except for a few subtle changes
+# in the tokenizer, patches of layers, etc.
+
+from pathlib import Path
+from typing import Any, Mapping, Optional, Literal
+import torch
+from torch.nn.parameter import Parameter
+
+from fairseq2.assets import InProcAssetMetadataProvider, asset_store, download_manager
+from fairseq2.generation.beam_search import BeamSearchSeq2SeqGenerator
+from fairseq2.nn.embedding import StandardEmbedding
+from fairseq2.models.nllb.builder import NllbBuilder, NllbConfig
+from fairseq2.models.nllb.loader import load_nllb_config
+from fairseq2.nn.projection import TiedProjection
+from fairseq2.models.transformer.model import TransformerModel
+from fairseq2.models.utils import ModelLoader
+from fairseq2.typing import Device, DataType
+from fairseq2.models.utils.checkpoint import convert_fairseq_checkpoint
+
+import sentencepiece as spm
+
+
+class MTBuilder(NllbBuilder):
+    def build_embedding(self) -> StandardEmbedding:
+        return StandardEmbedding(
+            num_embeddings=self.config.vocab_info.size,
+            embedding_dim=self.config.model_dim,
+            pad_idx=self.config.vocab_info.pad_idx,
+            init_fn=lambda x: x,
+            device=self.device,
+            dtype=self.dtype,
+        ).requires_grad_(False)
+
+    def build_model(self) -> TransformerModel:
+        """Build a model."""
+        encoder_embed = self.build_embedding()
+        decoder_embed = self.build_embedding()
+
+        encoder_frontend = self.build_frontend(encoder_embed)
+        decoder_frontend = self.build_frontend(decoder_embed)
+
+        encoder = self.build_encoder()
+        decoder = self.build_decoder()
+
+        # Unlike NLLB, in MT we de-couple
+        new_weight = Parameter(torch.zeros_like(
+            encoder_embed.weight, requires_grad=False)
+        )
+        final_proj = TiedProjection(new_weight, bias=None)
+
+        return TransformerModel(
+            encoder_frontend,
+            encoder,
+            decoder_frontend,
+            decoder,
+            final_proj,
+            self.config.vocab_info,
+        )
+
+
+def create_mt_model(
+    config: NllbConfig,
+    *,
+    device: Optional[Device] = None,
+    dtype: Optional[DataType] = None,
+) -> TransformerModel:
+    return MTBuilder(config, device=device, dtype=dtype).build_model()
+
+
+def convert_mt_checkpoint(
+    ckpt: Mapping[str, Any], config: NllbConfig,
+) -> Mapping[str, Any]:
+    global_key_map = {
+        # fmt: off
+        r"^encoder\.embed_tokens\.":                              r"encoder_frontend.embed.",
+        r"^decoder\.embed_tokens\.":                              r"decoder_frontend.embed.",
+        r"^encoder\.embed_positions.weights":                     r"encoder_frontend.pos_encoder.freqs",
+        r"^decoder\.embed_positions.weights":                     r"decoder_frontend.pos_encoder.freqs",
+        r"^encoder\.layernorm_embedding\.":                       r"encoder_frontend.layer_norm.",
+        r"^decoder\.layernorm_embedding\.":                       r"decoder_frontend.layer_norm.",
+        r"^decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"decoder.layers.\1.self_attn.output_proj.",
+        r"^encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"encoder.layers.\1.self_attn.output_proj.",
+        r"^decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.":  r"decoder.layers.\1.encoder_decoder_attn.output_proj.",
+        r"^decoder\.layers\.([0-9]+)\.encoder_attn\.":            r"decoder.layers.\1.encoder_decoder_attn.",
+        r"^decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"decoder.layers.\1.encoder_decoder_attn_layer_norm.",
+        r"^encoder\.layers\.([0-9]+)\.fc1\.":                     r"encoder.layers.\1.ffn.inner_proj.",
+        r"^decoder\.layers\.([0-9]+)\.fc1\.":                     r"decoder.layers.\1.ffn.inner_proj.",
+        r"^encoder\.layers\.([0-9]+)\.fc2\.":                     r"encoder.layers.\1.ffn.output_proj.",
+        r"^decoder\.layers\.([0-9]+)\.fc2\.":                     r"decoder.layers.\1.ffn.output_proj.",
+        r"^encoder\.layers\.([0-9]+)\.final_layer_norm\.":        r"encoder.layers.\1.ffn_layer_norm.",
+        r"^decoder\.layers\.([0-9]+)\.final_layer_norm\.":        r"decoder.layers.\1.ffn_layer_norm.",
+        r"^decoder\.output_projection\.":                         r"final_proj.",
+        # fmt: on
+    }
+    return convert_fairseq_checkpoint(ckpt, global_key_map)
+
+
+def load_vocab(model_dir: str, mode: Literal["src", "tgt"]):
+    vocab_file = f"{model_dir}/{mode}.spm"
+    spmp = spm.SentencePieceProcessor(vocab_file)
+
+    return [
+        (spmp.id_to_piece(id).replace("▁", " "), spmp.get_score(id))
+        for id in range(spmp.get_piece_size())
+    ], spmp
+
+
+def load_mt_model(model_dir: str):
+    """
+    Load MT model and the vocabulary processors (spm) for source and target languages
+    Args:
+        model_dir: Directory of the model. It must contain files averaged_checkpoint.pt, src.spm and tgt.spm
+    """
+
+    # Create a fairseq2 model card on the fly. This must ensure that we do not have any other fairseq2
+    # environment resolvers and always return
+    model_dir = Path(model_dir)
+    model_card_info = [
+        {
+            "name": "mt_model",
+            "model_type": "nllb",  # Re-use the same encoder-decoder arch of NLLB
+            "model_arch": "dense_600m",  # Dummy value to pass fairseq2 asset's valdilation logic
+            "checkpoint": "file://" + str(model_dir / "averaged_checkpoint.pt"),
+            "model_config": {
+                "model_dim": 512,
+                "num_encoder_layers": 4,
+                "num_decoder_layers": 2,
+                "ffn_inner_dim": 2048,
+                "vocab_info": {
+                    "size": 10000,
+                    "unk_idx": 3,
+                    "bos_idx": 0,
+                    "eos_idx": 2,
+                    "pad_idx": 1,
+                }
+            }
+        }
+    ]
+    asset_store.metadata_providers.append(
+        InProcAssetMetadataProvider(model_card_info)
+    )
+    mt_card = asset_store.retrieve_card("mt_model")
+
+    return ModelLoader[TransformerModel, NllbConfig](
+        asset_store,
+        download_manager,
+        load_nllb_config,
+        create_mt_model,
+        convert_mt_checkpoint,
+        restrict_checkpoints=False,
+    )(mt_card)
+
+
+def test_mt(
+    model: TransformerModel,
+    src_spm: spm.SentencePieceProcessor,
+    tgt_spm: spm.SentencePieceProcessor,
+):
+    from fairseq2.nn.padding import pad_seqs
+
+    # Tokens of "This is an example"
+    src_tokens = torch.LongTensor([688, 153, 62, 4581, 2])
+    src_seqs, src_padding_mask = pad_seqs(src_tokens, src_spm.pad_id())
+
+    # Force the developer begins with the EOS <s> token
+    prompt_tokens = torch.LongTensor([[2]])
+    generator = BeamSearchSeq2SeqGenerator(model)
+    output = generator(src_seqs, src_padding_mask, prompt_tokens, None)
+
+    print(output.hypotheses[0][0].seq)
+    tgt_tokens = output.hypotheses[0][0].seq.tolist()
+    out_text = tgt_spm.decode(tgt_tokens)
+
+    # assert out_text == "Este es un ejemplo"
+    print(out_text)