AgentReview / agentreview /arguments.py
USTC975's picture
create the app
43c34cc
import argparse
import logging
import os
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Argument parser for configuring OpenAI API and experiment settings")
# Authentication details for OpenAI API
parser.add_argument(
"--openai_key", type=str, default=None,
help="API key to authenticate with OpenAI. Can be set via this argument or through the OPENAI_API_KEY environment variable."
)
parser.add_argument(
"--deployment", type=str, default=None,
help="For Azure OpenAI: the deployment name to be used when calling the API."
)
parser.add_argument(
"--openai_client_type", type=str, default="openai", choices=["openai", "azure_openai"],
help="Specify the OpenAI client type to use: 'openai' for standard OpenAI API or 'azure_openai' for Azure-hosted OpenAI services."
)
parser.add_argument(
"--endpoint", type=str, default=None,
help="For Azure OpenAI: custom endpoint to access the API. Should be in the format 'https://<your-endpoint>.openai.azure.com'."
)
parser.add_argument(
"--api_version", type=str, default="2023-05-15", help="API version to be used for making requests. Required "
"for Azure OpenAI clients."
)
# Experiment configuration
parser.add_argument(
"--ac_scoring_method", type=str, default="ranking", choices=["recommendation", "ranking"],
help="Specifies the scoring method used by the Area Chair (AC) to evaluate papers: 'recommendation' or 'ranking'."
)
parser.add_argument(
"--conference", type=str, default="ICLR2023",
help="Conference name where the papers are being evaluated, e.g., 'ICLR2023'."
)
parser.add_argument(
"--num_reviewers_per_paper", type=int, default=3, help="The number of reviewers assigned to each paper."
)
parser.add_argument(
"--experiment_name",
type=str, default=None, required=False,
help="Specifies the name of the experiment to run. Choose from predefined experiment types based on the reviewer and AC behavior or experiment configuration."
)
parser.add_argument(
"--overwrite", action="store_true",
help="If set, existing results or output files will be overwritten without prompting."
)
parser.add_argument(
"--skip_logging", action="store_true", help="If set, we do not log the messages in the console."
)
parser.add_argument(
"--num_papers_per_area_chair", type=int, default=10,
help="The number of papers each area chair is assigned for evaluation."
)
# Model configuration
parser.add_argument(
"--model_name", type=str, default="gpt-4o", choices=["gpt-4", "gpt-4o", "gpt-35-turbo"],
help="Specifies which GPT model to use: 'gpt-4' for the standard GPT-4 model, 'gpt-35-turbo' for a "
"cost-effective alternative, or 'gpt-4o' for larger context support."
)
# Output directories
parser.add_argument(
"--output_dir", type=str, default="outputs", help="Directory where results, logs, and outputs will be stored."
)
# Output directories
parser.add_argument(
"--max_num_words", type=int, default=16384, help="Maximum number of words in the paper."
)
parser.add_argument(
"--visual_dir", type=str, default="outputs/visual",
help="Directory where visualization files (such as graphs and plots) will be stored."
)
# System configuration
parser.add_argument(
"--device", type=str, default='cuda',
help="The device to be used for processing (e.g., 'cuda' for GPU acceleration or 'cpu' for standard processing)."
)
parser.add_argument(
"--data_dir", type=str, default='data', help="Directory where input data (e.g., papers) are stored."
)
parser.add_argument(
"--acceptance_rate", type=float, default=0.32,
help="Percentage of papers to accept. We use 0.32, the average acceptance rate for ICLR 2020 - 2023"
)
args = parser.parse_args()
# Ensure necessary directories exist
os.makedirs(args.visual_dir, exist_ok=True)
os.makedirs(args.output_dir, exist_ok=True)
# Set 'player_to_test' based on experiment name
if args.experiment_name is None:
args.player_to_test = None
elif "Rx" in args.experiment_name:
args.player_to_test = "Reviewer"
elif "ACx" in args.experiment_name:
args.player_to_test = "Area Chair"
elif "no_rebuttal" in args.experiment_name or "no_overall_score" in args.experiment_name:
args.player_to_test = "Review Mechanism"
# Sanity checks for authentication
print("Running sanity checks for the arguments...")
if args.openai_client_type == "openai":
if os.environ.get('OPENAI_API_KEY') is None:
assert isinstance(args.openai_key, str), ("Please specify the `--openai_key` argument OR set the "
"OPENAI_API_KEY environment variable.")
raise ValueError("OpenAI key is missing.")
EXISTING_EXPERIMENT_SETTINGS = [
"BASELINE", "benign_Rx1", "malicious_Rx1", "malicious_Rx2", "malicious_Rx3", "unknowledgeable_Rx1",
"knowledgeable_Rx1", "responsible_Rx1", "irresponsible_Rx1", "irresponsible_Rx2", "irresponsible_Rx3",
"inclusive_ACx1", "authoritarian_ACx1", "conformist_ACx1", "no_numeric_ratings"]
if args.experiment_name not in EXISTING_EXPERIMENT_SETTINGS:
logger.warning(f"Experiment name '{args.experiment_name}' is not recognized. "
f"This can happen if you are customizing your own experiment settings. "
f"Otherwise, please choose from the following: "
f"{EXISTING_EXPERIMENT_SETTINGS}")
if args.openai_client_type == "azure_openai":
if os.environ.get('AZURE_OPENAI_KEY') is None:
assert isinstance(args.openai_key, str), ("Please specify the `--openai_key` argument OR set the "
"AZURE_OPENAI_KEY environment variable.")
os.environ['AZURE_OPENAI_KEY'] = args.openai_key
if os.environ.get('AZURE_DEPLOYMENT') is None:
assert isinstance(args.deployment, str), ("Please specify the `--deployment` argument OR set the "
"AZURE_DEPLOYMENT environment variable.")
os.environ['AZURE_DEPLOYMENT'] = args.deployment
if os.environ.get('AZURE_ENDPOINT') is None:
assert isinstance(args.endpoint, str), ("Please specify the `--endpoint` argument OR set the "
"AZURE_ENDPOINT environment variable.")
endpoint = args.endpoint
else:
endpoint = os.environ.get('AZURE_ENDPOINT')
if not endpoint.startswith("https://"):
endpoint = f"https://{endpoint}.openai.azure.com"
os.environ['AZURE_ENDPOINT'] = endpoint
if os.environ.get('OPENAI_API_VERSION') is None:
assert isinstance(args.api_version, str), ("Please specify the `--api_version` argument OR set the "
"OPENAI_API_VERSION environment variable.")
os.environ['OPENAI_API_VERSION'] = args.api_version
return args