|
@@ -149,14 +149,14 @@ class Translator(nn.Module):
|
|
|
@torch.inference_mode()
|
|
|
def predict(
|
|
|
self,
|
|
|
- input: Union[str, torch.Tensor],
|
|
|
+ input: Union[str, Tensor],
|
|
|
task_str: str,
|
|
|
tgt_lang: str,
|
|
|
src_lang: Optional[str] = None,
|
|
|
spkr: Optional[int] = -1,
|
|
|
ngram_filtering: bool = False,
|
|
|
sample_rate: int = 16000,
|
|
|
- ) -> Tuple[StringLike, Optional[List[Tensor]], Optional[int]]:
|
|
|
+ ) -> Tuple[StringLike, Optional[Tensor], Optional[int]]:
|
|
|
"""
|
|
|
The main method used to perform inference on all tasks.
|
|
|
|