소스 검색

nicer ctypes_utils

# Conflicts:
#	ggml/ggml.py
#	ggml/test_unity_cpp.py
Guillaume Wenzek 1 년 전
부모
커밋
a771adc782
3개의 변경된 파일82개의 추가작업 그리고 61개의 파일을 삭제
  1. 26 12
      ggml/ctypes_utils.py
  2. 14 9
      ggml/ggml.py
  3. 42 40
      ggml/test_unity_cpp.py

+ 26 - 12
ggml/ctypes_utils.py

@@ -2,32 +2,44 @@ import inspect
 import ctypes
 import ctypes
 import types
 import types
 import functools
 import functools
+import dataclasses
+from typing import Callable
+from typing import Any
+from typing import Optional
+from typing import Type
 from typing import TypeVar
 from typing import TypeVar
 from typing import Generic
 from typing import Generic
 
 
 T = TypeVar("T")
 T = TypeVar("T")
 
 
 
 
-class Ptr(Generic[T]):
+class Ptr(Generic[T], ctypes._Pointer):  # type: ignore
     contents: T
     contents: T
 
 
-    def __new__(cls):
-        breakpoint()
-        return ctypes.pointer()
+    def __new__(cls, x: T) -> "Ptr[T]":
+        return ctypes.pointer(x)  # type: ignore
 
 
 
 
-def c_struct(cls):
+NULLPTR: Ptr[Any] = None  # type: ignore[assignment]
+
+def c_struct(cls: Type[T]) -> Type[T]:
     struct = types.new_class(cls.__name__, bases=(ctypes.Structure,))
     struct = types.new_class(cls.__name__, bases=(ctypes.Structure,))
     struct.__module__ = cls.__module__
     struct.__module__ = cls.__module__
-    struct._fields_ = [
+    struct._fields_ = [  # type: ignore
         (k, _py_type_to_ctype(v)) for k, v in cls.__annotations__.items()
         (k, _py_type_to_ctype(v)) for k, v in cls.__annotations__.items()
     ]
     ]
 
 
+    def nice_init(self: T, *args: Any, **kwargs: Any) -> None:
+        dc = cls(*args, **kwargs)
+        for k, _ in self._fields_:  # type: ignore
+            setattr(self, k, getattr(dc, k))
+
+    setattr(struct, "__init__", nice_init)
     return struct
     return struct
 
 
 
 
 @functools.lru_cache(256)
 @functools.lru_cache(256)
