|
@@ -3,7 +3,6 @@
|
|
|
#
|
|
|
# This source code is licensed under the license found in the
|
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
-
|
|
|
from typing import Any, Mapping, final
|
|
|
|
|
|
from fairseq2.assets import asset_store, download_manager
|
|
@@ -29,6 +28,12 @@ class VocoderLoader(ModelLoader[Vocoder, VocoderConfig]):
|
|
|
def _convert_checkpoint(
|
|
|
self, checkpoint: Mapping[str, Any], config: VocoderConfig
|
|
|
) -> Mapping[str, Any]:
|
|
|
+ if (
|
|
|
+ "model" in checkpoint
|
|
|
+ and "code_generator.resblocks.0.convs1.0.weight_g" in checkpoint["model"]
|
|
|
+ ):
|
|
|
+ return checkpoint
|
|
|
+
|
|
|
old_state_dict = checkpoint["generator"]
|
|
|
new_state_dict = {}
|
|
|
for key in old_state_dict:
|