Guillaume Wenzek il y a 1 an
Parent
commit
f2ef995b95
6 fichiers modifiés avec 71 ajouts et 80 suppressions
  1. 9 10
      ggml/ctypes_utils.py
  2. 7 14
      ggml/ggml.py
  3. 5 7
      ggml/ggml_convert.py
  4. 11 14
      ggml/test_ggml_integration.py
  5. 21 20
      ggml/test_unity_cpp.py
  6. 18 15
      ggml/third_party_ggml.py

+ 9 - 10
ggml/ctypes_utils.py

@@ -1,14 +1,9 @@
-import inspect
 import ctypes
-import types
-import functools
 import dataclasses
-from typing import Callable
-from typing import Any
-from typing import Optional
-from typing import Type
-from typing import TypeVar
-from typing import Generic
+import functools
+import inspect
+import types
+from typing import Any, Callable, Generic, Optional, Type, TypeVar
 
 T = TypeVar("T")
 
@@ -22,6 +17,7 @@ class Ptr(Generic[T], ctypes._Pointer):  # type: ignore
 
 NULLPTR: Ptr[Any] = None  # type: ignore[assignment]
 
+
 def c_struct(cls: Type[T]) -> Type[T]:
     struct = types.new_class(cls.__name__, bases=(ctypes.Structure,))
     struct.__module__ = cls.__module__
@@ -71,6 +67,7 @@ def _py_type_to_ctype(t: type) -> type:
 
 F = TypeVar("F", bound=Callable[..., Any])
 
+
 def _c_fn(module: Any, fn: F) -> F:
     if callable(module):
         c_fn = module
@@ -78,7 +75,9 @@ def _c_fn(module: Any, fn: F) -> F:
         c_fn = getattr(module, fn.__name__)
     annotations = fn.__annotations__
     if "return" not in annotations:
