Ver código fonte

save model as fp16 (#277)

* save model as fp16

* rollback change to model format

* cleanup comments

* layer_filter

# Conflicts:
#	ggml/ggml_convert.py
Guillaume Wenzek 1 ano atrás
pai
commit
c634c446d1

+ 27 - 11
ggml/examples/unity/model_loader.cpp

@@ -39,12 +39,15 @@ std::int64_t
 model_loader::load_model_weights(fairseq2_model &model, std::ifstream &fin)
 {
     std::int64_t num_tensor = 0;
-    std::int64_t ctx_size = 0;
+    std::int64_t f32_tensor_size = 0;
     fin.read((char*) &num_tensor, sizeof(num_tensor));
-    fin.read((char*) &ctx_size, sizeof(ctx_size));
+    fin.read((char*) &f32_tensor_size, sizeof(f32_tensor_size));
 
+    // TODO: it might be interesting to allow the caller to not upcast the weights to float32.
+    // Note this require changing the on disk format
+    bool as_float32 = true;
     struct ggml_init_params params = {
-        /*.mem_size   =*/ ctx_size,
+        /*.mem_size   =*/ f32_tensor_size + num_tensor * (int64_t)ggml_tensor_overhead(),
         /*.mem_buffer =*/ NULL,
         /*.no_alloc   =*/ false,
     };
@@ -55,7 +58,7 @@ model_loader::load_model_weights(fairseq2_model &model, std::ifstream &fin)
         std::string name = get_name(fin);
         if (name.length() == 0)
             break;
-        auto tensor = load_tensor_value(fin, model.tensors_ctx);
+        auto tensor = load_tensor_value(fin, model.tensors_ctx, as_float32);
         if (tensor == nullptr) {
             // Abort in case of error, the input stream is corrupted at this point.
             printf("Error while reading tensor %s\n", name.c_str() );
@@ -75,10 +78,10 @@ model_loader::load_model_weights(fairseq2_model &model, std::ifstream &fin)
         __func__,
         model_size / mb,
         ggml_used_mem(model.tensors_ctx) / mb,
-        ctx_size / mb
+        ggml_get_mem_size(model.tensors_ctx) / mb
     );
 
-    return ctx_size;
+    return ggml_get_mem_size(model.tensors_ctx);
 }
 
 void assert_endianness() {
@@ -139,9 +142,9 @@ void model_loader::load_vocab(llama_vocab& vocab, std::ifstream &fin)
     std::int64_t ctx_size = vocab_size * sizeof(float) + vocab_size + 2 * ggml_tensor_overhead();
     ctx_size *= 2;
     ggml_context* ctx = ggml_init(ggml_init_params{ctx_size, nullptr, false});
-    ggml_tensor* lengths_tensor = load_tensor_value(fin, ctx);
+    ggml_tensor* lengths_tensor = load_tensor_value(fin, ctx, true);
     std::int8_t* lengths = (std::int8_t*)lengths_tensor->data;
-    ggml_tensor* scores_tensor = load_tensor_value(fin, ctx);
+    ggml_tensor* scores_tensor = load_tensor_value(fin, ctx, true);
     float* scores = ggml_get_data_f32(scores_tensor);
 
     int64_t offset = 0;
@@ -159,7 +162,7 @@ void model_loader::load_vocab(llama_vocab& vocab, std::ifstream &fin)
     // TODO: special tokens stuff ?
 }
 
-ggml_tensor* load_tensor_value(std::ifstream &fin, ggml_context* ctx)
+ggml_tensor* load_tensor_value(std::ifstream &fin, ggml_context* ctx, bool as_float32)
 {
     int32_t n_dims = 0;
     int32_t raw_type = 0;
@@ -176,8 +179,21 @@ ggml_tensor* load_tensor_value(std::ifstream &fin, ggml_context* ctx)
         fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
     }
 
-    ggml_tensor* tensor = ggml_new_tensor(ctx, type, n_dims, ne);
-    fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
+    ggml_tensor* tensor;
+    if (as_float32 && type == GGML_TYPE_F16) {
+        // read quantized weights from disk, and convert them to f32.
+        tensor = ggml_new_tensor(ctx, GGML_TYPE_F32, n_dims, ne);
+        ggml_fp16_t buf[128];
+        int num_el = ggml_nelements(tensor);
+        for (int i = 0; i < num_el; i += 128) {
+            int block_size = std::min(128, num_el - i);
+            fin.read(reinterpret_cast<char *>(&buf), ggml_type_size(type) * block_size);
+            ggml_fp16_to_fp32_row((const ggml_fp16_t*)&buf, (float*)tensor->data + i, block_size);
+        }
+    } else {
+        tensor = ggml_new_tensor(ctx, type, n_dims, ne);
+        fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
+    }
     return tensor;
 }
 

