execution.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. from typing import Optional, Callable, Dict
  2. import ast
  3. import contextlib
  4. import faulthandler
  5. import io
  6. import os
  7. import multiprocessing
  8. import platform
  9. import signal
  10. import tempfile
  11. def check_correctness(problem: Dict, completion: str, timeout: float,
  12. completion_id: Optional[int] = None) -> Dict:
  13. """
  14. Evaluates the functional correctness of a completion by running the test
  15. suite provided in the problem.
  16. :param completion_id: an optional completion ID so we can match
  17. the results later even if execution finishes asynchronously.
  18. """
  19. def unsafe_execute():
  20. with create_tempdir():
  21. # These system calls are needed when cleaning up tempdir.
  22. import os
  23. import shutil
  24. rmtree = shutil.rmtree
  25. rmdir = os.rmdir
  26. chdir = os.chdir
  27. # Disable functionalities that can make destructive changes to the test.
  28. reliability_guard()
  29. # Construct the check program and run it.
  30. check_program = (
  31. problem["prompt"] + completion + "\n" +
  32. problem["test"] + "\n" +
  33. f"check({problem['entry_point']})"
  34. )
  35. try:
  36. exec_globals = {}
  37. with swallow_io():
  38. with time_limit(timeout):
  39. # WARNING
  40. # This program exists to execute untrusted model-generated code. Although
  41. # it is highly unlikely that model-generated code will do something overtly
  42. # malicious in response to this test suite, model-generated code may act
  43. # destructively due to a lack of model capability or alignment.
  44. # Users are strongly encouraged to sandbox this evaluation suite so that it
  45. # does not perform destructive actions on their host or network. For more
  46. # information on how OpenAI sandboxes its code, see the accompanying paper.
  47. # Once you have read this disclaimer and taken appropriate precautions,
  48. # uncomment the following line and proceed at your own risk:
  49. exec(check_program, exec_globals)
  50. result.append("passed")
  51. except TimeoutException:
  52. result.append("timed out")
  53. except BaseException as e:
  54. result.append(f"failed: {e}")
  55. # Needed for cleaning up.
  56. shutil.rmtree = rmtree
  57. os.rmdir = rmdir
  58. os.chdir = chdir
  59. manager = multiprocessing.Manager()
  60. result = manager.list()
  61. p = multiprocessing.Process(target=unsafe_execute)
  62. p.start()
  63. p.join(timeout=timeout + 1)
  64. if p.is_alive():
  65. p.kill()
  66. if not result:
  67. result.append("timed out")
  68. return dict(
  69. task_id=problem["task_id"],
  70. passed=result[0] == "passed",
  71. result=result[0],
  72. completion_id=completion_id,
  73. )
  74. @contextlib.contextmanager
  75. def time_limit(seconds: float):
  76. def signal_handler(signum, frame):
  77. raise TimeoutException("Timed out!")
  78. signal.setitimer(signal.ITIMER_REAL, seconds)
  79. signal.signal(signal.SIGALRM, signal_handler)
  80. try:
  81. yield
  82. finally:
  83. signal.setitimer(signal.ITIMER_REAL, 0)
  84. @contextlib.contextmanager
  85. def swallow_io():
  86. stream = WriteOnlyStringIO()
  87. with contextlib.redirect_stdout(stream):
  88. with contextlib.redirect_stderr(stream):
  89. with redirect_stdin(stream):
  90. yield
  91. @contextlib.contextmanager
  92. def create_tempdir():
  93. with tempfile.TemporaryDirectory() as dirname:
  94. with chdir(dirname):
  95. yield dirname
  96. class TimeoutException(Exception):
  97. pass
  98. class WriteOnlyStringIO(io.StringIO):
  99. """ StringIO that throws an exception when it's read from """
  100. def read(self, *args, **kwargs):
  101. raise IOError
  102. def readline(self, *args, **kwargs):
  103. raise IOError
  104. def readlines(self, *args, **kwargs):
  105. raise IOError
  106. def readable(self, *args, **kwargs):
  107. """ Returns True if the IO object can be read. """
  108. return False
  109. class redirect_stdin(contextlib._RedirectStream): # type: ignore
  110. _stream = 'stdin'
  111. @contextlib.contextmanager
  112. def chdir(root):
  113. if root == ".":
  114. yield
  115. return
  116. cwd = os.getcwd()
  117. os.chdir(root)
  118. try:
  119. yield
  120. except BaseException as exc:
  121. raise exc
  122. finally:
  123. os.chdir(cwd)
  124. def reliability_guard(maximum_memory_bytes: Optional[int] = None):
  125. """
  126. This disables various destructive functions and prevents the generated code
  127. from interfering with the test (e.g. fork bomb, killing other processes,
  128. removing filesystem files, etc.)
  129. WARNING
  130. This function is NOT a security sandbox. Untrusted code, including, model-
  131. generated code, should not be blindly executed outside of one. See the
  132. Codex paper for more information about OpenAI's code sandbox, and proceed
  133. with caution.
  134. """
  135. if maximum_memory_bytes is not None:
  136. import resource
  137. resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
  138. resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
  139. if not platform.uname().system == 'Darwin':
  140. resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
  141. faulthandler.disable()
  142. import builtins
  143. builtins.exit = None
  144. builtins.quit = None
  145. import os
  146. os.environ['OMP_NUM_THREADS'] = '1'
  147. os.kill = None
  148. os.system = None
  149. os.putenv = None
  150. os.remove = None
  151. os.removedirs = None
  152. os.rmdir = None
  153. os.fchdir = None
  154. os.setuid = None
  155. os.fork = None
  156. os.forkpty = None
  157. os.killpg = None
  158. os.rename = None
  159. os.renames = None
  160. os.truncate = None
  161. os.replace = None
  162. os.unlink = None
  163. os.fchmod = None
  164. os.fchown = None
  165. os.chmod = None
  166. os.chown = None
  167. os.chroot = None
  168. os.fchdir = None
  169. os.lchflags = None
  170. os.lchmod = None
  171. os.lchown = None
  172. os.getcwd = None
  173. os.chdir = None
  174. import shutil
  175. shutil.rmtree = None
  176. shutil.move = None
  177. shutil.chown = None
  178. import subprocess
  179. subprocess.Popen = None # type: ignore
  180. __builtins__['help'] = None
  181. import sys
  182. sys.modules['ipdb'] = None
  183. sys.modules['joblib'] = None
  184. sys.modules['resource'] = None
  185. sys.modules['psutil'] = None
  186. sys.modules['tkinter'] = None