-        raise ValueError("@c_fn decorator requires type annotations on the decorated function.")
+        raise ValueError(
+            "@c_fn decorator requires type annotations on the decorated function."
+        )
 
     c_fn.argtypes = [
         _py_type_to_ctype(t) for k, t in fn.__annotations__.items() if k != "return"

+ 7 - 14
ggml/ggml.py

@@ -3,25 +3,19 @@ We are vendoring https://github.com/abetlen/ggml-python (MIT License)
 adding a few utilities to convert between ggml and numpy tensors for testing.
 """
 
-import numpy as np
+import contextlib
 import ctypes
-import torch
+import dataclasses
 import functools
 import logging
-import dataclasses
-import contextlib
-from typing import Iterator
-from typing import NamedTuple
 from pathlib import Path
-from typing import Dict
-from typing import Callable
-from typing import Any
-from typing import Tuple
-from typing import Union
-from typing import Type
+from typing import Any, Callable, Dict, Iterator, NamedTuple, Tuple, Type, Union
 
+import numpy as np
+import torch
+
+from ctypes_utils import Ptr, c_fn, c_struct
 from third_party_ggml import *
-from ctypes_utils import c_struct, c_fn, Ptr
 
 ### Helpers
 
@@ -44,7 +38,6 @@ def numpy_dtype(ggml_type: ctypes.c_int) -> np.dtype:
 
 @functools.lru_cache()
 def from_numpy_dtype(dtype: np.dtype) -> ctypes.c_int:
-
     def _ggml_type(name: bytes, value: int) -> ctypes.c_int:
         t = ctypes.c_int(value)
         type_name = ggml_type_name(t)

+ 5 - 7
ggml/ggml_convert.py

@@ -6,23 +6,22 @@
 
 import dataclasses
 import logging
+import math
 import struct
 from enum import Enum
 from io import BufferedWriter
 from pathlib import Path
-from typing import Any, Callable, Dict, Optional, Tuple, Union
-import math
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
 import torch
-import ggml
-from typing import Callable
-from typing import Optional
-from typing import List
 from fairseq2.assets import AssetCard
 from fairseq2.models.transformer.frontend import TransformerEmbeddingFrontend
 from fairseq2.nn import SinusoidalPositionEncoder
 from fairseq2.nn.transformer import RelativePositionalEncoding
 from seamless_communication.models.unity import load_unity_config, load_unity_model
 
+import ggml
+
 Preprocessor = Callable[[Any], Any]
 
 
@@ -111,7 +110,6 @@ def fixup_model(model: torch.nn.Module, state_dict: Dict[str, torch.Tensor]) ->
         assert name not in state_dict
         state_dict[name] = pos_encoder.weight
 
-
     relative_pos_encs = find_children(model, RelativePositionalEncoding)
     # speech_encoder has several copies of the relative_pos_enc module.
     # For efficiency reasons we only make one copy of it to GGML.

+ 11 - 14
ggml/test_ggml_integration.py

@@ -1,24 +1,21 @@
-import ggml
 import ctypes
-import torch
-import pytest
-import numpy as np
-import torch
-import fairseq2.nn
-import fairseq2.nn.transformer
+import functools
 import logging
 import sys
-import functools
-from typing import Tuple
-from pathlib import Path
-from ctypes_utils import Ptr
 from ctypes import c_void_p
-from typing import Any
 from pathlib import Path
-from typing import Iterator
+from typing import Any, Iterator, Tuple
+
+import fairseq2.nn
+import fairseq2.nn.transformer
+import numpy as np
+import pytest
+import torch
+
+import ggml
+from ctypes_utils import Ptr
 from ggml import NativeObj
 from ggml_convert import convert_model
-from seamless_communication.models.inference.translator import Translator, Modality
 
 Ctx = ggml.ggml_context_p
 

+ 21 - 20
ggml/test_unity_cpp.py

@@ -1,29 +1,27 @@
-import ggml
 import ctypes
-import torch
-import pytest
-import numpy as np
-import torch
-import fairseq2.nn
-import fairseq2.nn.transformer
+import functools
 import logging
 import sys
-import functools
-from typing import Tuple
-from pathlib import Path
-from ctypes_utils import Ptr
 from ctypes import c_void_p
-from typing import Any
 from pathlib import Path
-from typing import Iterator
-from ggml import NativeObj
-from ggml_convert import convert_model, read_layer_config
-from seamless_communication.models.inference.translator import Translator, Modality
-from fairseq2.data.audio import WaveformToFbankConverter
+from typing import Any, Iterator, List, Tuple
+
+import fairseq2.nn
+import fairseq2.nn.transformer
+import numpy as np
+import pytest
+import torch
 import torchaudio
-from typing import List
-from ctypes_utils import NULLPTR
+from fairseq2.data.audio import WaveformToFbankConverter
+from fairseq2.generation import SequenceGeneratorOptions
 from fairseq2.models.wav2vec2.feature_extractor import Wav2Vec2FbankFeatureExtractor
+from seamless_communication.inference.translator import Modality, Translator
+
+import ggml
+from ctypes_utils import NULLPTR, Ptr
+from ggml import NativeObj
+from ggml_convert import convert_model, read_layer_config
+
 
 Ctx = ggml.ggml_context_p
 
@@ -569,7 +567,10 @@ def test_PositionalEmbedding_forward_with_cache(ctx: Ctx, g_model: c_void_p) ->
             gseq = ggml.from_numpy(ctx, seq[:, t : t + 1, :].numpy())
             ggml.ggml_set_name(gseq, b"seq")
             gy = ggml.forward(
-                "PositionalEmbedding", g_model, "text_decoder_frontend.pos_encoder", gseq
+                "PositionalEmbedding",
+                g_model,
+                "text_decoder_frontend.pos_encoder",
+                gseq,
             )
             gf = ggml.ggml_build_forward(gy)
             ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)

+ 18 - 15
ggml/third_party_ggml.py

@@ -49,20 +49,15 @@ ggml.ggml_free(ctx)
 ```
 
 """
-import os
-import sys
 import ctypes
-import pathlib
 import importlib.resources
-import numpy as np
-from typing import Union
-from typing import Type
-from typing import Callable
-from typing import Tuple
-from typing import Dict
-from typing import Any
+import os
+import pathlib
+import sys
 from pathlib import Path
-from typing import List, Optional, Sequence, Union
+from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
+
+import numpy as np
 from typing_extensions import TypeAlias
 
 NULL: ctypes.c_void_p = None  # ignore: type
@@ -4863,7 +4858,9 @@ ggml_custom3_op_f32_t = ctypes.CFUNCTYPE(
 #         struct ggml_tensor         * a,
 #                ggml_unary_op_f32_t   fun);
 def ggml_map_unary_f32(
-    ctx: ggml_context_p, a: ggml_tensor_p, fun: "ctypes._FuncPointer"  # type: ignore
+    ctx: ggml_context_p,
+    a: ggml_tensor_p,
+    fun: "ctypes._FuncPointer",  # type: ignore
 ) -> ggml_tensor_p:
     return lib.ggml_map_unary_f32(ctx, a, fun)
 
@@ -4881,7 +4878,9 @@ lib.ggml_map_unary_f32.restype = ctypes.POINTER(ggml_tensor)
 #         struct ggml_tensor         * a,
 #                 ggml_unary_op_f32_t   fun);
 def ggml_map_unary_inplace_f32(
-    ctx: ggml_context_p, a: ggml_tensor_p, fun: "ctypes._FuncPointer"  # type: ignore
+    ctx: ggml_context_p,
+    a: ggml_tensor_p,
+    fun: "ctypes._FuncPointer",  # type: ignore
 ) -> ggml_tensor_p:
     return lib.ggml_map_unary_inplace_f32(ctx, a, fun)
 
@@ -4945,7 +4944,9 @@ lib.ggml_map_binary_inplace_f32.restype = ctypes.POINTER(ggml_tensor)
 #         struct ggml_tensor           * a,
 #                 ggml_custom1_op_f32_t   fun);
 def ggml_map_custom1_f32(
-    ctx: ggml_context_p, a: ggml_tensor_p, fun: "ctypes._FuncPointer"  # type: ignore
+    ctx: ggml_context_p,
+    a: ggml_tensor_p,
+    fun: "ctypes._FuncPointer",  # type: ignore
 ) -> ggml_tensor_p:
     """Custom unary operator on a tensor.
 
@@ -4985,7 +4986,9 @@ lib.ggml_map_custom1_f32.restype = ctypes.POINTER(ggml_tensor)
 #         struct ggml_tensor           * a,
 #                 ggml_custom1_op_f32_t   fun);
 def ggml_map_custom1_inplace_f32(
-    ctx: ggml_context_p, a: ggml_tensor_p, fun: "ctypes._CFuncPtr"  # type: ignore
+    ctx: ggml_context_p,
+    a: ggml_tensor_p,
+    fun: "ctypes._CFuncPtr",  # type: ignore
 ) -> ggml_tensor_p:
     """Custom unary operator on a tensor inplace.