Spaces:
Sleeping
Sleeping
File size: 7,496 Bytes
43c34cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import argparse
import logging
import os
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Argument parser for configuring OpenAI API and experiment settings")
# Authentication details for OpenAI API
parser.add_argument(
"--openai_key", type=str, default=None,
help="API key to authenticate with OpenAI. Can be set via this argument or through the OPENAI_API_KEY environment variable."
)
parser.add_argument(
"--deployment", type=str, default=None,
help="For Azure OpenAI: the deployment name to be used when calling the API."
)
parser.add_argument(
"--openai_client_type", type=str, default="openai", choices=["openai", "azure_openai"],
help="Specify the OpenAI client type to use: 'openai' for standard OpenAI API or 'azure_openai' for Azure-hosted OpenAI services."
)
parser.add_argument(
"--endpoint", type=str, default=None,
help="For Azure OpenAI: custom endpoint to access the API. Should be in the format 'https://<your-endpoint>.openai.azure.com'."
)
parser.add_argument(
"--api_version", type=str, default="2023-05-15", help="API version to be used for making requests. Required "
"for Azure OpenAI clients."
)
# Experiment configuration
parser.add_argument(
"--ac_scoring_method", type=str, default="ranking", choices=["recommendation", "ranking"],
help="Specifies the scoring method used by the Area Chair (AC) to evaluate papers: 'recommendation' or 'ranking'."
)
parser.add_argument(
"--conference", type=str, default="ICLR2023",
help="Conference name where the papers are being evaluated, e.g., 'ICLR2023'."
)
parser.add_argument(
"--num_reviewers_per_paper", type=int, default=3, help="The number of reviewers assigned to each paper."
)
parser.add_argument(
"--experiment_name",
type=str, default=None, required=False,
help="Specifies the name of the experiment to run. Choose from predefined experiment types based on the reviewer and AC behavior or experiment configuration."
)
parser.add_argument(
"--overwrite", action="store_true",
help="If set, existing results or output files will be overwritten without prompting."
)
parser.add_argument(
"--skip_logging", action="store_true", help="If set, we do not log the messages in the console."
)
parser.add_argument(
"--num_papers_per_area_chair", type=int, default=10,
help="The number of papers each area chair is assigned for evaluation."
)
# Model configuration
parser.add_argument(
"--model_name", type=str, default="gpt-4o", choices=["gpt-4", "gpt-4o", "gpt-35-turbo"],
help="Specifies which GPT model to use: 'gpt-4' for the standard GPT-4 model, 'gpt-35-turbo' for a "
"cost-effective alternative, or 'gpt-4o' for larger context support."
)
# Output directories
parser.add_argument(
"--output_dir", type=str, default="outputs", help="Directory where results, logs, and outputs will be stored."
)
# Output directories
parser.add_argument(
"--max_num_words", type=int, default=16384, help="Maximum number of words in the paper."
)
parser.add_argument(
"--visual_dir", type=str, default="outputs/visual",
help="Directory where visualization files (such as graphs and plots) will be stored."
)
# System configuration
parser.add_argument(
"--device", type=str, default='cuda',
help="The device to be used for processing (e.g., 'cuda' for GPU acceleration or 'cpu' for standard processing)."
)
parser.add_argument(
"--data_dir", type=str, default='data', help="Directory where input data (e.g., papers) are stored."
)
parser.add_argument(
"--acceptance_rate", type=float, default=0.32,
help="Percentage of papers to accept. We use 0.32, the average acceptance rate for ICLR 2020 - 2023"
)
args = parser.parse_args()
# Ensure necessary directories exist
os.makedirs(args.visual_dir, exist_ok=True)
os.makedirs(args.output_dir, exist_ok=True)
# Set 'player_to_test' based on experiment name
if args.experiment_name is None:
args.player_to_test = None
elif "Rx" in args.experiment_name:
args.player_to_test = "Reviewer"
elif "ACx" in args.experiment_name:
args.player_to_test = "Area Chair"
elif "no_rebuttal" in args.experiment_name or "no_overall_score" in args.experiment_name:
args.player_to_test = "Review Mechanism"
# Sanity checks for authentication
print("Running sanity checks for the arguments...")
if args.openai_client_type == "openai":
if os.environ.get('OPENAI_API_KEY') is None:
assert isinstance(args.openai_key, str), ("Please specify the `--openai_key` argument OR set the "
"OPENAI_API_KEY environment variable.")
raise ValueError("OpenAI key is missing.")
EXISTING_EXPERIMENT_SETTINGS = [
"BASELINE", "benign_Rx1", "malicious_Rx1", "malicious_Rx2", "malicious_Rx3", "unknowledgeable_Rx1",
"knowledgeable_Rx1", "responsible_Rx1", "irresponsible_Rx1", "irresponsible_Rx2", "irresponsible_Rx3",
"inclusive_ACx1", "authoritarian_ACx1", "conformist_ACx1", "no_numeric_ratings"]
if args.experiment_name not in EXISTING_EXPERIMENT_SETTINGS:
logger.warning(f"Experiment name '{args.experiment_name}' is not recognized. "
f"This can happen if you are customizing your own experiment settings. "
f"Otherwise, please choose from the following: "
f"{EXISTING_EXPERIMENT_SETTINGS}")
if args.openai_client_type == "azure_openai":
if os.environ.get('AZURE_OPENAI_KEY') is None:
assert isinstance(args.openai_key, str), ("Please specify the `--openai_key` argument OR set the "
"AZURE_OPENAI_KEY environment variable.")
os.environ['AZURE_OPENAI_KEY'] = args.openai_key
if os.environ.get('AZURE_DEPLOYMENT') is None:
assert isinstance(args.deployment, str), ("Please specify the `--deployment` argument OR set the "
"AZURE_DEPLOYMENT environment variable.")
os.environ['AZURE_DEPLOYMENT'] = args.deployment
if os.environ.get('AZURE_ENDPOINT') is None:
assert isinstance(args.endpoint, str), ("Please specify the `--endpoint` argument OR set the "
"AZURE_ENDPOINT environment variable.")
endpoint = args.endpoint
else:
endpoint = os.environ.get('AZURE_ENDPOINT')
if not endpoint.startswith("https://"):
endpoint = f"https://{endpoint}.openai.azure.com"
os.environ['AZURE_ENDPOINT'] = endpoint
if os.environ.get('OPENAI_API_VERSION') is None:
assert isinstance(args.api_version, str), ("Please specify the `--api_version` argument OR set the "
"OPENAI_API_VERSION environment variable.")
os.environ['OPENAI_API_VERSION'] = args.api_version
return args
|