Guillaume Wenzek před 1 rokem
rodič
revize
02fe0cbec5
3 změnil soubory, kde provedl 10 přidání a 2 odebrání
  1. 3 0
      ggml/examples/unity/fairseq2.h
  2. 6 1
      ggml/examples/unity/unity.cpp
  3. 1 1
      ggml/ggml.py

+ 3 - 0
ggml/examples/unity/fairseq2.h

@@ -276,6 +276,9 @@ struct SequenceGeneratorOptions {
 
     /// If ``True``, normalizes scores by the length of generated sequences.
     bool normalize_scores = true;
+
+    // memory needed is largely a fn of model size + sentence length and beam_size
+    int mem_mb = 256;
 };
 
 

+ 6 - 1
ggml/examples/unity/unity.cpp

@@ -34,7 +34,9 @@ struct unity_params {
         /*len_penalty*/ 1.0,
         /*unk_penalty*/ 0.0,
         /*normalize_scores*/ true,
+        /*mem_mb*/ 256,
     };
+    int32_t mem_mb = 256; // mem_usage
 };
 
 
@@ -48,6 +50,7 @@ void unity_print_usage(int /*argc*/, char ** argv, const unity_params & params)
     fprintf(stderr, "                        model path (default: %s)\n", params.model.c_str());
     fprintf(stderr, "  --text                text output\n");
     fprintf(stderr, "  --beam-size           beam size (default: %d)\n", params.opts.beam_size);
+    fprintf(stderr, "  -M, --mem             memory buffer, increase for long inputs (default: %d)\n", params.mem_mb);
     fprintf(stderr, "\n");
 }
 
@@ -77,6 +80,8 @@ bool unity_params_parse(int argc, char ** argv, unity_params & params) {
             params.text = true;
         } else if (arg == "-b" || arg == "--beam-size") {
             params.opts.beam_size = std::stoi(get_next_arg(i, argc, argv, arg, params));
+        } else if (arg == "-M" || arg == "--mem") {
+            params.mem_mb = std::stoi(get_next_arg(i, argc, argv, arg, params));
         } else {
             params.files.push_back(std::string(arg));
         }
@@ -136,7 +141,7 @@ int main(int argc, char ** argv) {
     }
 
     // The ctx_size_mb mostly depends of input length and model dim.
-    int ctx_size_mb = 128;
+    int ctx_size_mb = params.mem_mb;
     auto encoder_buf = std::vector<uint8_t>(128 * 1024 * 1024);
     auto encoder_fwd_buf = std::vector<uint8_t>(ctx_size_mb * 1024 * 1024);
     ggml_allocr* fwd_alloc = ggml_allocr_new(encoder_fwd_buf.data(), encoder_fwd_buf.capacity(), 8);

+ 1 - 1
ggml/ggml.py

@@ -459,7 +459,7 @@ class SequenceGeneratorOptions:
     len_penalty: float = 1.0
     unk_penalty: float = 0.0
     normalize_scores: bool = True
-
+    mem_mb: int = 256
 
 @c_struct
 @dataclasses.dataclass