+ 1 - 1
ggml/examples/unity/model_loader.h

@@ -30,7 +30,7 @@ private:
     std::string get_name(std::ifstream &fin);
 };
 
-ggml_tensor* load_tensor_value(std::ifstream &fin, ggml_context* ctx);
+ggml_tensor* load_tensor_value(std::ifstream &fin, ggml_context* ctx, bool as_float32);
 
 std::ifstream open_ggml_file(const char* fname);
 

+ 73 - 33
ggml/ggml_convert.py

@@ -21,15 +21,19 @@ from fairseq2.nn.transformer import RelativePositionalEncoding
 from seamless_communication.models import unity
 
 import ggml
+import re
 
 Preprocessor = Callable[[Any], Any]
+log = logging.getLogger("ggml_convert")
 
 
 def convert_model(
     model_name: Union[str, torch.nn.Module],
     out: Optional[Path] = None,
+    layers: str = "",
     hparams: Optional[Dict[str, Any]] = None,
     vocab: Optional[List[Tuple[str, float]]] = None,
+    fp16: bool = False,
 ) -> None:
     if isinstance(model_name, str):
         # Load the corresponding fairseq2 model
@@ -43,7 +47,7 @@ def convert_model(
                 hparams = flatten_config(
                     dataclasses.asdict(model_config), separator="__"
                 )
-                print(hparams)
+                log.info(hparams)
             model = unity.load_unity_model(model_name)
             if vocab is None:
                 tokenizer = unity.load_unity_text_tokenizer(model_name)
@@ -59,11 +63,12 @@ def convert_model(
         model = model_name
 
     state_dict = model.state_dict()
-    fixup_model(model, state_dict)
-    layer_config = read_layer_config(model)
+    if layers:
+        state_dict = {k: v for k, v in state_dict.items() if re.match(layers, k)}
+    fixup_model(model, state_dict, layer_filter=layers)
+    layer_config = read_layer_config(model, layer_filter=layers)
     vocab = vocab or []
-
-    write_ggml_file(out, hparams, layer_config, vocab, state_dict)
+    write_ggml_file(out, hparams, layer_config, vocab, state_dict, fp16)
 
 
 def _nested_getattr(model: Any, name: str) -> Any:
@@ -76,13 +81,15 @@ def _nested_getattr(model: Any, name: str) -> Any:
     return node
 
 
-def find_children(model: torch.nn.Module, t: type) -> List[Tuple[str, torch.nn.Module]]:
+def find_children(model: torch.nn.Module, t: type, layer_filter: str = "") -> List[Tuple[str, torch.nn.Module]]:
     queue = list(model._modules.items())
     modules = []
     while queue:
         name, node = queue.pop()
         if node is None:
             continue
+        if layer_filter and not re.match(layer_filter, name):
+            continue
         if isinstance(node, t):
             modules.append((name, node))
         for child_name, child_node in node._modules.items():
@@ -91,39 +98,50 @@ def find_children(model: torch.nn.Module, t: type) -> List[Tuple[str, torch.nn.M
     return modules
 
 
-def fixup_model(model: torch.nn.Module, state_dict: Dict[str, torch.Tensor]) -> None:
+def fixup_model(model: torch.nn.Module, state_dict: Dict[str, torch.Tensor], layer_filter: str) -> None:
     # Bake the embedding scaling into the weights
-    frontends = find_children(model, TransformerEmbeddingFrontend)
-    print(
-        "Upgrading the following TransformerEmbeddingFrontend:",
-        [x[0] for x in frontends],
-    )
+    frontends = find_children(model, TransformerEmbeddingFrontend, layer_filter)
+    if frontends:
+        log.info(
+            "Upgrading the following TransformerEmbeddingFrontend: {}",
+            [x[0] for x in frontends],
+        )
     for name, frontend in frontends:
         embed_weights = state_dict[name + ".embed.weight"]
         state_dict[name + ".embed.weight"] = embed_weights * frontend.scale
 
     # Sinusoidal embeddings are typically not saved since they are easily recomputed,
     # but this allows to avoid porting the sinusoidal logic to GGML
-    pos_encoders = find_children(model, SinusoidalPositionEncoder)
-    print(
-        "Upgrading the following SinusoidalPositionEncoder:",
-        [x[0] for x in pos_encoders],
-    )
+    pos_encoders = find_children(model, SinusoidalPositionEncoder, layer_filter)
+    if pos_encoders:
+        log.info(
+            "Upgrading the following SinusoidalPositionEncoder: {}",
+            [x[0] for x in pos_encoders],
+        )
     for name, pos_encoder in pos_encoders:
         assert isinstance(pos_encoder.freqs, torch.Tensor)
         assert name not in state_dict
         state_dict[name] = pos_encoder.freqs
 
-    relative_pos_encs = find_children(model, RelativePositionalEncoding)
+    relative_pos_encs = find_children(model, RelativePositionalEncoding, layer_filter)
     # speech_encoder has several copies of the relative_pos_enc module.
     # For efficiency reasons we only make one copy of it to GGML.
     if relative_pos_encs:
-        print("Merging all speech_encoder RelativePositionalEncoding into one.")
+        log.info("Merging all speech_encoder RelativePositionalEncoding into one.")
         _, rel_pos_enc = relative_pos_encs[0]
         assert isinstance(rel_pos_enc.freqs, torch.Tensor)
         state_dict["speech_encoder.pos_enc"] = rel_pos_enc.freqs
 
 
+def convert_to_fp16(state_dict: Dict[str, torch.Tensor]) -> None:
+    for k in state_dict:
+        v = state_dict[k]
+        if v.dtype != torch.float32:
+            # ignore int tensors
+            continue
+        state_dict[k] = v.to(torch.float16)
+
+
 def read_vocab(tokenizer: Any) -> List[Tuple[str, float]]:
     vocab_info = tokenizer.vocab_info
     vocab = [
@@ -139,13 +157,14 @@ def write_ggml_file(
     layer_config: Dict[str, Any],
     vocab: List[Tuple[str, float]],
     state_dict: Dict[str, torch.Tensor],
+    fp16: bool,
 ) -> None:
     with out.open("wb") as o:
         write_ggml_header(o)
         write_hparams(o, hparams)
         write_hparams(o, layer_config)
         write_vocab(o, vocab)
-        write_state_dict(o, state_dict)
+        write_state_dict(o, state_dict, fp16)
 
 
 def write_ggml_header(out: BufferedWriter) -> None:
@@ -196,21 +215,40 @@ def write_vocab(out: BufferedWriter, vocab: List[Tuple[str, float]]) -> None:
     write_tensor(out, scores)
 
 
-def write_state_dict(out: BufferedWriter, state_dict: Dict[str, torch.Tensor]) -> None:
+def write_state_dict(
+    out: BufferedWriter, state_dict: Dict[str, torch.Tensor], fp16: bool
+) -> None:
     """Write pytorch state dict.
 
-    :paras state_dict:
+    :params state_dict:
         state dict returned by pytorch model
+    :params fp16:
+        convert float32 tensors to float16 on disk
     """
     out.write(struct.pack("<q", len(state_dict)))
-    # Size of each tensor
-    byte_size = sum(x.numel() * x.element_size() for x in state_dict.values())
-    # + tensor overhead
-    byte_size += ggml.ggml_tensor_overhead() * (len(state_dict) + 10)
-    out.write(struct.pack("<q", byte_size))
-    logging.warning(
-        f"Saving a ggml file with {len(state_dict)} tensors, for an estimated amount of {byte_size / (1024**3):.3f} GGML Gb"
-    )
+    # True size of each tensor (before downcasting to float16)
+    true_byte_size = sum(x.numel() * x.element_size() for x in state_dict.values())
+    out.write(struct.pack("<q", true_byte_size))
+
+    GB = 1024**3
+    if not fp16:
+        log.warning(
+            f"Saving a ggml file with {len(state_dict)} tensors, totalling {true_byte_size / GB:.3f}Gb"
+        )
+    else:
+
+        def _fp16_byte_size(x: torch.Tensor) -> int:
+            full_byte_size = x.numel() * x.element_size()
+            if fp16 and x.dtype == torch.float32:
+                full_byte_size //= 2
+            return full_byte_size
+
+        # Compressed size
+        compressed_byte_size = sum(_fp16_byte_size(x) for x in state_dict.values())
+        log.warning(
+            f"Saving a ggml file with {len(state_dict)} tensors, totalling {true_byte_size / GB:.3f}Gb compressed to {compressed_byte_size / GB:.3f}"
+        )
+
     for key, value in state_dict.items():
         write_string(out, key)
         if key.endswith(".bias") and value.ndim == 1 and "adaptor" not in key:
@@ -220,6 +258,8 @@ def write_state_dict(out: BufferedWriter, state_dict: Dict[str, torch.Tensor]) -
             value = value.squeeze(-1)
         if "depthwise_conv" in key:
             value = value.squeeze(1)
+        if fp16 and value.dtype == torch.float32:
+            value = value.to(torch.float16)
         write_tensor(out, value.contiguous())
 
 
@@ -319,7 +359,7 @@ def flatten_config(
     return __flatten(config)
 
 
-def read_layer_config(model: torch.nn.Module) -> Dict[str, Any]:
+def read_layer_config(model: torch.nn.Module, layer_filter: str) -> Dict[str, Any]:
     layer_config = {}
 
     def _append_node_config(node: Any, prefix: str) -> None:
@@ -337,12 +377,12 @@ def read_layer_config(model: torch.nn.Module) -> Dict[str, Any]:
             try:
                 to_ctype(v)
             except ValueError:
-                logging.warning(f"Skipping layer config {k}={v!r}")
+                log.warning(f"Skipping layer config {k}={v!r}")
                 continue
             layer_config[prefix + k] = v
 
     _append_node_config(model, "")
-    for name, node in find_children(model, torch.nn.Module):
+    for name, node in find_children(model, torch.nn.Module, layer_filter):
         _append_node_config(node, name + ".")
     return layer_config
 

+ 23 - 1
ggml/test_unity_cpp.py

@@ -100,7 +100,7 @@ def test_convert_linear(tmp_path: Path) -> None:
     layer_config = read_layer_config(module)
     assert layer_config == {"input_dim": 16, "output_dim": 24}
 
-    module_file = Path("module.ggml")
+    module_file = tmp_path / "module.ggml"
     convert_model(module, module_file)
     g_module = ggml.load_fairseq2_ggml_file(module_file)
 
@@ -109,6 +109,28 @@ def test_convert_linear(tmp_path: Path) -> None:
             ggml.fairseq2_model_layer_config_int(g_module.ptr, bytes(k, "ascii")) == v
         )
 
+def test_convert_linear_fp16(tmp_path: Path, ctx: Ctx) -> None:
+    pt_model = torch.nn.ModuleDict({"linear": fairseq2.nn.Linear(16, 24, True)})
+
+    layer_config = read_layer_config(pt_model)
+    assert layer_config == {"linear.input_dim": 16, "linear.output_dim": 24}
+
+    ggml_file = tmp_path / "linear.ggml"
+    convert_model(pt_model, ggml_file, fp16=True)
+    assert ggml_file.stat().st_size < (16 * 24 + 24) * 2 * 1.5
+    g_model = ggml.load_fairseq2_ggml_file(ggml_file)
+    ggml.lib.fairseq2_model_set_inference_ctx(g_model.ptr, ctx)
+
+    x = torch.empty((2, 5, 16))
+    torch.nn.init.uniform_(x, -1, 1)
+    y_exp = pt_model.linear(x).numpy()
+    gx = ggml.from_numpy(ctx, x)
+    gy = ggml.forward("Linear", g_model.ptr, "linear", gx)
+    ggml.build_and_compute(ctx, gy)
+    y = ggml.to_numpy(gy)
+
+    assert np.allclose(y_exp, y, atol=1e-3)
+
 
 def test_causal_attention_mask(ctx: Ctx):
     x = torch.zeros((1, 10, 32))