Эх сурвалжийг харах

Rename upgrade_checkpoint to convert_checkpoint (#32)

Can Balioglu 2 жил өмнө
parent
commit
4ea2510ed2

+ 1 - 1
src/seamless_communication/models/unity/loader.py

@@ -29,7 +29,7 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
     """Loads UnitY models."""
 
     @finaloverride
-    def _upgrade_checkpoint(
+    def _convert_checkpoint(
         self, checkpoint: Mapping[str, Any], config: UnitYConfig
     ) -> Mapping[str, Any]:
         state_dict = checkpoint["model"]

+ 1 - 1
src/seamless_communication/models/vocoder/loader.py

@@ -23,7 +23,7 @@ class VocoderLoader(ModelLoader[Vocoder, VocoderConfig]):
     """Loads Vocoder models."""
 
     @finaloverride
-    def _upgrade_checkpoint(
+    def _convert_checkpoint(
         self, checkpoint: Mapping[str, Any], config: VocoderConfig
     ) -> Mapping[str, Any]:
         old_state_dict = checkpoint["generator"]