code_eval_octopack / execute.py
Muennighoff's picture
Fix bug
0406cd0
raw
history blame
13.6 kB
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This code is adapted from OpenAI's release
# https://github.com/openai/human-eval/blob/master/human_eval/execution.py
import contextlib
import faulthandler
import io
import logging
import multiprocessing
import os
import platform
import signal
import subprocess
import tempfile
# https://github.com/THUDM/CodeGeeX/blob/ebeb850f227a90c79de39f7e26b1302f374f3240/codegeex/benchmark/rust/Cargo.toml
BASE_CARGO = """[package]
name = "rust"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
rand = "0.4"
regex = "1"
md5 = "0.7.0
"""
def check_correctness(check_program, timeout, task_id, completion_id, language):
"""
Evaluates the functional correctness of a completion by running the test
suite provided in the problem.
:param completion_id: an optional completion ID so we can match
the results later even if execution finishes asynchronously.
"""
manager = multiprocessing.Manager()
result = manager.list()
if language == "python":
p = multiprocessing.Process(target=unsafe_execute, args=(check_program, result, timeout))
elif language == "cpp":
p = multiprocessing.Process(target=unsafe_execute_cpp, args=(check_program, result, timeout))
elif language == "go":
p = multiprocessing.Process(target=unsafe_execute_go, args=(check_program, result, timeout))
elif language == "java":
p = multiprocessing.Process(target=unsafe_execute_java, args=(check_program, result, timeout))
elif language == "javascript":
p = multiprocessing.Process(target=unsafe_execute_js, args=(check_program, result, timeout))
elif language == "rust":
p = multiprocessing.Process(target=unsafe_execute_rust, args=(check_program, result, timeout))
else:
raise ValueError(f"Language {language} not supported. Feel free to add it :)")
p.start()
p.join(timeout=timeout + 1)
if p.is_alive():
p.kill()
if not result:
result.append("timed out")
return dict(
task_id=task_id,
passed=result[0] == "passed",
result=result[0],
completion_id=completion_id,
)
def unsafe_execute(check_program, result, timeout):
with create_tempdir():
# These system calls are needed when cleaning up tempdir.
import os
import shutil
rmtree = shutil.rmtree
rmdir = os.rmdir
chdir = os.chdir
# Disable functionalities that can make destructive changes to the test.
reliability_guard()
# Run program.
try:
exec_globals = {}
with swallow_io():
with time_limit(timeout):
exec(check_program, exec_globals)
result.append("passed")
except TimeoutException:
result.append("timed out")
except BaseException as e:
result.append(f"failed: {e}")
# Needed for cleaning up.
shutil.rmtree = rmtree
os.rmdir = rmdir
os.chdir = chdir
def unsafe_execute_cpp(check_program, result, timeout):
with create_tempdir():
open(f"test.cpp", 'w').write(check_program)
# -lcrypto & -lssl are needed for the crypto library required by HumanEval-X-Bugs Task 162
compilation_result = subprocess.run(
["/usr/bin/g++", "-std=c++11", "test.cpp", "-lcrypto", "-lssl"],
timeout=timeout,
capture_output=True,
)
if compilation_result.returncode != 0:
if compilation_result.stderr:
err = compilation_result.stderr.decode()
else:
err = compilation_result.stdout.decode()
result.append(f"failed: compilation error: {err}")
else:
try:
exec_result = subprocess.run(["./a.out"], timeout=timeout, capture_output=True)
if exec_result.returncode == 0:
result.append("passed")
else:
if exec_result.stderr:
try:
err = exec_result.stderr.decode()
except:
err = exec_result.stderr
else:
try:
err = exec_result.stdout.decode()
except:
err = exec_result.stdout
result.append(f"failed: {err}")
except subprocess.TimeoutExpired as e:
result.append("timed out")
def unsafe_execute_go(check_program, result, timeout):
with create_tempdir():
open(f"main_test.go", 'w').write(check_program)
try:
exec_result = subprocess.run(["go", "test", f"-timeout={timeout}s", "main_test.go"], timeout=timeout, capture_output=True)
if exec_result.returncode == 0:
result.append("passed")
else:
if exec_result.stderr:
try:
err = exec_result.stderr.decode()
except:
err = exec_result.stderr
else:
try:
err = exec_result.stdout.decode()
except:
err = exec_result.stdout
result.append(f"failed: {err}")
except subprocess.TimeoutExpired as e:
result.append("timed out")
def unsafe_execute_java(check_program, result, timeout):
with create_tempdir():
open(f"Main.java", 'w').write(check_program)
res = "failed: unknown error"
compile_returncode = -1
for _ in range(5):
try:
compilation_result = subprocess.run(
['javac', "Main.java"], timeout=5, capture_output=True
)
compile_returncode = compilation_result.returncode
break
except subprocess.TimeoutExpired as e:
continue
if compile_returncode != 0:
res = "failed: compilation error"
else:
try:
exec_result = subprocess.run([f'java', '-cp', '.', 'Main'], timeout=timeout, capture_output=True)
if exec_result.returncode == 0:
res = "passed"
elif exec_result.returncode == 1:
if "AssertionError" in exec_result.stderr.decode('unicode-escape'):
res = "failed: wrong answer"
else:
res = f"failed: {exec_result.stderr.decode()}"
except subprocess.TimeoutExpired as e:
res = "timed out"
except BaseException as e:
res = f"failed: {e}"
result.append(res)
def unsafe_execute_js(check_program, result, timeout):
with create_tempdir():
open(f"test.js", 'w').write(check_program)
# Run program.
try:
exec_result = subprocess.run(["node", "test.js"], timeout=timeout, capture_output=True)
if exec_result.stderr.decode():
err = exec_result.stderr.decode()
result.append(f"failed: {err}")
elif exec_result.stdout.decode():
err = exec_result.stdout.decode()
result.append(f"failed: {err}")
else:
result.append("passed")
except subprocess.TimeoutExpired as e:
result.append("timed out")
def unsafe_execute_rust(check_program, result, timeout):
with create_tempdir():
WD: str = os.getcwd()
RUST_DIR: str = os.path.join(WD, "rust")
RUST_SRC: str = os.path.join(RUST_DIR, "src")
RUST_BIN: str = os.path.join(RUST_SRC, "bin")
RUST_TMP_DIR: str = os.path.join(RUST_DIR, "tmp")
RUST_LOGS: str = os.path.join(RUST_TMP_DIR, "logs")
RUST_EXT: str = ".rs"
# Create mandatory tmp directories
os.makedirs(RUST_TMP_DIR, exist_ok=True)
os.makedirs(RUST_LOGS, exist_ok=True)
os.makedirs(RUST_SRC, exist_ok=True)
os.makedirs(RUST_BIN, exist_ok=True)
# Check if Cargo exists, if so copy it here
if os.path.exists("../Cargo.toml"):
shutil.copy("../Cargo.toml", RUST_DIR)
else:
# Warn that no Cargo was found in the parent directory
logging.warning(f"Cargo.toml not found in {os.path.abspath('../')}. Creating a new one. Timeout of >300 is recommended.")
# Create Cargo.toml
open(f"{RUST_DIR}/Cargo.toml", 'w').write(BASE_CARGO)
open(f"test.rs", 'w').write(check_program)
compilation_result = subprocess.run(["cargo", "check", "--bin", "test", "--message-format", "json"], capture_output=True)
if compilation_result.returncode == 0:
exec_result = subprocess.run(["cargo", "test", "--bin", "test", "--message-format", "json"], capture_output=True)
if exec_result.returncode == 0:
result.append("passed")
else:
result.append("failed: execution error: " + exec_result.stderr.decode())
else:
result.append("failed: compilation error: " + compilation_result.stderr.decode())
@contextlib.contextmanager
def time_limit(seconds):
def signal_handler(signum, frame):
raise TimeoutException("Timed out!")
signal.setitimer(signal.ITIMER_REAL, seconds)
signal.signal(signal.SIGALRM, signal_handler)
try:
yield
finally:
signal.setitimer(signal.ITIMER_REAL, 0)
@contextlib.contextmanager
def swallow_io():
stream = WriteOnlyStringIO()
with contextlib.redirect_stdout(stream):
with contextlib.redirect_stderr(stream):
with redirect_stdin(stream):
yield
@contextlib.contextmanager
def create_tempdir():
with tempfile.TemporaryDirectory() as dirname:
with chdir(dirname):
yield dirname
class TimeoutException(Exception):
pass
class WriteOnlyStringIO(io.StringIO):
"""StringIO that throws an exception when it's read from"""
def read(self, *args, **kwargs):
raise OSError
def readline(self, *args, **kwargs):
raise OSError
def readlines(self, *args, **kwargs):
raise OSError
def readable(self, *args, **kwargs):
"""Returns True if the IO object can be read."""
return False
class redirect_stdin(contextlib._RedirectStream): # type: ignore
_stream = "stdin"
@contextlib.contextmanager
def chdir(root):
if root == ".":
yield
return
cwd = os.getcwd()
os.chdir(root)
try:
yield
except BaseException as exc:
raise exc
finally:
os.chdir(cwd)
def reliability_guard(maximum_memory_bytes=None):
"""
This disables various destructive functions and prevents the generated code
from interfering with the test (e.g. fork bomb, killing other processes,
removing filesystem files, etc.)
WARNING
This function is NOT a security sandbox. Untrusted code, including, model-
generated code, should not be blindly executed outside of one. See the
Codex paper for more information about OpenAI's code sandbox, and proceed
with caution.
"""
if maximum_memory_bytes is not None:
import resource
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
if not platform.uname().system == "Darwin":
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
faulthandler.disable()
import builtins
builtins.exit = None
builtins.quit = None
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.kill = None
os.system = None
os.putenv = None
os.remove = None
os.removedirs = None
os.rmdir = None
os.fchdir = None
os.setuid = None
os.fork = None
os.forkpty = None
os.killpg = None
os.rename = None
os.renames = None
os.truncate = None
os.replace = None
os.unlink = None
os.fchmod = None
os.fchown = None
os.chmod = None
os.chown = None
os.chroot = None
os.fchdir = None
os.lchflags = None
os.lchmod = None
os.lchown = None
os.getcwd = None
os.chdir = None
import shutil
shutil.rmtree = None
shutil.move = None
shutil.chown = None
import subprocess
subprocess.Popen = None # type: ignore
__builtins__["help"] = None
import sys
sys.modules["ipdb"] = None
sys.modules["joblib"] = None
sys.modules["resource"] = None
sys.modules["psutil"] = None
sys.modules["tkinter"] = None