Kaushik Ram Sadagopan 2 лет назад
Родитель
Сommit
e605855da8

+ 1 - 1
src/seamless_communication/assets/cards/xlsr2_1b_v2.yaml

@@ -7,4 +7,4 @@
 name: xlsr2_1b_v2
 model_type: wav2vec2
 model_arch: xlsr2_1b_v2
-checkpoint: "file://private/home/changhan/data/models/wav2vec2/xlsr2_1b_v2.pt"
+checkpoint: "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/xlsr2_1b_v2.pt"

+ 41 - 0
src/seamless_communication/models/unit_extraction/kmeans.py

@@ -3,3 +3,44 @@
 #
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import Tensor, nn
+import numpy as np
+from pathlib import Path
+from fairseq2.typing import Device
+from seamless_communication.assets import download_manager
+
+
+class KmeansModel(nn.Module):
+    @staticmethod
+    def load_model(km_path: Path, device: Device) -> "KmeansModel":
+        km_model = np.load(km_path)
+        centroids_numpy = km_model.transpose()
+        return KmeansModel(torch.from_numpy(centroids_numpy), device)
+
+    def __init__(self, centroids: Tensor, device: Device):
+        super().__init__()
+
+        self.centroids = nn.Parameter(centroids, requires_grad=False).to(device)
+        self.centroid_norm = nn.Parameter(
+            (centroids**2).sum(0, keepdims=True), requires_grad=False
+        ).to(device)
+
+    def forward(self, x: Tensor) -> Tensor:
+        dist: Tensor = (
+            x.pow(2).sum(1, keepdim=True)
+            - 2 * torch.matmul(x, self.centroids)
+            + self.centroid_norm
+        )
+        return dist.argmin(dim=-1)
+
+
+if __name__ == "__main__":
+    kmeans_uri = "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy"
+    km_path = download_manager.download_checkpoint(kmeans_uri, "kmeans_10k")
+    device = torch.device("cuda:1")
+    model = KmeansModel.load_model(km_path, device)
+    t = torch.randn((1000, 1280), device=device, dtype=torch.float32)
+    units = model(t)
+    print(units)