|
@@ -53,10 +53,12 @@ class Translator(nn.Module):
|
|
|
model_name_or_card: Union[str, AssetCard],
|
|
|
vocoder_name_or_card: Union[str, AssetCard],
|
|
|
device: Device,
|
|
|
- dtype: DataType,
|
|
|
+ dtype: DataType = torch.float16,
|
|
|
):
|
|
|
super().__init__()
|
|
|
# Load the model.
|
|
|
+ if device == torch.device("cpu"):
|
|
|
+ dtype = torch.float32
|
|
|
self.model: UnitYModel = self.load_model_for_inference(
|
|
|
load_unity_model, model_name_or_card, device, dtype
|
|
|
)
|