@@ -18,6 +18,11 @@ import json
import torch
import importlib
import dataclasses
+import tensorflow as tf
+
+gpus = tf.config.list_physical_devices("GPU")
+local_rank = int(os.environ["LOCAL_RANK"]) % 8
+tf.config.set_visible_devices(gpus[local_rank], "GPU")
from typing import *
@@ -27,7 +27,7 @@ class ModelForBigBench(Model):
config = GenerationTaskConfig(
name="big-bench",
- type=TaskType.MULTICHOICE,
+ type=TaskType.GENERATION,
max_gen_length=max_length or 128,
path="",
micro_batch_size=micro_batch_size,