|
@@ -4,8 +4,7 @@
|
|
|
# This source code is licensed under the license found in the
|
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
-from typing import Any, Dict, Optional
|
|
|
-
|
|
|
+from typing import Any, Dict, Optional, List, Union
|
|
|
import torch
|
|
|
from torch import Tensor
|
|
|
from torch.nn import Module
|
|
@@ -26,18 +25,25 @@ class Vocoder(Module):
|
|
|
def forward(
|
|
|
self,
|
|
|
units: Tensor,
|
|
|
- lang: str,
|
|
|
- spkr: Optional[int] = -1,
|
|
|
+ lang_list: Union[List[str], str],
|
|
|
+ spkr_list: Union[Optional[List[int]], int] = None,
|
|
|
dur_prediction: bool = True,
|
|
|
) -> Tensor:
|
|
|
- lang_idx = self.lang_spkr_idx_map["multilingual"][lang]
|
|
|
- spkr_list = self.lang_spkr_idx_map["multispkr"][lang]
|
|
|
- if not spkr:
|
|
|
- spkr = -1
|
|
|
- spkr = spkr_list[0] if spkr == -1 else spkr
|
|
|
+ # TODO: Do we need this backward compatibility, or just update all calling sites?
|
|
|
+ if len(units.shape) == 1:
|
|
|
+ units = units.unsqueeze(0) # add batch dim
|
|
|
+ if isinstance(lang_list, str):
|
|
|
+ lang_list = [lang_list] * units.size(0)
|
|
|
+ if isinstance(spkr_list, int):
|
|
|
+ spkr_list = [spkr_list] * units.size(0)
|
|
|
+ lang_idx_list = [self.lang_spkr_idx_map["multilingual"][l] for l in lang_list]
|
|
|
+ if not spkr_list:
|
|
|
+ spkr_list = [-1 for _ in range(len(lang_list))]
|
|
|
+ spkr_list = [self.lang_spkr_idx_map["multispkr"][lang_list[i]][0] if spkr_list[i] == -1 else spkr_list[i] for i in range(len(spkr_list))]
|
|
|
x = {
|
|
|
- "code": units.view(1, -1),
|
|
|
- "spkr": torch.tensor([[spkr]], device=units.device),
|
|
|
- "lang": torch.tensor([[lang_idx]], device=units.device),
|
|
|
+ "code": units.view(units.size(0), -1),
|
|
|
+ "spkr": torch.tensor([spkr_list], device=units.device).t(),
|
|
|
+ "lang": torch.tensor([lang_idx_list], device=units.device).t(),
|
|
|
+
|
|
|
}
|
|
|
return self.code_generator(x, dur_prediction) # type: ignore[no-any-return]
|