|
@@ -61,12 +61,14 @@ def main():
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
if torch.cuda.is_available():
|
|
device = torch.device("cuda:0")
|
|
device = torch.device("cuda:0")
|
|
- logger.info("Running inference on the GPU.")
|
|
|
|
|
|
+ dtype = torch.float16
|
|
|
|
+ logger.info(f"Running inference on the GPU in {dtype}.")
|
|
else:
|
|
else:
|
|
device = torch.device("cpu")
|
|
device = torch.device("cpu")
|
|
- logger.info("Running inference on the CPU.")
|
|
|
|
|
|
+ dtype = torch.float32
|
|
|
|
+ logger.info(f"Running inference on the CPU in {dtype}.")
|
|
|
|
|
|
- translator = Translator(args.model_name, args.vocoder_name, device)
|
|
|
|
|
|
+ translator = Translator(args.model_name, args.vocoder_name, device, dtype)
|
|
translated_text, wav, sr = translator.predict(
|
|
translated_text, wav, sr = translator.predict(
|
|
args.input,
|
|
args.input,
|
|
args.task,
|
|
args.task,
|