File size: 7,496 Bytes
43c34cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import argparse
import logging
import os

logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser(description="Argument parser for configuring OpenAI API and experiment settings")

    # Authentication details for OpenAI API
    parser.add_argument(
        "--openai_key", type=str, default=None,
        help="API key to authenticate with OpenAI. Can be set via this argument or through the OPENAI_API_KEY environment variable."
    )

    parser.add_argument(
        "--deployment", type=str, default=None,
        help="For Azure OpenAI: the deployment name to be used when calling the API."
    )

    parser.add_argument(
        "--openai_client_type", type=str, default="openai", choices=["openai", "azure_openai"],
        help="Specify the OpenAI client type to use: 'openai' for standard OpenAI API or 'azure_openai' for Azure-hosted OpenAI services."
    )

    parser.add_argument(
        "--endpoint", type=str, default=None,
        help="For Azure OpenAI: custom endpoint to access the API. Should be in the format 'https://<your-endpoint>.openai.azure.com'."
    )

    parser.add_argument(
        "--api_version", type=str, default="2023-05-15", help="API version to be used for making requests. Required "
                                                              "for Azure OpenAI clients."
    )

    # Experiment configuration
    parser.add_argument(
        "--ac_scoring_method", type=str, default="ranking", choices=["recommendation", "ranking"],
        help="Specifies the scoring method used by the Area Chair (AC) to evaluate papers: 'recommendation' or 'ranking'."
    )

    parser.add_argument(
        "--conference", type=str, default="ICLR2023",
        help="Conference name where the papers are being evaluated, e.g., 'ICLR2023'."
    )

    parser.add_argument(
        "--num_reviewers_per_paper", type=int, default=3, help="The number of reviewers assigned to each paper."
    )

    parser.add_argument(
        "--experiment_name",
        type=str, default=None, required=False,
        help="Specifies the name of the experiment to run. Choose from predefined experiment types based on the reviewer and AC behavior or experiment configuration."
    )

    parser.add_argument(
        "--overwrite", action="store_true",
        help="If set, existing results or output files will be overwritten without prompting."
    )
    parser.add_argument(
        "--skip_logging", action="store_true", help="If set, we do not log the messages in the console."
    )

    parser.add_argument(
        "--num_papers_per_area_chair", type=int, default=10,
        help="The number of papers each area chair is assigned for evaluation."
    )

    # Model configuration
    parser.add_argument(
        "--model_name", type=str, default="gpt-4o", choices=["gpt-4", "gpt-4o", "gpt-35-turbo"],
        help="Specifies which GPT model to use: 'gpt-4' for the standard GPT-4 model, 'gpt-35-turbo' for a "
             "cost-effective alternative, or 'gpt-4o' for larger context support."
    )

    # Output directories
    parser.add_argument(
        "--output_dir", type=str, default="outputs", help="Directory where results, logs, and outputs will be stored."
    )

    # Output directories
    parser.add_argument(
        "--max_num_words", type=int, default=16384, help="Maximum number of words in the paper."
    )

    parser.add_argument(
        "--visual_dir", type=str, default="outputs/visual",
        help="Directory where visualization files (such as graphs and plots) will be stored."
    )

    # System configuration
    parser.add_argument(
        "--device", type=str, default='cuda',
        help="The device to be used for processing (e.g., 'cuda' for GPU acceleration or 'cpu' for standard processing)."
    )

    parser.add_argument(
        "--data_dir", type=str, default='data', help="Directory where input data (e.g., papers) are stored."
    )

    parser.add_argument(
        "--acceptance_rate", type=float, default=0.32,
        help="Percentage of papers to accept. We use 0.32, the average acceptance rate for ICLR 2020 - 2023"
    )

    args = parser.parse_args()

    # Ensure necessary directories exist
    os.makedirs(args.visual_dir, exist_ok=True)
    os.makedirs(args.output_dir, exist_ok=True)

    # Set 'player_to_test' based on experiment name
    if args.experiment_name is None:
        args.player_to_test = None
    elif "Rx" in args.experiment_name:
        args.player_to_test = "Reviewer"
    elif "ACx" in args.experiment_name:
        args.player_to_test = "Area Chair"
    elif "no_rebuttal" in args.experiment_name or "no_overall_score" in args.experiment_name:
        args.player_to_test = "Review Mechanism"

    # Sanity checks for authentication
    print("Running sanity checks for the arguments...")

    if args.openai_client_type == "openai":
        if os.environ.get('OPENAI_API_KEY') is None:
            assert isinstance(args.openai_key, str), ("Please specify the `--openai_key` argument OR set the "
                                                      "OPENAI_API_KEY environment variable.")
            raise ValueError("OpenAI key is missing.")

    EXISTING_EXPERIMENT_SETTINGS = [
        "BASELINE", "benign_Rx1", "malicious_Rx1", "malicious_Rx2", "malicious_Rx3", "unknowledgeable_Rx1",
        "knowledgeable_Rx1", "responsible_Rx1", "irresponsible_Rx1", "irresponsible_Rx2", "irresponsible_Rx3",
        "inclusive_ACx1", "authoritarian_ACx1", "conformist_ACx1", "no_numeric_ratings"]

    if args.experiment_name not in EXISTING_EXPERIMENT_SETTINGS:
        logger.warning(f"Experiment name '{args.experiment_name}' is not recognized. "
                       f"This can happen if you are customizing your own experiment settings. "
                       f"Otherwise, please choose from the following: "
                       f"{EXISTING_EXPERIMENT_SETTINGS}")

    if args.openai_client_type == "azure_openai":
        if os.environ.get('AZURE_OPENAI_KEY') is None:
            assert isinstance(args.openai_key, str), ("Please specify the `--openai_key` argument OR set the "
                                                      "AZURE_OPENAI_KEY environment variable.")
            os.environ['AZURE_OPENAI_KEY'] = args.openai_key

        if os.environ.get('AZURE_DEPLOYMENT') is None:
            assert isinstance(args.deployment, str), ("Please specify the `--deployment` argument OR set the "
                                                      "AZURE_DEPLOYMENT environment variable.")
            os.environ['AZURE_DEPLOYMENT'] = args.deployment

        if os.environ.get('AZURE_ENDPOINT') is None:
            assert isinstance(args.endpoint, str), ("Please specify the `--endpoint` argument OR set the "
                                                    "AZURE_ENDPOINT environment variable.")
            endpoint = args.endpoint
        else:
            endpoint = os.environ.get('AZURE_ENDPOINT')

        if not endpoint.startswith("https://"):
            endpoint = f"https://{endpoint}.openai.azure.com"
        os.environ['AZURE_ENDPOINT'] = endpoint

        if os.environ.get('OPENAI_API_VERSION') is None:
            assert isinstance(args.api_version, str), ("Please specify the `--api_version` argument OR set the "
                                                       "OPENAI_API_VERSION environment variable.")
            os.environ['OPENAI_API_VERSION'] = args.api_version

    return args