Эх сурвалжийг харах

working out new way of saving hparams

Guillaume Wenzek 1 жил өмнө
parent
commit
c31926c1a8

+ 14 - 4
ggml/examples/unity/fairseq2.cpp

@@ -13,14 +13,23 @@ extern "C" fairseq2_model* fairseq2_model_alloc() {
     // pre-allocate some memory to write hyperparameters and tensors pointers
     auto* model = new fairseq2_model;
     model->hparams = new std::uint8_t[8 * 1024];
-    model->arch = new std::uint64_t[16 * 1024];  // max tensors allowed
     model->tensors_ctx = nullptr;
     return model;
 }
 
+
+double fairseq2_model_layer_config_double(const fairseq2_model& model, std::string name) {
+    const std::int64_t* data = &model.layer_config.at(name);
+    return *(double*)data;
+}
+
+std::int64_t fairseq2_model_layer_config_int(const fairseq2_model& model, std::string name) {
+    return model.layer_config.at(name);
+}
+
+
 extern "C" void fairseq2_model_free(fairseq2_model* model) {
     if (model->tensors_ctx) ggml_free(model->tensors_ctx);
-    delete (std::uint64_t*)(model->arch);
     delete (std::uint8_t*)model->hparams;
     delete model;
 }
@@ -68,8 +77,9 @@ extern "C" ggml_tensor* LayerNorm_forward(
     GGML_ASSERT(bias != nullptr);
 
     auto ctx = model.ctx;
-    // TODO: should `eps` be part of unity hparams ?
-    input = ggml_norm(ctx, input, /*eps*/1e-5);
+    double eps = fairseq2_model_layer_config_double(model, prefix + ".eps");
+
+    input = ggml_norm(ctx, input, /*eps*/eps);
     return ggml_add_inplace(
         ctx,
         ggml_mul_inplace(ctx, ggml_repeat(ctx, weight, input), input),

+ 3 - 1
ggml/examples/unity/fairseq2.h

@@ -12,13 +12,15 @@ struct fairseq2_model {
     ggml_context* tensors_ctx;
     // Named tensors, all tensors should belong to tensors_ctx
     std::map<std::string, struct ggml_tensor *> tensors;
-    void* arch;
+    std::map<std::string, std::int64_t> layer_config;
     void* hparams;
     // an inference context, not managed by this object
     // TODO: is this the best place to store this or should we also pass this to all forward methods ?
     ggml_context* ctx;
 };
 
+double fairseq2_model_layer_config_double(const fairseq2_model& model, std::string name);
+
 /// allocate the fairseq2 model and hyperparameters
 extern "C" fairseq2_model* fairseq2_model_alloc();
 // free the models and all its owned tensors

+ 19 - 1
ggml/examples/unity/model_loader.cpp

@@ -38,8 +38,10 @@ void register_prefix(fairseq2_model &model, const std::string& name) {
 int
 model_loader::load_model_weights(fairseq2_model &model, std::ifstream &fin)
 {
+    int num_tensor = 0;
+    fin.read((char*) &num_tensor, sizeof(num_tensor));
     size_t total_size = 0;
-    while (!fin.eof()) {
+    for (int i = 0; i < num_tensor; ++i) {
         std::string name = get_name(fin);
         if (name.length() == 0)
             break;
@@ -62,6 +64,22 @@ model_loader::load_model_weights(fairseq2_model &model, std::ifstream &fin)
     return 0;
 }
 
+
+int
+model_loader::load_layer_config(fairseq2_model &model, std::ifstream &fin)
+{
+    std::int64_t value;
+    while (!fin.eof()) {
+        std::string name = get_name(fin);
+        if (name.length() == 0)
+            break;
+        fin.read((char*) &value, sizeof(value));
+        model.layer_config[name] = value;
+    }
+
+    return 0;
+}
+
 ggml_tensor* load_tensor_value(std::ifstream &fin, ggml_context* ctx)
 {
     int32_t n_dims = 0;

+ 6 - 1
ggml/examples/unity/model_loader.h

@@ -27,6 +27,8 @@ public:
 
     int load_model_weights(fairseq2_model &model, std::ifstream &fin);
 
+    int load_layer_config(fairseq2_model &model, std::ifstream &fin);
+
 private:
     ggml_tensor * next_tensor(std::ifstream &fin, fairseq2_model &model);
 
@@ -51,6 +53,9 @@ int load_fairseq2_ggml_file(fairseq2_model& model, const char* fname) {
     };
     model.tensors_ctx = ggml_init(params);
 
-    return loader.load_model_weights(model, fin);
+    int err = loader.load_model_weights(model, fin);
+    if (err) return err;
+
+    return loader.load_layer_config(model, fin);
 }
 

+ 1 - 1
ggml/examples/unity/unity_model_loader.cpp

@@ -15,7 +15,7 @@
 void unity_model_loader::load_hparams(fairseq2_model& model, std::ifstream &fin)
 {
     unity_hparams* hparams = (unity_hparams*)model.hparams;
-    read_unity_hparams(hparams, fin);
+    read_unity_hparams(*hparams, fin);
     if (hparams->__end_of_hparams__ != 6877961321223123048) {
         throw std::invalid_argument("");
     }

+ 73 - 76
ggml/examples/unity/unity_model_loader.h

@@ -10,14 +10,13 @@
 #include "model_loader.h"
 
 
-// TODO Merge with Ning implementation
 struct unity_hparams {
     std::int64_t model_dim;
     std::int64_t w2v2_encoder_config__model_dim;
     std::int64_t w2v2_encoder_config__max_seq_len;
     std::int64_t w2v2_encoder_config__feature_dim;
     std::int64_t w2v2_encoder_config__use_fbank;
-    float w2v2_encoder_config__first_pass_dropout_p;
+    double w2v2_encoder_config__first_pass_dropout_p;
     std::int64_t w2v2_encoder_config__layer_norm_features;
     // Error: Unsupported type <class 'list'> w2v2_encoder_config__feature_extractor_layer_descs;
     std::int64_t w2v2_encoder_config__feature_extractor_bias;
@@ -34,21 +33,21 @@ struct unity_hparams {
     std::int64_t w2v2_encoder_config__num_encoder_layers;
     std::int64_t w2v2_encoder_config__num_encoder_attn_heads;
     std::int64_t w2v2_encoder_config__ffn_inner_dim;
-    float w2v2_encoder_config__dropout_p;
-    float w2v2_encoder_config__attn_dropout_p;
-    float w2v2_encoder_config__layer_drop_p;
-    std::int32_t w2v2_encoder_config__norm_order;
+    double w2v2_encoder_config__dropout_p;
+    double w2v2_encoder_config__attn_dropout_p;
+    double w2v2_encoder_config__layer_drop_p;
+    std::int64_t w2v2_encoder_config__norm_order;
     std::int64_t w2v2_encoder_config__depthwise_conv_kernel_size;
-    std::int64_t nllb_config__model_dim;
-    std::int64_t nllb_config__max_seq_len;
-    std::int64_t nllb_config__vocabulary_size;
-    std::int64_t nllb_config__pad_idx;
-    std::int64_t nllb_config__num_encoder_layers;
-    std::int64_t nllb_config__num_decoder_layers;
-    std::int64_t nllb_config__num_encoder_attn_heads;
-    std::int64_t nllb_config__num_decoder_attn_heads;
-    std::int64_t nllb_config__ffn_inner_dim;
-    float nllb_config__dropout_p;
+    std::int64_t mt_model_config__model_dim;
+    std::int64_t mt_model_config__max_seq_len;
+    std::int64_t mt_model_config__vocabulary_size;
+    std::int64_t mt_model_config__pad_idx;
+    std::int64_t mt_model_config__num_encoder_layers;
+    std::int64_t mt_model_config__num_decoder_layers;
+    std::int64_t mt_model_config__num_encoder_attn_heads;
+    std::int64_t mt_model_config__num_decoder_attn_heads;
+    std::int64_t mt_model_config__ffn_inner_dim;
+    double mt_model_config__dropout_p;
     std::int64_t t2u_config__model_dim;
     std::int64_t t2u_config__unit_max_seq_len;
     std::int64_t t2u_config__unit_vocabulary_size;
@@ -58,76 +57,74 @@ struct unity_hparams {
     std::int64_t t2u_config__num_encoder_attn_heads;
     std::int64_t t2u_config__num_decoder_attn_heads;
     std::int64_t t2u_config__ffn_inner_dim;
-    float t2u_config__dropout_p;
+    double t2u_config__dropout_p;
     std::int64_t use_text_encoder;
     std::int64_t use_conformer_adaptor;
     std::int64_t num_adaptor_layers;
     std::int64_t adaptor_kernel_size;
     std::int64_t adaptor_stride;
     std::int64_t adaptor_layer_norm;
-    float adaptor_dropout_p;
+    double adaptor_dropout_p;
     std::int64_t model_byte_size;
     std::int64_t __end_of_hparams__;
-
 };
 
-void read_unity_hparams(unity_hparams* out, std::ifstream &fin) {
-    fin.read((char*) &out->model_dim, sizeof(out->model_dim));
-    fin.read((char*) &out->w2v2_encoder_config__model_dim, sizeof(out->w2v2_encoder_config__model_dim));
-    fin.read((char*) &out->w2v2_encoder_config__max_seq_len, sizeof(out->w2v2_encoder_config__max_seq_len));
-    fin.read((char*) &out->w2v2_encoder_config__feature_dim, sizeof(out->w2v2_encoder_config__feature_dim));
-    fin.read((char*) &out->w2v2_encoder_config__use_fbank, sizeof(out->w2v2_encoder_config__use_fbank));
-    fin.read((char*) &out->w2v2_encoder_config__first_pass_dropout_p, sizeof(out->w2v2_encoder_config__first_pass_dropout_p));
-    fin.read((char*) &out->w2v2_encoder_config__layer_norm_features, sizeof(out->w2v2_encoder_config__layer_norm_features));
-    fin.read((char*) &out->w2v2_encoder_config__feature_extractor_bias, sizeof(out->w2v2_encoder_config__feature_extractor_bias));
-    fin.read((char*) &out->w2v2_encoder_config__feature_extractor_layer_norm_convs, sizeof(out->w2v2_encoder_config__feature_extractor_layer_norm_convs));
-    fin.read((char*) &out->w2v2_encoder_config__feature_grad_scale, sizeof(out->w2v2_encoder_config__feature_grad_scale));
-    fin.read((char*) &out->w2v2_encoder_config__num_fbank_channels, sizeof(out->w2v2_encoder_config__num_fbank_channels));
-    fin.read((char*) &out->w2v2_encoder_config__fbank_stride, sizeof(out->w2v2_encoder_config__fbank_stride));
-    fin.read((char*) &out->w2v2_encoder_config__sample_fbank_every_k, sizeof(out->w2v2_encoder_config__sample_fbank_every_k));
-    fin.read((char*) &out->w2v2_encoder_config__pos_encoder_depth, sizeof(out->w2v2_encoder_config__pos_encoder_depth));
-    fin.read((char*) &out->w2v2_encoder_config__pos_conv_kernel_size, sizeof(out->w2v2_encoder_config__pos_conv_kernel_size));
-    fin.read((char*) &out->w2v2_encoder_config__num_pos_conv_groups, sizeof(out->w2v2_encoder_config__num_pos_conv_groups));
-    fin.read((char*) &out->w2v2_encoder_config__use_conformer, sizeof(out->w2v2_encoder_config__use_conformer));
-    fin.read((char*) &out->w2v2_encoder_config__num_encoder_layers, sizeof(out->w2v2_encoder_config__num_encoder_layers));
-    fin.read((char*) &out->w2v2_encoder_config__num_encoder_attn_heads, sizeof(out->w2v2_encoder_config__num_encoder_attn_heads));
-    fin.read((char*) &out->w2v2_encoder_config__ffn_inner_dim, sizeof(out->w2v2_encoder_config__ffn_inner_dim));
-    fin.read((char*) &out->w2v2_encoder_config__dropout_p, sizeof(out->w2v2_encoder_config__dropout_p));
-    fin.read((char*) &out->w2v2_encoder_config__attn_dropout_p, sizeof(out->w2v2_encoder_config__attn_dropout_p));
-    fin.read((char*) &out->w2v2_encoder_config__layer_drop_p, sizeof(out->w2v2_encoder_config__layer_drop_p));
-    fin.read((char*) &out->w2v2_encoder_config__norm_order, sizeof(out->w2v2_encoder_config__norm_order));
-    fin.read((char*) &out->w2v2_encoder_config__depthwise_conv_kernel_size, sizeof(out->w2v2_encoder_config__depthwise_conv_kernel_size));
-    fin.read((char*) &out->nllb_config__model_dim, sizeof(out->nllb_config__model_dim));
-    fin.read((char*) &out->nllb_config__max_seq_len, sizeof(out->nllb_config__max_seq_len));
-    fin.read((char*) &out->nllb_config__vocabulary_size, sizeof(out->nllb_config__vocabulary_size));
-    fin.read((char*) &out->nllb_config__pad_idx, sizeof(out->nllb_config__pad_idx));
-    fin.read((char*) &out->nllb_config__num_encoder_layers, sizeof(out->nllb_config__num_encoder_layers));
-    fin.read((char*) &out->nllb_config__num_decoder_layers, sizeof(out->nllb_config__num_decoder_layers));
-    fin.read((char*) &out->nllb_config__num_encoder_attn_heads, sizeof(out->nllb_config__num_encoder_attn_heads));
-    fin.read((char*) &out->nllb_config__num_decoder_attn_heads, sizeof(out->nllb_config__num_decoder_attn_heads));
-    fin.read((char*) &out->nllb_config__ffn_inner_dim, sizeof(out->nllb_config__ffn_inner_dim));
-    fin.read((char*) &out->nllb_config__dropout_p, sizeof(out->nllb_config__dropout_p));
-    fin.read((char*) &out->t2u_config__model_dim, sizeof(out->t2u_config__model_dim));
-    fin.read((char*) &out->t2u_config__unit_max_seq_len, sizeof(out->t2u_config__unit_max_seq_len));
-    fin.read((char*) &out->t2u_config__unit_vocabulary_size, sizeof(out->t2u_config__unit_vocabulary_size));
-    fin.read((char*) &out->t2u_config__unit_pad_idx, sizeof(out->t2u_config__unit_pad_idx));
-    fin.read((char*) &out->t2u_config__num_encoder_layers, sizeof(out->t2u_config__num_encoder_layers));
-    fin.read((char*) &out->t2u_config__num_decoder_layers, sizeof(out->t2u_config__num_decoder_layers));
-    fin.read((char*) &out->t2u_config__num_encoder_attn_heads, sizeof(out->t2u_config__num_encoder_attn_heads));
-    fin.read((char*) &out->t2u_config__num_decoder_attn_heads, sizeof(out->t2u_config__num_decoder_attn_heads));
-    fin.read((char*) &out->t2u_config__ffn_inner_dim, sizeof(out->t2u_config__ffn_inner_dim));
-    fin.read((char*) &out->t2u_config__dropout_p, sizeof(out->t2u_config__dropout_p));
-    fin.read((char*) &out->use_text_encoder, sizeof(out->use_text_encoder));
-    fin.read((char*) &out->use_conformer_adaptor, sizeof(out->use_conformer_adaptor));
-    fin.read((char*) &out->num_adaptor_layers, sizeof(out->num_adaptor_layers));
-    fin.read((char*) &out->adaptor_kernel_size, sizeof(out->adaptor_kernel_size));
-    fin.read((char*) &out->adaptor_stride, sizeof(out->adaptor_stride));
-    fin.read((char*) &out->adaptor_layer_norm, sizeof(out->adaptor_layer_norm));
-    fin.read((char*) &out->adaptor_dropout_p, sizeof(out->adaptor_dropout_p));
-    fin.read((char*) &out->model_byte_size, sizeof(out->model_byte_size));
-    fin.read((char*) &out->__end_of_hparams__, sizeof(out->__end_of_hparams__));
-
-}
+void read_unity_hparams(unity_hparams& out, std::ifstream &fin) {
+    fin.read((char*) &out.model_dim, sizeof(out.model_dim));
+    fin.read((char*) &out.w2v2_encoder_config__model_dim, sizeof(out.w2v2_encoder_config__model_dim));
+    fin.read((char*) &out.w2v2_encoder_config__max_seq_len, sizeof(out.w2v2_encoder_config__max_seq_len));
+    fin.read((char*) &out.w2v2_encoder_config__feature_dim, sizeof(out.w2v2_encoder_config__feature_dim));
+    fin.read((char*) &out.w2v2_encoder_config__use_fbank, sizeof(out.w2v2_encoder_config__use_fbank));
+    fin.read((char*) &out.w2v2_encoder_config__first_pass_dropout_p, sizeof(out.w2v2_encoder_config__first_pass_dropout_p));
+    fin.read((char*) &out.w2v2_encoder_config__layer_norm_features, sizeof(out.w2v2_encoder_config__layer_norm_features));
+    fin.read((char*) &out.w2v2_encoder_config__feature_extractor_bias, sizeof(out.w2v2_encoder_config__feature_extractor_bias));
+    fin.read((char*) &out.w2v2_encoder_config__feature_extractor_layer_norm_convs, sizeof(out.w2v2_encoder_config__feature_extractor_layer_norm_convs));
+    fin.read((char*) &out.w2v2_encoder_config__feature_grad_scale, sizeof(out.w2v2_encoder_config__feature_grad_scale));
+    fin.read((char*) &out.w2v2_encoder_config__num_fbank_channels, sizeof(out.w2v2_encoder_config__num_fbank_channels));
+    fin.read((char*) &out.w2v2_encoder_config__fbank_stride, sizeof(out.w2v2_encoder_config__fbank_stride));
+    fin.read((char*) &out.w2v2_encoder_config__sample_fbank_every_k, sizeof(out.w2v2_encoder_config__sample_fbank_every_k));
+    fin.read((char*) &out.w2v2_encoder_config__pos_encoder_depth, sizeof(out.w2v2_encoder_config__pos_encoder_depth));
+    fin.read((char*) &out.w2v2_encoder_config__pos_conv_kernel_size, sizeof(out.w2v2_encoder_config__pos_conv_kernel_size));
+    fin.read((char*) &out.w2v2_encoder_config__num_pos_conv_groups, sizeof(out.w2v2_encoder_config__num_pos_conv_groups));
+    fin.read((char*) &out.w2v2_encoder_config__use_conformer, sizeof(out.w2v2_encoder_config__use_conformer));
+    fin.read((char*) &out.w2v2_encoder_config__num_encoder_layers, sizeof(out.w2v2_encoder_config__num_encoder_layers));
+    fin.read((char*) &out.w2v2_encoder_config__num_encoder_attn_heads, sizeof(out.w2v2_encoder_config__num_encoder_attn_heads));
+    fin.read((char*) &out.w2v2_encoder_config__ffn_inner_dim, sizeof(out.w2v2_encoder_config__ffn_inner_dim));
+    fin.read((char*) &out.w2v2_encoder_config__dropout_p, sizeof(out.w2v2_encoder_config__dropout_p));
+    fin.read((char*) &out.w2v2_encoder_config__attn_dropout_p, sizeof(out.w2v2_encoder_config__attn_dropout_p));
+    fin.read((char*) &out.w2v2_encoder_config__layer_drop_p, sizeof(out.w2v2_encoder_config__layer_drop_p));
+    fin.read((char*) &out.w2v2_encoder_config__norm_order, sizeof(out.w2v2_encoder_config__norm_order));
+    fin.read((char*) &out.w2v2_encoder_config__depthwise_conv_kernel_size, sizeof(out.w2v2_encoder_config__depthwise_conv_kernel_size));
+    fin.read((char*) &out.mt_model_config__model_dim, sizeof(out.mt_model_config__model_dim));
+    fin.read((char*) &out.mt_model_config__max_seq_len, sizeof(out.mt_model_config__max_seq_len));
+    fin.read((char*) &out.mt_model_config__vocabulary_size, sizeof(out.mt_model_config__vocabulary_size));
+    fin.read((char*) &out.mt_model_config__pad_idx, sizeof(out.mt_model_config__pad_idx));
+    fin.read((char*) &out.mt_model_config__num_encoder_layers, sizeof(out.mt_model_config__num_encoder_layers));
+    fin.read((char*) &out.mt_model_config__num_decoder_layers, sizeof(out.mt_model_config__num_decoder_layers));
+    fin.read((char*) &out.mt_model_config__num_encoder_attn_heads, sizeof(out.mt_model_config__num_encoder_attn_heads));
+    fin.read((char*) &out.mt_model_config__num_decoder_attn_heads, sizeof(out.mt_model_config__num_decoder_attn_heads));
+    fin.read((char*) &out.mt_model_config__ffn_inner_dim, sizeof(out.mt_model_config__ffn_inner_dim));
+    fin.read((char*) &out.mt_model_config__dropout_p, sizeof(out.mt_model_config__dropout_p));
+    fin.read((char*) &out.t2u_config__model_dim, sizeof(out.t2u_config__model_dim));
+    fin.read((char*) &out.t2u_config__unit_max_seq_len, sizeof(out.t2u_config__unit_max_seq_len));
+    fin.read((char*) &out.t2u_config__unit_vocabulary_size, sizeof(out.t2u_config__unit_vocabulary_size));
+    fin.read((char*) &out.t2u_config__unit_pad_idx, sizeof(out.t2u_config__unit_pad_idx));
+    fin.read((char*) &out.t2u_config__num_encoder_layers, sizeof(out.t2u_config__num_encoder_layers));
+    fin.read((char*) &out.t2u_config__num_decoder_layers, sizeof(out.t2u_config__num_decoder_layers));
+    fin.read((char*) &out.t2u_config__num_encoder_attn_heads, sizeof(out.t2u_config__num_encoder_attn_heads));
+    fin.read((char*) &out.t2u_config__num_decoder_attn_heads, sizeof(out.t2u_config__num_decoder_attn_heads));
+    fin.read((char*) &out.t2u_config__ffn_inner_dim, sizeof(out.t2u_config__ffn_inner_dim));
+    fin.read((char*) &out.t2u_config__dropout_p, sizeof(out.t2u_config__dropout_p));
+    fin.read((char*) &out.use_text_encoder, sizeof(out.use_text_encoder));
+    fin.read((char*) &out.use_conformer_adaptor, sizeof(out.use_conformer_adaptor));
+    fin.read((char*) &out.num_adaptor_layers, sizeof(out.num_adaptor_layers));
+    fin.read((char*) &out.adaptor_kernel_size, sizeof(out.adaptor_kernel_size));
+    fin.read((char*) &out.adaptor_stride, sizeof(out.adaptor_stride));
+    fin.read((char*) &out.adaptor_layer_norm, sizeof(out.adaptor_layer_norm));
+    fin.read((char*) &out.adaptor_dropout_p, sizeof(out.adaptor_dropout_p));
+    fin.read((char*) &out.model_byte_size, sizeof(out.model_byte_size));
+    fin.read((char*) &out.__end_of_hparams__, sizeof(out.__end_of_hparams__));
+};
 
 class unity_model_loader: public model_loader {
     public:

+ 39 - 5
ggml/ggml_convert.py

@@ -14,6 +14,8 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union
 import math
 import torch
 import ggml
+from typing import Callable
+from typing import Optional
 from typing import List
 from fairseq2.assets import AssetCard
 from fairseq2.models.transformer.frontend import TransformerEmbeddingFrontend
@@ -84,6 +86,7 @@ def convert_model(model_name: str, out: Optional[Path] = None) -> None:
 
     with out.open("wb") as o:
         write_ggml_file(o, hparams, state_dict)
+        write_layer_config(o, model)
 
     with out.with_suffix(".hparams.h").open("w") as h:
         h.write(generate_hparams_struct(hparams, "unity_hparams"))
@@ -174,8 +177,8 @@ def write_hparams(out: BufferedWriter, hparams: Dict[str, Any]) -> None:
             # TODO: this is not cross platform, what's the standard way of writing hparams in GGML ?
             ctype, cvalue = to_ctype(value)
             out.write(struct.pack(ctype, cvalue))
-        except ValueError as e:
-            logging.warning(f"[Warning] {e}. Skipping config for key {key}")
+        except ValueError:
+            logging.warning(f"Skipping config for key {key}={value!r}")
             continue
 
 
@@ -185,6 +188,7 @@ def write_state_dict(out: BufferedWriter, state_dict: Dict[str, torch.Tensor]) -
     :paras state_dict:
         state dict returned by pytorch model
     """
+    out.write(struct.pack("i", len(state_dict)))
     for key, value in state_dict.items():
         write_string(out, key)
         if key.endswith(".bias") and value.ndim == 1 and "adaptor" not in key:
@@ -197,6 +201,27 @@ def write_state_dict(out: BufferedWriter, state_dict: Dict[str, torch.Tensor]) -
         write_tensor(out, value.contiguous())
 
 
+def write_layer_config(out: BufferedWriter, model: torch.nn.Module) -> None:
+    for name, node in find_children(model, torch.nn.Module):
+        for k, v in node.__dict__.items():
+            # Skip special members. In particular all children module and tensors
+            # will be hidden in special dicts `_parameters` and `_modules`
+            if k.startswith("_"):
+                continue
+            # All modules have a "training" flag
+            if k == "training":
+                continue
+            if v is None:
+                continue
+            try:
+                ctype, cvalue = to_ctype(v)
+                write_string(out, f"{name}.{k}")
+                out.write(struct.pack(ctype, cvalue))
+            except ValueError as e:
+                logging.warning(f"Skipping config for {name}.{k}={v!r}")
+                continue
+
+
 def write_string(out: BufferedWriter, value: str) -> None:
     """Write string in utf-8 format.
 
@@ -301,11 +326,18 @@ def to_ctype(value: Any) -> Tuple[str, Any]:
     if isinstance(value, int):
         return ("l", value)
     if isinstance(value, float):
-        return ("f", value)
+        return ("d", value)
     if isinstance(value, bool):
-        return ("?", value)
+        return ("l", value)
     if isinstance(value, Enum):
-        return ("i", value.value)
+        return ("l", value.value)
+    if isinstance(value, tuple) and len(value) == 1:
+        return to_ctype(value[0])
+    if isinstance(value, str) and len(value) < 8:
+        value = bytes(value, "ascii")
+        if len(value) < 8:
+            value = value + (8 - len(value)) * b"\0"
+        return ("l", struct.unpack("l", value)[0])
 
     raise ValueError(f"Unsupported type {type(value)}")
 
@@ -331,6 +363,8 @@ def get_cpp_type(value: Any) -> str:
         return "std::int64_t"
     if ctype == "f":
         return "float"
+    if ctype == "d":
+        return "double"
     if ctype == "?":
         return "bool"