|
@@ -28,7 +28,7 @@ from seamless_communication.models.monotonic_decoder.model import MonotonicDecod
|
|
class MonotonicDecoderLoader(
|
|
class MonotonicDecoderLoader(
|
|
ModelLoader[MonotonicDecoderModel, MonotonicDecoderConfig]
|
|
ModelLoader[MonotonicDecoderModel, MonotonicDecoderConfig]
|
|
):
|
|
):
|
|
- """Loads NLLB models."""
|
|
|
|
|
|
+ """Loads Monotonic Decoder models."""
|
|
|
|
|
|
@finaloverride
|
|
@finaloverride
|
|
def _convert_checkpoint(
|
|
def _convert_checkpoint(
|
|
@@ -37,7 +37,7 @@ class MonotonicDecoderLoader(
|
|
state_dict = checkpoint["model"]
|
|
state_dict = checkpoint["model"]
|
|
|
|
|
|
# Check if we have a fairseq2 checkpoint.
|
|
# Check if we have a fairseq2 checkpoint.
|
|
- if "decoder_frontend.embed_weight" in state_dict:
|
|
|
|
|
|
+ if "text_decoder.layers.0.self_attn.k_proj.weight" in state_dict:
|
|
return checkpoint
|
|
return checkpoint
|
|
|
|
|
|
key_map = self._fairseq_key_map()
|
|
key_map = self._fairseq_key_map()
|