-def _py_type_to_ctype(t: type):
+def _py_type_to_ctype(t: type) -> type:
     if isinstance(t, str):
     if isinstance(t, str):
         raise ValueError(
         raise ValueError(
             f"Type parsing of '{t}' isn't supported, you need to provide a real type annotation."
             f"Type parsing of '{t}' isn't supported, you need to provide a real type annotation."
@@ -51,13 +63,15 @@ def _py_type_to_ctype(t: type):
         return ctypes.c_char_p
         return ctypes.c_char_p
 
 
     if getattr(t, "__origin__", None) is Ptr:
     if getattr(t, "__origin__", None) is Ptr:
-        pointee = _py_type_to_ctype(t.__args__[0])
+        pointee = _py_type_to_ctype(t.__args__[0])  # type: ignore
         return ctypes.POINTER(pointee)
         return ctypes.POINTER(pointee)
 
 
     return ctypes.c_void_p
     return ctypes.c_void_p
 
 
 
 
-def _c_fn(module, fn):
+F = TypeVar("F", bound=Callable[..., Any])
+
+def _c_fn(module: Any, fn: F) -> F:
     c_fn = getattr(module, fn.__name__)
     c_fn = getattr(module, fn.__name__)
     annotations = fn.__annotations__
     annotations = fn.__annotations__
     if "return" not in annotations:
     if "return" not in annotations:
@@ -69,12 +83,12 @@ def _c_fn(module, fn):
     c_fn.restype = _py_type_to_ctype(fn.__annotations__["return"])
     c_fn.restype = _py_type_to_ctype(fn.__annotations__["return"])
 
 
     @functools.wraps(fn)
     @functools.wraps(fn)
-    def actual_fn(*args, **kwargs):
+    def actual_fn(*args, **kwargs):  # type: ignore
         raw_res = c_fn(*args, **kwargs)
         raw_res = c_fn(*args, **kwargs)
         return raw_res
         return raw_res
 
 
-    return actual_fn
+    return actual_fn  # type: ignore
 
 
 
 
-def c_fn(module):
+def c_fn(module: Any) -> Callable[[F], F]:
     return functools.partial(_c_fn, module)
     return functools.partial(_c_fn, module)

+ 14 - 9
ggml/ggml.py

@@ -8,6 +8,8 @@ import ctypes
 import torch
 import torch
 import functools
 import functools
 import logging
 import logging
+import dataclasses
+from typing import NamedTuple
 from pathlib import Path
 from pathlib import Path
 from typing import Dict
 from typing import Dict
 from typing import Callable
 from typing import Callable
@@ -199,7 +201,7 @@ def _compute_nbytes(
 
 
 def from_numpy(
 def from_numpy(
     ctx: ggml_context_p, array: Union[np.ndarray, "torch.Tensor"], name: bytes = b""
     ctx: ggml_context_p, array: Union[np.ndarray, "torch.Tensor"], name: bytes = b""
-) -> ggml_tensor_p:
+) -> Ptr[ggml_tensor]:
     if type(array).__name__ == "Tensor":
     if type(array).__name__ == "Tensor":
         array = array.numpy()
         array = array.numpy()
     # Create an empty tensor so we don't allocate memory for the data pointer
     # Create an empty tensor so we don't allocate memory for the data pointer
@@ -219,7 +221,7 @@ def from_numpy(
     setattr(tensor_p, "__data", array)
     setattr(tensor_p, "__data", array)
     if name:
     if name:
         ggml_set_name(tensor_p, name)
         ggml_set_name(tensor_p, name)
-    return tensor_p
+    return tensor_p  # type: ignore
 
 
 
 
 def ggml_can_mul_mat(t0: ggml_tensor_p, t1: ggml_tensor_p) -> bool:
 def ggml_can_mul_mat(t0: ggml_tensor_p, t1: ggml_tensor_p) -> bool:
@@ -425,18 +427,20 @@ def ggml_unflatten_1d(
 
 
 
 
 @c_struct
 @c_struct
+@dataclasses.dataclass
 class SequenceGeneratorOptions:
 class SequenceGeneratorOptions:
     beam_size: int
     beam_size: int
-    min_seq_len: int
-    soft_max_seq_len_a: float
-    soft_max_seq_len_b: int
-    hard_max_seq_len: int
-    len_penalty: float
-    unk_penalty: float
-    normalize_scores: bool
+    min_seq_len: int = 5
+    soft_max_seq_len_a: float = 1.0
+    soft_max_seq_len_b: int = 200
+    hard_max_seq_len: int = 1024
+    len_penalty: float = 1.0
+    unk_penalty: float = 0.0
+    normalize_scores: bool = True
 
 
 
 
 @c_struct
 @c_struct
+@dataclasses.dataclass
 class SequenceGeneratorJob:
 class SequenceGeneratorJob:
     opts: SequenceGeneratorOptions
     opts: SequenceGeneratorOptions
     prefix_seq: Ptr[ggml_tensor]
     prefix_seq: Ptr[ggml_tensor]
@@ -444,6 +448,7 @@ class SequenceGeneratorJob:
     unk_idx: int
     unk_idx: int
     bos_idx: int
     bos_idx: int
     eos_idx: int
     eos_idx: int
+    num_threads: int = 1
 
 
 
 
 @c_struct
 @c_struct

+ 42 - 40
ggml/test_unity_cpp.py

@@ -21,6 +21,7 @@ from ggml_convert import convert_model, read_layer_config
 from seamless_communication.models.inference.translator import Translator, Modality
 from seamless_communication.models.inference.translator import Translator, Modality
 from fairseq2.data.audio import WaveformToFbankConverter
 from fairseq2.data.audio import WaveformToFbankConverter
 import torchaudio
 import torchaudio
+from ctypes_utils import NULLPTR
 from fairseq2.models.wav2vec2.feature_extractor import Wav2Vec2FbankFeatureExtractor
 from fairseq2.models.wav2vec2.feature_extractor import Wav2Vec2FbankFeatureExtractor
 
 
 Ctx = ggml.ggml_context_p
 Ctx = ggml.ggml_context_p
@@ -31,6 +32,8 @@ CTX_PARAMS = ggml.ggml_init_params(mem_size=1024 * 1024 * 1024 * 5, mem_buffer=N
 FAIRSEQ2_CPP = Path(__file__).parent / "examples/unity/fairseq2.cpp"
 FAIRSEQ2_CPP = Path(__file__).parent / "examples/unity/fairseq2.cpp"
 UNITY_FLASH_ATTN = "\n# define UNITY_FLASH_ATTN 0\n" not in FAIRSEQ2_CPP.read_text()
 UNITY_FLASH_ATTN = "\n# define UNITY_FLASH_ATTN 0\n" not in FAIRSEQ2_CPP.read_text()
 
 
+DATA = Path(__file__).parent
+
 
 
 @pytest.fixture(name="ctx")
 @pytest.fixture(name="ctx")
 def _ctx() -> Iterator[Ctx]:
 def _ctx() -> Iterator[Ctx]:
@@ -188,7 +191,7 @@ def test_MultiheadAttention_forward(
         gxq,
         gxq,
         gxk,
         gxk,
         gxk,
         gxk,
-        None,  # TODO: tests with causal attention masks
+        NULLPTR,  # TODO: tests with causal attention masks
     )
     )
     gf = ggml.ggml_build_forward(gy)
     gf = ggml.ggml_build_forward(gy)
     ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
     ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
@@ -453,9 +456,7 @@ def test_StandardTransformerEncoder_forward(ctx: Ctx, g_model: c_void_p) -> None
 
 
 def test_StandardConformerEncoder_forward(ctx: Ctx, g_model: c_void_p) -> None:
 def test_StandardConformerEncoder_forward(ctx: Ctx, g_model: c_void_p) -> None:
     pt_model = load_pt_model()
     pt_model = load_pt_model()
-    wav, _ = torchaudio.load(
-        "/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/test.wav"
-    )
+    wav, _ = torchaudio.load(DATA / "test.wav")
     gx = ggml.from_numpy(ctx, wav * 2**15)  # Apply scale before sending into ggml!
     gx = ggml.from_numpy(ctx, wav * 2**15)  # Apply scale before sending into ggml!
     ggml.ggml_set_name(gx, b"x")
     ggml.ggml_set_name(gx, b"x")
     gy = ggml.forward(
     gy = ggml.forward(
@@ -613,11 +614,11 @@ def test_StandardTransformerDecoder_forward(ctx: Ctx, g_model: c_void_p) -> None
     assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-3)
     assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-3)
 
 
 
 
-def test_t2tt(ctx: Ctx, g_model: c_void_p):
+def test_t2tt(ctx: Ctx, g_model: c_void_p) -> None:
     src_lang = "eng"
     src_lang = "eng"
     src_text = "We are all in a yellow submarine."
     src_text = "We are all in a yellow submarine."
     tgt_lang = "fra"
     tgt_lang = "fra"
-    sample_file = Path(__file__).parent / "sample_input.npz"
+    sample_file = DATA / "sample_input.npz"
     beam_size = 2
     beam_size = 2
 
 
     if not sample_file.exists():
     if not sample_file.exists():
@@ -663,21 +664,24 @@ def test_t2tt(ctx: Ctx, g_model: c_void_p):
     prefix_seq = np.array(text_out["hypotheses"][0]["seq"][:2]).astype(np.int32)
     prefix_seq = np.array(text_out["hypotheses"][0]["seq"][:2]).astype(np.int32)
     max_seq_len = max(len(h["seq"]) for h in text_out["hypotheses"])
     max_seq_len = max(len(h["seq"]) for h in text_out["hypotheses"])
 
 
-    job = ggml.SequenceGeneratorJob()
-    job.opts.beam_size = beam_size
-    job.opts.min_seq_len = 1
-    job.opts.soft_max_seq_len_a = 1
-    job.opts.soft_max_seq_len_b = 200
-    job.opts.hard_max_seq_len = int(max_seq_len * 1.5)
-    job.opts.len_penalty = 1.0
-    job.opts.unk_penalty = 0.0
-    job.opts.normalize_scores = True
-
-    job.prefix_seq = ggml.from_numpy(ctx, prefix_seq)
-    job.pad_idx = 0
-    job.unk_idx = 1
-    job.bos_idx = 2
-    job.eos_idx = 3
+    opts = ggml.SequenceGeneratorOptions(
+        beam_size=beam_size,
+        min_seq_len=1,
+        soft_max_seq_len_a=1,
+        soft_max_seq_len_b=200,
+        hard_max_seq_len=int(max_seq_len * 1.5),
+        len_penalty=1.0,
+        unk_penalty=0.0,
+        normalize_scores=True,
+    )
+    job = ggml.SequenceGeneratorJob(
+        opts=opts,
+        prefix_seq=ggml.from_numpy(ctx, prefix_seq),
+        pad_idx=0,
+        unk_idx=1,
+        bos_idx=2,
+        eos_idx=3,
+    )
 
 
     result_ptr = ggml.generate_sequence(
     result_ptr = ggml.generate_sequence(
         g_model, job, encoder_out, encoder_padding_mask, ctx
         g_model, job, encoder_out, encoder_padding_mask, ctx
@@ -695,9 +699,7 @@ def test_t2tt(ctx: Ctx, g_model: c_void_p):
 
 
 
 
 def test_s2tt(ctx: Ctx, g_model: c_void_p):
 def test_s2tt(ctx: Ctx, g_model: c_void_p):
-    src_audio_wav, _ = torchaudio.load(
-        "/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/test.wav"
-    )
+    src_audio_wav, _ = torchaudio.load(DATA / "test.wav")
     # translator = load_translator()
     # translator = load_translator()
     # token_encoder = translator.text_tokenizer.create_encoder(
     # token_encoder = translator.text_tokenizer.create_encoder(
     #     task="translation"
     #     task="translation"
@@ -738,6 +740,7 @@ def test_s2tt(ctx: Ctx, g_model: c_void_p):
         253935,
         253935,
         3,
         3,
     ]  # "大家好 , 世界无主题。"
     ]  # "大家好 , 世界无主题。"
+    score = -1.606838583946228
     gx = ggml.from_numpy(
     gx = ggml.from_numpy(
         ctx, src_audio_wav * 2**15
         ctx, src_audio_wav * 2**15
     )  # Apply scale before sending into ggml!
     )  # Apply scale before sending into ggml!
@@ -754,21 +757,20 @@ def test_s2tt(ctx: Ctx, g_model: c_void_p):
 
 
     encoder_out = gy
     encoder_out = gy
 
 
-    job = ggml.SequenceGeneratorJob()
-    job.opts.beam_size = 5
-    job.opts.min_seq_len = 1
-    job.opts.soft_max_seq_len_a = 1
-    job.opts.soft_max_seq_len_b = 200
-    job.opts.hard_max_seq_len = 1000
-    job.opts.len_penalty = 1.0
-    job.opts.unk_penalty = 0.0
-    job.prefix_seq = ggml.from_numpy(ctx, np.array([3, 256200]).astype(np.int32))
-    job.opts.normalize_scores = True
-    job.pad_idx = 0
-    job.unk_idx = 1
-    job.bos_idx = 2
-    job.eos_idx = 3
-
-    result_ptr = ggml.generate_sequence(g_model, job, encoder_out, None, ctx)
+    opts = ggml.SequenceGeneratorOptions(
+        beam_size=5,
+        soft_max_seq_len_a=1,
+        soft_max_seq_len_b=200,
+        hard_max_seq_len=1000,
+    )
+    job = ggml.SequenceGeneratorJob(
+        opts=opts,
+        prefix_seq=ggml.from_numpy(ctx, np.array([3, 256200]).astype(np.int32)),
+        pad_idx=0,
+        unk_idx=1,
+        bos_idx=2,
+        eos_idx=3,
+    )
+    result_ptr = ggml.generate_sequence(g_model, job, encoder_out, NULLPTR, ctx)
     g_tokens = list(ggml.to_numpy(result_ptr[0].seq))
     g_tokens = list(ggml.to_numpy(result_ptr[0].seq))
     assert g_tokens == tgt_tokens
     assert g_tokens == tgt_tokens