Spaces:
Sleeping
Sleeping
create the app
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +185 -0
- LICENSE +203 -0
- README.md +135 -6
- agentreview/__init__.py +8 -0
- agentreview/agent.py +255 -0
- agentreview/arena.py +201 -0
- agentreview/arguments.py +171 -0
- agentreview/backends/__init__.py +30 -0
- agentreview/backends/anthropic.py +119 -0
- agentreview/backends/bard.py +90 -0
- agentreview/backends/base.py +66 -0
- agentreview/backends/cohere.py +126 -0
- agentreview/backends/dummy.py +14 -0
- agentreview/backends/hf_transformers.py +127 -0
- agentreview/backends/human.py +23 -0
- agentreview/backends/langchain.py +169 -0
- agentreview/backends/openai.py +178 -0
- agentreview/config.py +143 -0
- agentreview/const.py +112 -0
- agentreview/database.py +136 -0
- agentreview/dataset/__init__.py +0 -0
- agentreview/dataset/download_openreview_paper.py +135 -0
- agentreview/dataset/process_submissions.py +112 -0
- agentreview/environments/__init__.py +25 -0
- agentreview/environments/base.py +188 -0
- agentreview/environments/conversation.py +198 -0
- agentreview/environments/paper_decision.py +161 -0
- agentreview/environments/paper_review.py +213 -0
- agentreview/experiment_config.py +265 -0
- agentreview/message.py +150 -0
- agentreview/paper_processor.py +163 -0
- agentreview/paper_review_arena.py +183 -0
- agentreview/paper_review_message.py +103 -0
- agentreview/paper_review_player.py +133 -0
- agentreview/paper_review_settings.py +110 -0
- agentreview/role_descriptions.py +515 -0
- agentreview/ui/__init__.py +0 -0
- agentreview/ui/cli.py +259 -0
- agentreview/utility/__init__.py +0 -0
- agentreview/utility/authentication_utils.py +44 -0
- agentreview/utility/data_utils.py +28 -0
- agentreview/utility/experiment_utils.py +84 -0
- agentreview/utility/general_utils.py +16 -0
- agentreview/utility/metrics_utils.py +17 -0
- agentreview/utility/text_utils.py +125 -0
- agentreview/utility/utils.py +582 -0
- agentreview/utils.py +116 -0
- app.py +612 -0
- docs/devdoc/design.md +39 -0
- docs/devdoc/mainloop.md +62 -0
.gitignore
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
key.py
|
4 |
+
|
5 |
+
*.pdf
|
6 |
+
*.json
|
7 |
+
*.png
|
8 |
+
*.jpg
|
9 |
+
*.jpeg
|
10 |
+
*.gif
|
11 |
+
data/
|
12 |
+
unused_data/
|
13 |
+
demo/
|
14 |
+
Summary/
|
15 |
+
|
16 |
+
# Byte-compiled / optimized / DLL files
|
17 |
+
__pycache__/
|
18 |
+
*.py[cod]
|
19 |
+
*$py.class
|
20 |
+
|
21 |
+
# C extensions
|
22 |
+
*.so
|
23 |
+
|
24 |
+
outputs
|
25 |
+
|
26 |
+
# Distribution / packaging
|
27 |
+
.Python
|
28 |
+
build/
|
29 |
+
develop-eggs/
|
30 |
+
dist/
|
31 |
+
downloads/
|
32 |
+
eggs/
|
33 |
+
.eggs/
|
34 |
+
lib/
|
35 |
+
lib64/
|
36 |
+
parts/
|
37 |
+
sdist/
|
38 |
+
var/
|
39 |
+
wheels/
|
40 |
+
pip-wheel-metadata/
|
41 |
+
share/python-wheels/
|
42 |
+
*.egg-info/
|
43 |
+
.installed.cfg
|
44 |
+
*.egg
|
45 |
+
MANIFEST
|
46 |
+
|
47 |
+
# PyInstaller
|
48 |
+
# Usually these files are written by a python script from a template
|
49 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
50 |
+
*.manifest
|
51 |
+
*.spec
|
52 |
+
|
53 |
+
# Installer logs
|
54 |
+
pip-log.txt
|
55 |
+
pip-delete-this-directory.txt
|
56 |
+
|
57 |
+
# Unit test / coverage reports
|
58 |
+
htmlcov/
|
59 |
+
.tox/
|
60 |
+
.nox/
|
61 |
+
.coverage
|
62 |
+
.coverage.*
|
63 |
+
.cache
|
64 |
+
nosetests.xml
|
65 |
+
coverage.xml
|
66 |
+
*.cover
|
67 |
+
*.py,cover
|
68 |
+
.hypothesis/
|
69 |
+
.pytest_cache/
|
70 |
+
|
71 |
+
# Translations
|
72 |
+
*.mo
|
73 |
+
*.pot
|
74 |
+
|
75 |
+
# Django stuff:
|
76 |
+
*.log
|
77 |
+
local_settings.py
|
78 |
+
db.sqlite3
|
79 |
+
db.sqlite3-journal
|
80 |
+
|
81 |
+
# Flask stuff:
|
82 |
+
instance/
|
83 |
+
.webassets-cache
|
84 |
+
|
85 |
+
# Scrapy stuff:
|
86 |
+
.scrapy
|
87 |
+
|
88 |
+
# Sphinx documentation
|
89 |
+
docs/_build/
|
90 |
+
|
91 |
+
# PyBuilder
|
92 |
+
.pybuilder/
|
93 |
+
target/
|
94 |
+
|
95 |
+
# Jupyter Notebook
|
96 |
+
.ipynb_checkpoints
|
97 |
+
|
98 |
+
# IPython
|
99 |
+
profile_default/
|
100 |
+
ipython_config.py
|
101 |
+
|
102 |
+
# pyenv
|
103 |
+
# For a library or package, you might want to ignore these files since the code is
|
104 |
+
# intended to run in multiple environments; otherwise, check them in:
|
105 |
+
.python-version
|
106 |
+
|
107 |
+
# pipenv
|
108 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
109 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
110 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
111 |
+
# install all needed dependencies.
|
112 |
+
#Pipfile.lock
|
113 |
+
|
114 |
+
# poetry
|
115 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
116 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
117 |
+
# commonly ignored for libraries.
|
118 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
119 |
+
#poetry.lock
|
120 |
+
|
121 |
+
# pdm
|
122 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
123 |
+
#pdm.lock
|
124 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
125 |
+
# in version control.
|
126 |
+
# https://pdm.fming.dev/#use-with-ide
|
127 |
+
.pdm.toml
|
128 |
+
|
129 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
130 |
+
__pypackages__/
|
131 |
+
|
132 |
+
# Celery stuff
|
133 |
+
celerybeat-schedule
|
134 |
+
celerybeat.pid
|
135 |
+
|
136 |
+
# SageMath parsed files
|
137 |
+
*.sage.py
|
138 |
+
|
139 |
+
# Environments
|
140 |
+
.env
|
141 |
+
.venv
|
142 |
+
env/
|
143 |
+
venv/
|
144 |
+
ENV/
|
145 |
+
env.bak/
|
146 |
+
venv.bak/
|
147 |
+
|
148 |
+
# Spyder project settings
|
149 |
+
.spyderproject
|
150 |
+
.spyproject
|
151 |
+
|
152 |
+
# Rope project settings
|
153 |
+
.ropeproject
|
154 |
+
|
155 |
+
# mkdocs documentation
|
156 |
+
/site
|
157 |
+
|
158 |
+
# mypy
|
159 |
+
.mypy_cache/
|
160 |
+
.dmypy.json
|
161 |
+
dmypy.json
|
162 |
+
|
163 |
+
# Pyre type checker
|
164 |
+
.pyre/
|
165 |
+
|
166 |
+
# pytype static type analyzer
|
167 |
+
.pytype/
|
168 |
+
|
169 |
+
# Cython debug symbols
|
170 |
+
cython_debug/
|
171 |
+
|
172 |
+
# PyCharm
|
173 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
174 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
175 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
176 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
177 |
+
.idea/
|
178 |
+
|
179 |
+
.DS_Store
|
180 |
+
hf-spaces/
|
181 |
+
etc/
|
182 |
+
.conda
|
183 |
+
*.xlsx
|
184 |
+
*.csv
|
185 |
+
*.zip
|
LICENSE
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright 2023 ChatArena. All rights reserved.
|
2 |
+
|
3 |
+
Apache License
|
4 |
+
Version 2.0, January 2004
|
5 |
+
http://www.apache.org/licenses/
|
6 |
+
|
7 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
8 |
+
|
9 |
+
1. Definitions.
|
10 |
+
|
11 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
12 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
13 |
+
|
14 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
15 |
+
the copyright owner that is granting the License.
|
16 |
+
|
17 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
18 |
+
other entities that control, are controlled by, or are under common
|
19 |
+
control with that entity. For the purposes of this definition,
|
20 |
+
"control" means (i) the power, direct or indirect, to cause the
|
21 |
+
direction or management of such entity, whether by contract or
|
22 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
23 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
24 |
+
|
25 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
26 |
+
exercising permissions granted by this License.
|
27 |
+
|
28 |
+
"Source" form shall mean the preferred form for making modifications,
|
29 |
+
including but not limited to software source code, documentation
|
30 |
+
source, and configuration files.
|
31 |
+
|
32 |
+
"Object" form shall mean any form resulting from mechanical
|
33 |
+
transformation or translation of a Source form, including but
|
34 |
+
not limited to compiled object code, generated documentation,
|
35 |
+
and conversions to other media types.
|
36 |
+
|
37 |
+
"Work" shall mean the work of authorship, whether in Source or
|
38 |
+
Object form, made available under the License, as indicated by a
|
39 |
+
copyright notice that is included in or attached to the work
|
40 |
+
(an example is provided in the Appendix below).
|
41 |
+
|
42 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
43 |
+
form, that is based on (or derived from) the Work and for which the
|
44 |
+
editorial revisions, annotations, elaborations, or other modifications
|
45 |
+
represent, as a whole, an original work of authorship. For the purposes
|
46 |
+
of this License, Derivative Works shall not include works that remain
|
47 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
48 |
+
the Work and Derivative Works thereof.
|
49 |
+
|
50 |
+
"Contribution" shall mean any work of authorship, including
|
51 |
+
the original version of the Work and any modifications or additions
|
52 |
+
to that Work or Derivative Works thereof, that is intentionally
|
53 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
54 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
55 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
56 |
+
means any form of electronic, verbal, or written communication sent
|
57 |
+
to the Licensor or its representatives, including but not limited to
|
58 |
+
communication on electronic mailing lists, source code control systems,
|
59 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
60 |
+
Licensor for the purpose of discussing and improving the Work, but
|
61 |
+
excluding communication that is conspicuously marked or otherwise
|
62 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
63 |
+
|
64 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
65 |
+
on behalf of whom a Contribution has been received by Licensor and
|
66 |
+
subsequently incorporated within the Work.
|
67 |
+
|
68 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
69 |
+
this License, each Contributor hereby grants to You a perpetual,
|
70 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
71 |
+
copyright license to reproduce, prepare Derivative Works of,
|
72 |
+
publicly display, publicly perform, sublicense, and distribute the
|
73 |
+
Work and such Derivative Works in Source or Object form.
|
74 |
+
|
75 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
76 |
+
this License, each Contributor hereby grants to You a perpetual,
|
77 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
78 |
+
(except as stated in this section) patent license to make, have made,
|
79 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
80 |
+
where such license applies only to those patent claims licensable
|
81 |
+
by such Contributor that are necessarily infringed by their
|
82 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
83 |
+
with the Work to which such Contribution(s) was submitted. If You
|
84 |
+
institute patent litigation against any entity (including a
|
85 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
86 |
+
or a Contribution incorporated within the Work constitutes direct
|
87 |
+
or contributory patent infringement, then any patent licenses
|
88 |
+
granted to You under this License for that Work shall terminate
|
89 |
+
as of the date such litigation is filed.
|
90 |
+
|
91 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
92 |
+
Work or Derivative Works thereof in any medium, with or without
|
93 |
+
modifications, and in Source or Object form, provided that You
|
94 |
+
meet the following conditions:
|
95 |
+
|
96 |
+
(a) You must give any other recipients of the Work or
|
97 |
+
Derivative Works a copy of this License; and
|
98 |
+
|
99 |
+
(b) You must cause any modified files to carry prominent notices
|
100 |
+
stating that You changed the files; and
|
101 |
+
|
102 |
+
(c) You must retain, in the Source form of any Derivative Works
|
103 |
+
that You distribute, all copyright, patent, trademark, and
|
104 |
+
attribution notices from the Source form of the Work,
|
105 |
+
excluding those notices that do not pertain to any part of
|
106 |
+
the Derivative Works; and
|
107 |
+
|
108 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
109 |
+
distribution, then any Derivative Works that You distribute must
|
110 |
+
include a readable copy of the attribution notices contained
|
111 |
+
within such NOTICE file, excluding those notices that do not
|
112 |
+
pertain to any part of the Derivative Works, in at least one
|
113 |
+
of the following places: within a NOTICE text file distributed
|
114 |
+
as part of the Derivative Works; within the Source form or
|
115 |
+
documentation, if provided along with the Derivative Works; or,
|
116 |
+
within a display generated by the Derivative Works, if and
|
117 |
+
wherever such third-party notices normally appear. The contents
|
118 |
+
of the NOTICE file are for informational purposes only and
|
119 |
+
do not modify the License. You may add Your own attribution
|
120 |
+
notices within Derivative Works that You distribute, alongside
|
121 |
+
or as an addendum to the NOTICE text from the Work, provided
|
122 |
+
that such additional attribution notices cannot be construed
|
123 |
+
as modifying the License.
|
124 |
+
|
125 |
+
You may add Your own copyright statement to Your modifications and
|
126 |
+
may provide additional or different license terms and conditions
|
127 |
+
for use, reproduction, or distribution of Your modifications, or
|
128 |
+
for any such Derivative Works as a whole, provided Your use,
|
129 |
+
reproduction, and distribution of the Work otherwise complies with
|
130 |
+
the conditions stated in this License.
|
131 |
+
|
132 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
133 |
+
any Contribution intentionally submitted for inclusion in the Work
|
134 |
+
by You to the Licensor shall be under the terms and conditions of
|
135 |
+
this License, without any additional terms or conditions.
|
136 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
137 |
+
the terms of any separate license agreement you may have executed
|
138 |
+
with Licensor regarding such Contributions.
|
139 |
+
|
140 |
+
6. Trademarks. This License does not grant permission to use the trade
|
141 |
+
names, trademarks, service marks, or product names of the Licensor,
|
142 |
+
except as required for reasonable and customary use in describing the
|
143 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
144 |
+
|
145 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
146 |
+
agreed to in writing, Licensor provides the Work (and each
|
147 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
148 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
149 |
+
implied, including, without limitation, any warranties or conditions
|
150 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
151 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
152 |
+
appropriateness of using or redistributing the Work and assume any
|
153 |
+
risks associated with Your exercise of permissions under this License.
|
154 |
+
|
155 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
156 |
+
whether in tort (including negligence), contract, or otherwise,
|
157 |
+
unless required by applicable law (such as deliberate and grossly
|
158 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
159 |
+
liable to You for damages, including any direct, indirect, special,
|
160 |
+
incidental, or consequential damages of any character arising as a
|
161 |
+
result of this License or out of the use or inability to use the
|
162 |
+
Work (including but not limited to damages for loss of goodwill,
|
163 |
+
work stoppage, computer failure or malfunction, or any and all
|
164 |
+
other commercial damages or losses), even if such Contributor
|
165 |
+
has been advised of the possibility of such damages.
|
166 |
+
|
167 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
168 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
169 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
170 |
+
or other liability obligations and/or rights consistent with this
|
171 |
+
License. However, in accepting such obligations, You may act only
|
172 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
173 |
+
of any other Contributor, and only if You agree to indemnify,
|
174 |
+
defend, and hold each Contributor harmless for any liability
|
175 |
+
incurred by, or claims asserted against, such Contributor by reason
|
176 |
+
of your accepting any such warranty or additional liability.
|
177 |
+
|
178 |
+
END OF TERMS AND CONDITIONS
|
179 |
+
|
180 |
+
APPENDIX: How to apply the Apache License to your work.
|
181 |
+
|
182 |
+
To apply the Apache License to your work, attach the following
|
183 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
184 |
+
replaced with your own identifying information. (Don't include
|
185 |
+
the brackets!) The text should be enclosed in the appropriate
|
186 |
+
comment syntax for the file format. We also recommend that a
|
187 |
+
file or class name and description of purpose be included on the
|
188 |
+
same "printed page" as the copyright notice for easier
|
189 |
+
identification within third-party archives.
|
190 |
+
|
191 |
+
Copyright [yyyy] [name of copyright owner]
|
192 |
+
|
193 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
194 |
+
you may not use this file except in compliance with the License.
|
195 |
+
You may obtain a copy of the License at
|
196 |
+
|
197 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
198 |
+
|
199 |
+
Unless required by applicable law or agreed to in writing, software
|
200 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
201 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
202 |
+
See the License for the specific language governing permissions and
|
203 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,14 +1,143 @@
|
|
1 |
---
|
2 |
title: AgentReview
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 5.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
11 |
-
short_description:
|
12 |
---
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
title: AgentReview
|
3 |
+
emoji: 🎓
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: pink
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 5.4.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
11 |
+
short_description: EMNLP 2024
|
12 |
---
|
13 |
|
14 |
+
# AgentReview
|
15 |
+
|
16 |
+
Official implementation for the 🔗[EMNLP 2024](https://2024.emnlp.org/) main track (Oral) paper -- [AgentReview: Exploring Peer Review Dynamics with LLM Agents](https://arxiv.org/abs/2406.12708)
|
17 |
+
|
18 |
+
* 🌐 Website: [https://agentreview.github.io/](https://agentreview.github.io/)
|
19 |
+
* 📄 Paper: [https://arxiv.org/abs/2406.12708](https://arxiv.org/abs/2406.12708)
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
```bibtex
|
24 |
+
@inproceedings{jin2024agentreview,
|
25 |
+
title={AgentReview: Exploring Peer Review Dynamics with LLM Agents},
|
26 |
+
author={Jin, Yiqiao and Zhao, Qinlin and Wang, Yiyang and Chen, Hao and Zhu, Kaijie and Xiao, Yijia and Wang, Jindong},
|
27 |
+
booktitle={EMNLP},
|
28 |
+
year={2024}
|
29 |
+
}
|
30 |
+
```
|
31 |
+
|
32 |
+
<img src="static/img/Overview.png">
|
33 |
+
|
34 |
+
---
|
35 |
+
|
36 |
+
## Introduction
|
37 |
+
|
38 |
+
AgentReview is a pioneering large language model (LLM)-based framework for simulating peer review processes, developed to analyze and address the complex, multivariate factors influencing review outcomes. Unlike traditional statistical methods, AgentReview captures latent variables while respecting the privacy of sensitive peer review data.
|
39 |
+
|
40 |
+
### Academic Abstract
|
41 |
+
|
42 |
+
Peer review is fundamental to the integrity and advancement of scientific publication. Traditional methods of peer review analyses often rely on exploration and statistics of existing peer review data, which do not adequately address the multivariate nature of the process, account for the latent variables, and are further constrained by privacy concerns due to the sensitive nature of the data. We introduce AgentReview, the first large language model (LLM) based peer review simulation
|
43 |
+
framework, which effectively disentangles the impacts of multiple latent factors and addresses the privacy issue. Our study reveals significant insights, including a notable 37.1% variation in paper decisions due to reviewers' biases, supported by sociological theories such as the social influence theory, altruism fatigue, and authority bias. We believe that this study could offer valuable insights to improve the design of peer review mechanisms.
|
44 |
+
|
45 |
+
|
46 |
+
![Review Stage Design](static/img/ReviewPipeline.png)
|
47 |
+
|
48 |
+
## Getting Started
|
49 |
+
|
50 |
+
### Installation
|
51 |
+
|
52 |
+
**Download the data**
|
53 |
+
|
54 |
+
Download both zip files in the [Dropbox](https://www.dropbox.com/scl/fo/etzu5h8kwrx8vrcaep9tt/ALCnxFt2cT9aF477d-h1-E8?rlkey=9r5ep9psp8u4yaxxo9caf5nnc&st=k946oui5&dl=0):
|
55 |
+
|
56 |
+
Unzip [AgentReview_Paper_Data.zip](https://www.dropbox.com/scl/fi/l17brtbzsy3xwflqd58ja/AgentReview_Paper_Data.zip?rlkey=vldiexmgzi7zycmz7pumgbooc&st=b6g3nkry&dl=0) under `data/`, which contains:
|
57 |
+
1. The PDF versions of the paper
|
58 |
+
2. The real-world peer review for ICLR 2020 - 2023
|
59 |
+
|
60 |
+
```bash
|
61 |
+
unzip AgentReview_Paper_Data.zip -d data/
|
62 |
+
```
|
63 |
+
|
64 |
+
(Optional) Unzip [AgentReview_LLM_Reviews.zip](https://www.dropbox.com/scl/fi/ckr0hpxyedx8u9s6235y6/AgentReview_LLM_Reviews.zip?rlkey=cgexir5xu38tm79eiph8ulbkq&st=q23x2trr&dl=0) under `outputs/`, which contains:
|
65 |
+
1. The LLM-generated reviews, (our LLM-generated dataset)
|
66 |
+
|
67 |
+
```bash
|
68 |
+
unzip AgentReview_LLM_Review.zip -d outputs/
|
69 |
+
```
|
70 |
+
|
71 |
+
**Install Required Packages**:
|
72 |
+
```
|
73 |
+
cd AgentReview/
|
74 |
+
pip install -r requirements.txt
|
75 |
+
```
|
76 |
+
|
77 |
+
3. Set environment variables
|
78 |
+
|
79 |
+
If you use OpenAI API, set OPENAI_API_KEY.
|
80 |
+
|
81 |
+
```bash
|
82 |
+
export OPENAI_API_KEY=... # Format: sk-...
|
83 |
+
```
|
84 |
+
|
85 |
+
If you use AzureOpenAI API, set the following
|
86 |
+
|
87 |
+
```bash
|
88 |
+
export AZURE_ENDPOINT=... # Format: https://<your-endpoint>.openai.azure.com/
|
89 |
+
export AZURE_DEPLOYMENT=... # Your Azure OpenAI deployment here
|
90 |
+
export AZURE_OPENAI_KEY=... # Your Azure OpenAI key here
|
91 |
+
```
|
92 |
+
|
93 |
+
**Running the Project**
|
94 |
+
|
95 |
+
Set the environment variables in `run.sh` and run it:
|
96 |
+
|
97 |
+
```bash
|
98 |
+
bash run.sh
|
99 |
+
```
|
100 |
+
|
101 |
+
**Note: all project files should be run from the `AgentReview` directory.**
|
102 |
+
|
103 |
+
**Demo**
|
104 |
+
|
105 |
+
A demo can be found in `notebooks/demo.ipynb`
|
106 |
+
|
107 |
+
### Customizing your own environment
|
108 |
+
|
109 |
+
You can add a new setting in `agentreview/experiment_config.py`, then add the setting as a new entry to the `all_settings` dictionary:
|
110 |
+
|
111 |
+
```python
|
112 |
+
all_settings = {
|
113 |
+
"BASELINE": baseline_setting,
|
114 |
+
"benign_Rx1": benign_Rx1_setting,
|
115 |
+
...
|
116 |
+
"your_setting_name": your_setting
|
117 |
+
```
|
118 |
+
|
119 |
+
## Framework Overview
|
120 |
+
|
121 |
+
### Stage Design
|
122 |
+
|
123 |
+
Our simulation adopts a structured, 5-phase pipeline
|
124 |
+
|
125 |
+
* **Phase I. Reviewer Assessment.** Each manuscript is evaluated by three reviewers independently.
|
126 |
+
* **Phase II. Author-Reviewer Discussion.** Authors submit rebuttals to address reviewers' concerns;
|
127 |
+
* **Phase III. Reviewer-AC Discussion.** The AC facilitates discussions among reviewers, prompting updates to their initial assessments.
|
128 |
+
* **Phase IV. Meta-Review Compilation.** The AC synthesizes the discussions into a meta-review.
|
129 |
+
* **Phase V. Paper Decision.** The AC makes the final decision on whether to accept or reject the paper, based on all gathered inputs.
|
130 |
+
|
131 |
+
## Note
|
132 |
+
|
133 |
+
- We use a fixed acceptance rate of 32%, corresponding to the actual acceptance rate of ICLR 2020 -- 2023. See [Conference Acceptance Rates](https://github.com/lixin4ever/Conference-Acceptance-Rate) for more information.
|
134 |
+
- Sometimes the API can apply strict filtering to the request. You may need to adjust the content filtering to get the desired results.
|
135 |
+
|
136 |
+
|
137 |
+
## License
|
138 |
+
|
139 |
+
This project is licensed under the Apache-2.0 License.
|
140 |
+
|
141 |
+
## Acknowledgements
|
142 |
+
|
143 |
+
The implementation is partially based on the [chatarena](https://github.com/Farama-Foundation/chatarena) framework.
|
agentreview/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
ROOT_DIR = (
|
4 |
+
os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) + os.path.sep
|
5 |
+
)
|
6 |
+
EXAMPLES_DIR = os.path.join(ROOT_DIR, "examples")
|
7 |
+
|
8 |
+
__version__ = "0.1.16"
|
agentreview/agent.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import re
|
3 |
+
import uuid
|
4 |
+
from abc import abstractmethod
|
5 |
+
from argparse import Namespace
|
6 |
+
from typing import List, Union
|
7 |
+
|
8 |
+
from tenacity import RetryError
|
9 |
+
|
10 |
+
from .backends import IntelligenceBackend, load_backend
|
11 |
+
from .config import AgentConfig, BackendConfig, Configurable
|
12 |
+
from .message import SYSTEM_NAME, Message
|
13 |
+
|
14 |
+
# A special signal sent by the player to indicate that it is not possible to continue the conversation, and it requests to end the conversation.
|
15 |
+
# It contains a random UUID string to avoid being exploited by any of the players.
|
16 |
+
SIGNAL_END_OF_CONVERSATION = f"<<<<<<END_OF_CONVERSATION>>>>>>{uuid.uuid4()}"
|
17 |
+
|
18 |
+
|
19 |
+
class Agent(Configurable):
|
20 |
+
"""An abstract base class for all the agents in the chatArena environment."""
|
21 |
+
|
22 |
+
@abstractmethod
|
23 |
+
def __init__(
|
24 |
+
self, name: str, role_desc: str, global_prompt: str = None, *args, **kwargs
|
25 |
+
):
|
26 |
+
"""
|
27 |
+
Initialize the agent.
|
28 |
+
|
29 |
+
Parameters:
|
30 |
+
name (str): The name of the agent.
|
31 |
+
role_desc (str): Description of the agent's role.
|
32 |
+
global_prompt (str): A universal prompt that applies to all agents. Defaults to None.
|
33 |
+
"""
|
34 |
+
super().__init__(
|
35 |
+
name=name, role_desc=role_desc, global_prompt=global_prompt, **kwargs
|
36 |
+
)
|
37 |
+
self.name = name
|
38 |
+
self.role_desc = role_desc
|
39 |
+
self.global_prompt = global_prompt
|
40 |
+
|
41 |
+
|
42 |
+
class Player(Agent):
|
43 |
+
"""
|
44 |
+
The Player class represents a player in the chatArena environment.
|
45 |
+
|
46 |
+
A player can observe the environment
|
47 |
+
and perform an action (generate a response) based on the observation.
|
48 |
+
"""
|
49 |
+
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
name: str,
|
53 |
+
role_desc: str,
|
54 |
+
backend: Union[BackendConfig, IntelligenceBackend],
|
55 |
+
global_prompt: str = None,
|
56 |
+
args: Namespace = None,
|
57 |
+
**kwargs,
|
58 |
+
):
|
59 |
+
"""
|
60 |
+
Initialize the player with a name, role description, backend, and a global prompt.
|
61 |
+
|
62 |
+
Parameters:
|
63 |
+
name (str): The name of the player.
|
64 |
+
role_desc (str): Description of the player's role.
|
65 |
+
backend (Union[BackendConfig, IntelligenceBackend]): The backend that will be used for decision making. It can be either a LLM backend or a Human backend.
|
66 |
+
global_prompt (str): A universal prompt that applies to all players. Defaults to None.
|
67 |
+
"""
|
68 |
+
|
69 |
+
self.data_dir = kwargs.pop("data_dir", None)
|
70 |
+
self.args = args
|
71 |
+
|
72 |
+
|
73 |
+
if isinstance(backend, BackendConfig):
|
74 |
+
backend_config = backend
|
75 |
+
backend_config['openai_client_type'] = args.openai_client_type
|
76 |
+
backend = load_backend(backend_config)
|
77 |
+
elif isinstance(backend, IntelligenceBackend):
|
78 |
+
backend_config = backend.to_config()
|
79 |
+
else:
|
80 |
+
raise ValueError(
|
81 |
+
f"backend must be a BackendConfig or an IntelligenceBackend, but got {type(backend)}"
|
82 |
+
)
|
83 |
+
|
84 |
+
assert (
|
85 |
+
name != SYSTEM_NAME
|
86 |
+
), f"Player name cannot be {SYSTEM_NAME}, which is reserved for the system."
|
87 |
+
|
88 |
+
# Register the fields in the _config
|
89 |
+
super().__init__(
|
90 |
+
name=name,
|
91 |
+
role_desc=role_desc,
|
92 |
+
backend=backend_config,
|
93 |
+
global_prompt=global_prompt,
|
94 |
+
**kwargs,
|
95 |
+
)
|
96 |
+
|
97 |
+
self.backend = backend
|
98 |
+
|
99 |
+
def to_config(self) -> AgentConfig:
|
100 |
+
return AgentConfig(
|
101 |
+
name=self.name,
|
102 |
+
role_desc=self.role_desc,
|
103 |
+
backend=self.backend.to_config(),
|
104 |
+
global_prompt=self.global_prompt,
|
105 |
+
)
|
106 |
+
|
107 |
+
def act(self, observation: List[Message]) -> str:
|
108 |
+
"""
|
109 |
+
Take an action based on the observation (Generate a response), which can later be parsed to actual actions that affect the game dynamics.
|
110 |
+
|
111 |
+
Parameters:
|
112 |
+
observation (List[Message]): The messages that the player has observed from the environment.
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
str: The action (response) of the player.
|
116 |
+
"""
|
117 |
+
try:
|
118 |
+
response = self.backend.query(
|
119 |
+
agent_name=self.name,
|
120 |
+
role_desc=self.role_desc,
|
121 |
+
history_messages=observation,
|
122 |
+
global_prompt=self.global_prompt,
|
123 |
+
request_msg=None,
|
124 |
+
)
|
125 |
+
except RetryError as e:
|
126 |
+
err_msg = f"Agent {self.name} failed to generate a response. Error: {e.last_attempt.exception()}. Sending signal to end the conversation."
|
127 |
+
logging.warning(err_msg)
|
128 |
+
response = SIGNAL_END_OF_CONVERSATION + err_msg
|
129 |
+
|
130 |
+
return response
|
131 |
+
|
132 |
+
def __call__(self, observation: List[Message]) -> str:
|
133 |
+
return self.act(observation)
|
134 |
+
|
135 |
+
async def async_act(self, observation: List[Message]) -> str:
|
136 |
+
"""
|
137 |
+
Async version of act().
|
138 |
+
|
139 |
+
This is used when you want to generate a response asynchronously.
|
140 |
+
|
141 |
+
Parameters:
|
142 |
+
observation (List[Message]): The messages that the player has observed from the environment.
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
str: The action (response) of the player.
|
146 |
+
"""
|
147 |
+
try:
|
148 |
+
response = self.backend.async_query(
|
149 |
+
agent_name=self.name,
|
150 |
+
role_desc=self.role_desc,
|
151 |
+
history_messages=observation,
|
152 |
+
global_prompt=self.global_prompt,
|
153 |
+
request_msg=None,
|
154 |
+
)
|
155 |
+
except RetryError as e:
|
156 |
+
err_msg = f"Agent {self.name} failed to generate a response. Error: {e.last_attempt.exception()}. Sending signal to end the conversation."
|
157 |
+
logging.warning(err_msg)
|
158 |
+
response = SIGNAL_END_OF_CONVERSATION + err_msg
|
159 |
+
|
160 |
+
return response
|
161 |
+
|
162 |
+
def reset(self):
|
163 |
+
"""
|
164 |
+
Reset the player's backend in case they are not stateless.
|
165 |
+
|
166 |
+
This is usually called at the end of each episode.
|
167 |
+
"""
|
168 |
+
self.backend.reset()
|
169 |
+
|
170 |
+
|
171 |
+
class Moderator(Player):
|
172 |
+
"""
|
173 |
+
The Moderator class represents a special type of player that moderates the conversation.
|
174 |
+
|
175 |
+
It is usually used as a component of the environment when the transition dynamics is conditioned on natural language that are not easy to parse programmatically.
|
176 |
+
"""
|
177 |
+
|
178 |
+
def __init__(
|
179 |
+
self,
|
180 |
+
role_desc: str,
|
181 |
+
backend: Union[BackendConfig, IntelligenceBackend],
|
182 |
+
terminal_condition: str,
|
183 |
+
global_prompt: str = None,
|
184 |
+
**kwargs,
|
185 |
+
):
|
186 |
+
"""
|
187 |
+
Initialize the moderator with a role description, backend, terminal condition, and a global prompt.
|
188 |
+
|
189 |
+
Parameters:
|
190 |
+
role_desc (str): Description of the moderator's role.
|
191 |
+
backend (Union[BackendConfig, IntelligenceBackend]): The backend that will be used for decision making.
|
192 |
+
terminal_condition (str): The condition that signifies the end of the conversation.
|
193 |
+
global_prompt (str): A universal prompt that applies to the moderator. Defaults to None.
|
194 |
+
"""
|
195 |
+
name = "Moderator"
|
196 |
+
super().__init__(
|
197 |
+
name=name,
|
198 |
+
role_desc=role_desc,
|
199 |
+
backend=backend,
|
200 |
+
global_prompt=global_prompt,
|
201 |
+
**kwargs,
|
202 |
+
)
|
203 |
+
|
204 |
+
self.terminal_condition = terminal_condition
|
205 |
+
|
206 |
+
def to_config(self) -> AgentConfig:
|
207 |
+
return AgentConfig(
|
208 |
+
name=self.name,
|
209 |
+
role_desc=self.role_desc,
|
210 |
+
backend=self.backend.to_config(),
|
211 |
+
terminal_condition=self.terminal_condition,
|
212 |
+
global_prompt=self.global_prompt,
|
213 |
+
)
|
214 |
+
|
215 |
+
def is_terminal(self, history: List[Message], *args, **kwargs) -> bool:
|
216 |
+
"""
|
217 |
+
Check whether an episode is terminated based on the terminal condition.
|
218 |
+
|
219 |
+
Parameters:
|
220 |
+
history (List[Message]): The conversation history.
|
221 |
+
|
222 |
+
Returns:
|
223 |
+
bool: True if the conversation is over, otherwise False.
|
224 |
+
"""
|
225 |
+
# If the last message is the signal, then the conversation is over
|
226 |
+
if history[-1].content == SIGNAL_END_OF_CONVERSATION:
|
227 |
+
return True
|
228 |
+
|
229 |
+
try:
|
230 |
+
request_msg = Message(
|
231 |
+
agent_name=self.name, content=self.terminal_condition, turn=-1
|
232 |
+
)
|
233 |
+
response = self.backend.query(
|
234 |
+
agent_name=self.name,
|
235 |
+
role_desc=self.role_desc,
|
236 |
+
history_messages=history,
|
237 |
+
global_prompt=self.global_prompt,
|
238 |
+
request_msg=request_msg,
|
239 |
+
*args,
|
240 |
+
**kwargs,
|
241 |
+
)
|
242 |
+
except RetryError as e:
|
243 |
+
logging.warning(
|
244 |
+
f"Agent {self.name} failed to generate a response. "
|
245 |
+
f"Error: {e.last_attempt.exception()}."
|
246 |
+
)
|
247 |
+
return True
|
248 |
+
|
249 |
+
if re.match(
|
250 |
+
r"yes|y|yea|yeah|yep|yup|sure|ok|okay|alright", response, re.IGNORECASE
|
251 |
+
):
|
252 |
+
# print(f"Decision: {response}. Conversation is ended by moderator.")
|
253 |
+
return True
|
254 |
+
else:
|
255 |
+
return False
|
agentreview/arena.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import uuid
|
5 |
+
from typing import Dict, List, Union
|
6 |
+
|
7 |
+
from .agent import Player
|
8 |
+
from .backends import Human
|
9 |
+
from .config import ArenaConfig
|
10 |
+
from .environments import Environment, TimeStep, load_environment
|
11 |
+
|
12 |
+
|
13 |
+
class TooManyInvalidActions(Exception):
|
14 |
+
pass
|
15 |
+
|
16 |
+
|
17 |
+
class Arena:
|
18 |
+
"""Utility class that manages the game environment and players."""
|
19 |
+
|
20 |
+
def __init__(
|
21 |
+
self, players: List[Player], environment: Environment, args, global_prompt: str = None
|
22 |
+
):
|
23 |
+
# Create a container for the players and environment and reset the game
|
24 |
+
self.players = players
|
25 |
+
self.environment = environment
|
26 |
+
self.global_prompt = global_prompt
|
27 |
+
|
28 |
+
self.current_timestep = environment.reset()
|
29 |
+
self.uuid = uuid.uuid4() # Generate a unique id for the game
|
30 |
+
self.invalid_actions_retry = 5
|
31 |
+
self.args = args
|
32 |
+
|
33 |
+
@property
|
34 |
+
def num_players(self):
|
35 |
+
return self.environment.num_players
|
36 |
+
|
37 |
+
@property
|
38 |
+
def name_to_player(self) -> Dict[str, Player]:
|
39 |
+
return {player.name: player for player in self.players}
|
40 |
+
|
41 |
+
def reset(self) -> TimeStep:
|
42 |
+
# Reset the environment
|
43 |
+
self.current_timestep = self.environment.reset()
|
44 |
+
# Reset the players
|
45 |
+
for player in self.players:
|
46 |
+
player.reset()
|
47 |
+
# Reset the uuid
|
48 |
+
self.uuid = uuid.uuid4()
|
49 |
+
return self.current_timestep
|
50 |
+
|
51 |
+
def step(self) -> TimeStep:
|
52 |
+
"""Take a step in the game: one player takes an action and the environment updates."""
|
53 |
+
player_name = self.environment.get_next_player()
|
54 |
+
player = self.name_to_player[player_name] # get the player object
|
55 |
+
observation = self.environment.get_observation(
|
56 |
+
player_name
|
57 |
+
) # get the observation for the player
|
58 |
+
|
59 |
+
timestep = None
|
60 |
+
for i in range(
|
61 |
+
self.invalid_actions_retry
|
62 |
+
): # try to take an action for a few times
|
63 |
+
action = player(observation) # take an action
|
64 |
+
if self.environment.check_action(action, player_name): # action is valid
|
65 |
+
timestep = self.environment.step(
|
66 |
+
player_name, action
|
67 |
+
) # update the environment
|
68 |
+
break
|
69 |
+
else: # action is invalid
|
70 |
+
logging.warning(f"{player_name} made an invalid action {action}")
|
71 |
+
continue
|
72 |
+
|
73 |
+
if (
|
74 |
+
timestep is None
|
75 |
+
): # if the player made invalid actions for too many times, terminate the game
|
76 |
+
warning_msg = f"{player_name} has made invalid actions for {self.invalid_actions_retry} times. Terminating the game."
|
77 |
+
logging.warning(warning_msg)
|
78 |
+
raise TooManyInvalidActions(warning_msg)
|
79 |
+
|
80 |
+
return timestep
|
81 |
+
|
82 |
+
def next_is_human(self):
|
83 |
+
"""Check if the next player is human."""
|
84 |
+
player_name = self.environment.get_next_player()
|
85 |
+
player = self.name_to_player[player_name]
|
86 |
+
return isinstance(player.backend, Human)
|
87 |
+
|
88 |
+
def run(self, num_steps: int = 1):
|
89 |
+
"""Run the game for num_turns."""
|
90 |
+
for i in range(num_steps):
|
91 |
+
timestep = self.step()
|
92 |
+
if timestep.terminal:
|
93 |
+
break
|
94 |
+
|
95 |
+
@classmethod
|
96 |
+
def from_config(cls, config: Union[str, ArenaConfig]):
|
97 |
+
"""Create an arena from a config."""
|
98 |
+
# If config is a path, load the config
|
99 |
+
if isinstance(config, str):
|
100 |
+
config = ArenaConfig.load(config)
|
101 |
+
|
102 |
+
global_prompt = config.get("global_prompt", None)
|
103 |
+
|
104 |
+
# Create the players
|
105 |
+
players = []
|
106 |
+
for player_config in config.players:
|
107 |
+
# Add public_prompt to the player config
|
108 |
+
if global_prompt is not None:
|
109 |
+
player_config["global_prompt"] = global_prompt
|
110 |
+
|
111 |
+
player = Player.from_config(player_config)
|
112 |
+
players.append(player)
|
113 |
+
|
114 |
+
# Check that the player names are unique
|
115 |
+
player_names = [player.name for player in players]
|
116 |
+
assert len(player_names) == len(
|
117 |
+
set(player_names)
|
118 |
+
), "Player names must be unique"
|
119 |
+
|
120 |
+
# Create the environment
|
121 |
+
config.environment[
|
122 |
+
"player_names"
|
123 |
+
] = player_names # add the player names to the environment config
|
124 |
+
env = load_environment(config.environment)
|
125 |
+
|
126 |
+
return cls(players, env, global_prompt=global_prompt)
|
127 |
+
|
128 |
+
def to_config(self) -> ArenaConfig:
|
129 |
+
"""Convert the arena to a config."""
|
130 |
+
# return {
|
131 |
+
# "players": [player.to_config() for player in self.players],
|
132 |
+
# "environment": self.environment.to_config(),
|
133 |
+
# "global_prompt": self.global_prompt
|
134 |
+
# }
|
135 |
+
return ArenaConfig(
|
136 |
+
players=[player.to_config() for player in self.players],
|
137 |
+
environment=self.environment.to_config(),
|
138 |
+
global_prompt=self.global_prompt,
|
139 |
+
)
|
140 |
+
|
141 |
+
def launch_cli(self, max_steps: int = None, interactive: bool = True):
|
142 |
+
"""Launch the command line interface."""
|
143 |
+
from agentreview.ui.cli import ArenaCLI
|
144 |
+
|
145 |
+
cli = ArenaCLI(self)
|
146 |
+
cli.launch(max_steps=max_steps, interactive=interactive)
|
147 |
+
|
148 |
+
def save_config(self, path: str):
|
149 |
+
"""Save the config to a file."""
|
150 |
+
config = self.to_config()
|
151 |
+
config.save(path)
|
152 |
+
|
153 |
+
def save_history(self, path: str):
|
154 |
+
"""
|
155 |
+
Save the history of the game to a file.
|
156 |
+
|
157 |
+
Supports csv and json formats.
|
158 |
+
"""
|
159 |
+
messages = self.environment.get_observation()
|
160 |
+
message_rows = []
|
161 |
+
|
162 |
+
if path.endswith(".csv"):
|
163 |
+
header = [
|
164 |
+
"agent_name",
|
165 |
+
"content",
|
166 |
+
"turn",
|
167 |
+
"timestamp",
|
168 |
+
"visible_to",
|
169 |
+
"msg_type",
|
170 |
+
]
|
171 |
+
for message in messages:
|
172 |
+
message_row = [
|
173 |
+
message.agent_name,
|
174 |
+
message.content,
|
175 |
+
message.turn,
|
176 |
+
str(message.timestamp),
|
177 |
+
message.visible_to,
|
178 |
+
message.msg_type,
|
179 |
+
]
|
180 |
+
message_rows.append(message_row)
|
181 |
+
|
182 |
+
with open(path, "w") as f:
|
183 |
+
writer = csv.writer(f)
|
184 |
+
writer.writerow(header)
|
185 |
+
writer.writerows(message_rows)
|
186 |
+
elif path.endswith(".json"):
|
187 |
+
for message in messages:
|
188 |
+
message_row = {
|
189 |
+
"agent_name": message.agent_name,
|
190 |
+
"content": message.content,
|
191 |
+
"turn": message.turn,
|
192 |
+
"timestamp": str(message.timestamp),
|
193 |
+
"visible_to": message.visible_to,
|
194 |
+
"msg_type": message.msg_type,
|
195 |
+
}
|
196 |
+
message_rows.append(message_row)
|
197 |
+
|
198 |
+
with open(path, "w") as f:
|
199 |
+
json.dump(message_rows, f, indent=2)
|
200 |
+
else:
|
201 |
+
raise ValueError("Invalid file format")
|
agentreview/arguments.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
|
5 |
+
logger = logging.getLogger(__name__)
|
6 |
+
|
7 |
+
|
8 |
+
def parse_args():
|
9 |
+
parser = argparse.ArgumentParser(description="Argument parser for configuring OpenAI API and experiment settings")
|
10 |
+
|
11 |
+
# Authentication details for OpenAI API
|
12 |
+
parser.add_argument(
|
13 |
+
"--openai_key", type=str, default=None,
|
14 |
+
help="API key to authenticate with OpenAI. Can be set via this argument or through the OPENAI_API_KEY environment variable."
|
15 |
+
)
|
16 |
+
|
17 |
+
parser.add_argument(
|
18 |
+
"--deployment", type=str, default=None,
|
19 |
+
help="For Azure OpenAI: the deployment name to be used when calling the API."
|
20 |
+
)
|
21 |
+
|
22 |
+
parser.add_argument(
|
23 |
+
"--openai_client_type", type=str, default="openai", choices=["openai", "azure_openai"],
|
24 |
+
help="Specify the OpenAI client type to use: 'openai' for standard OpenAI API or 'azure_openai' for Azure-hosted OpenAI services."
|
25 |
+
)
|
26 |
+
|
27 |
+
parser.add_argument(
|
28 |
+
"--endpoint", type=str, default=None,
|
29 |
+
help="For Azure OpenAI: custom endpoint to access the API. Should be in the format 'https://<your-endpoint>.openai.azure.com'."
|
30 |
+
)
|
31 |
+
|
32 |
+
parser.add_argument(
|
33 |
+
"--api_version", type=str, default="2023-05-15", help="API version to be used for making requests. Required "
|
34 |
+
"for Azure OpenAI clients."
|
35 |
+
)
|
36 |
+
|
37 |
+
# Experiment configuration
|
38 |
+
parser.add_argument(
|
39 |
+
"--ac_scoring_method", type=str, default="ranking", choices=["recommendation", "ranking"],
|
40 |
+
help="Specifies the scoring method used by the Area Chair (AC) to evaluate papers: 'recommendation' or 'ranking'."
|
41 |
+
)
|
42 |
+
|
43 |
+
parser.add_argument(
|
44 |
+
"--conference", type=str, default="ICLR2023",
|
45 |
+
help="Conference name where the papers are being evaluated, e.g., 'ICLR2023'."
|
46 |
+
)
|
47 |
+
|
48 |
+
parser.add_argument(
|
49 |
+
"--num_reviewers_per_paper", type=int, default=3, help="The number of reviewers assigned to each paper."
|
50 |
+
)
|
51 |
+
|
52 |
+
parser.add_argument(
|
53 |
+
"--experiment_name",
|
54 |
+
type=str, default=None, required=False,
|
55 |
+
help="Specifies the name of the experiment to run. Choose from predefined experiment types based on the reviewer and AC behavior or experiment configuration."
|
56 |
+
)
|
57 |
+
|
58 |
+
parser.add_argument(
|
59 |
+
"--overwrite", action="store_true",
|
60 |
+
help="If set, existing results or output files will be overwritten without prompting."
|
61 |
+
)
|
62 |
+
parser.add_argument(
|
63 |
+
"--skip_logging", action="store_true", help="If set, we do not log the messages in the console."
|
64 |
+
)
|
65 |
+
|
66 |
+
parser.add_argument(
|
67 |
+
"--num_papers_per_area_chair", type=int, default=10,
|
68 |
+
help="The number of papers each area chair is assigned for evaluation."
|
69 |
+
)
|
70 |
+
|
71 |
+
# Model configuration
|
72 |
+
parser.add_argument(
|
73 |
+
"--model_name", type=str, default="gpt-4o", choices=["gpt-4", "gpt-4o", "gpt-35-turbo"],
|
74 |
+
help="Specifies which GPT model to use: 'gpt-4' for the standard GPT-4 model, 'gpt-35-turbo' for a "
|
75 |
+
"cost-effective alternative, or 'gpt-4o' for larger context support."
|
76 |
+
)
|
77 |
+
|
78 |
+
# Output directories
|
79 |
+
parser.add_argument(
|
80 |
+
"--output_dir", type=str, default="outputs", help="Directory where results, logs, and outputs will be stored."
|
81 |
+
)
|
82 |
+
|
83 |
+
# Output directories
|
84 |
+
parser.add_argument(
|
85 |
+
"--max_num_words", type=int, default=16384, help="Maximum number of words in the paper."
|
86 |
+
)
|
87 |
+
|
88 |
+
parser.add_argument(
|
89 |
+
"--visual_dir", type=str, default="outputs/visual",
|
90 |
+
help="Directory where visualization files (such as graphs and plots) will be stored."
|
91 |
+
)
|
92 |
+
|
93 |
+
# System configuration
|
94 |
+
parser.add_argument(
|
95 |
+
"--device", type=str, default='cuda',
|
96 |
+
help="The device to be used for processing (e.g., 'cuda' for GPU acceleration or 'cpu' for standard processing)."
|
97 |
+
)
|
98 |
+
|
99 |
+
parser.add_argument(
|
100 |
+
"--data_dir", type=str, default='data', help="Directory where input data (e.g., papers) are stored."
|
101 |
+
)
|
102 |
+
|
103 |
+
parser.add_argument(
|
104 |
+
"--acceptance_rate", type=float, default=0.32,
|
105 |
+
help="Percentage of papers to accept. We use 0.32, the average acceptance rate for ICLR 2020 - 2023"
|
106 |
+
)
|
107 |
+
|
108 |
+
args = parser.parse_args()
|
109 |
+
|
110 |
+
# Ensure necessary directories exist
|
111 |
+
os.makedirs(args.visual_dir, exist_ok=True)
|
112 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
113 |
+
|
114 |
+
# Set 'player_to_test' based on experiment name
|
115 |
+
if args.experiment_name is None:
|
116 |
+
args.player_to_test = None
|
117 |
+
elif "Rx" in args.experiment_name:
|
118 |
+
args.player_to_test = "Reviewer"
|
119 |
+
elif "ACx" in args.experiment_name:
|
120 |
+
args.player_to_test = "Area Chair"
|
121 |
+
elif "no_rebuttal" in args.experiment_name or "no_overall_score" in args.experiment_name:
|
122 |
+
args.player_to_test = "Review Mechanism"
|
123 |
+
|
124 |
+
# Sanity checks for authentication
|
125 |
+
print("Running sanity checks for the arguments...")
|
126 |
+
|
127 |
+
if args.openai_client_type == "openai":
|
128 |
+
if os.environ.get('OPENAI_API_KEY') is None:
|
129 |
+
assert isinstance(args.openai_key, str), ("Please specify the `--openai_key` argument OR set the "
|
130 |
+
"OPENAI_API_KEY environment variable.")
|
131 |
+
raise ValueError("OpenAI key is missing.")
|
132 |
+
|
133 |
+
EXISTING_EXPERIMENT_SETTINGS = [
|
134 |
+
"BASELINE", "benign_Rx1", "malicious_Rx1", "malicious_Rx2", "malicious_Rx3", "unknowledgeable_Rx1",
|
135 |
+
"knowledgeable_Rx1", "responsible_Rx1", "irresponsible_Rx1", "irresponsible_Rx2", "irresponsible_Rx3",
|
136 |
+
"inclusive_ACx1", "authoritarian_ACx1", "conformist_ACx1", "no_numeric_ratings"]
|
137 |
+
|
138 |
+
if args.experiment_name not in EXISTING_EXPERIMENT_SETTINGS:
|
139 |
+
logger.warning(f"Experiment name '{args.experiment_name}' is not recognized. "
|
140 |
+
f"This can happen if you are customizing your own experiment settings. "
|
141 |
+
f"Otherwise, please choose from the following: "
|
142 |
+
f"{EXISTING_EXPERIMENT_SETTINGS}")
|
143 |
+
|
144 |
+
if args.openai_client_type == "azure_openai":
|
145 |
+
if os.environ.get('AZURE_OPENAI_KEY') is None:
|
146 |
+
assert isinstance(args.openai_key, str), ("Please specify the `--openai_key` argument OR set the "
|
147 |
+
"AZURE_OPENAI_KEY environment variable.")
|
148 |
+
os.environ['AZURE_OPENAI_KEY'] = args.openai_key
|
149 |
+
|
150 |
+
if os.environ.get('AZURE_DEPLOYMENT') is None:
|
151 |
+
assert isinstance(args.deployment, str), ("Please specify the `--deployment` argument OR set the "
|
152 |
+
"AZURE_DEPLOYMENT environment variable.")
|
153 |
+
os.environ['AZURE_DEPLOYMENT'] = args.deployment
|
154 |
+
|
155 |
+
if os.environ.get('AZURE_ENDPOINT') is None:
|
156 |
+
assert isinstance(args.endpoint, str), ("Please specify the `--endpoint` argument OR set the "
|
157 |
+
"AZURE_ENDPOINT environment variable.")
|
158 |
+
endpoint = args.endpoint
|
159 |
+
else:
|
160 |
+
endpoint = os.environ.get('AZURE_ENDPOINT')
|
161 |
+
|
162 |
+
if not endpoint.startswith("https://"):
|
163 |
+
endpoint = f"https://{endpoint}.openai.azure.com"
|
164 |
+
os.environ['AZURE_ENDPOINT'] = endpoint
|
165 |
+
|
166 |
+
if os.environ.get('OPENAI_API_VERSION') is None:
|
167 |
+
assert isinstance(args.api_version, str), ("Please specify the `--api_version` argument OR set the "
|
168 |
+
"OPENAI_API_VERSION environment variable.")
|
169 |
+
os.environ['OPENAI_API_VERSION'] = args.api_version
|
170 |
+
|
171 |
+
return args
|
agentreview/backends/__init__.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..config import BackendConfig
|
2 |
+
from .anthropic import Claude
|
3 |
+
from .base import IntelligenceBackend
|
4 |
+
from .cohere import CohereAIChat
|
5 |
+
from .hf_transformers import TransformersConversational
|
6 |
+
from .human import Human
|
7 |
+
from .openai import OpenAIChat
|
8 |
+
from .dummy import Dummy
|
9 |
+
|
10 |
+
ALL_BACKENDS = [
|
11 |
+
Human,
|
12 |
+
OpenAIChat,
|
13 |
+
CohereAIChat,
|
14 |
+
TransformersConversational,
|
15 |
+
Claude,
|
16 |
+
Dummy,
|
17 |
+
]
|
18 |
+
|
19 |
+
BACKEND_REGISTRY = {backend.type_name: backend for backend in ALL_BACKENDS}
|
20 |
+
|
21 |
+
|
22 |
+
# Load a backend from a config dictionary
|
23 |
+
def load_backend(config: BackendConfig):
|
24 |
+
try:
|
25 |
+
backend_cls = BACKEND_REGISTRY[config.backend_type]
|
26 |
+
except KeyError:
|
27 |
+
raise ValueError(f"Unknown backend type: {config.backend_type}")
|
28 |
+
|
29 |
+
backend = backend_cls.from_config(config)
|
30 |
+
return backend
|
agentreview/backends/anthropic.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
6 |
+
|
7 |
+
from ..message import SYSTEM_NAME as SYSTEM
|
8 |
+
from ..message import Message
|
9 |
+
from .base import IntelligenceBackend
|
10 |
+
|
11 |
+
try:
|
12 |
+
import anthropic
|
13 |
+
except ImportError:
|
14 |
+
is_anthropic_available = False
|
15 |
+
# logging.warning("anthropic package is not installed")
|
16 |
+
else:
|
17 |
+
anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY")
|
18 |
+
if anthropic_api_key is None:
|
19 |
+
# logging.warning("Anthropic API key is not set. Please set the environment variable ANTHROPIC_API_KEY")
|
20 |
+
is_anthropic_available = False
|
21 |
+
else:
|
22 |
+
is_anthropic_available = True
|
23 |
+
|
24 |
+
DEFAULT_MAX_TOKENS = 256
|
25 |
+
DEFAULT_MODEL = "claude-v1"
|
26 |
+
|
27 |
+
|
28 |
+
class Claude(IntelligenceBackend):
|
29 |
+
"""Interface to the Claude offered by Anthropic."""
|
30 |
+
|
31 |
+
stateful = False
|
32 |
+
type_name = "claude"
|
33 |
+
|
34 |
+
def __init__(
|
35 |
+
self, max_tokens: int = DEFAULT_MAX_TOKENS, model: str = DEFAULT_MODEL, **kwargs
|
36 |
+
):
|
37 |
+
assert (
|
38 |
+
is_anthropic_available
|
39 |
+
), "anthropic package is not installed or the API key is not set"
|
40 |
+
super().__init__(max_tokens=max_tokens, model=model, **kwargs)
|
41 |
+
|
42 |
+
self.max_tokens = max_tokens
|
43 |
+
self.model = model
|
44 |
+
|
45 |
+
self.client = anthropic.Client(os.environ["ANTHROPIC_API_KEY"])
|
46 |
+
|
47 |
+
@retry(stop=stop_after_attempt(6), wait=wait_random_exponential(min=1, max=60))
|
48 |
+
def _get_response(self, prompt: str):
|
49 |
+
response = self.client.completion(
|
50 |
+
prompt=prompt,
|
51 |
+
stop_sequences=[anthropic.HUMAN_PROMPT],
|
52 |
+
model=self.model,
|
53 |
+
max_tokens_to_sample=self.max_tokens,
|
54 |
+
)
|
55 |
+
|
56 |
+
response = response["completion"].strip()
|
57 |
+
return response
|
58 |
+
|
59 |
+
def query(
|
60 |
+
self,
|
61 |
+
agent_name: str,
|
62 |
+
role_desc: str,
|
63 |
+
history_messages: List[Message],
|
64 |
+
global_prompt: str = None,
|
65 |
+
request_msg: Message = None,
|
66 |
+
*args,
|
67 |
+
**kwargs,
|
68 |
+
) -> str:
|
69 |
+
"""
|
70 |
+
Format the input and call the Claude API.
|
71 |
+
|
72 |
+
args:
|
73 |
+
agent_name: the name of the agent
|
74 |
+
role_desc: the description of the role of the agent
|
75 |
+
env_desc: the description of the environment
|
76 |
+
history_messages: the history of the conversation, or the observation for the agent
|
77 |
+
request_msg: the request from the system to guide the agent's next response
|
78 |
+
"""
|
79 |
+
all_messages = (
|
80 |
+
[(SYSTEM, global_prompt), (SYSTEM, role_desc)]
|
81 |
+
if global_prompt
|
82 |
+
else [(SYSTEM, role_desc)]
|
83 |
+
)
|
84 |
+
|
85 |
+
for message in history_messages:
|
86 |
+
all_messages.append((message.agent_name, message.content))
|
87 |
+
if request_msg:
|
88 |
+
all_messages.append((SYSTEM, request_msg.content))
|
89 |
+
|
90 |
+
prompt = ""
|
91 |
+
prev_is_human = False # Whether the previous message is from human (in anthropic, the human is the user)
|
92 |
+
for i, message in enumerate(all_messages):
|
93 |
+
if i == 0:
|
94 |
+
assert (
|
95 |
+
message[0] == SYSTEM
|
96 |
+
) # The first message should be from the system
|
97 |
+
|
98 |
+
if message[0] == agent_name:
|
99 |
+
if prev_is_human:
|
100 |
+
prompt = f"{prompt}{anthropic.AI_PROMPT} {message[1]}"
|
101 |
+
else:
|
102 |
+
prompt = f"{prompt}\n\n{message[1]}"
|
103 |
+
prev_is_human = False
|
104 |
+
else:
|
105 |
+
if prev_is_human:
|
106 |
+
prompt = f"{prompt}\n\n[{message[0]}]: {message[1]}"
|
107 |
+
else:
|
108 |
+
prompt = f"{prompt}{anthropic.HUMAN_PROMPT}\n[{message[0]}]: {message[1]}"
|
109 |
+
prev_is_human = True
|
110 |
+
assert prev_is_human # The last message should be from the human
|
111 |
+
# Add the AI prompt for Claude to generate the response
|
112 |
+
prompt = f"{prompt}{anthropic.AI_PROMPT}"
|
113 |
+
|
114 |
+
response = self._get_response(prompt, *args, **kwargs)
|
115 |
+
|
116 |
+
# Remove the agent name if the response starts with it
|
117 |
+
response = re.sub(rf"^\s*\[{agent_name}]:?", "", response).strip()
|
118 |
+
|
119 |
+
return response
|
agentreview/backends/bard.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
6 |
+
|
7 |
+
from ..message import SYSTEM_NAME as SYSTEM
|
8 |
+
from ..message import Message
|
9 |
+
from .base import IntelligenceBackend
|
10 |
+
|
11 |
+
try:
|
12 |
+
import bardapi
|
13 |
+
except ImportError:
|
14 |
+
is_bard_available = False
|
15 |
+
# logging.warning("bard package is not installed")
|
16 |
+
else:
|
17 |
+
bard_api_key = os.environ.get("_BARD_API_KEY")
|
18 |
+
if bard_api_key is None:
|
19 |
+
# logging.warning(
|
20 |
+
# "Bard API key is not set. Please set the environment variable _BARD_API_KEY")
|
21 |
+
is_bard_available = False
|
22 |
+
else:
|
23 |
+
is_bard_available = True
|
24 |
+
|
25 |
+
DEFAULT_MAX_TOKENS = 4096
|
26 |
+
|
27 |
+
|
28 |
+
class Bard(IntelligenceBackend):
|
29 |
+
"""Interface to the Bard offered by Google."""
|
30 |
+
|
31 |
+
stateful = False
|
32 |
+
type_name = "bard"
|
33 |
+
|
34 |
+
def __init__(self, max_tokens: int = DEFAULT_MAX_TOKENS, **kwargs):
|
35 |
+
assert (
|
36 |
+
is_bard_available
|
37 |
+
), "bard package is not installed or the API key is not set"
|
38 |
+
super().__init__(max_tokens=max_tokens, **kwargs)
|
39 |
+
|
40 |
+
self.max_tokens = max_tokens
|
41 |
+
|
42 |
+
self.client = bardapi.core.Bard()
|
43 |
+
|
44 |
+
@retry(stop=stop_after_attempt(6), wait=wait_random_exponential(min=1, max=60))
|
45 |
+
def _get_response(self, prompt: str):
|
46 |
+
response = self.client.get_answer(
|
47 |
+
input_text=prompt,
|
48 |
+
)
|
49 |
+
|
50 |
+
response = response["content"].strip()
|
51 |
+
return response
|
52 |
+
|
53 |
+
def query(
|
54 |
+
self,
|
55 |
+
agent_name: str,
|
56 |
+
role_desc: str,
|
57 |
+
history_messages: List[Message],
|
58 |
+
global_prompt: str = None,
|
59 |
+
request_msg: Message = None,
|
60 |
+
*args,
|
61 |
+
**kwargs,
|
62 |
+
) -> str:
|
63 |
+
"""
|
64 |
+
Format the input and call the Bard API.
|
65 |
+
|
66 |
+
args:
|
67 |
+
agent_name: the name of the agent
|
68 |
+
role_desc: the description of the role of the agent
|
69 |
+
env_desc: the description of the environment
|
70 |
+
history_messages: the history of the conversation, or the observation for the agent
|
71 |
+
request_msg: the request from the system to guide the agent's next response
|
72 |
+
"""
|
73 |
+
all_messages = (
|
74 |
+
[(SYSTEM, global_prompt), (SYSTEM, role_desc)]
|
75 |
+
if global_prompt
|
76 |
+
else [(SYSTEM, role_desc)]
|
77 |
+
)
|
78 |
+
|
79 |
+
for message in history_messages:
|
80 |
+
all_messages.append((message.agent_name, message.content))
|
81 |
+
if request_msg:
|
82 |
+
all_messages.append((SYSTEM, request_msg.content))
|
83 |
+
|
84 |
+
# current bard api doesn't support role system, so just dump the raw messages as prompt
|
85 |
+
response = self._get_response(str(all_messages), *args, **kwargs)
|
86 |
+
|
87 |
+
# Remove the agent name if the response starts with it
|
88 |
+
response = re.sub(rf"^\s*\[{agent_name}]:?", "", response).strip()
|
89 |
+
|
90 |
+
return response
|
agentreview/backends/base.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
from ..config import BackendConfig, Configurable
|
5 |
+
from ..message import Message
|
6 |
+
|
7 |
+
|
8 |
+
class IntelligenceBackend(Configurable):
|
9 |
+
"""An abstraction of the intelligence source of the agents."""
|
10 |
+
|
11 |
+
stateful = None
|
12 |
+
type_name = None
|
13 |
+
|
14 |
+
@abstractmethod
|
15 |
+
def __init__(self, **kwargs):
|
16 |
+
super().__init__(**kwargs) # registers the arguments with Configurable
|
17 |
+
|
18 |
+
def __init_subclass__(cls, **kwargs):
|
19 |
+
# check if the subclass has the required attributes
|
20 |
+
for required in (
|
21 |
+
"stateful",
|
22 |
+
"type_name",
|
23 |
+
):
|
24 |
+
if getattr(cls, required) is None:
|
25 |
+
raise TypeError(
|
26 |
+
f"Can't instantiate abstract class {cls.__name__} without {required} attribute defined"
|
27 |
+
)
|
28 |
+
return super().__init_subclass__(**kwargs)
|
29 |
+
|
30 |
+
def to_config(self) -> BackendConfig:
|
31 |
+
self._config_dict["backend_type"] = self.type_name
|
32 |
+
return BackendConfig(**self._config_dict)
|
33 |
+
|
34 |
+
@abstractmethod
|
35 |
+
def query(
|
36 |
+
self,
|
37 |
+
agent_name: str,
|
38 |
+
role_desc: str,
|
39 |
+
history_messages: List[Message],
|
40 |
+
global_prompt: str = None,
|
41 |
+
request_msg: Message = None,
|
42 |
+
*args,
|
43 |
+
**kwargs,
|
44 |
+
) -> str:
|
45 |
+
raise NotImplementedError
|
46 |
+
|
47 |
+
@abstractmethod
|
48 |
+
async def async_query(
|
49 |
+
self,
|
50 |
+
agent_name: str,
|
51 |
+
role_desc: str,
|
52 |
+
history_messages: List[Message],
|
53 |
+
global_prompt: str = None,
|
54 |
+
request_msg: Message = None,
|
55 |
+
*args,
|
56 |
+
**kwargs,
|
57 |
+
) -> str:
|
58 |
+
"""Async querying."""
|
59 |
+
raise NotImplementedError
|
60 |
+
|
61 |
+
# reset the state of the backend
|
62 |
+
def reset(self):
|
63 |
+
if self.stateful:
|
64 |
+
raise NotImplementedError
|
65 |
+
else:
|
66 |
+
pass
|
agentreview/backends/cohere.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
5 |
+
|
6 |
+
from ..message import Message
|
7 |
+
from .base import IntelligenceBackend
|
8 |
+
|
9 |
+
# Try to import the cohere package and check whether the API key is set
|
10 |
+
try:
|
11 |
+
import cohere
|
12 |
+
except ImportError:
|
13 |
+
is_cohere_available = False
|
14 |
+
else:
|
15 |
+
if os.environ.get("COHEREAI_API_KEY") is None:
|
16 |
+
is_cohere_available = False
|
17 |
+
else:
|
18 |
+
is_cohere_available = True
|
19 |
+
|
20 |
+
# Default config follows the [Cohere documentation](https://cohere-sdk.readthedocs.io/en/latest/cohere.html#cohere.client.Client.chat)
|
21 |
+
DEFAULT_TEMPERATURE = 0.8
|
22 |
+
DEFAULT_MAX_TOKENS = 200
|
23 |
+
DEFAULT_MODEL = "command-xlarge"
|
24 |
+
|
25 |
+
|
26 |
+
class CohereAIChat(IntelligenceBackend):
|
27 |
+
"""Interface to the Cohere API."""
|
28 |
+
|
29 |
+
stateful = True
|
30 |
+
type_name = "cohere-chat"
|
31 |
+
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
temperature: float = DEFAULT_TEMPERATURE,
|
35 |
+
max_tokens: int = DEFAULT_MAX_TOKENS,
|
36 |
+
model: str = DEFAULT_MODEL,
|
37 |
+
**kwargs,
|
38 |
+
):
|
39 |
+
super().__init__(
|
40 |
+
temperature=temperature, max_tokens=max_tokens, model=model, **kwargs
|
41 |
+
)
|
42 |
+
|
43 |
+
self.temperature = temperature
|
44 |
+
self.max_tokens = max_tokens
|
45 |
+
self.model = model
|
46 |
+
|
47 |
+
assert (
|
48 |
+
is_cohere_available
|
49 |
+
), "Cohere package is not installed or the API key is not set"
|
50 |
+
self.client = cohere.Client(os.environ.get("COHEREAI_API_KEY"))
|
51 |
+
|
52 |
+
# Stateful variables
|
53 |
+
self.session_id = None # The session id for the last conversation
|
54 |
+
self.last_msg_hash = (
|
55 |
+
None # The hash of the last message of the last conversation
|
56 |
+
)
|
57 |
+
|
58 |
+
def reset(self):
|
59 |
+
self.session_id = None
|
60 |
+
self.last_msg_hash = None
|
61 |
+
|
62 |
+
@retry(stop=stop_after_attempt(6), wait=wait_random_exponential(min=1, max=60))
|
63 |
+
def _get_response(self, new_message: str, persona_prompt: str):
|
64 |
+
response = self.client.chat(
|
65 |
+
new_message,
|
66 |
+
persona_prompt=persona_prompt,
|
67 |
+
temperature=self.temperature,
|
68 |
+
max_tokens=self.max_tokens,
|
69 |
+
session_id=self.session_id,
|
70 |
+
)
|
71 |
+
|
72 |
+
self.session_id = response.session_id # Update the session id
|
73 |
+
return response.reply
|
74 |
+
|
75 |
+
def query(
|
76 |
+
self,
|
77 |
+
agent_name: str,
|
78 |
+
role_desc: str,
|
79 |
+
history_messages: List[Message],
|
80 |
+
global_prompt: str = None,
|
81 |
+
request_msg: Message = None,
|
82 |
+
*args,
|
83 |
+
**kwargs,
|
84 |
+
) -> str:
|
85 |
+
"""
|
86 |
+
Format the input and call the Cohere API.
|
87 |
+
|
88 |
+
args:
|
89 |
+
agent_name: the name of the agent
|
90 |
+
role_desc: the description of the role of the agent
|
91 |
+
env_desc: the description of the environment
|
92 |
+
history_messages: the history of the conversation, or the observation for the agent
|
93 |
+
request_msg: the request for the CohereAI
|
94 |
+
"""
|
95 |
+
# Find the index of the last message of the last conversation
|
96 |
+
new_message_start_idx = 0
|
97 |
+
if self.last_msg_hash is not None:
|
98 |
+
for i, message in enumerate(history_messages):
|
99 |
+
if message.msg_hash == self.last_msg_hash:
|
100 |
+
new_message_start_idx = i + 1
|
101 |
+
break
|
102 |
+
|
103 |
+
new_messages = history_messages[new_message_start_idx:]
|
104 |
+
assert len(new_messages) > 0, "No new messages found (this should not happen)"
|
105 |
+
|
106 |
+
new_conversations = []
|
107 |
+
for message in new_messages:
|
108 |
+
if message.agent_name != agent_name:
|
109 |
+
# Since there are more than one player, we need to distinguish between the players
|
110 |
+
new_conversations.append(f"[{message.agent_name}]: {message.content}")
|
111 |
+
|
112 |
+
if request_msg:
|
113 |
+
new_conversations.append(
|
114 |
+
f"[{request_msg.agent_name}]: {request_msg.content}"
|
115 |
+
)
|
116 |
+
|
117 |
+
# Concatenate all new messages into one message because the Cohere API only accepts one message
|
118 |
+
new_message = "\n".join(new_conversations)
|
119 |
+
persona_prompt = f"Environment:\n{global_prompt}\n\nYour role:\n{role_desc}"
|
120 |
+
|
121 |
+
response = self._get_response(new_message, persona_prompt)
|
122 |
+
|
123 |
+
# Only update the last message hash if the API call is successful
|
124 |
+
self.last_msg_hash = new_messages[-1].msg_hash
|
125 |
+
|
126 |
+
return response
|
agentreview/backends/dummy.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from agentreview.config import Configurable
|
2 |
+
|
3 |
+
|
4 |
+
class Dummy(Configurable):
|
5 |
+
"""A dummy backend does not make any API calls. We use it for extracting paper contents in PaperExtractor
|
6 |
+
and also for testing."""
|
7 |
+
stateful = False
|
8 |
+
type_name = "dummy"
|
9 |
+
|
10 |
+
def __init__(self, **kwargs):
|
11 |
+
super().__init__(**kwargs)
|
12 |
+
|
13 |
+
def reset(self):
|
14 |
+
pass
|
agentreview/backends/hf_transformers.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from contextlib import contextmanager, redirect_stderr, redirect_stdout
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
6 |
+
|
7 |
+
from ..message import SYSTEM_NAME as SYSTEM
|
8 |
+
from ..message import Message
|
9 |
+
from .base import IntelligenceBackend
|
10 |
+
|
11 |
+
|
12 |
+
@contextmanager
|
13 |
+
def suppress_stdout_stderr():
|
14 |
+
"""A context manager that redirects stdout and stderr to devnull."""
|
15 |
+
with open(os.devnull, "w") as fnull:
|
16 |
+
with redirect_stderr(fnull) as err, redirect_stdout(fnull) as out:
|
17 |
+
yield (err, out)
|
18 |
+
|
19 |
+
|
20 |
+
with suppress_stdout_stderr():
|
21 |
+
# Try to import the transformers package
|
22 |
+
try:
|
23 |
+
import transformers
|
24 |
+
from transformers import pipeline
|
25 |
+
from transformers.pipelines.conversational import (
|
26 |
+
Conversation,
|
27 |
+
ConversationalPipeline,
|
28 |
+
)
|
29 |
+
except ImportError:
|
30 |
+
is_transformers_available = False
|
31 |
+
else:
|
32 |
+
is_transformers_available = True
|
33 |
+
|
34 |
+
|
35 |
+
class TransformersConversational(IntelligenceBackend):
|
36 |
+
"""Interface to the Transformers ConversationalPipeline."""
|
37 |
+
|
38 |
+
stateful = False
|
39 |
+
type_name = "transformers:conversational"
|
40 |
+
|
41 |
+
def __init__(self, model: str, device: int = -1, **kwargs):
|
42 |
+
super().__init__(model=model, device=device, **kwargs)
|
43 |
+
self.model = model
|
44 |
+
self.device = device
|
45 |
+
|
46 |
+
assert is_transformers_available, "Transformers package is not installed"
|
47 |
+
self.chatbot = pipeline(
|
48 |
+
task="conversational", model=self.model, device=self.device
|
49 |
+
)
|
50 |
+
|
51 |
+
@retry(stop=stop_after_attempt(6), wait=wait_random_exponential(min=1, max=60))
|
52 |
+
def _get_response(self, conversation):
|
53 |
+
conversation = self.chatbot(conversation)
|
54 |
+
response = conversation.generated_responses[-1]
|
55 |
+
return response
|
56 |
+
|
57 |
+
@staticmethod
|
58 |
+
def _msg_template(agent_name, content):
|
59 |
+
return f"[{agent_name}]: {content}"
|
60 |
+
|
61 |
+
def query(
|
62 |
+
self,
|
63 |
+
agent_name: str,
|
64 |
+
role_desc: str,
|
65 |
+
history_messages: List[Message],
|
66 |
+
global_prompt: str = None,
|
67 |
+
request_msg: Message = None,
|
68 |
+
*args,
|
69 |
+
**kwargs,
|
70 |
+
) -> str:
|
71 |
+
user_inputs, generated_responses = [], []
|
72 |
+
all_messages = (
|
73 |
+
[(SYSTEM, global_prompt), (SYSTEM, role_desc)]
|
74 |
+
if global_prompt
|
75 |
+
else [(SYSTEM, role_desc)]
|
76 |
+
)
|
77 |
+
|
78 |
+
for msg in history_messages:
|
79 |
+
all_messages.append((msg.agent_name, msg.content))
|
80 |
+
if request_msg:
|
81 |
+
all_messages.append((SYSTEM, request_msg.content))
|
82 |
+
|
83 |
+
prev_is_user = False # Whether the previous message is from the user
|
84 |
+
for i, message in enumerate(all_messages):
|
85 |
+
if i == 0:
|
86 |
+
assert (
|
87 |
+
message[0] == SYSTEM
|
88 |
+
) # The first message should be from the system
|
89 |
+
|
90 |
+
if message[0] != agent_name:
|
91 |
+
if not prev_is_user:
|
92 |
+
user_inputs.append(self._msg_template(message[0], message[1]))
|
93 |
+
else:
|
94 |
+
user_inputs[-1] += "\n" + self._msg_template(message[0], message[1])
|
95 |
+
prev_is_user = True
|
96 |
+
else:
|
97 |
+
if prev_is_user:
|
98 |
+
generated_responses.append(message[1])
|
99 |
+
else:
|
100 |
+
generated_responses[-1] += "\n" + message[1]
|
101 |
+
prev_is_user = False
|
102 |
+
|
103 |
+
assert len(user_inputs) == len(generated_responses) + 1
|
104 |
+
past_user_inputs = user_inputs[:-1]
|
105 |
+
new_user_input = user_inputs[-1]
|
106 |
+
|
107 |
+
# Recreate a conversation object from the history messages
|
108 |
+
conversation = Conversation(
|
109 |
+
text=new_user_input,
|
110 |
+
past_user_inputs=past_user_inputs,
|
111 |
+
generated_responses=generated_responses,
|
112 |
+
)
|
113 |
+
|
114 |
+
# Get the response
|
115 |
+
response = self._get_response(conversation)
|
116 |
+
return response
|
117 |
+
|
118 |
+
|
119 |
+
# conversation = Conversation("Going to the movies tonight - any suggestions?")
|
120 |
+
#
|
121 |
+
# # Steps usually performed by the model when generating a response:
|
122 |
+
# # 1. Mark the user input as processed (moved to the history)
|
123 |
+
# conversation.mark_processed()
|
124 |
+
# # 2. Append a mode response
|
125 |
+
# conversation.append_response("The Big lebowski.")
|
126 |
+
#
|
127 |
+
# conversation.add_user_input("Is it good?")
|
agentreview/backends/human.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..config import BackendConfig
|
2 |
+
from .base import IntelligenceBackend
|
3 |
+
|
4 |
+
|
5 |
+
# An Error class for the human backend
|
6 |
+
class HumanBackendError(Exception):
|
7 |
+
def __init__(self, agent_name: str):
|
8 |
+
self.agent_name = agent_name
|
9 |
+
super().__init__(f"Human backend requires a UI to get input from {agent_name}.")
|
10 |
+
|
11 |
+
|
12 |
+
class Human(IntelligenceBackend):
|
13 |
+
stateful = False
|
14 |
+
type_name = "human"
|
15 |
+
|
16 |
+
def __init__(self, **kwargs):
|
17 |
+
super().__init__(**kwargs)
|
18 |
+
|
19 |
+
def to_config(self) -> BackendConfig:
|
20 |
+
return BackendConfig(backend_type=self.type_name)
|
21 |
+
|
22 |
+
def query(self, agent_name: str, **kwargs) -> str:
|
23 |
+
raise HumanBackendError(agent_name)
|
agentreview/backends/langchain.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
6 |
+
|
7 |
+
from ..message import SYSTEM_NAME, Message
|
8 |
+
from .base import IntelligenceBackend
|
9 |
+
|
10 |
+
try:
|
11 |
+
from langchain.llms import OpenAI
|
12 |
+
except ImportError:
|
13 |
+
is_langchain_openai_available = False
|
14 |
+
# logging.warning("openai package is not installed")
|
15 |
+
else:
|
16 |
+
api_key = os.environ.get("OPENAI_API_KEY")
|
17 |
+
if api_key is None:
|
18 |
+
# logging.warning("OpenAI API key is not set. Please set the environment variable OPENAI_API_KEY")
|
19 |
+
is_langchain_openai_available = False
|
20 |
+
else:
|
21 |
+
is_langchain_openai_available = True
|
22 |
+
|
23 |
+
# Default config follows the OpenAI playground
|
24 |
+
DEFAULT_TEMPERATURE = 0.7
|
25 |
+
DEFAULT_MAX_TOKENS = 2048
|
26 |
+
DEFAULT_MODEL = "gpt-4"
|
27 |
+
|
28 |
+
END_OF_MESSAGE = "<EOS>" # End of message token specified by us not OpenAI
|
29 |
+
STOP = ("<|endoftext|>", END_OF_MESSAGE) # End of sentence token
|
30 |
+
BASE_PROMPT = f"The messages always end with the token {END_OF_MESSAGE}."
|
31 |
+
|
32 |
+
|
33 |
+
class LangChainOpenAIChat(IntelligenceBackend):
|
34 |
+
"""Interface to the ChatGPT style model with system, user, assistant roles separation."""
|
35 |
+
|
36 |
+
stateful = False
|
37 |
+
type_name = "openai-chat"
|
38 |
+
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
temperature: float = DEFAULT_TEMPERATURE,
|
42 |
+
max_tokens: int = DEFAULT_MAX_TOKENS,
|
43 |
+
model: str = DEFAULT_MODEL,
|
44 |
+
merge_other_agents_as_one_user: bool = True,
|
45 |
+
**kwargs,
|
46 |
+
):
|
47 |
+
"""
|
48 |
+
Instantiate the OpenAIChat backend.
|
49 |
+
|
50 |
+
args:
|
51 |
+
temperature: the temperature of the sampling
|
52 |
+
max_tokens: the maximum number of tokens to sample
|
53 |
+
model: the model to use
|
54 |
+
merge_other_agents_as_one_user: whether to merge messages from other agents as one user message
|
55 |
+
"""
|
56 |
+
assert (
|
57 |
+
is_langchain_openai_available
|
58 |
+
), "langchain package is not installed or the API key is not set"
|
59 |
+
super().__init__(
|
60 |
+
temperature=temperature,
|
61 |
+
max_tokens=max_tokens,
|
62 |
+
model=model,
|
63 |
+
merge_other_agents_as_one_user=merge_other_agents_as_one_user,
|
64 |
+
**kwargs,
|
65 |
+
)
|
66 |
+
|
67 |
+
self.temperature = temperature
|
68 |
+
self.max_tokens = max_tokens
|
69 |
+
self.model = model
|
70 |
+
self.merge_other_agent_as_user = merge_other_agents_as_one_user
|
71 |
+
self.llm = OpenAI(
|
72 |
+
model_name=model,
|
73 |
+
temperature=temperature,
|
74 |
+
max_tokens=max_tokens,
|
75 |
+
openai_api_key=api_key,
|
76 |
+
)
|
77 |
+
|
78 |
+
@retry(stop=stop_after_attempt(6), wait=wait_random_exponential(min=1, max=60))
|
79 |
+
def _get_response(self, messages):
|
80 |
+
response = self.llm(prompt=messages, stop=STOP)
|
81 |
+
return response
|
82 |
+
|
83 |
+
def query(
|
84 |
+
self,
|
85 |
+
agent_name: str,
|
86 |
+
role_desc: str,
|
87 |
+
history_messages: List[Message],
|
88 |
+
global_prompt: str = None,
|
89 |
+
request_msg: Message = None,
|
90 |
+
*args,
|
91 |
+
**kwargs,
|
92 |
+
) -> str:
|
93 |
+
"""
|
94 |
+
Format the input and call the ChatGPT/GPT-4 API.
|
95 |
+
|
96 |
+
args:
|
97 |
+
agent_name: the name of the agent
|
98 |
+
role_desc: the description of the role of the agent
|
99 |
+
env_desc: the description of the environment
|
100 |
+
history_messages: the history of the conversation, or the observation for the agent
|
101 |
+
request_msg: the request from the system to guide the agent's next response
|
102 |
+
"""
|
103 |
+
|
104 |
+
# Merge the role description and the global prompt as the system prompt for the agent
|
105 |
+
if global_prompt: # Prepend the global prompt if it exists
|
106 |
+
system_prompt = f"{global_prompt.strip()}\n{BASE_PROMPT}\n\nYour name: {agent_name}\n\nYour role:{role_desc}"
|
107 |
+
else:
|
108 |
+
system_prompt = (
|
109 |
+
f"You are {agent_name}.\n\nYour role:{role_desc}\n\n{BASE_PROMPT}"
|
110 |
+
)
|
111 |
+
|
112 |
+
all_messages = [(SYSTEM_NAME, system_prompt)]
|
113 |
+
for msg in history_messages:
|
114 |
+
if msg.agent_name == SYSTEM_NAME:
|
115 |
+
all_messages.append((SYSTEM_NAME, msg.content))
|
116 |
+
else: # non-system messages are suffixed with the end of message token
|
117 |
+
all_messages.append((msg.agent_name, f"{msg.content}{END_OF_MESSAGE}"))
|
118 |
+
|
119 |
+
if request_msg:
|
120 |
+
all_messages.append((SYSTEM_NAME, request_msg.content))
|
121 |
+
else: # The default request message that reminds the agent its role and instruct it to speak
|
122 |
+
all_messages.append(
|
123 |
+
(SYSTEM_NAME, f"Now you speak, {agent_name}.{END_OF_MESSAGE}")
|
124 |
+
)
|
125 |
+
|
126 |
+
messages = []
|
127 |
+
for i, msg in enumerate(all_messages):
|
128 |
+
if i == 0:
|
129 |
+
assert (
|
130 |
+
msg[0] == SYSTEM_NAME
|
131 |
+
) # The first message should be from the system
|
132 |
+
messages.append({"role": "system", "content": msg[1]})
|
133 |
+
else:
|
134 |
+
if msg[0] == agent_name:
|
135 |
+
messages.append({"role": "assistant", "content": msg[1]})
|
136 |
+
else:
|
137 |
+
if messages[-1]["role"] == "user": # last message is from user
|
138 |
+
if self.merge_other_agent_as_user:
|
139 |
+
messages[-1][
|
140 |
+
"content"
|
141 |
+
] = f"{messages[-1]['content']}\n\n[{msg[0]}]: {msg[1]}"
|
142 |
+
else:
|
143 |
+
messages.append(
|
144 |
+
{"role": "user", "content": f"[{msg[0]}]: {msg[1]}"}
|
145 |
+
)
|
146 |
+
elif (
|
147 |
+
messages[-1]["role"] == "assistant"
|
148 |
+
): # consecutive assistant messages
|
149 |
+
# Merge the assistant messages
|
150 |
+
messages[-1]["content"] = f"{messages[-1]['content']}\n{msg[1]}"
|
151 |
+
elif messages[-1]["role"] == "system":
|
152 |
+
messages.append(
|
153 |
+
{"role": "user", "content": f"[{msg[0]}]: {msg[1]}"}
|
154 |
+
)
|
155 |
+
else:
|
156 |
+
raise ValueError(f"Invalid role: {messages[-1]['role']}")
|
157 |
+
|
158 |
+
response = self._get_response(messages, *args, **kwargs)
|
159 |
+
|
160 |
+
# Remove the agent name if the response starts with it
|
161 |
+
response = re.sub(rf"^\s*\[.*]:", "", response).strip() # noqa: F541
|
162 |
+
response = re.sub(
|
163 |
+
rf"^\s*{re.escape(agent_name)}\s*:", "", response
|
164 |
+
).strip() # noqa: F541
|
165 |
+
|
166 |
+
# Remove the tailing end of message token
|
167 |
+
response = re.sub(rf"{END_OF_MESSAGE}$", "", response).strip()
|
168 |
+
|
169 |
+
return response
|
agentreview/backends/openai.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
5 |
+
|
6 |
+
from agentreview.arguments import parse_args
|
7 |
+
from agentreview.utility.authentication_utils import get_openai_client
|
8 |
+
from .base import IntelligenceBackend
|
9 |
+
from ..message import SYSTEM_NAME, Message
|
10 |
+
|
11 |
+
# Default config follows the OpenAI playground
|
12 |
+
DEFAULT_TEMPERATURE = 1.0
|
13 |
+
DEFAULT_MAX_TOKENS = 4096
|
14 |
+
|
15 |
+
# Check https://platform.openai.com/docs/models for more models
|
16 |
+
|
17 |
+
DEFAULT_MODEL = "gpt-4o"
|
18 |
+
|
19 |
+
END_OF_MESSAGE = "<EOS>" # End of message token specified by us not OpenAI
|
20 |
+
STOP = ("<|endoftext|>", END_OF_MESSAGE) # End of sentence token
|
21 |
+
BASE_PROMPT = f"The messages always end with the token {END_OF_MESSAGE}."
|
22 |
+
|
23 |
+
|
24 |
+
class OpenAIChat(IntelligenceBackend):
|
25 |
+
"""Interface to the ChatGPT style model with system, user, assistant roles separation."""
|
26 |
+
|
27 |
+
stateful = False
|
28 |
+
type_name = "openai-chat"
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
temperature: float = DEFAULT_TEMPERATURE,
|
33 |
+
max_tokens: int = DEFAULT_MAX_TOKENS,
|
34 |
+
model: str = DEFAULT_MODEL,
|
35 |
+
merge_other_agents_as_one_user: bool = True,
|
36 |
+
**kwargs,
|
37 |
+
):
|
38 |
+
"""
|
39 |
+
Instantiate the OpenAIChat backend.
|
40 |
+
|
41 |
+
args:
|
42 |
+
temperature: the temperature of the sampling
|
43 |
+
max_tokens: the maximum number of tokens to sample
|
44 |
+
model: the model to use
|
45 |
+
merge_other_agents_as_one_user: whether to merge messages from other agents as one user message
|
46 |
+
"""
|
47 |
+
super().__init__(
|
48 |
+
temperature=temperature,
|
49 |
+
max_tokens=max_tokens,
|
50 |
+
model=model,
|
51 |
+
merge_other_agents_as_one_user=merge_other_agents_as_one_user,
|
52 |
+
**kwargs,
|
53 |
+
)
|
54 |
+
self.client_type = kwargs.get("openai_client_type", None)
|
55 |
+
self.client = get_openai_client(self.client_type)
|
56 |
+
self.temperature = temperature
|
57 |
+
self.max_tokens = max_tokens
|
58 |
+
self.model = model
|
59 |
+
self.merge_other_agent_as_user = merge_other_agents_as_one_user
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
@retry(stop=stop_after_attempt(6), wait=wait_random_exponential(min=1, max=60))
|
64 |
+
def _get_response(self, messages):
|
65 |
+
# Refer to https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/switching-endpoints for how to
|
66 |
+
# make API calls
|
67 |
+
|
68 |
+
if self.client_type == "openai":
|
69 |
+
completion = self.client.chat.completions.create(
|
70 |
+
model=self.model,
|
71 |
+
messages=messages,
|
72 |
+
temperature=self.temperature,
|
73 |
+
max_tokens=self.max_tokens,
|
74 |
+
stop=STOP,
|
75 |
+
)
|
76 |
+
|
77 |
+
elif self.client_type == "azure_openai":
|
78 |
+
completion = self.client.chat.completions.create(
|
79 |
+
model=self.model,
|
80 |
+
messages=messages,
|
81 |
+
temperature=self.temperature,
|
82 |
+
max_tokens=self.max_tokens,
|
83 |
+
stop=STOP,
|
84 |
+
)
|
85 |
+
|
86 |
+
else:
|
87 |
+
raise NotImplementedError
|
88 |
+
|
89 |
+
response = completion.choices[0].message.content
|
90 |
+
|
91 |
+
response = response.strip()
|
92 |
+
return response
|
93 |
+
|
94 |
+
def query(
|
95 |
+
self,
|
96 |
+
agent_name: str,
|
97 |
+
role_desc: str,
|
98 |
+
history_messages: List[Message],
|
99 |
+
global_prompt: str = None,
|
100 |
+
request_msg: Message = None,
|
101 |
+
*args,
|
102 |
+
**kwargs,
|
103 |
+
) -> str:
|
104 |
+
"""
|
105 |
+
Format the input and call the ChatGPT/GPT-4 API.
|
106 |
+
|
107 |
+
args:
|
108 |
+
agent_name: the name of the agent
|
109 |
+
role_desc: the description of the role of the agent
|
110 |
+
env_desc: the description of the environment
|
111 |
+
history_messages: the history of the conversation, or the observation for the agent
|
112 |
+
request_msg: the request from the system to guide the agent's next response
|
113 |
+
"""
|
114 |
+
|
115 |
+
# Merge the role description and the global prompt as the system prompt for the agent
|
116 |
+
if global_prompt: # Prepend the global prompt if it exists
|
117 |
+
system_prompt = f"You are a helpful assistant.\n{global_prompt.strip()}\n{BASE_PROMPT}\n\nYour name is {agent_name}.\n\nYour role:{role_desc}"
|
118 |
+
else:
|
119 |
+
system_prompt = f"You are a helpful assistant. Your name is {agent_name}.\n\nYour role:{role_desc}\n\n{BASE_PROMPT}"
|
120 |
+
|
121 |
+
all_messages = [(SYSTEM_NAME, system_prompt)]
|
122 |
+
for msg in history_messages:
|
123 |
+
if msg.agent_name == SYSTEM_NAME:
|
124 |
+
all_messages.append((SYSTEM_NAME, msg.content))
|
125 |
+
else: # non-system messages are suffixed with the end of message token
|
126 |
+
all_messages.append((msg.agent_name, f"{msg.content}{END_OF_MESSAGE}"))
|
127 |
+
|
128 |
+
if request_msg:
|
129 |
+
all_messages.append((SYSTEM_NAME, request_msg.content))
|
130 |
+
else: # The default request message that reminds the agent its role and instruct it to speak
|
131 |
+
all_messages.append(
|
132 |
+
(SYSTEM_NAME, f"Now you speak, {agent_name}.{END_OF_MESSAGE}")
|
133 |
+
)
|
134 |
+
|
135 |
+
messages = []
|
136 |
+
for i, msg in enumerate(all_messages):
|
137 |
+
if i == 0:
|
138 |
+
assert (
|
139 |
+
msg[0] == SYSTEM_NAME
|
140 |
+
) # The first message should be from the system
|
141 |
+
messages.append({"role": "system", "content": msg[1]})
|
142 |
+
else:
|
143 |
+
if msg[0] == agent_name:
|
144 |
+
messages.append({"role": "assistant", "content": msg[1]})
|
145 |
+
else:
|
146 |
+
if messages[-1]["role"] == "user": # last message is from user
|
147 |
+
if self.merge_other_agent_as_user:
|
148 |
+
messages[-1][
|
149 |
+
"content"
|
150 |
+
] = f"{messages[-1]['content']}\n\n[{msg[0]}]: {msg[1]}"
|
151 |
+
else:
|
152 |
+
messages.append(
|
153 |
+
{"role": "user", "content": f"[{msg[0]}]: {msg[1]}"}
|
154 |
+
)
|
155 |
+
elif (
|
156 |
+
messages[-1]["role"] == "assistant"
|
157 |
+
): # consecutive assistant messages
|
158 |
+
# Merge the assistant messages
|
159 |
+
messages[-1]["content"] = f"{messages[-1]['content']}\n{msg[1]}"
|
160 |
+
elif messages[-1]["role"] == "system":
|
161 |
+
messages.append(
|
162 |
+
{"role": "user", "content": f"[{msg[0]}]: {msg[1]}"}
|
163 |
+
)
|
164 |
+
else:
|
165 |
+
raise ValueError(f"Invalid role: {messages[-1]['role']}")
|
166 |
+
|
167 |
+
response = self._get_response(messages, *args, **kwargs)
|
168 |
+
|
169 |
+
# Remove the agent name if the response starts with it
|
170 |
+
response = re.sub(rf"^\s*\[.*]:", "", response).strip() # noqa: F541
|
171 |
+
response = re.sub(
|
172 |
+
rf"^\s*{re.escape(agent_name)}\s*:", "", response
|
173 |
+
).strip() # noqa: F451
|
174 |
+
|
175 |
+
# Remove the tailing end of message token
|
176 |
+
response = re.sub(rf"{END_OF_MESSAGE}$", "", response).strip()
|
177 |
+
|
178 |
+
return response
|
agentreview/config.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import json
|
3 |
+
|
4 |
+
from .utils import AttributedDict
|
5 |
+
|
6 |
+
|
7 |
+
class Config(AttributedDict):
|
8 |
+
"""
|
9 |
+
Config class to manage the configuration of the games.
|
10 |
+
|
11 |
+
The class has a few useful methods to load and save the config.
|
12 |
+
"""
|
13 |
+
|
14 |
+
# convert dict to Config recursively
|
15 |
+
def __init__(self, *args, **kwargs):
|
16 |
+
super().__init__(*args, **kwargs)
|
17 |
+
for key, value in self.items():
|
18 |
+
|
19 |
+
# Try to convert the value (the "metadata" field) to dict if applicable
|
20 |
+
try:
|
21 |
+
value = dict(eval(value))
|
22 |
+
except Exception:
|
23 |
+
pass
|
24 |
+
|
25 |
+
if isinstance(value, dict):
|
26 |
+
self[key] = init_config(value) # convert dict to Config recursively
|
27 |
+
# convert list of dict to list of Config recursively
|
28 |
+
elif isinstance(value, list) and len(value) > 0:
|
29 |
+
self[key] = [
|
30 |
+
init_config(item) if isinstance(item, dict) else item
|
31 |
+
for item in value
|
32 |
+
]
|
33 |
+
|
34 |
+
def save(self, path: str):
|
35 |
+
# save config to file
|
36 |
+
with open(path, "w") as f:
|
37 |
+
json.dump(self, f, indent=4)
|
38 |
+
|
39 |
+
@classmethod
|
40 |
+
def load(cls, path: str):
|
41 |
+
# load config from file
|
42 |
+
with open(path) as f:
|
43 |
+
config = json.load(f)
|
44 |
+
return cls(config)
|
45 |
+
|
46 |
+
def deepcopy(self):
|
47 |
+
# get the config class so that subclasses can be copied in the correct class
|
48 |
+
config_class = self.__class__
|
49 |
+
# make a deep copy of the config
|
50 |
+
return config_class(copy.deepcopy(self))
|
51 |
+
|
52 |
+
|
53 |
+
class Configurable:
|
54 |
+
"""Configurable is an interface for classes that can be initialized with a config."""
|
55 |
+
|
56 |
+
def __init__(self, **kwargs):
|
57 |
+
self._config_dict = kwargs
|
58 |
+
|
59 |
+
@classmethod
|
60 |
+
def from_config(cls, config: Config):
|
61 |
+
return cls(**config)
|
62 |
+
|
63 |
+
def to_config(self) -> Config:
|
64 |
+
# Convert the _config_dict to Config
|
65 |
+
return Config(**self._config_dict)
|
66 |
+
|
67 |
+
def save_config(self, path: str):
|
68 |
+
self.to_config().save(path)
|
69 |
+
|
70 |
+
|
71 |
+
class EnvironmentConfig(Config):
|
72 |
+
"""EnvironmentConfig contains a env_type field to indicate the name of the environment."""
|
73 |
+
|
74 |
+
def __init__(self, *args, **kwargs):
|
75 |
+
super().__init__(*args, **kwargs)
|
76 |
+
# check if the env_type field is specified
|
77 |
+
if "env_type" not in self:
|
78 |
+
raise ValueError("The env_type field is not specified")
|
79 |
+
|
80 |
+
|
81 |
+
class BackendConfig(Config):
|
82 |
+
"""BackendConfig contains a backend_type field to indicate the name of the backend."""
|
83 |
+
|
84 |
+
def __init__(self, *args, **kwargs):
|
85 |
+
super().__init__(*args, **kwargs)
|
86 |
+
# check if the backend_type field is specified
|
87 |
+
if "backend_type" not in self:
|
88 |
+
raise ValueError("The backend_type field is not specified")
|
89 |
+
|
90 |
+
|
91 |
+
class AgentConfig(Config):
|
92 |
+
"""AgentConfig contains role_desc and backend fields."""
|
93 |
+
|
94 |
+
def __init__(self, *args, **kwargs):
|
95 |
+
super().__init__(*args, **kwargs)
|
96 |
+
# check if the role_desc field is specified
|
97 |
+
if "role_desc" not in self:
|
98 |
+
raise ValueError("The role_desc field is not specified")
|
99 |
+
# check if the backend field is specified
|
100 |
+
if "backend" not in self:
|
101 |
+
raise ValueError("The backend field is not specified")
|
102 |
+
# Make sure the backend field is a BackendConfig
|
103 |
+
if not isinstance(self["backend"], BackendConfig):
|
104 |
+
raise ValueError("The backend field must be a BackendConfig")
|
105 |
+
|
106 |
+
|
107 |
+
class ArenaConfig(Config):
|
108 |
+
"""ArenaConfig contains a list of AgentConfig."""
|
109 |
+
|
110 |
+
def __init__(self, *args, **kwargs):
|
111 |
+
super().__init__(*args, **kwargs)
|
112 |
+
# check if the players field is specified and it is List[AgentConfig]
|
113 |
+
if "players" not in self:
|
114 |
+
raise ValueError("The players field is not specified")
|
115 |
+
if not isinstance(self["players"], list):
|
116 |
+
raise ValueError("The players field must be a list")
|
117 |
+
for player in self["players"]:
|
118 |
+
if not isinstance(player, AgentConfig):
|
119 |
+
raise ValueError("The players field must be a list of AgentConfig")
|
120 |
+
|
121 |
+
# check if environment field is specified and it is EnvironmentConfig
|
122 |
+
if "environment" not in self:
|
123 |
+
raise ValueError("The environment field is not specified")
|
124 |
+
if not isinstance(self["environment"], EnvironmentConfig):
|
125 |
+
raise ValueError("The environment field must be an EnvironmentConfig")
|
126 |
+
|
127 |
+
|
128 |
+
# Initialize with different config class depending on whether the config is for environment or backend
|
129 |
+
def init_config(config: dict):
|
130 |
+
if not isinstance(config, dict):
|
131 |
+
raise ValueError("The config must be a dict")
|
132 |
+
|
133 |
+
# check if the config is for environment or backend
|
134 |
+
if "env_type" in config:
|
135 |
+
return EnvironmentConfig(config)
|
136 |
+
elif "backend_type" in config:
|
137 |
+
return BackendConfig(config)
|
138 |
+
elif "role_desc" in config:
|
139 |
+
return AgentConfig(config)
|
140 |
+
elif "players" in config:
|
141 |
+
return ArenaConfig(config)
|
142 |
+
else:
|
143 |
+
return Config(config)
|
agentreview/const.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
"""
|
4 |
+
Note
|
5 |
+
- ICLR 2021 has a category for "Significant-concerns"
|
6 |
+
- ICLR 2023 categories the papers as "Accept-notable-top-5", "Accept-notable-top-25", "Accept-poster", and "Reject"
|
7 |
+
"""
|
8 |
+
PAPER_DECISIONS = ["Reject","Accept-oral", "Accept-spotlight", "Accept-poster",]
|
9 |
+
PAPER_DECISIONS_ICLR2019 = ["Accept-oral", "Accept-poster", "Reject"]
|
10 |
+
|
11 |
+
AREA_CHAIR_TYPES = ['inclusive', 'conformist', 'authoritarian', 'BASELINE']
|
12 |
+
|
13 |
+
GLOBAL_PROMPT = "This is a realistic simulation of academic peer review."
|
14 |
+
|
15 |
+
# These are papers that contain potentially sensitive content. GPT-4 refused to generate reviews for these papers.
|
16 |
+
FILTERED_PAPER_IDS = {
|
17 |
+
"ICLR2020": [],
|
18 |
+
"ICLR2021": [],
|
19 |
+
"ICLR2022": [186, 200, 270],
|
20 |
+
"ICLR2023": []
|
21 |
+
}
|
22 |
+
|
23 |
+
ALL_REVIEW_PHASES = ["reviewer_write_reviews", "author_reviewer_discussion", "reviewer_ac_discussion", "ac_discussion"]
|
24 |
+
|
25 |
+
|
26 |
+
EXPERIMENT_NAME2REVIEWER_TYPES = {
|
27 |
+
"BASELINE": "BASELINE",
|
28 |
+
"knowledgeable_Rx1": "knowledgeable",
|
29 |
+
"unknowledgeable_Rx1": "unknowledgeable",
|
30 |
+
"irresponsible_Rx1": "irresponsible",
|
31 |
+
"irresponsible_Rx2": "irresponsible",
|
32 |
+
"irresponsible_Rx3": "irresponsible",
|
33 |
+
"responsible_Rx1": "responsible",
|
34 |
+
"malicious_Rx1": "malicious",
|
35 |
+
"malicious_Rx2": "malicious",
|
36 |
+
"malicious_Rx3": "malicious",
|
37 |
+
"benign_Rx1": "benign",
|
38 |
+
"inclusive_ACx1": "BASELINE",
|
39 |
+
"authoritarian_ACx1": "BASELINE",
|
40 |
+
"conformist_ACx1": "BASELINE",
|
41 |
+
"authors_are_famous_Rx1": "authors_are_famous",
|
42 |
+
"authors_are_famous_Rx2": "authors_are_famous",
|
43 |
+
"authors_are_famous_Rx3": "authors_are_famous",
|
44 |
+
"authors_are_famous_Rx1_no_rebuttal": "authors_are_famous",
|
45 |
+
"authors_are_famous_Rx2_no_rebuttal": "authors_are_famous",
|
46 |
+
"authors_are_famous_Rx3_no_rebuttal": "authors_are_famous",
|
47 |
+
"no_rebuttal": "BASELINE",
|
48 |
+
"no_overall_score": "BASELINE",
|
49 |
+
}
|
50 |
+
|
51 |
+
|
52 |
+
year2paper_ids = {
|
53 |
+
"ICLR2018": [45, 47, 59, 76, 229, 254, 372, 415, 447, 517, 543, 544, 562, 596, 615, 639] +
|
54 |
+
[1, 2, 7, 10, 16, 26, 33, 51, 60, 61, 65,
|
55 |
+
67, 69, 72, 73, 77, 84, 88, 94, 99, 104,
|
56 |
+
117, 121, 124, 131, 132, 134, 136, 143, 147, 148, 149, 155, 162, 164, 166, 168, 169, 171,
|
57 |
+
175, 178, 179, 189, 196, 201, 203, 204, 205] +
|
58 |
+
[3, 4, 6, 8, 9, 11, 12, 13, 15, 17, 18, 19, 20, 21, 24, 25, 27, 28, 30, 31, 32, 34, 36, 37, 39, 40,
|
59 |
+
41, 42, 43, 44, 46, 52, 53, 54, 55, 56, 58, 63, 66, 68, 71, 74, 75, 78, 80, 83, 85, 87, 89,
|
60 |
+
91, 92, 93, 95, 96, 97, 100, 101, 102, 103, 105, 106, 107, 108, 109, 110, 111, 113, 114, 115,
|
61 |
+
116, 118, 120, 122, 123, 125, 127, 128, 129, 133, 135, 138, 141, 142, 144, 153, 154, 156, 157, 158,
|
62 |
+
159, 161, 163, 170, 172, 173, 174, 176, 177, 180, 181, 182, 184, 185, 186, 187, 190, 191, 193, 194,
|
63 |
+
197, 200, 206, 207, 209, 211, 213, 214, 218, 219, 221, 222, 223, 225, 226, 230, 234, 237, 238, 241,
|
64 |
+
243, 244, 247, 248, 253, 255, 256, 257, 258, 259, 266, 268, 271, 272, 273, 275, 276, 278, 283,
|
65 |
+
286],
|
66 |
+
|
67 |
+
|
68 |
+
"ICLR2019": [1, 26, 119, 220, 231, 507, 563, 566, 574, 632, 654, 709, 734, 780, 835, 917] + [4, 27, 33, 39, 40,
|
69 |
+
51, 57, 67, 70, 72,
|
70 |
+
73, 76, 77, 82, 87,
|
71 |
+
98, 99, 100, 106,
|
72 |
+
108, 109, 110, 111,
|
73 |
+
113, 114, 116, 123,
|
74 |
+
129, 130, 143, 146,
|
75 |
+
147, 150, 155, 177,
|
76 |
+
184, 187, 190, 194,
|
77 |
+
201, 202, 203, 205,
|
78 |
+
211, 213, 222, 237,
|
79 |
+
238] + [2, 3, 6, 8,
|
80 |
+
9, 11, 12,
|
81 |
+
13, 14, 15,
|
82 |
+
16, 17, 18,
|
83 |
+
19, 20, 21,
|
84 |
+
22, 23, 24,
|
85 |
+
28, 29, 32,
|
86 |
+
35, 36, 37,
|
87 |
+
38, 41, 42,
|
88 |
+
43, 44, 45, 46, 47, 48, 49, 50, 52, 54, 55, 58, 59, 60, 61, 62, 63, 65, 66, 68, 69, 71, 74, 75, 78, 79, 80, 83, 84, 85, 86, 89, 90, 91, 92, 93, 94, 95, 96, 97, 101, 102, 104, 105, 107, 112, 115, 117, 118, 120, 122, 124, 125, 127, 128, 131, 132, 133, 134, 135, 136, 137, 139, 140, 141, 142, 144, 145, 148, 149, 152, 153, 154, 158, 160, 161, 162, 163, 164, 165, 167, 168, 171, 172, 173, 174, 178, 179, 180, 181, 182, 185, 186, 189, 191, 192, 193, 195, 196, 197, 198, 204, 206, 208, 209, 210, 214, 216, 217, 218, 221, 223, 224, 225, 226, 228, 229, 230, 233],
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
"ICLR2020": [2, 3, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 19, 20, 21, 23, 27, 28, 31, 32, 33, 34, 35, 37, 38,
|
93 |
+
41, 42, 43, 46, 48, 49, 52, 56, 58, 60, 62, 63, 66, 67, 68, 74, 75, 76, 79, 80, 82, 83, 84, 85, 86,
|
94 |
+
87, 90, 94, 97, 101, 105, 106, 107, 109, 110, 111, 115, 116, 118, 119, 120, 121, 122, 123, 124, 126,
|
95 |
+
1, 16, 18, 22, 24, 29, 40, 45, 50, 53, 55, 57, 61, 70, 73, 77, 91, 93, 95, 103, 108, 112, 113, 117,
|
96 |
+
125, 132, 5, 44, 47, 54, 71, 88, 69, 78, 102, 221],
|
97 |
+
"ICLR2021": [140, 218, 294, 332, 362, 420, 1, 5, 12, 20, 28, 52, 67, 69, 75, 102, 103, 110, 126, 135, 138, 147, 149, 151, 160, 170, 174, 181, 182, 190, 44, 3, 4, 10, 11, 14, 17, 18, 21, 22, 24, 27, 33, 36, 38, 45, 47, 49, 59, 70, 72, 73, 78, 79, 82, 84, 86, 88, 89, 91, 92, 98, 100, 101, 104, 105, 106, 107, 108, 111, 114, 120, 123, 124, 125, 130, 131, 133, 137, 141, 142, 143, 154, 155, 156, 157, 159, 161, 162, 164, 166, 167, 168, 171, 172, 176, 193, 194, 81, 94, 95, 146, 177, 179, 184, 186],
|
98 |
+
"ICLR2022": [86, 154, 208, 222, 224, 284, 9, 10, 11, 12, 14, 25, 30, 31, 39, 42, 45, 56, 68, 73, 80, 88, 89, 90, 96, 101, 102, 104, 109, 1, 4, 6, 27, 36, 43, 47, 61, 62, 63, 65, 67, 69, 81, 82, 95, 98, 99, 100, 103, 105, 106, 108, 115, 120, 121, 122, 130, 134, 142, 143, 144, 145, 152, 153, 157, 159, 168, 173, 174, 175, 176, 179, 180, 186, 187, 193, 194, 197, 200, 201, 205, 210, 216, 226, 229, 233, 234, 235, 236, 239, 248, 261, 262, 263, 264, 269, 270, 271, 112, 113, 34, 64, 158, 172, 277, 280, 283, 286],
|
99 |
+
"ICLR2023": [210, 219, 1759, 1774, 9, 11, 12, 33, 54, 55, 61, 70, 79, 86, 88, 90, 97, 116, 128, 129, 143, 152, 160, 168, 174, 177, 181, 193, 1647, 1651, 1666, 1670, 1673, 1675, 1677, 1678, 1680, 1683, 1692, 1698, 1703, 1709, 1716, 1720, 1723, 1727, 1728, 1742, 1743, 1752, 1754, 1760, 113, 156, 214, 220, 317, 318, 1657, 1686, 1740, 1762, 1783, 1817, 2, 3, 6, 8, 17, 18, 24, 25, 29, 30, 45, 62, 77, 80, 82, 84, 89, 96, 104, 105, 107, 108, 118, 119, 120, 122, 130, 131, 133, 139, 141, 145, 146, 149, 150, 151, 153, 158, 161, 163, 164, 169, 175, 178, 179, 198, 200, 206, 207, 211, 212, 225, 226, 231, 235, 236, 237, 245, 246, 249, 252, 253, 255, 257, 258, 259, 264, 265, 266, 275, 1645, 1649, 1655, 1658, 1663, 1664, 1665, 1672, 1679, 1682, 1685, 1695, 1697, 1701, 1704, 1706, 1708, 1710, 1712, 1713, 1715, 1722, 1726, 1729, 1731, 1734, 1736, 1738, 1739, 1741, 1744, 1745, 1749, 1750, 1755, 1756, 1758, 1761, 1764, 1767, 1772, 1773, 1778, 1779, 1780, 1786, 1788, 1790, 1791, 1795, 1796, 1797, 1800, 1802, 1803, 1805, 1810, 1812, 1813, 1821, 1822, 1827, 1829, 1833, 1840, 1845, 1851, 1856],
|
100 |
+
"ICLR2024": [39, 247, 289, 400, 489, 742, 749] + [62, 78, 159, 161, 170, 192, 198, 215, 219, 335, 344, 386, 427, 432, 448, 451, 461, 472, 485, 536, 546, 559, 573, 577, 597] + [5, 9, 11, 19, 20, 30, 31, 32, 40, 49, 52, 53, 54, 56, 61, 66, 67, 73, 74, 77, 85, 87, 100, 104, 114, 116, 124, 130, 133, 138, 145, 151, 153, 156, 165, 166, 172, 181, 183, 187, 195, 204, 212, 221, 224, 230, 237, 243, 248, 257, 258, 259, 263, 272, 278, 287, 288, 291, 292, 298, 300, 302, 304, 306, 308, 318, 320, 321, 324, 325, 326, 327, 331, 332, 334, 336, 338, 340, 345, 349, 350, 356, 357, 358, 360] + [1, 2, 12, 14, 24, 26, 33, 35, 36, 41, 42, 44, 50, 51, 55, 57, 59, 70, 72, 75, 76, 81, 89, 90, 93,
|
101 |
+
94, 97, 99, 101, 105, 110, 111, 112, 117, 119, 120, 125, 128, 129, 131, 134, 135, 140, 148, 150, 157, 158, 163, 167, 173, 175, 177, 182, 185, 186, 188, 189, 197, 202, 207, 209, 210, 214, 216, 226, 231, 234, 236, 238, 239, 241, 244, 245, 249, 260, 262, 264, 265, 271, 276, 277, 279, 281, 282, 284, 286, 290, 294, 295, 301, 303, 307, 309, 313, 315, 319, 322, 333, 337, 339, 342, 354, 363, 364, 369, 373, 374, 375, 377, 378, 381, 382, 385, 388, 398, 399, 401, 407, 412, 413, 415, 416, 417, 420, 421, 422, 426, 428, 436, 437, 444, 446, 449, 453, 454, 463, 464, 469, 478, 480, 487, 490, 496, 498, 501, 502, 504, 506, 513, 516, 517, 518, 520, 521, 523, 524, 525, 537, 541, 545, 551, 552, 554, 555, 558, 562, 563, 574, 575, 579, 581, 584, 588, 595, 596, 598, 607, 608, 615, 622, 624, 625, 627, 629, 630, 634, 636, 641, 645, 647, 648, 651, 652, 654, 655, 662, 667, 668, 671, 672, 673, 681, 682, 685, 689, 690, 691, 697, 698, 701]
|
102 |
+
}
|
103 |
+
AGENTREVIEW_LOGO = r"""
|
104 |
+
_ _____ _
|
105 |
+
/\ | | | __ \ (_)
|
106 |
+
/ \ __ _ ___ _ __ | |_| |__) |_____ ___ _____ __
|
107 |
+
/ /\ \ / _` |/ _ \ '_ \| __| _ // _ \ \ / / |/ _ \ \ /\ / /
|
108 |
+
/ ____ \ (_| | __/ | | | |_| | \ \ __/\ V /| | __/\ V V /
|
109 |
+
/_/ \_\__, |\___|_| |_|\__|_| \_\___| \_/ |_|\___| \_/\_/
|
110 |
+
__/ |
|
111 |
+
|___/
|
112 |
+
"""
|
agentreview/database.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Datastore module for chat_arena.
|
3 |
+
|
4 |
+
This module provides utilities for storing the messages and the game results into database.
|
5 |
+
Currently, it supports Supabase.
|
6 |
+
"""
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
import uuid
|
10 |
+
from typing import List
|
11 |
+
|
12 |
+
from .arena import Arena
|
13 |
+
from .message import Message
|
14 |
+
|
15 |
+
# Attempt importing Supabase
|
16 |
+
try:
|
17 |
+
import supabase
|
18 |
+
|
19 |
+
# Get the Supabase URL and secret key from environment variables
|
20 |
+
SUPABASE_URL = os.environ.get("SUPABASE_URL", "")
|
21 |
+
SUPABASE_SECRET_KEY = os.environ.get("SUPABASE_SECRET_KEY", "")
|
22 |
+
assert SUPABASE_URL and SUPABASE_SECRET_KEY
|
23 |
+
except Exception:
|
24 |
+
supabase_available = False
|
25 |
+
else:
|
26 |
+
supabase_available = True
|
27 |
+
|
28 |
+
|
29 |
+
# Store the messages into the Supabase database
|
30 |
+
class SupabaseDB:
|
31 |
+
def __init__(self):
|
32 |
+
assert supabase_available and SUPABASE_URL and SUPABASE_SECRET_KEY
|
33 |
+
supabase_client = supabase.create_client(SUPABASE_URL, SUPABASE_SECRET_KEY)
|
34 |
+
self.client = supabase_client
|
35 |
+
|
36 |
+
# Save Arena state to Supabase
|
37 |
+
def save_arena(self, arena: Arena):
|
38 |
+
# Save the environment config
|
39 |
+
self._save_environment(arena)
|
40 |
+
|
41 |
+
# Save the player configs
|
42 |
+
self._save_player_configs(arena)
|
43 |
+
|
44 |
+
# Save the messages
|
45 |
+
self.save_messages(arena)
|
46 |
+
|
47 |
+
# Save the environment config of the arena
|
48 |
+
def _save_environment(self, arena: Arena):
|
49 |
+
env = arena.environment
|
50 |
+
env_config = env.to_config()
|
51 |
+
moderator_config = env_config.pop("moderator", None)
|
52 |
+
|
53 |
+
arena_row = {
|
54 |
+
"arena_id": str(arena.uuid),
|
55 |
+
"global_prompt": arena.global_prompt,
|
56 |
+
"env_type": env_config["env_type"],
|
57 |
+
"env_config": json.dumps(env_config),
|
58 |
+
}
|
59 |
+
self.client.table("Arena").insert(arena_row).execute()
|
60 |
+
|
61 |
+
# Get the moderator config
|
62 |
+
if moderator_config:
|
63 |
+
moderator_row = {
|
64 |
+
"moderator_id": str(
|
65 |
+
uuid.uuid5(arena.uuid, json.dumps(moderator_config))
|
66 |
+
),
|
67 |
+
"arena_id": str(arena.uuid),
|
68 |
+
"role_desc": moderator_config["role_desc"],
|
69 |
+
"terminal_condition": moderator_config["terminal_condition"],
|
70 |
+
"backend_type": moderator_config["backend"]["backend_type"],
|
71 |
+
"temperature": moderator_config["backend"]["temperature"],
|
72 |
+
"max_tokens": moderator_config["backend"]["max_tokens"],
|
73 |
+
}
|
74 |
+
self.client.table("Moderator").insert(moderator_row).execute()
|
75 |
+
|
76 |
+
# Save the player configs of the arena
|
77 |
+
def _save_player_configs(self, arena: Arena):
|
78 |
+
player_rows = []
|
79 |
+
for player in arena.players:
|
80 |
+
player_config = player.to_config()
|
81 |
+
player_row = {
|
82 |
+
"player_id": str(uuid.uuid5(arena.uuid, json.dumps(player_config))),
|
83 |
+
"arena_id": str(arena.uuid),
|
84 |
+
"name": player.name,
|
85 |
+
"role_desc": player_config["role_desc"],
|
86 |
+
"backend_type": player_config["backend"]["backend_type"],
|
87 |
+
"temperature": player_config["backend"].get("temperature", None),
|
88 |
+
"max_tokens": player_config["backend"].get("max_tokens", None),
|
89 |
+
}
|
90 |
+
player_rows.append(player_row)
|
91 |
+
|
92 |
+
self.client.table("Player").insert(player_rows).execute()
|
93 |
+
|
94 |
+
# Save the messages
|
95 |
+
def save_messages(self, arena: Arena, messages: List[Message] = None):
|
96 |
+
if messages is None:
|
97 |
+
messages = arena.environment.get_observation()
|
98 |
+
|
99 |
+
# Filter messages that are already logged
|
100 |
+
messages = [msg for msg in messages if not msg.logged]
|
101 |
+
|
102 |
+
message_rows = []
|
103 |
+
for message in messages:
|
104 |
+
message_row = {
|
105 |
+
"message_id": str(uuid.uuid5(arena.uuid, message.msg_hash)),
|
106 |
+
"arena_id": str(arena.uuid),
|
107 |
+
"agent_name": message.agent_name,
|
108 |
+
"content": message.content,
|
109 |
+
"turn": message.turn,
|
110 |
+
"timestamp": str(message.timestamp),
|
111 |
+
"msg_type": message.msg_type,
|
112 |
+
"visible_to": json.dumps(message.visible_to),
|
113 |
+
}
|
114 |
+
message_rows.append(message_row)
|
115 |
+
|
116 |
+
self.client.table("Message").insert(message_rows).execute()
|
117 |
+
|
118 |
+
# Mark the messages as logged
|
119 |
+
for message in messages:
|
120 |
+
message.logged = True
|
121 |
+
|
122 |
+
|
123 |
+
# Log the arena results into the Supabase database
|
124 |
+
def log_arena(arena: Arena, database=None):
|
125 |
+
if database is None:
|
126 |
+
pass
|
127 |
+
else:
|
128 |
+
database.save_arena(arena)
|
129 |
+
|
130 |
+
|
131 |
+
# Log the messages into the Supabase database
|
132 |
+
def log_messages(arena: Arena, messages: List[Message], database=None):
|
133 |
+
if database is None:
|
134 |
+
pass
|
135 |
+
else:
|
136 |
+
database.save_messages(arena, messages)
|
agentreview/dataset/__init__.py
ADDED
File without changes
|
agentreview/dataset/download_openreview_paper.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Download all papers from one year of ICLR conference using OpenReview API.
|
3 |
+
|
4 |
+
This script downloads all paper PDFs and their corresponding metadata
|
5 |
+
from the ICLR 2023 conference using the OpenReview API.
|
6 |
+
|
7 |
+
Alternative methods to download can be found in this
|
8 |
+
[colab notebook](https://colab.research.google.com/drive/1vXXNxn8lnO3j1dgoidjybbKIN0DW0Bt2),
|
9 |
+
though it's not used here.
|
10 |
+
"""
|
11 |
+
|
12 |
+
import glob
|
13 |
+
import json
|
14 |
+
import os
|
15 |
+
import time
|
16 |
+
import requests
|
17 |
+
|
18 |
+
from agentreview.arguments import parse_args
|
19 |
+
|
20 |
+
try:
|
21 |
+
import openreview
|
22 |
+
except ImportError:
|
23 |
+
raise ImportError("Please install openreview package using `pip install openreview-py`")
|
24 |
+
|
25 |
+
def download_papers(args):
|
26 |
+
"""Downloads all papers from ICLR 2023 using OpenReview API.
|
27 |
+
|
28 |
+
This function authenticates with the OpenReview API using environment
|
29 |
+
variables for the username and password. It then iterates through the
|
30 |
+
available papers, downloads the PDF, and saves the corresponding metadata
|
31 |
+
(in JSON format) in the specified directories.
|
32 |
+
|
33 |
+
Raises:
|
34 |
+
AssertionError: If the OPENREVIEW_USERNAME or OPENREVIEW_PASSWORD environment
|
35 |
+
variables are not set.
|
36 |
+
AssertionError: If the conference argument is not for ICLR.
|
37 |
+
"""
|
38 |
+
|
39 |
+
openreview_username = os.environ.get("OPENREVIEW_USERNAME")
|
40 |
+
openreview_password = os.environ.get("OPENREVIEW_PASSWORD")
|
41 |
+
|
42 |
+
assert openreview_username is not None, (
|
43 |
+
"Please set your OpenReview username through the OPENREVIEW_USERNAME environment variable."
|
44 |
+
)
|
45 |
+
assert openreview_password is not None, (
|
46 |
+
"Please set your OpenReview password through the OPENREVIEW_PASSWORD environment variable."
|
47 |
+
)
|
48 |
+
|
49 |
+
client = openreview.Client(
|
50 |
+
baseurl='https://api.openreview.net',
|
51 |
+
username=openreview_username,
|
52 |
+
password=openreview_password
|
53 |
+
)
|
54 |
+
|
55 |
+
page_size = 1000
|
56 |
+
offset = 0
|
57 |
+
papers_directory = os.path.join(args.data_dir, args.conference, "paper")
|
58 |
+
notes_directory = os.path.join(args.data_dir, args.conference, "notes")
|
59 |
+
|
60 |
+
assert "ICLR" in args.conference, "Only works for ICLR conferences!"
|
61 |
+
year = int(args.conference.split("ICLR")[-1]) # Only works for ICLR currently
|
62 |
+
ids = []
|
63 |
+
|
64 |
+
# Create directories if they don't exist
|
65 |
+
for path in [papers_directory, notes_directory]:
|
66 |
+
os.makedirs(path, exist_ok=True)
|
67 |
+
|
68 |
+
while True:
|
69 |
+
# Fetch submissions with pagination
|
70 |
+
notes = client.get_notes(
|
71 |
+
invitation=f'ICLR.cc/{year}/Conference/-/Blind_Submission',
|
72 |
+
details='all',
|
73 |
+
offset=offset,
|
74 |
+
limit=page_size
|
75 |
+
)
|
76 |
+
|
77 |
+
if not notes:
|
78 |
+
break # Exit if no more notes are available
|
79 |
+
|
80 |
+
# Get existing paper IDs to avoid re-downloading
|
81 |
+
existing_papers = glob.glob(f"{papers_directory}/*.pdf")
|
82 |
+
existing_paper_ids = {int(os.path.basename(paper).split(".pdf")[0]) for paper in existing_papers}
|
83 |
+
|
84 |
+
for note in notes:
|
85 |
+
paper_id = note.number
|
86 |
+
paper_path = os.path.join(papers_directory, f"{paper_id}.pdf")
|
87 |
+
note_path = os.path.join(notes_directory, f"{paper_id}.json")
|
88 |
+
|
89 |
+
# Skip existing papers
|
90 |
+
if paper_id in existing_paper_ids:
|
91 |
+
print(f"Paper {paper_id} already downloaded.")
|
92 |
+
continue
|
93 |
+
|
94 |
+
print(f"Title: {note.content.get('title', 'N/A')}")
|
95 |
+
print(f"Abstract: {note.content.get('abstract', 'N/A')}")
|
96 |
+
print(f"TL;DR: {note.content.get('TL;DR', 'N/A')}")
|
97 |
+
pdf_link = f"https://openreview.net/pdf?id={note.id}"
|
98 |
+
print(f"PDF Link: {pdf_link}")
|
99 |
+
|
100 |
+
# Attempt to download the paper PDF, retry if fails
|
101 |
+
tries = 0
|
102 |
+
while tries < 10:
|
103 |
+
try:
|
104 |
+
response = requests.get(pdf_link)
|
105 |
+
|
106 |
+
if response.status_code == 200:
|
107 |
+
|
108 |
+
with open(paper_path, "wb") as pdf_file:
|
109 |
+
pdf_file.write(response.content)
|
110 |
+
|
111 |
+
print(f"PDF downloaded successfully as {paper_path}")
|
112 |
+
|
113 |
+
# Save metadata as JSON, which contains the reviews, rebuttals, and decisions.
|
114 |
+
with open(note_path, "w") as note_file:
|
115 |
+
json.dump(note.to_json(), note_file, indent=2)
|
116 |
+
|
117 |
+
break
|
118 |
+
|
119 |
+
else:
|
120 |
+
print(f"Attempt {tries} failed. Status code: {response.status_code}")
|
121 |
+
if response.status_code == 429: # Too many requests
|
122 |
+
print("Too many requests. Sleeping for 10 seconds.")
|
123 |
+
time.sleep(10)
|
124 |
+
|
125 |
+
except Exception as e:
|
126 |
+
print(f"Attempt {tries} failed with error: {e}")
|
127 |
+
|
128 |
+
tries += 1
|
129 |
+
|
130 |
+
offset += page_size
|
131 |
+
|
132 |
+
|
133 |
+
if __name__ == "__main__":
|
134 |
+
args = parse_args()
|
135 |
+
download_papers(args)
|
agentreview/dataset/process_submissions.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Process and classify ICLR submissions using OpenReview API.
|
3 |
+
|
4 |
+
This script processes ICLR submissions, classifies them into subdirectories
|
5 |
+
based on decisions, extracts paper content into JSON format, and checks the
|
6 |
+
validity of the processed papers.
|
7 |
+
|
8 |
+
It includes three main functions:
|
9 |
+
- classify_ICLR_submissions_into_subdirectories: Classifies papers into
|
10 |
+
directories based on decisions.
|
11 |
+
- process_submission: Processes each submission by extracting text and saving
|
12 |
+
it as a JSON file.
|
13 |
+
- check_processed_paper: Verifies if all processed papers are valid JSON files.
|
14 |
+
"""
|
15 |
+
|
16 |
+
import os
|
17 |
+
import sys
|
18 |
+
import traceback
|
19 |
+
from collections import Counter
|
20 |
+
|
21 |
+
from tqdm import tqdm
|
22 |
+
|
23 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
24 |
+
|
25 |
+
from agentreview.arguments import parse_args
|
26 |
+
from agentreview.utility.utils import print_colored
|
27 |
+
|
28 |
+
decision_map = {
|
29 |
+
# ICLR 2023
|
30 |
+
"Reject": "Reject",
|
31 |
+
"Accept: poster": "Accept-poster",
|
32 |
+
"Accept: notable-top-25%": "Accept-notable-top-25",
|
33 |
+
"Accept: notable-top-5%": "Accept-notable-top-5",
|
34 |
+
|
35 |
+
# ICLR 2022
|
36 |
+
"Accept (Poster)": "Accept-poster",
|
37 |
+
"Accept (Oral)": "Accept-oral",
|
38 |
+
"Accept (Spotlight)": "Accept-spotlight",
|
39 |
+
|
40 |
+
# ICLR 2021
|
41 |
+
"Significant concerns (Do not publish)": "Significant-concerns",
|
42 |
+
"Concerns raised (can publish with adjustment)": "Concerns-raised",
|
43 |
+
|
44 |
+
# ICLR 2020
|
45 |
+
"Accept (Talk)": "Accept-oral", # We assume this signifies an oral presentation
|
46 |
+
|
47 |
+
# ICLR 2018
|
48 |
+
"Invite to Workshop Track": "Reject"
|
49 |
+
}
|
50 |
+
|
51 |
+
|
52 |
+
def categorize_ICLR_submissions_into_subdirectories():
|
53 |
+
"""Classifies ICLR submissions into subdirectories based on review decisions.
|
54 |
+
|
55 |
+
This function iterates through the review notes and identifies the decision
|
56 |
+
(recommendation or final decision) for each submission. It then moves the
|
57 |
+
notes and their corresponding papers into directories based on the decision.
|
58 |
+
|
59 |
+
Raises:
|
60 |
+
AssertionError: If the line containing the decision does not have the
|
61 |
+
expected format.
|
62 |
+
"""
|
63 |
+
note_dir = f"data/{args.conference}/notes"
|
64 |
+
paper_dir = f"data/{args.conference}/paper"
|
65 |
+
|
66 |
+
for note in os.listdir(note_dir):
|
67 |
+
print(note)
|
68 |
+
|
69 |
+
# Skip directories or irrelevant files
|
70 |
+
if os.path.isdir(os.path.join(note_dir, note)) or ".DS_Store" in note:
|
71 |
+
continue
|
72 |
+
|
73 |
+
note_path = os.path.join(note_dir, note)
|
74 |
+
lines = open(note_path, "r").readlines()
|
75 |
+
decision = None
|
76 |
+
|
77 |
+
for line in tqdm(lines):
|
78 |
+
if "\"recommendation\"" in line:
|
79 |
+
assert Counter(line)["\""] == 4, "Unexpected format in recommendation line."
|
80 |
+
print(line)
|
81 |
+
decision = line.split("\"recommendation\"")[1].split("\"")[1]
|
82 |
+
break
|
83 |
+
|
84 |
+
elif "\"decision\"" in line:
|
85 |
+
assert Counter(line)["\""] == 4, "Unexpected format in decision line."
|
86 |
+
print(line)
|
87 |
+
try:
|
88 |
+
decision = line.split("\"decision\"")[1].split("\"")[1]
|
89 |
+
break
|
90 |
+
except Exception:
|
91 |
+
traceback.print_exc()
|
92 |
+
print_colored(line, 'red')
|
93 |
+
|
94 |
+
if decision is None:
|
95 |
+
# Possibly withdrawn papers
|
96 |
+
print_colored(f"Could not find decision for {note}", "red")
|
97 |
+
continue
|
98 |
+
|
99 |
+
os.makedirs(os.path.join(note_dir, decision_map[decision]), exist_ok=True)
|
100 |
+
os.makedirs(os.path.join(paper_dir, decision_map[decision]), exist_ok=True)
|
101 |
+
os.rename(note_path, os.path.join(note_dir, decision_map[decision], note))
|
102 |
+
|
103 |
+
paper_id = int(note.split(".json")[0])
|
104 |
+
paper_path = os.path.join(paper_dir, f"{paper_id}.pdf")
|
105 |
+
os.rename(paper_path, os.path.join(paper_dir, decision_map[decision], f"{paper_id}.pdf"))
|
106 |
+
|
107 |
+
|
108 |
+
if __name__ == "__main__":
|
109 |
+
args = parse_args()
|
110 |
+
|
111 |
+
# Extract contents of each paper into a JSON file
|
112 |
+
categorize_ICLR_submissions_into_subdirectories()
|
agentreview/environments/__init__.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..config import EnvironmentConfig
|
2 |
+
from .base import Environment, TimeStep
|
3 |
+
from .conversation import Conversation, ModeratedConversation
|
4 |
+
from .paper_review import PaperReview
|
5 |
+
from .paper_decision import PaperDecision
|
6 |
+
|
7 |
+
ALL_ENVIRONMENTS = [
|
8 |
+
Conversation,
|
9 |
+
ModeratedConversation,
|
10 |
+
PaperReview,
|
11 |
+
PaperDecision,
|
12 |
+
]
|
13 |
+
|
14 |
+
ENV_REGISTRY = {env.type_name: env for env in ALL_ENVIRONMENTS}
|
15 |
+
|
16 |
+
|
17 |
+
# Load an environment from a config dictionary
|
18 |
+
def load_environment(config: EnvironmentConfig):
|
19 |
+
try:
|
20 |
+
env_cls = ENV_REGISTRY[config["env_type"]]
|
21 |
+
except KeyError:
|
22 |
+
raise ValueError(f"Unknown environment type: {config['env_type']}")
|
23 |
+
|
24 |
+
env = env_cls.from_config(config)
|
25 |
+
return env
|
agentreview/environments/base.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Dict, List
|
4 |
+
|
5 |
+
from ..config import Configurable, EnvironmentConfig
|
6 |
+
from ..message import Message
|
7 |
+
from ..utils import AttributedDict
|
8 |
+
|
9 |
+
|
10 |
+
@dataclass
|
11 |
+
class TimeStep(AttributedDict):
|
12 |
+
"""
|
13 |
+
Represents a single step in time within the simulation.
|
14 |
+
|
15 |
+
It includes observation, reward, and terminal state.
|
16 |
+
|
17 |
+
Attributes:
|
18 |
+
observation (List[Message]): A list of messages (observations) for the current timestep.
|
19 |
+
reward (Dict[str, float]): A dictionary with player names as keys and corresponding rewards as values.
|
20 |
+
terminal (bool): A boolean indicating whether the current state is terminal (end of episode).
|
21 |
+
"""
|
22 |
+
|
23 |
+
observation: List[Message]
|
24 |
+
reward: Dict[str, float]
|
25 |
+
terminal: bool
|
26 |
+
|
27 |
+
|
28 |
+
class Environment(Configurable):
|
29 |
+
"""
|
30 |
+
Abstract class representing an environment.
|
31 |
+
|
32 |
+
It defines the necessary methods any environment must implement.
|
33 |
+
|
34 |
+
Inherits from:
|
35 |
+
Configurable: A custom class that provides methods to handle configuration settings.
|
36 |
+
|
37 |
+
Attributes:
|
38 |
+
type_name (str): Type of the environment, typically set to the lower case of the class name.
|
39 |
+
|
40 |
+
Note:
|
41 |
+
Subclasses should override and implement the abstract methods defined here.
|
42 |
+
"""
|
43 |
+
|
44 |
+
type_name = None
|
45 |
+
phase_index = 0
|
46 |
+
task = None
|
47 |
+
@abstractmethod
|
48 |
+
def __init__(self, player_names: List[str], **kwargs):
|
49 |
+
"""
|
50 |
+
Initialize the Environment.
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
player_names (List[str]): Names of the players in the environment.
|
54 |
+
"""
|
55 |
+
super().__init__(
|
56 |
+
player_names=player_names, **kwargs
|
57 |
+
) # registers the arguments with Configurable
|
58 |
+
self.player_names = player_names
|
59 |
+
|
60 |
+
def __init_subclass__(cls, **kwargs):
|
61 |
+
"""
|
62 |
+
Automatically called when a subclass is being initialized.
|
63 |
+
|
64 |
+
Here it's used to check if the subclass has the required attributes.
|
65 |
+
"""
|
66 |
+
for required in ("type_name",):
|
67 |
+
if getattr(cls, required) is None:
|
68 |
+
cls.type_name = cls.__name__.lower()
|
69 |
+
|
70 |
+
return super().__init_subclass__(**kwargs)
|
71 |
+
|
72 |
+
@abstractmethod
|
73 |
+
def reset(self):
|
74 |
+
"""
|
75 |
+
Reset the environment to its initial state.
|
76 |
+
|
77 |
+
Note:
|
78 |
+
This method must be implemented by subclasses.
|
79 |
+
"""
|
80 |
+
pass
|
81 |
+
|
82 |
+
def to_config(self) -> EnvironmentConfig:
|
83 |
+
self._config_dict["env_type"] = self.type_name
|
84 |
+
return EnvironmentConfig(**self._config_dict)
|
85 |
+
|
86 |
+
@property
|
87 |
+
def num_players(self) -> int:
|
88 |
+
"""Get the number of players."""
|
89 |
+
return len(self.player_names)
|
90 |
+
|
91 |
+
@abstractmethod
|
92 |
+
def get_next_player(self) -> str:
|
93 |
+
"""
|
94 |
+
Return the name of the next player.
|
95 |
+
|
96 |
+
Note:
|
97 |
+
This method must be implemented by subclasses.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
str: The name of the next player.
|
101 |
+
"""
|
102 |
+
pass
|
103 |
+
|
104 |
+
@abstractmethod
|
105 |
+
def get_observation(self, player_name=None) -> List[Message]:
|
106 |
+
"""
|
107 |
+
Return observation for a given player.
|
108 |
+
|
109 |
+
Note:
|
110 |
+
This method must be implemented by subclasses.
|
111 |
+
|
112 |
+
Parameters:
|
113 |
+
player_name (str, optional): The name of the player for whom to get the observation.
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
List[Message]: The observation for the player in the form of a list of messages.
|
117 |
+
"""
|
118 |
+
pass
|
119 |
+
|
120 |
+
@abstractmethod
|
121 |
+
def print(self):
|
122 |
+
"""Print the environment state."""
|
123 |
+
pass
|
124 |
+
|
125 |
+
@abstractmethod
|
126 |
+
def step(self, player_name: str, action: str) -> TimeStep:
|
127 |
+
"""
|
128 |
+
Execute a step in the environment given an action from a player.
|
129 |
+
|
130 |
+
Note:
|
131 |
+
This method must be implemented by subclasses.
|
132 |
+
|
133 |
+
Parameters:
|
134 |
+
player_name (str): The name of the player.
|
135 |
+
action (str): The action that the player wants to take.
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
TimeStep: An object of the TimeStep class containing the observation, reward, and done state.
|
139 |
+
"""
|
140 |
+
pass
|
141 |
+
|
142 |
+
@abstractmethod
|
143 |
+
def check_action(self, action: str, player_name: str) -> bool:
|
144 |
+
"""
|
145 |
+
Check whether a given action is valid for a player.
|
146 |
+
|
147 |
+
Note:
|
148 |
+
This method must be implemented by subclasses.
|
149 |
+
|
150 |
+
Parameters:
|
151 |
+
action (str): The action to be checked.
|
152 |
+
player_name (str): The name of the player.
|
153 |
+
|
154 |
+
Returns:
|
155 |
+
bool: True if the action is valid, False otherwise.
|
156 |
+
"""
|
157 |
+
return True
|
158 |
+
|
159 |
+
@abstractmethod
|
160 |
+
def is_terminal(self) -> bool:
|
161 |
+
"""
|
162 |
+
Check whether the environment is in a terminal state (end of episode).
|
163 |
+
|
164 |
+
Note:
|
165 |
+
This method must be implemented by subclasses.
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
bool: True if the environment is in a terminal state, False otherwise.
|
169 |
+
"""
|
170 |
+
pass
|
171 |
+
|
172 |
+
def get_zero_rewards(self) -> Dict[str, float]:
|
173 |
+
"""
|
174 |
+
Return a dictionary with all player names as keys and zero as reward.
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
Dict[str, float]: A dictionary of players and their rewards (all zero).
|
178 |
+
"""
|
179 |
+
return {player_name: 0.0 for player_name in self.player_names}
|
180 |
+
|
181 |
+
def get_one_rewards(self) -> Dict[str, float]:
|
182 |
+
"""
|
183 |
+
Return a dictionary with all player names as keys and one as reward.
|
184 |
+
|
185 |
+
Returns:
|
186 |
+
Dict[str, float]: A dictionary of players and their rewards (all one).
|
187 |
+
"""
|
188 |
+
return {player_name: 1.0 for player_name in self.player_names}
|
agentreview/environments/conversation.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Union
|
2 |
+
|
3 |
+
from ..agent import SIGNAL_END_OF_CONVERSATION, Moderator
|
4 |
+
from ..config import AgentConfig, EnvironmentConfig
|
5 |
+
from ..message import Message, MessagePool
|
6 |
+
from .base import Environment, TimeStep
|
7 |
+
|
8 |
+
|
9 |
+
class Conversation(Environment):
|
10 |
+
"""
|
11 |
+
Turn-based fully observable conversation environment.
|
12 |
+
|
13 |
+
Next speaker order is either parallel or round-robin.
|
14 |
+
"""
|
15 |
+
|
16 |
+
type_name = "conversation"
|
17 |
+
|
18 |
+
def __init__(self, player_names: List[str], parallel: bool = False, **kwargs):
|
19 |
+
super().__init__(player_names=player_names, parallel=parallel, **kwargs)
|
20 |
+
|
21 |
+
self.parallel = parallel
|
22 |
+
|
23 |
+
# The "state" of the environment is maintained by the message pool
|
24 |
+
self.message_pool = MessagePool()
|
25 |
+
|
26 |
+
self._current_turn = 0
|
27 |
+
self._next_player_index = 0
|
28 |
+
|
29 |
+
def reset(self):
|
30 |
+
self._current_turn = 0
|
31 |
+
self._next_player_index = 0
|
32 |
+
self.message_pool.reset()
|
33 |
+
|
34 |
+
init_timestep = TimeStep(
|
35 |
+
observation=[], reward=self.get_zero_rewards(), terminal=False
|
36 |
+
)
|
37 |
+
return init_timestep
|
38 |
+
|
39 |
+
@property
|
40 |
+
def phase_index(self):
|
41 |
+
return self._phase_index
|
42 |
+
|
43 |
+
@phase_index.setter
|
44 |
+
def phase_index(self, value):
|
45 |
+
self._phase_index = value
|
46 |
+
|
47 |
+
def to_config(self) -> EnvironmentConfig:
|
48 |
+
return EnvironmentConfig(
|
49 |
+
env_type=self.type_name,
|
50 |
+
player_names=self.player_names,
|
51 |
+
parallel=self.parallel,
|
52 |
+
)
|
53 |
+
|
54 |
+
def print(self):
|
55 |
+
self.message_pool.print()
|
56 |
+
|
57 |
+
def get_next_player(self) -> str:
|
58 |
+
"""Get the next player."""
|
59 |
+
return self.player_names[self._next_player_index]
|
60 |
+
|
61 |
+
def get_observation(self, player_name=None) -> List[Message]:
|
62 |
+
"""Get observation for the player."""
|
63 |
+
if player_name is None:
|
64 |
+
return self.message_pool.get_all_messages()
|
65 |
+
else:
|
66 |
+
return self.message_pool.get_visible_messages(
|
67 |
+
player_name, turn=self._current_turn
|
68 |
+
)
|
69 |
+
|
70 |
+
def is_terminal(self) -> bool:
|
71 |
+
"""Check if the conversation is over."""
|
72 |
+
# If the last message is the signal, then the conversation is over
|
73 |
+
if self.message_pool.last_message.content.startswith(
|
74 |
+
SIGNAL_END_OF_CONVERSATION
|
75 |
+
):
|
76 |
+
return True
|
77 |
+
|
78 |
+
def step(self, player_name: str, action: str) -> TimeStep:
|
79 |
+
"""
|
80 |
+
Step function that is called by the arena.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
player_name: the name of the player that takes the action
|
84 |
+
action: the action that the agents wants to take
|
85 |
+
"""
|
86 |
+
message = Message(
|
87 |
+
agent_name=player_name, content=action, turn=self._current_turn
|
88 |
+
)
|
89 |
+
self.message_pool.append_message(message)
|
90 |
+
|
91 |
+
# Update the counters
|
92 |
+
if not self.parallel or self._next_player_index == 0:
|
93 |
+
self._current_turn += 1
|
94 |
+
self._next_player_index = (self._next_player_index + 1) % self.num_players
|
95 |
+
|
96 |
+
timestep = TimeStep(
|
97 |
+
observation=self.get_observation(),
|
98 |
+
reward=self.get_zero_rewards(),
|
99 |
+
terminal=self.is_terminal(),
|
100 |
+
) # Return all the messages
|
101 |
+
return timestep
|
102 |
+
|
103 |
+
|
104 |
+
class ModeratedConversation(Conversation):
|
105 |
+
"""
|
106 |
+
Turn-based fully observable conversation environment.
|
107 |
+
|
108 |
+
Next speaker order is either parallel or round-robin.
|
109 |
+
Moderator is a special agent that can see all messages and can decide whether the conversation is over.
|
110 |
+
"""
|
111 |
+
|
112 |
+
type_name = "moderated_conversation"
|
113 |
+
|
114 |
+
def __init__(
|
115 |
+
self,
|
116 |
+
player_names: List[str],
|
117 |
+
moderator: Union[Moderator, AgentConfig],
|
118 |
+
parallel: bool = False,
|
119 |
+
moderator_visibility="all",
|
120 |
+
moderator_period=None,
|
121 |
+
**kwargs,
|
122 |
+
):
|
123 |
+
super().__init__(player_names=player_names, parallel=parallel, **kwargs)
|
124 |
+
|
125 |
+
if isinstance(moderator, AgentConfig):
|
126 |
+
moderator_config = moderator
|
127 |
+
moderator = Moderator.from_config(moderator_config)
|
128 |
+
elif not isinstance(moderator, Moderator):
|
129 |
+
raise ValueError(
|
130 |
+
"moderator must be either an AgentConfig or a Moderator instance."
|
131 |
+
)
|
132 |
+
|
133 |
+
self.moderator = moderator
|
134 |
+
self.moderator_visibility = moderator_visibility
|
135 |
+
if moderator_period is None:
|
136 |
+
if parallel:
|
137 |
+
self.moderator_period = "round"
|
138 |
+
else:
|
139 |
+
self.moderator_period = "turn"
|
140 |
+
else:
|
141 |
+
self.moderator_period = moderator_period
|
142 |
+
|
143 |
+
def to_config(self) -> EnvironmentConfig:
|
144 |
+
# This environment contains some special config arguments that needs to be handle specially
|
145 |
+
return EnvironmentConfig(
|
146 |
+
env_type=self.type_name,
|
147 |
+
player_names=self.player_names,
|
148 |
+
parallel=self.parallel,
|
149 |
+
moderator=self.moderator.to_config(),
|
150 |
+
moderator_visibility=self.moderator_visibility,
|
151 |
+
moderator_period=self.moderator_period,
|
152 |
+
)
|
153 |
+
|
154 |
+
def step(self, player_name: str, action: str) -> TimeStep:
|
155 |
+
"""
|
156 |
+
Step function that is called by the arena.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
player_name: the name of the player that takes the action
|
160 |
+
action: the action that the agents wants to take
|
161 |
+
"""
|
162 |
+
message = Message(
|
163 |
+
agent_name=player_name, content=action, turn=self._current_turn
|
164 |
+
)
|
165 |
+
self.message_pool.append_message(message)
|
166 |
+
|
167 |
+
# Round-robin order for the next player
|
168 |
+
self._next_player_index = (self._next_player_index + 1) % self.num_players
|
169 |
+
|
170 |
+
if self.moderator_period == "turn" or (
|
171 |
+
self.moderator_period == "round" and self._next_player_index == 0
|
172 |
+
):
|
173 |
+
# Moderator's turn
|
174 |
+
moderator_history = self.message_pool.get_all_messages()
|
175 |
+
moderator_response = self.moderator(moderator_history)
|
176 |
+
moderator_message = Message(
|
177 |
+
agent_name=self.moderator.name,
|
178 |
+
content=moderator_response,
|
179 |
+
turn=self._current_turn,
|
180 |
+
visible_to=self.moderator_visibility,
|
181 |
+
)
|
182 |
+
self.message_pool.append_message(moderator_message)
|
183 |
+
terminal = (
|
184 |
+
self.moderator.is_terminal(moderator_history) or self.is_terminal()
|
185 |
+
)
|
186 |
+
else:
|
187 |
+
terminal = self.is_terminal()
|
188 |
+
|
189 |
+
# Update the counters
|
190 |
+
if not self.parallel or self._next_player_index == 0:
|
191 |
+
self._current_turn += 1
|
192 |
+
|
193 |
+
timestep = TimeStep(
|
194 |
+
observation=self.get_observation(),
|
195 |
+
reward=self.get_zero_rewards(),
|
196 |
+
terminal=terminal,
|
197 |
+
) # Return all the messages
|
198 |
+
return timestep
|
agentreview/environments/paper_decision.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import traceback
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
from agentreview.environments import Conversation
|
6 |
+
from .base import TimeStep
|
7 |
+
from ..message import Message, MessagePool
|
8 |
+
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
class PaperDecision(Conversation):
|
14 |
+
"""
|
15 |
+
Area chairs make decision based on the meta reviews
|
16 |
+
"""
|
17 |
+
|
18 |
+
type_name = "paper_decision"
|
19 |
+
|
20 |
+
def __init__(self,
|
21 |
+
player_names: List[str],
|
22 |
+
experiment_setting: dict,
|
23 |
+
paper_ids: List[int] = None,
|
24 |
+
metareviews: List[str] = None,
|
25 |
+
parallel: bool = False,
|
26 |
+
|
27 |
+
**kwargs):
|
28 |
+
"""
|
29 |
+
|
30 |
+
Args:
|
31 |
+
paper_id (int): the id of the paper, such as 917
|
32 |
+
paper_decision (str): the decision of the paper, such as "Accept: notable-top-25%"
|
33 |
+
|
34 |
+
"""
|
35 |
+
|
36 |
+
# Inherit from the parent class of `class Conversation`
|
37 |
+
super(Conversation, self).__init__(player_names=player_names, parallel=parallel, **kwargs)
|
38 |
+
|
39 |
+
self.paper_ids = paper_ids
|
40 |
+
self.metareviews = metareviews
|
41 |
+
self.parallel = parallel
|
42 |
+
self.experiment_setting = experiment_setting
|
43 |
+
self.ac_scoring_method = kwargs.get("ac_scoring_method")
|
44 |
+
# The "state" of the environment is maintained by the message pool
|
45 |
+
self.message_pool = MessagePool()
|
46 |
+
|
47 |
+
self.ac_decisions = None
|
48 |
+
|
49 |
+
self._current_turn = 0
|
50 |
+
self._next_player_index = 0
|
51 |
+
self.phase_index = 5 # "ACs make decision based on meta review" is the last phase (Phase 5)
|
52 |
+
|
53 |
+
self._phases = None
|
54 |
+
|
55 |
+
@property
|
56 |
+
def phases(self):
|
57 |
+
|
58 |
+
if self._phases is None:
|
59 |
+
self._phases = {
|
60 |
+
5: {
|
61 |
+
"name": "ac_make_decisions",
|
62 |
+
'speaking_order': ["AC"]
|
63 |
+
},
|
64 |
+
}
|
65 |
+
return self._phases
|
66 |
+
|
67 |
+
def step(self, player_name: str, action: str) -> TimeStep:
|
68 |
+
"""
|
69 |
+
Step function that is called by the arena.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
player_name: the name of the player that takes the action
|
73 |
+
action: the action that the agents wants to take
|
74 |
+
"""
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
message = Message(
|
79 |
+
agent_name=player_name, content=action, turn=self._current_turn
|
80 |
+
)
|
81 |
+
self.message_pool.append_message(message)
|
82 |
+
|
83 |
+
speaking_order = self.phases[self.phase_index]["speaking_order"]
|
84 |
+
|
85 |
+
# Reached the end of the speaking order. Move to the next phase.
|
86 |
+
|
87 |
+
logging.info(f"Phase {self.phase_index}: {self.phases[self.phase_index]['name']} "
|
88 |
+
f"| Player {self._next_player_index}: {speaking_order[self._next_player_index]}")
|
89 |
+
if self._next_player_index == len(speaking_order) - 1:
|
90 |
+
self._next_player_index = 0
|
91 |
+
logger.info(f"Phase {self.phase_index}: end of the speaking order. Move to Phase {self.phase_index + 1}.")
|
92 |
+
self.phase_index += 1
|
93 |
+
self._current_turn += 1
|
94 |
+
else:
|
95 |
+
self._next_player_index += 1
|
96 |
+
|
97 |
+
timestep = TimeStep(
|
98 |
+
observation=self.get_observation(),
|
99 |
+
reward=self.get_zero_rewards(),
|
100 |
+
terminal=self.is_terminal(),
|
101 |
+
) # Return all the messages
|
102 |
+
|
103 |
+
return timestep
|
104 |
+
|
105 |
+
|
106 |
+
def check_action(self, action: str, player_name: str) -> bool:
|
107 |
+
"""Check if the action is valid."""
|
108 |
+
|
109 |
+
if player_name.startswith("AC"):
|
110 |
+
|
111 |
+
try:
|
112 |
+
self.ac_decisions = self.parse_ac_decisions(action)
|
113 |
+
|
114 |
+
except:
|
115 |
+
traceback.print_exc()
|
116 |
+
return False
|
117 |
+
|
118 |
+
if not isinstance(self.ac_decisions, dict):
|
119 |
+
return False
|
120 |
+
|
121 |
+
return True
|
122 |
+
|
123 |
+
@property
|
124 |
+
def ac_decisions(self):
|
125 |
+
return self._ac_decisions
|
126 |
+
|
127 |
+
@ac_decisions.setter
|
128 |
+
def ac_decisions(self, value):
|
129 |
+
self._ac_decisions = value
|
130 |
+
|
131 |
+
def parse_ac_decisions(self, action: str):
|
132 |
+
"""
|
133 |
+
Parse the decisions made by the ACs
|
134 |
+
"""
|
135 |
+
|
136 |
+
lines = action.split("\n")
|
137 |
+
|
138 |
+
paper2rating = {}
|
139 |
+
|
140 |
+
paper_id, rank = None, None
|
141 |
+
|
142 |
+
for line in lines:
|
143 |
+
|
144 |
+
if line.lower().startswith("paper id:"):
|
145 |
+
paper_id = int(line.split(":")[1].split('(')[0].strip())
|
146 |
+
elif self.ac_scoring_method == "ranking" and line.lower().startswith("willingness to accept:"):
|
147 |
+
rank = int(line.split(":")[1].strip())
|
148 |
+
|
149 |
+
elif self.ac_scoring_method == "recommendation" and line.lower().startswith("decision"):
|
150 |
+
rank = line.split(":")[1].strip()
|
151 |
+
|
152 |
+
|
153 |
+
|
154 |
+
if paper_id in paper2rating:
|
155 |
+
raise ValueError(f"Paper {paper_id} is assigned a rank twice.")
|
156 |
+
|
157 |
+
if paper_id is not None and rank is not None:
|
158 |
+
paper2rating[paper_id] = rank
|
159 |
+
paper_id, rank = None, None
|
160 |
+
|
161 |
+
return paper2rating
|
agentreview/environments/paper_review.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os.path as osp
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
from agentreview.environments import Conversation
|
7 |
+
from agentreview.utility.utils import get_rebuttal_dir
|
8 |
+
from .base import TimeStep
|
9 |
+
from ..message import Message
|
10 |
+
from ..paper_review_message import PaperReviewMessagePool
|
11 |
+
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
class PaperReview(Conversation):
|
16 |
+
"""
|
17 |
+
Discussion between reviewers and area chairs.
|
18 |
+
|
19 |
+
There are several phases in the reviewing process:
|
20 |
+
reviewer_write_reviews: reviewers write their reviews based on the paper content.
|
21 |
+
author_reviewer_discussion: An author respond to comments from the reviewers.
|
22 |
+
reviewer_ac_discussion: reviewers and an area chair discuss the paper.
|
23 |
+
ac_discussion: an area chair makes the final decision.
|
24 |
+
"""
|
25 |
+
|
26 |
+
type_name = "paper_review"
|
27 |
+
|
28 |
+
def __init__(self, player_names: List[str], paper_id: int, paper_decision: str, experiment_setting: dict, args,
|
29 |
+
parallel: bool = False,
|
30 |
+
|
31 |
+
**kwargs):
|
32 |
+
"""
|
33 |
+
Args:
|
34 |
+
paper_id (int): the id of the paper, such as 917
|
35 |
+
paper_decision (str): the decision of the paper, such as "Accept: notable-top-25%"
|
36 |
+
"""
|
37 |
+
|
38 |
+
# Inherit from the parent class of `class Conversation`
|
39 |
+
super(Conversation, self).__init__(player_names=player_names, parallel=parallel, **kwargs)
|
40 |
+
self.args = args
|
41 |
+
self.paper_id = paper_id
|
42 |
+
self.paper_decision = paper_decision
|
43 |
+
self.parallel = parallel
|
44 |
+
self.experiment_setting = experiment_setting
|
45 |
+
self.player_to_test = experiment_setting.get('player_to_test', None)
|
46 |
+
self.task = kwargs.get("task")
|
47 |
+
self.experiment_name = args.experiment_name
|
48 |
+
|
49 |
+
# The "state" of the environment is maintained by the message pool
|
50 |
+
self.message_pool = PaperReviewMessagePool(experiment_setting)
|
51 |
+
|
52 |
+
self.phase_index = 0
|
53 |
+
self._phases = None
|
54 |
+
|
55 |
+
@property
|
56 |
+
def phases(self):
|
57 |
+
|
58 |
+
if self._phases is not None:
|
59 |
+
return self._phases
|
60 |
+
|
61 |
+
reviewer_names = [name for name in self.player_names if name.startswith("Reviewer")]
|
62 |
+
|
63 |
+
num_reviewers = len(reviewer_names)
|
64 |
+
|
65 |
+
reviewer_names = [f"Reviewer {i}" for i in range(1, num_reviewers + 1)]
|
66 |
+
|
67 |
+
self._phases = {
|
68 |
+
# In phase 0, no LLM-based agents are called.
|
69 |
+
0: {
|
70 |
+
"name": "paper_extraction",
|
71 |
+
'speaking_order': ["Paper Extractor"],
|
72 |
+
},
|
73 |
+
|
74 |
+
1: {
|
75 |
+
"name": 'reviewer_write_reviews',
|
76 |
+
'speaking_order': reviewer_names
|
77 |
+
},
|
78 |
+
|
79 |
+
# The author responds to each reviewer's review
|
80 |
+
2: {
|
81 |
+
'name': 'author_reviewer_discussion',
|
82 |
+
'speaking_order': ["Author" for _ in reviewer_names],
|
83 |
+
},
|
84 |
+
|
85 |
+
3: {
|
86 |
+
'name': 'reviewer_ac_discussion',
|
87 |
+
'speaking_order': ["AC"] + reviewer_names,
|
88 |
+
},
|
89 |
+
|
90 |
+
4: {
|
91 |
+
'name': 'ac_write_metareviews',
|
92 |
+
'speaking_order': ["AC"]
|
93 |
+
},
|
94 |
+
5: {
|
95 |
+
'name': 'ac_makes_decisions',
|
96 |
+
'speaking_order': ["AC"]
|
97 |
+
},
|
98 |
+
}
|
99 |
+
|
100 |
+
return self.phases
|
101 |
+
|
102 |
+
@phases.setter
|
103 |
+
def phases(self, value):
|
104 |
+
self._phases = value
|
105 |
+
|
106 |
+
def reset(self):
|
107 |
+
self._current_phase = "review"
|
108 |
+
self.phase_index = 0
|
109 |
+
return super().reset()
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
def load_message_history_from_cache(self):
|
114 |
+
if self._phase_index == 0:
|
115 |
+
|
116 |
+
print("Loading message history from BASELINE experiment")
|
117 |
+
|
118 |
+
full_paper_discussion_path = get_rebuttal_dir(paper_id=self.paper_id,
|
119 |
+
experiment_name="BASELINE",
|
120 |
+
model_name=self.args.model_name,
|
121 |
+
conference=self.args.conference)
|
122 |
+
|
123 |
+
messages = json.load(open(osp.join(full_paper_discussion_path, f"{self.paper_id}.json"), 'r',
|
124 |
+
encoding='utf-8'))['messages']
|
125 |
+
|
126 |
+
num_messages_from_AC = 0
|
127 |
+
|
128 |
+
for msg in messages:
|
129 |
+
|
130 |
+
# We have already extracted contents from the paper.
|
131 |
+
if msg['agent_name'] == "Paper Extractor":
|
132 |
+
continue
|
133 |
+
|
134 |
+
# Encountering the 2nd message from the AC. Stop loading messages.
|
135 |
+
if msg['agent_name'] == "AC" and num_messages_from_AC == 1:
|
136 |
+
break
|
137 |
+
|
138 |
+
if msg['agent_name'] == "AC":
|
139 |
+
num_messages_from_AC += 1
|
140 |
+
|
141 |
+
message = Message(**msg)
|
142 |
+
self.message_pool.append_message(message)
|
143 |
+
|
144 |
+
num_unique_reviewers = len(
|
145 |
+
set([msg['agent_name'] for msg in messages if msg['agent_name'].startswith("Reviewer")]))
|
146 |
+
|
147 |
+
assert num_unique_reviewers == self.args.num_reviewers_per_paper
|
148 |
+
|
149 |
+
self._phase_index = 4
|
150 |
+
|
151 |
+
def step(self, player_name: str, action: str) -> TimeStep:
|
152 |
+
"""
|
153 |
+
Step function that is called by the arena.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
player_name: the name of the player that takes the action
|
157 |
+
action: the action that the agents wants to take
|
158 |
+
"""
|
159 |
+
|
160 |
+
message = Message(
|
161 |
+
agent_name=player_name, content=action, turn=self._current_turn
|
162 |
+
)
|
163 |
+
self.message_pool.append_message(message)
|
164 |
+
|
165 |
+
speaking_order = self.phases[self.phase_index]["speaking_order"]
|
166 |
+
|
167 |
+
# Reached the end of the speaking order. Move to the next phase.
|
168 |
+
logging.info(f"Phase {self.phase_index}: {self.phases[self._phase_index]['name']} "
|
169 |
+
f"| Player {self._next_player_index}: {speaking_order[self._next_player_index]}")
|
170 |
+
|
171 |
+
terminal = self.is_terminal()
|
172 |
+
|
173 |
+
if self._next_player_index == len(speaking_order) - 1:
|
174 |
+
self._next_player_index = 0
|
175 |
+
|
176 |
+
if self.phase_index == 4:
|
177 |
+
terminal = True
|
178 |
+
logger.info(
|
179 |
+
"Finishing the simulation for Phase I - IV. Please run `python run_paper_decision_cli.py ` for "
|
180 |
+
"Phase V. (AC makes decisions).")
|
181 |
+
|
182 |
+
else:
|
183 |
+
print(f"Phase {self.phase_index}: end of the speaking order. Move to Phase ({self.phase_index + 1}).")
|
184 |
+
self.phase_index += 1
|
185 |
+
self._current_turn += 1
|
186 |
+
|
187 |
+
else:
|
188 |
+
self._next_player_index += 1
|
189 |
+
|
190 |
+
timestep = TimeStep(
|
191 |
+
observation=self.get_observation(),
|
192 |
+
reward=self.get_zero_rewards(),
|
193 |
+
terminal=terminal,
|
194 |
+
) # Return all the messages
|
195 |
+
|
196 |
+
return timestep
|
197 |
+
|
198 |
+
def get_next_player(self) -> str:
|
199 |
+
"""Get the next player in the current phase."""
|
200 |
+
speaking_order = self.phases[self.phase_index]["speaking_order"]
|
201 |
+
next_player = speaking_order[self._next_player_index]
|
202 |
+
return next_player
|
203 |
+
|
204 |
+
def get_observation(self, player_name=None) -> List[Message]:
|
205 |
+
"""Get observation for the player."""
|
206 |
+
if player_name is None:
|
207 |
+
return self.message_pool.get_all_messages()
|
208 |
+
else:
|
209 |
+
|
210 |
+
return self.message_pool.get_visible_messages_for_paper_review(
|
211 |
+
player_name, phase_index=self.phase_index, next_player_idx=self._next_player_index,
|
212 |
+
player_names=self.player_names
|
213 |
+
)
|
agentreview/experiment_config.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
BASELINE: The default settings which all other settings compare against.
|
3 |
+
|
4 |
+
"""
|
5 |
+
|
6 |
+
baseline_setting = {
|
7 |
+
"AC": [
|
8 |
+
"BASELINE"
|
9 |
+
],
|
10 |
+
|
11 |
+
"reviewer": [
|
12 |
+
"BASELINE",
|
13 |
+
"BASELINE",
|
14 |
+
"BASELINE"
|
15 |
+
],
|
16 |
+
|
17 |
+
"author": [
|
18 |
+
"BASELINE"
|
19 |
+
],
|
20 |
+
"global_settings":{
|
21 |
+
"provides_numeric_rating": ['reviewer', 'ac'],
|
22 |
+
"persons_aware_of_authors_identities": []
|
23 |
+
}
|
24 |
+
}
|
25 |
+
|
26 |
+
benign_Rx1_setting = {
|
27 |
+
"AC": [
|
28 |
+
"BASELINE"
|
29 |
+
],
|
30 |
+
|
31 |
+
"reviewer": [
|
32 |
+
"benign",
|
33 |
+
"BASELINE",
|
34 |
+
"BASELINE"
|
35 |
+
],
|
36 |
+
|
37 |
+
"author": [
|
38 |
+
"BASELINE"
|
39 |
+
],
|
40 |
+
"global_settings":{
|
41 |
+
"provides_numeric_rating": ['reviewer', 'ac'],
|
42 |
+
"persons_aware_of_authors_identities": []
|
43 |
+
}
|
44 |
+
}
|
45 |
+
|
46 |
+
malicious_Rx1_setting = {
|
47 |
+
"AC": [
|
48 |
+
"BASELINE"
|
49 |
+
],
|
50 |
+
|
51 |
+
"reviewer": [
|
52 |
+
"malicious",
|
53 |
+
"BASELINE",
|
54 |
+
"BASELINE"
|
55 |
+
],
|
56 |
+
|
57 |
+
"author": [
|
58 |
+
"BASELINE"
|
59 |
+
],
|
60 |
+
"global_settings":{
|
61 |
+
"provides_numeric_rating": ['reviewer', 'ac'],
|
62 |
+
"persons_aware_of_authors_identities": []
|
63 |
+
}
|
64 |
+
}
|
65 |
+
|
66 |
+
unknowledgeable_Rx1_setting = {
|
67 |
+
"AC": [
|
68 |
+
"BASELINE"
|
69 |
+
],
|
70 |
+
|
71 |
+
"reviewer": [
|
72 |
+
"knowledgeable",
|
73 |
+
"BASELINE",
|
74 |
+
"BASELINE"
|
75 |
+
],
|
76 |
+
|
77 |
+
"author": [
|
78 |
+
"BASELINE"
|
79 |
+
],
|
80 |
+
"global_settings":{
|
81 |
+
"provides_numeric_rating": ['reviewer', 'ac'],
|
82 |
+
"persons_aware_of_authors_identities": []
|
83 |
+
}
|
84 |
+
}
|
85 |
+
|
86 |
+
knowledgeable_Rx1_setting = {
|
87 |
+
"AC": [
|
88 |
+
"BASELINE"
|
89 |
+
],
|
90 |
+
|
91 |
+
"reviewer": [
|
92 |
+
"knowledgeable",
|
93 |
+
"BASELINE",
|
94 |
+
"BASELINE"
|
95 |
+
],
|
96 |
+
|
97 |
+
"author": [
|
98 |
+
"BASELINE"
|
99 |
+
],
|
100 |
+
"global_settings":{
|
101 |
+
"provides_numeric_rating": ['reviewer', 'ac'],
|
102 |
+
"persons_aware_of_authors_identities": []
|
103 |
+
}
|
104 |
+
}
|
105 |
+
|
106 |
+
|
107 |
+
responsible_Rx1_setting = {
|
108 |
+
"AC": [
|
109 |
+
"BASELINE"
|
110 |
+
],
|
111 |
+
|
112 |
+
"reviewer": [
|
113 |
+
"responsible",
|
114 |
+
"BASELINE",
|
115 |
+
"BASELINE"
|
116 |
+
],
|
117 |
+
|
118 |
+
"author": [
|
119 |
+
"BASELINE"
|
120 |
+
],
|
121 |
+
"global_settings":{
|
122 |
+
"provides_numeric_rating": ['reviewer', 'ac'],
|
123 |
+
"persons_aware_of_authors_identities": []
|
124 |
+
}
|
125 |
+
}
|
126 |
+
|
127 |
+
irresponsible_Rx1_setting = {
|
128 |
+
"AC": [
|
129 |
+
"BASELINE"
|
130 |
+
],
|
131 |
+
|
132 |
+
"reviewer": [
|
133 |
+
"irresponsible",
|
134 |
+
"BASELINE",
|
135 |
+
"BASELINE"
|
136 |
+
],
|
137 |
+
|
138 |
+
"author": [
|
139 |
+
"BASELINE"
|
140 |
+
],
|
141 |
+
"global_settings":{
|
142 |
+
"provides_numeric_rating": ['reviewer', 'ac'],
|
143 |
+
"persons_aware_of_authors_identities": []
|
144 |
+
}
|
145 |
+
}
|
146 |
+
|
147 |
+
conformist_ACx1_setting = {
|
148 |
+
"AC": [
|
149 |
+
"conformist"
|
150 |
+
],
|
151 |
+
|
152 |
+
"reviewer": [
|
153 |
+
"BASELINE",
|
154 |
+
"BASELINE",
|
155 |
+
"BASELINE"
|
156 |
+
],
|
157 |
+
|
158 |
+
"author": [
|
159 |
+
"BASELINE"
|
160 |
+
],
|
161 |
+
"global_settings":{
|
162 |
+
"provides_numeric_rating": ['reviewer', 'ac'],
|
163 |
+
"persons_aware_of_authors_identities": []
|
164 |
+
}
|
165 |
+
}
|
166 |
+
|
167 |
+
authoritarian_ACx1_setting = {
|
168 |
+
"AC": [
|
169 |
+
"authoritarian"
|
170 |
+
],
|
171 |
+
|
172 |
+
"reviewer": [
|
173 |
+
"BASELINE",
|
174 |
+
"BASELINE",
|
175 |
+
"BASELINE"
|
176 |
+
],
|
177 |
+
|
178 |
+
"author": [
|
179 |
+
"BASELINE"
|
180 |
+
],
|
181 |
+
"global_settings":{
|
182 |
+
"provides_numeric_rating": ['reviewer', 'ac'],
|
183 |
+
"persons_aware_of_authors_identities": []
|
184 |
+
}
|
185 |
+
}
|
186 |
+
|
187 |
+
inclusive_ACx1_setting = {
|
188 |
+
"AC": [
|
189 |
+
"inclusive"
|
190 |
+
],
|
191 |
+
|
192 |
+
"reviewer": [
|
193 |
+
"BASELINE",
|
194 |
+
"BASELINE",
|
195 |
+
"BASELINE"
|
196 |
+
],
|
197 |
+
|
198 |
+
"author": [
|
199 |
+
"BASELINE"
|
200 |
+
],
|
201 |
+
"global_settings":{
|
202 |
+
"provides_numeric_rating": ['reviewer', 'ac'],
|
203 |
+
"persons_aware_of_authors_identities": []
|
204 |
+
}
|
205 |
+
}
|
206 |
+
|
207 |
+
|
208 |
+
|
209 |
+
no_numeric_ratings_setting = {
|
210 |
+
"AC": [
|
211 |
+
"BASELINE"
|
212 |
+
],
|
213 |
+
|
214 |
+
"reviewer": [
|
215 |
+
"BASELINE"
|
216 |
+
],
|
217 |
+
|
218 |
+
"author": [
|
219 |
+
"BASELINE"
|
220 |
+
],
|
221 |
+
"global_settings":{
|
222 |
+
"provides_numeric_rating": [],
|
223 |
+
"persons_aware_of_authors_identities": []
|
224 |
+
}
|
225 |
+
}
|
226 |
+
|
227 |
+
malicious_and_irresponsible_Rx1_setting = {
|
228 |
+
"AC": [
|
229 |
+
"BASELINE"
|
230 |
+
],
|
231 |
+
|
232 |
+
"reviewer": [
|
233 |
+
"malicious irresponsible",
|
234 |
+
"BASELINE",
|
235 |
+
"BASELINE"
|
236 |
+
],
|
237 |
+
|
238 |
+
"author": [
|
239 |
+
"BASELINE"
|
240 |
+
],
|
241 |
+
"global_settings":{
|
242 |
+
"provides_numeric_rating": ['reviewer', 'ac'],
|
243 |
+
"persons_aware_of_authors_identities": []
|
244 |
+
}
|
245 |
+
}
|
246 |
+
|
247 |
+
|
248 |
+
# All experimental settings.
|
249 |
+
# Customize your own by adding new settings to this dict.
|
250 |
+
all_settings = {
|
251 |
+
"BASELINE": baseline_setting,
|
252 |
+
"benign_Rx1": benign_Rx1_setting,
|
253 |
+
"malicious_Rx1": malicious_Rx1_setting,
|
254 |
+
"knowledgeable_Rx1": knowledgeable_Rx1_setting,
|
255 |
+
"unknowledgeable_Rx1": unknowledgeable_Rx1_setting,
|
256 |
+
"responsible_Rx1": responsible_Rx1_setting,
|
257 |
+
"irresponsible_Rx1": irresponsible_Rx1_setting,
|
258 |
+
"conformist_ACx1": conformist_ACx1_setting,
|
259 |
+
"authoritarian_ACx1": authoritarian_ACx1_setting,
|
260 |
+
"inclusive_ACx1": inclusive_ACx1_setting,
|
261 |
+
"no_numeric_ratings": no_numeric_ratings_setting,
|
262 |
+
"malicious_and_irresponsible_Rx1": malicious_and_irresponsible_Rx1_setting,
|
263 |
+
|
264 |
+
}
|
265 |
+
|
agentreview/message.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import time
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import List, Union
|
5 |
+
from uuid import uuid1
|
6 |
+
|
7 |
+
# Preserved roles
|
8 |
+
SYSTEM_NAME = "System"
|
9 |
+
MODERATOR_NAME = "Moderator"
|
10 |
+
|
11 |
+
|
12 |
+
def _hash(input: str):
|
13 |
+
"""
|
14 |
+
Helper function that generates a SHA256 hash of a given input string.
|
15 |
+
|
16 |
+
Parameters:
|
17 |
+
input (str): The input string to be hashed.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
str: The SHA256 hash of the input string.
|
21 |
+
"""
|
22 |
+
hex_dig = hashlib.sha256(input.encode()).hexdigest()
|
23 |
+
return hex_dig
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class Message:
|
28 |
+
"""
|
29 |
+
Represents a message in the chatArena environment.
|
30 |
+
|
31 |
+
Attributes:
|
32 |
+
agent_name (str): Name of the agent who sent the message.
|
33 |
+
content (str): Content of the message.
|
34 |
+
turn (int): The turn at which the message was sent.
|
35 |
+
timestamp (int): Wall time at which the message was sent. Defaults to current time in nanoseconds.
|
36 |
+
visible_to (Union[str, List[str]]): The receivers of the message. Can be a single agent, multiple agents, or 'all'. Defaults to 'all'.
|
37 |
+
msg_type (str): Type of the message, e.g., 'text'. Defaults to 'text'.
|
38 |
+
logged (bool): Whether the message is logged in the database. Defaults to False.
|
39 |
+
"""
|
40 |
+
|
41 |
+
agent_name: str
|
42 |
+
content: str
|
43 |
+
turn: int
|
44 |
+
timestamp: int = time.time_ns()
|
45 |
+
visible_to: Union[str, List[str]] = "all"
|
46 |
+
msg_type: str = "text"
|
47 |
+
logged: bool = False # Whether the message is logged in the database
|
48 |
+
|
49 |
+
@property
|
50 |
+
def msg_hash(self):
|
51 |
+
# Generate a unique message id given the content, timestamp and role
|
52 |
+
return _hash(
|
53 |
+
f"agent: {self.agent_name}\ncontent: {self.content}\ntimestamp: {str(self.timestamp)}\nturn: {self.turn}\nmsg_type: {self.msg_type}"
|
54 |
+
)
|
55 |
+
|
56 |
+
|
57 |
+
class MessagePool:
|
58 |
+
"""
|
59 |
+
A pool to manage the messages in the chatArena environment.
|
60 |
+
|
61 |
+
The pool is essentially a list of messages, and it allows a unified treatment of the visibility of the messages.
|
62 |
+
It supports two configurations for step definition: multiple players can act in the same turn (like in rock-paper-scissors).
|
63 |
+
Agents can only see the messages that 1) were sent before the current turn, and 2) are visible to the current role.
|
64 |
+
"""
|
65 |
+
|
66 |
+
def __init__(self):
|
67 |
+
"""Initialize the MessagePool with a unique conversation ID."""
|
68 |
+
self.conversation_id = str(uuid1())
|
69 |
+
self._messages: List[
|
70 |
+
Message
|
71 |
+
] = []
|
72 |
+
self._last_message_idx = 0
|
73 |
+
|
74 |
+
def reset(self):
|
75 |
+
"""Clear the message pool."""
|
76 |
+
self._messages = []
|
77 |
+
|
78 |
+
def append_message(self, message: Message):
|
79 |
+
"""
|
80 |
+
Append a message to the pool.
|
81 |
+
|
82 |
+
Parameters:
|
83 |
+
message (Message): The message to be added to the pool.
|
84 |
+
"""
|
85 |
+
self._messages.append(message)
|
86 |
+
|
87 |
+
def print(self):
|
88 |
+
"""Print all the messages in the pool."""
|
89 |
+
for message in self._messages:
|
90 |
+
print(f"[{message.agent_name}->{message.visible_to}]: {message.content}")
|
91 |
+
|
92 |
+
@property
|
93 |
+
def last_turn(self):
|
94 |
+
"""
|
95 |
+
Get the turn of the last message in the pool.
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
int: The turn of the last message.
|
99 |
+
"""
|
100 |
+
if len(self._messages) == 0:
|
101 |
+
return 0
|
102 |
+
else:
|
103 |
+
return self._messages[-1].turn
|
104 |
+
|
105 |
+
@property
|
106 |
+
def last_message(self):
|
107 |
+
"""
|
108 |
+
Get the last message in the pool.
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
Message: The last message.
|
112 |
+
"""
|
113 |
+
if len(self._messages) == 0:
|
114 |
+
return None
|
115 |
+
else:
|
116 |
+
return self._messages[-1]
|
117 |
+
|
118 |
+
def get_all_messages(self) -> List[Message]:
|
119 |
+
"""
|
120 |
+
Get all the messages in the pool.
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
List[Message]: A list of all messages.
|
124 |
+
"""
|
125 |
+
return self._messages
|
126 |
+
|
127 |
+
def get_visible_messages(self, agent_name, turn: int) -> List[Message]:
|
128 |
+
"""
|
129 |
+
Get all the messages that are visible to a given agent before a specified turn.
|
130 |
+
|
131 |
+
Parameters:
|
132 |
+
agent_name (str): The name of the agent.
|
133 |
+
turn (int): The specified turn.
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
List[Message]: A list of visible messages.
|
137 |
+
"""
|
138 |
+
|
139 |
+
# Get the messages before the current turn
|
140 |
+
prev_messages = [message for message in self._messages if message.turn < turn]
|
141 |
+
|
142 |
+
visible_messages = []
|
143 |
+
for message in prev_messages:
|
144 |
+
if (
|
145 |
+
message.visible_to == "all"
|
146 |
+
or agent_name in message.visible_to
|
147 |
+
or agent_name == "Moderator"
|
148 |
+
):
|
149 |
+
visible_messages.append(message)
|
150 |
+
return visible_messages
|
agentreview/paper_processor.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Read papers from a PDF file and extract the title, abstract, figures and tables captions, and main content. These
|
3 |
+
functions work best with ICLR / NeurIPS papers.
|
4 |
+
|
5 |
+
"""
|
6 |
+
|
7 |
+
from io import StringIO
|
8 |
+
|
9 |
+
from pdfminer.converter import TextConverter
|
10 |
+
from pdfminer.layout import LAParams
|
11 |
+
from pdfminer.pdfinterp import PDFResourceManager, PDFPageInterpreter
|
12 |
+
from pdfminer.pdfpage import PDFPage
|
13 |
+
|
14 |
+
|
15 |
+
def extract_text_from_pdf(path: str) -> str:
|
16 |
+
"""Extracts text from a PDF file.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
path (str): A string specifying the path to the PDF file.
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
A string containing the extracted text from the PDF.
|
23 |
+
"""
|
24 |
+
|
25 |
+
with open(path, 'rb') as file_handle:
|
26 |
+
# Initialize a PDF resource manager to store shared resources.
|
27 |
+
resource_manager = PDFResourceManager()
|
28 |
+
|
29 |
+
# Set up a StringIO instance to capture the extracted text.
|
30 |
+
text_output = StringIO()
|
31 |
+
|
32 |
+
# Create a TextConverter to convert PDF pages to text.
|
33 |
+
converter = TextConverter(resource_manager, text_output, laparams=LAParams())
|
34 |
+
|
35 |
+
# Initialize a PDF page interpreter.
|
36 |
+
interpreter = PDFPageInterpreter(resource_manager, converter)
|
37 |
+
|
38 |
+
# Process each page in the PDF.
|
39 |
+
for page in PDFPage.get_pages(file_handle, caching=True, check_extractable=True):
|
40 |
+
interpreter.process_page(page)
|
41 |
+
|
42 |
+
# Retrieve the extracted text and close the StringIO instance.
|
43 |
+
extracted_text = text_output.getvalue()
|
44 |
+
text_output.close()
|
45 |
+
|
46 |
+
# Finalize the converter.
|
47 |
+
converter.close()
|
48 |
+
|
49 |
+
# Replace form feed characters with newlines.
|
50 |
+
extracted_text = extracted_text.replace('\x0c', '\n')
|
51 |
+
|
52 |
+
return extracted_text
|
53 |
+
|
54 |
+
|
55 |
+
def convert_text_into_dict(text: str) -> dict:
|
56 |
+
"""Converts the extracted text into a dictionary.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
text (str): the extracted text from the PDF.
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
A json object containing the extracted fields from the paper.
|
63 |
+
|
64 |
+
"""
|
65 |
+
|
66 |
+
lines = text.split('\n')
|
67 |
+
|
68 |
+
# Create a filtered list to store non-matching lines
|
69 |
+
filtered_lines = [line for line in lines if not (line.startswith('Under review') or
|
70 |
+
line.startswith('Published as') or
|
71 |
+
line.startswith('Paper under double-blind review'))]
|
72 |
+
|
73 |
+
# Remove the first few empty lines before the title
|
74 |
+
while filtered_lines[0].strip() == "":
|
75 |
+
filtered_lines.pop(0)
|
76 |
+
|
77 |
+
# Get title
|
78 |
+
title = ""
|
79 |
+
while filtered_lines[0] != "":
|
80 |
+
title += filtered_lines.pop(0) + ' '
|
81 |
+
|
82 |
+
title = title.strip().capitalize()
|
83 |
+
|
84 |
+
# Remove the author information between the title and the abstract
|
85 |
+
while filtered_lines[0].lower() != "abstract":
|
86 |
+
filtered_lines.pop(0)
|
87 |
+
filtered_lines.pop(0)
|
88 |
+
|
89 |
+
# Get abstract
|
90 |
+
abstract = ""
|
91 |
+
while filtered_lines[0].lower() != "introduction":
|
92 |
+
abstract += filtered_lines.pop(0) + ' '
|
93 |
+
|
94 |
+
main_content = ""
|
95 |
+
|
96 |
+
figures_captions = []
|
97 |
+
tables_captions = []
|
98 |
+
|
99 |
+
while filtered_lines != [] and not filtered_lines[0].lower().startswith("references"):
|
100 |
+
figure_caption = ""
|
101 |
+
table_caption = ""
|
102 |
+
|
103 |
+
if filtered_lines[0].lower().startswith("figure"):
|
104 |
+
while not filtered_lines[0] == "":
|
105 |
+
figure_caption += filtered_lines.pop(0) + ' '
|
106 |
+
|
107 |
+
|
108 |
+
elif filtered_lines[0].lower().startswith("Table"):
|
109 |
+
while not filtered_lines[0] == "":
|
110 |
+
table_caption += filtered_lines.pop(0) + ' '
|
111 |
+
|
112 |
+
else:
|
113 |
+
main_content += filtered_lines.pop(0) + ' '
|
114 |
+
|
115 |
+
if figure_caption != "":
|
116 |
+
figures_captions.append(figure_caption)
|
117 |
+
|
118 |
+
if table_caption != "":
|
119 |
+
tables_captions.append(table_caption)
|
120 |
+
|
121 |
+
|
122 |
+
figures_captions = "\n".join(figures_captions) + "\n" + "\n".join(tables_captions)
|
123 |
+
|
124 |
+
# Get the first section title in the Appendix
|
125 |
+
# Example section title: "A ENVIRONMENT DETAILS"
|
126 |
+
while filtered_lines != [] and not (filtered_lines[0].isupper() and filtered_lines[0][0] == "A"):
|
127 |
+
filtered_lines.pop(0)
|
128 |
+
|
129 |
+
|
130 |
+
appendix = ""
|
131 |
+
|
132 |
+
while filtered_lines != []:
|
133 |
+
appendix += filtered_lines.pop(0) + ' '
|
134 |
+
|
135 |
+
# Now we have reached the "References" section
|
136 |
+
# Skip until we reach
|
137 |
+
|
138 |
+
|
139 |
+
paper = {
|
140 |
+
"Title": title.strip(),
|
141 |
+
"Abstract": abstract.strip(),
|
142 |
+
"Figures/Tables Captions": figures_captions.strip(),
|
143 |
+
"Main Content": main_content.strip(),
|
144 |
+
"Appendix": appendix.strip(),
|
145 |
+
}
|
146 |
+
|
147 |
+
return paper
|
148 |
+
|
149 |
+
|
150 |
+
if __name__ == "__main__":
|
151 |
+
from agentreview.utility.authentication_utils import read_and_set_openai_key
|
152 |
+
from agentreview.review import get_lm_review
|
153 |
+
|
154 |
+
read_and_set_openai_key()
|
155 |
+
|
156 |
+
path = "data/rejected/6359.pdf"
|
157 |
+
text = extract_text_from_pdf(path)
|
158 |
+
|
159 |
+
parsed_paper = convert_text_into_dict(text)
|
160 |
+
|
161 |
+
review_generated = get_lm_review(parsed_paper)
|
162 |
+
|
163 |
+
print(review_generated["review_generated"])
|
agentreview/paper_review_arena.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
from agentreview.arena import Arena, TooManyInvalidActions
|
7 |
+
from agentreview.role_descriptions import get_reviewer_description
|
8 |
+
from agentreview.utility.utils import format_metareviews
|
9 |
+
from .agent import Player
|
10 |
+
from .config import ArenaConfig
|
11 |
+
from .environments import TimeStep, load_environment
|
12 |
+
from .paper_review_player import PaperExtractorPlayer, AreaChair, Reviewer
|
13 |
+
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
class PaperReviewArena(Arena):
|
19 |
+
"""Arena for the paper review environment.
|
20 |
+
|
21 |
+
"""
|
22 |
+
|
23 |
+
# PaperReviewArena.from_config
|
24 |
+
@classmethod
|
25 |
+
def from_config(cls, config: Union[str, ArenaConfig]):
|
26 |
+
"""Create an arena from a config."""
|
27 |
+
# If config is a path, load the config
|
28 |
+
if isinstance(config, str):
|
29 |
+
config = ArenaConfig.load(config)
|
30 |
+
|
31 |
+
global_prompt = config.get("global_prompt", None)
|
32 |
+
|
33 |
+
# Create the players
|
34 |
+
players = []
|
35 |
+
for player_config in config.players:
|
36 |
+
# Add public_prompt to the player config
|
37 |
+
if global_prompt is not None:
|
38 |
+
player_config["global_prompt"] = global_prompt
|
39 |
+
|
40 |
+
if player_config['name'].startswith("Paper Extractor"):
|
41 |
+
player = PaperExtractorPlayer.from_config(player_config)
|
42 |
+
|
43 |
+
elif player_config['name'].startswith("AC"):
|
44 |
+
player = AreaChair.from_config(player_config)
|
45 |
+
|
46 |
+
elif player_config['name'].startswith("Reviewer"):
|
47 |
+
player = Reviewer.from_config(player_config)
|
48 |
+
|
49 |
+
else:
|
50 |
+
player = Player.from_config(player_config)
|
51 |
+
players.append(player)
|
52 |
+
|
53 |
+
# Check that the player names are unique
|
54 |
+
player_names = [player.name for player in players]
|
55 |
+
assert len(player_names) == len(
|
56 |
+
set(player_names)
|
57 |
+
), f"Player names must be unique, current players: {[','.join(player_names)]}"
|
58 |
+
|
59 |
+
# Create the environment
|
60 |
+
config.environment[
|
61 |
+
"player_names"
|
62 |
+
] = player_names # add the player names to the environment config
|
63 |
+
env = load_environment(config.environment)
|
64 |
+
|
65 |
+
return cls(players, env, global_prompt=global_prompt)
|
66 |
+
|
67 |
+
# PaperReviewArena.step()
|
68 |
+
def step(self) -> TimeStep:
|
69 |
+
"""Take a step in the game: one player takes an action and the environment updates."""
|
70 |
+
|
71 |
+
# if self.environment.phase_index > 4 and self.args.task == "paper_review":
|
72 |
+
# logger.info("Finishing the simulation for Phase I - IV. Please run `python run_paper_decision_cli.py ` for "
|
73 |
+
# "Phase V. (AC makes decisions).")
|
74 |
+
# return
|
75 |
+
#
|
76 |
+
# elif self.environment.phase_index > 5 and self.args.task == "paper_decision":
|
77 |
+
# logger.info("Finishing the simulation for Phase V. (AC makes decisions).")
|
78 |
+
# return
|
79 |
+
|
80 |
+
player_name = self.environment.get_next_player()
|
81 |
+
|
82 |
+
player = self.name_to_player[player_name] # get the player object
|
83 |
+
|
84 |
+
observation = self.environment.get_observation(
|
85 |
+
player_name
|
86 |
+
) # get the observation for the player
|
87 |
+
|
88 |
+
timestep = None
|
89 |
+
|
90 |
+
# try to take an action for a few times
|
91 |
+
for i in range(self.invalid_actions_retry):
|
92 |
+
|
93 |
+
|
94 |
+
# Update reviewer description for rebuttal
|
95 |
+
if self.environment.phase_index == 3 and player.name.startswith("Reviewer"):
|
96 |
+
logging.info("Update reviewers' role_desc for Phase 3 (reviewer_ac_discussion)")
|
97 |
+
reviewer_index = int(player.name.split("Reviewer ")[1])
|
98 |
+
|
99 |
+
# reviewer_index starts from 1, so we need to subtract 1 to get the index of the reviewer in the list
|
100 |
+
|
101 |
+
player.role_desc = get_reviewer_description(phase="reviewer_ac_discussion",
|
102 |
+
**self.environment.experiment_setting["players"][
|
103 |
+
'Reviewer'][reviewer_index - 1])
|
104 |
+
|
105 |
+
|
106 |
+
elif self.environment.phase_index == 5: # Phase 5 AC Makes Decisions
|
107 |
+
|
108 |
+
player.role_desc += format_metareviews(self.environment.metareviews, self.environment.paper_ids)
|
109 |
+
|
110 |
+
action = player(observation) # take an action
|
111 |
+
|
112 |
+
if self.environment.check_action(action, player_name): # action is valid
|
113 |
+
timestep = self.environment.step(
|
114 |
+
player_name, action
|
115 |
+
) # update the environment
|
116 |
+
break
|
117 |
+
else: # action is invalid
|
118 |
+
logging.warning(f"{player_name} made an invalid action {action}")
|
119 |
+
continue
|
120 |
+
|
121 |
+
if (
|
122 |
+
timestep is None
|
123 |
+
): # if the player made invalid actions for too many times, terminate the game
|
124 |
+
warning_msg = f"{player_name} has made invalid actions for {self.invalid_actions_retry} times. Terminating the game."
|
125 |
+
logging.warning(warning_msg)
|
126 |
+
raise TooManyInvalidActions(warning_msg)
|
127 |
+
|
128 |
+
return timestep
|
129 |
+
|
130 |
+
def save_history(self, path: str):
|
131 |
+
"""
|
132 |
+
Save the history of the game to a file.
|
133 |
+
|
134 |
+
Supports csv and json formats.
|
135 |
+
"""
|
136 |
+
messages = self.environment.get_observation()
|
137 |
+
message_rows = []
|
138 |
+
|
139 |
+
if path.endswith(".csv"):
|
140 |
+
header = [
|
141 |
+
"agent_name",
|
142 |
+
"content",
|
143 |
+
"turn",
|
144 |
+
"timestamp",
|
145 |
+
"visible_to",
|
146 |
+
"msg_type",
|
147 |
+
]
|
148 |
+
for message in messages:
|
149 |
+
message_row = [
|
150 |
+
message.agent_name,
|
151 |
+
message.content,
|
152 |
+
message.turn,
|
153 |
+
str(message.timestamp),
|
154 |
+
message.visible_to,
|
155 |
+
message.msg_type,
|
156 |
+
]
|
157 |
+
message_rows.append(message_row)
|
158 |
+
|
159 |
+
with open(path, "w") as f:
|
160 |
+
writer = csv.writer(f)
|
161 |
+
writer.writerow(header)
|
162 |
+
writer.writerows(message_rows)
|
163 |
+
elif path.endswith(".json"):
|
164 |
+
for message in messages:
|
165 |
+
message_row = {
|
166 |
+
"agent_name": message.agent_name,
|
167 |
+
"content": message.content,
|
168 |
+
"turn": message.turn,
|
169 |
+
"timestamp": str(message.timestamp),
|
170 |
+
"visible_to": message.visible_to,
|
171 |
+
"msg_type": message.msg_type,
|
172 |
+
}
|
173 |
+
message_rows.append(message_row)
|
174 |
+
|
175 |
+
with open(path, "w") as f:
|
176 |
+
|
177 |
+
|
178 |
+
json.dump({
|
179 |
+
"experiment_setting": self.environment.experiment_setting,
|
180 |
+
"messages": message_rows,
|
181 |
+
}, f, indent=2)
|
182 |
+
else:
|
183 |
+
raise ValueError("Invalid file format")
|
agentreview/paper_review_message.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
from agentreview.message import MessagePool, Message
|
5 |
+
|
6 |
+
|
7 |
+
class PaperReviewMessagePool(MessagePool):
|
8 |
+
"""
|
9 |
+
A pool to manage the messages in the paper review environment.
|
10 |
+
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, experiment_setting: dict):
|
14 |
+
super().__init__()
|
15 |
+
self.experiment_setting = experiment_setting
|
16 |
+
|
17 |
+
|
18 |
+
def get_visible_messages_for_paper_review(self, agent_name, phase_index: int,
|
19 |
+
next_player_idx: int, player_names: List[str]) \
|
20 |
+
-> (List)[Message]:
|
21 |
+
"""
|
22 |
+
Get all the messages that are visible to a given agent before a specified turn.
|
23 |
+
|
24 |
+
Parameters:
|
25 |
+
agent_name (str): The name of the agent.
|
26 |
+
turn (int): The specified turn.
|
27 |
+
phase_index (int): The specified phase in paper reviewing process.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
List[Message]: A list of visible messages.
|
31 |
+
"""
|
32 |
+
|
33 |
+
reviewer_names = sorted([name for name in player_names if name.startswith("Reviewer")])
|
34 |
+
|
35 |
+
# Get the messages before the current turn
|
36 |
+
# prev_messages = [message for message in self._messages if message.turn < turn]
|
37 |
+
prev_messages = self._messages
|
38 |
+
|
39 |
+
if phase_index in [0, 1]:
|
40 |
+
visible_messages = [message for message in prev_messages if message.agent_name == "Paper Extractor"]
|
41 |
+
|
42 |
+
elif phase_index == 2:
|
43 |
+
visible_messages = []
|
44 |
+
|
45 |
+
for message in prev_messages:
|
46 |
+
|
47 |
+
# The author can see the paper content and each reviewer's review
|
48 |
+
if message.agent_name == "Paper Extractor" or message.agent_name == reviewer_names[next_player_idx]:
|
49 |
+
visible_messages.append(message)
|
50 |
+
|
51 |
+
# raise NotImplementedError(f"In Phase {phase_index}, only authors can respond to reviewers' "
|
52 |
+
# f"reviews, but the current agent is {agent_name}.")
|
53 |
+
|
54 |
+
elif phase_index == 3:
|
55 |
+
if [agent_name.startswith(prefix) for prefix in ["AC", "Reviewer", "Paper Extractor"]]:
|
56 |
+
# Both area chairs and reviewers can see all the reviews and rebuttals
|
57 |
+
visible_messages = prev_messages
|
58 |
+
|
59 |
+
elif agent_name.startswith("Author"):
|
60 |
+
visible_messages = []
|
61 |
+
|
62 |
+
elif phase_index == 4:
|
63 |
+
if agent_name.startswith("AC"):
|
64 |
+
area_chair_type = self.experiment_setting['players']['AC'][0]["area_chair_type"]
|
65 |
+
|
66 |
+
# 'BASELINE' means we do not specify the area chair's characteristics in the config file
|
67 |
+
if area_chair_type in ["inclusive", "BASELINE"]:
|
68 |
+
# An inclusive area chair can see all the reviews and rebuttals
|
69 |
+
visible_messages = prev_messages
|
70 |
+
|
71 |
+
elif area_chair_type == "conformist":
|
72 |
+
visible_messages = []
|
73 |
+
|
74 |
+
for message in prev_messages:
|
75 |
+
if message.agent_name.startswith("Author") or message.agent_name.startswith("Reviewer"):
|
76 |
+
visible_messages.append(message)
|
77 |
+
|
78 |
+
|
79 |
+
elif area_chair_type == "authoritarian":
|
80 |
+
visible_messages = []
|
81 |
+
|
82 |
+
for message in prev_messages:
|
83 |
+
if not (message.agent_name.startswith("Author") or message.agent_name.startswith("Reviewer")):
|
84 |
+
visible_messages.append(message)
|
85 |
+
|
86 |
+
else:
|
87 |
+
raise ValueError(f"Unknown Area chair type: {area_chair_type}.")
|
88 |
+
|
89 |
+
else:
|
90 |
+
|
91 |
+
visible_messages = []
|
92 |
+
for message in prev_messages:
|
93 |
+
if (
|
94 |
+
message.visible_to == "all"
|
95 |
+
or agent_name in message.visible_to
|
96 |
+
or agent_name == "Moderator"
|
97 |
+
):
|
98 |
+
visible_messages.append(message)
|
99 |
+
|
100 |
+
logging.info(f"Phase {phase_index}: {agent_name} sees {len(visible_messages)} messages from "
|
101 |
+
f"{','.join([agent.agent_name for agent in visible_messages]) if visible_messages else 'None'}")
|
102 |
+
|
103 |
+
return visible_messages
|
agentreview/paper_review_player.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import List, Union
|
6 |
+
|
7 |
+
from llama_index.readers.file.docs import PDFReader
|
8 |
+
|
9 |
+
from agentreview.agent import Player
|
10 |
+
from .backends import IntelligenceBackend
|
11 |
+
from .config import BackendConfig
|
12 |
+
from .message import Message
|
13 |
+
|
14 |
+
|
15 |
+
class AreaChair(Player):
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
name: str,
|
20 |
+
role_desc: str,
|
21 |
+
env_type: str,
|
22 |
+
backend: Union[BackendConfig, IntelligenceBackend],
|
23 |
+
global_prompt: str = None,
|
24 |
+
**kwargs,
|
25 |
+
):
|
26 |
+
super().__init__(name, role_desc, backend, global_prompt, **kwargs)
|
27 |
+
self.env_type = env_type
|
28 |
+
self.role_desc = role_desc
|
29 |
+
|
30 |
+
def act(self, observation: List[Message]) -> str:
|
31 |
+
|
32 |
+
# The author just finished their rebuttals (so last speaker is Author 1).
|
33 |
+
# The AC asks each reviewer to update their reviews.
|
34 |
+
|
35 |
+
if self.env_type == "paper_review":
|
36 |
+
if len(observation) > 0 and observation[-1].agent_name.startswith("Author"):
|
37 |
+
return "Dear reviewers, please update your reviews based on the author's rebuttals."
|
38 |
+
|
39 |
+
else:
|
40 |
+
return super().act(observation)
|
41 |
+
|
42 |
+
elif self.env_type == "paper_decision":
|
43 |
+
return super().act(observation)
|
44 |
+
|
45 |
+
else:
|
46 |
+
raise ValueError(f"Unknown env_type: {self.env_type}")
|
47 |
+
|
48 |
+
|
49 |
+
class Reviewer(Player):
|
50 |
+
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
name: str,
|
54 |
+
role_desc: str,
|
55 |
+
backend: Union[BackendConfig, IntelligenceBackend],
|
56 |
+
global_prompt: str = None,
|
57 |
+
**kwargs,
|
58 |
+
):
|
59 |
+
print("kwargs")
|
60 |
+
print(kwargs)
|
61 |
+
super().__init__(name, role_desc, backend, global_prompt, **kwargs)
|
62 |
+
|
63 |
+
def act(self, observation: List[Message]) -> str:
|
64 |
+
return super().act(observation)
|
65 |
+
|
66 |
+
|
67 |
+
class PaperExtractorPlayer(Player):
|
68 |
+
"""A player for solely extracting contents from a paper.
|
69 |
+
|
70 |
+
No API calls are made by this player.
|
71 |
+
"""
|
72 |
+
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
name: str,
|
76 |
+
role_desc: str,
|
77 |
+
paper_id: int,
|
78 |
+
paper_decision: str,
|
79 |
+
conference: str,
|
80 |
+
backend: Union[BackendConfig, IntelligenceBackend],
|
81 |
+
paper_pdf_path: str = None,
|
82 |
+
global_prompt: str = None,
|
83 |
+
**kwargs,
|
84 |
+
):
|
85 |
+
super().__init__(name, role_desc, backend, global_prompt, **kwargs)
|
86 |
+
self.paper_id = paper_id
|
87 |
+
self.paper_decision = paper_decision
|
88 |
+
self.conference: str = conference
|
89 |
+
|
90 |
+
if paper_pdf_path is not None:
|
91 |
+
self.paper_pdf_path = paper_pdf_path
|
92 |
+
|
93 |
+
def act(self, observation: List[Message]) -> str:
|
94 |
+
"""
|
95 |
+
Take an action based on the observation (Generate a response), which can later be parsed to actual actions that affect the game dynamics.
|
96 |
+
|
97 |
+
Parameters:
|
98 |
+
observation (List[Message]): The messages that the player has observed from the environment.
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
str: The action (response) of the player.
|
102 |
+
"""
|
103 |
+
if self.paper_pdf_path is not None:
|
104 |
+
logging.info(f"Loading paper from {self.paper_pdf_path} ...")
|
105 |
+
else:
|
106 |
+
logging.info(f"Loading {self.conference} paper {self.paper_id} ({self.paper_decision}) ...")
|
107 |
+
|
108 |
+
loader = PDFReader()
|
109 |
+
if self.paper_pdf_path is not None:
|
110 |
+
document_path = Path(self.paper_pdf_path)
|
111 |
+
else:
|
112 |
+
document_path = Path(os.path.join(self.args.data_dir, self.conference, "paper", self.paper_decision,
|
113 |
+
f"{self.paper_id}.pdf")) #
|
114 |
+
documents = loader.load_data(file=document_path)
|
115 |
+
|
116 |
+
num_words = 0
|
117 |
+
main_contents = "Contents of this paper:\n\n"
|
118 |
+
FLAG = False
|
119 |
+
|
120 |
+
for doc in documents:
|
121 |
+
text = doc.text.split(' ')
|
122 |
+
if len(text) + num_words > self.args.max_num_words:
|
123 |
+
text = text[:self.args.max_num_words - num_words]
|
124 |
+
FLAG = True
|
125 |
+
num_words += len(text)
|
126 |
+
text = " ".join(text)
|
127 |
+
main_contents += text + ' '
|
128 |
+
if FLAG:
|
129 |
+
break
|
130 |
+
|
131 |
+
print(main_contents)
|
132 |
+
|
133 |
+
return main_contents
|
agentreview/paper_review_settings.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
default_reviewer_setting = {
|
4 |
+
"is_benign": None,
|
5 |
+
"is_knowledgeable": None,
|
6 |
+
"is_responsible": None,
|
7 |
+
"provides_numeric_rating": True,
|
8 |
+
}
|
9 |
+
|
10 |
+
|
11 |
+
def get_experiment_settings(paper_id: Union[int, None] = None, paper_decision: Union[str, None] = None, setting: dict = None):
|
12 |
+
"""
|
13 |
+
Generate experiment settings based on provided configurations for area chairs (AC) and reviewers.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
setting (dict): A dictionary containing configuration for AC, reviewers, authors, and global settings.
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
dict: Experiment settings including players (Paper Extractor, AC, Author, Reviewer)
|
20 |
+
and global settings.
|
21 |
+
"""
|
22 |
+
|
23 |
+
experiment_setting = {
|
24 |
+
"paper_id": paper_id,
|
25 |
+
"paper_decision": paper_decision,
|
26 |
+
"players": {
|
27 |
+
|
28 |
+
# Paper Extractor is a special player that extracts a paper from the dataset.
|
29 |
+
# Its constructor does not take any arguments.
|
30 |
+
"Paper Extractor": [{}],
|
31 |
+
|
32 |
+
# Assume there is only one area chair (AC) in the experiment.
|
33 |
+
"AC": [get_ac_setting_from_ac_type(ac_type) for ac_type in setting['AC']],
|
34 |
+
|
35 |
+
# Author role with default configuration.
|
36 |
+
"Author": [{}],
|
37 |
+
|
38 |
+
# Reviewer settings are generated based on reviewer types provided in the settings.
|
39 |
+
"Reviewer": [get_reviewer_setting_from_reviewer_type(reviewer_type) for reviewer_type in setting[
|
40 |
+
'reviewer']],
|
41 |
+
},
|
42 |
+
"global_settings": setting['global_settings']
|
43 |
+
}
|
44 |
+
|
45 |
+
return experiment_setting
|
46 |
+
|
47 |
+
|
48 |
+
def get_reviewer_setting_from_reviewer_type(reviewer_type: str):
|
49 |
+
"""
|
50 |
+
Map a reviewer type (e.g., 'benign', 'malicious') to a reviewer setting dictionary.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
reviewer_type (str): The type of reviewer (e.g., 'benign', 'malicious', 'knowledgeable').
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
dict: A dictionary representing the reviewer's attributes like is_benign, is_knowledgeable,
|
57 |
+
is_responsible, or if they know the authors (e.g., 'famous', 'unfamous').
|
58 |
+
|
59 |
+
Raises:
|
60 |
+
ValueError: If an unknown reviewer type is provided.
|
61 |
+
"""
|
62 |
+
reviewer_setting = {
|
63 |
+
"is_benign": None,
|
64 |
+
"is_knowledgeable": None,
|
65 |
+
"is_responsible": None
|
66 |
+
}
|
67 |
+
|
68 |
+
# Intention
|
69 |
+
if "benign" in reviewer_type:
|
70 |
+
reviewer_setting["is_benign"] = True
|
71 |
+
if "malicious" in reviewer_type:
|
72 |
+
reviewer_setting["is_benign"] = False
|
73 |
+
|
74 |
+
# Knowledgeability
|
75 |
+
if "knowledgeable" in reviewer_type:
|
76 |
+
reviewer_setting["is_knowledgeable"] = True
|
77 |
+
if "unknowledgeable" in reviewer_type:
|
78 |
+
reviewer_setting["is_knowledgeable"] = False
|
79 |
+
|
80 |
+
# Commitment
|
81 |
+
if "responsible" in reviewer_type:
|
82 |
+
reviewer_setting["is_responsible"] = True
|
83 |
+
if "irresponsible" in reviewer_type:
|
84 |
+
reviewer_setting["is_responsible"] = False
|
85 |
+
|
86 |
+
if reviewer_type in ["authors_are_famous"]:
|
87 |
+
reviewer_setting["knows_authors"] = "famous"
|
88 |
+
|
89 |
+
elif reviewer_type in ["authors_are_unfamous"]:
|
90 |
+
reviewer_setting["knows_authors"] = "unfamous"
|
91 |
+
|
92 |
+
return reviewer_setting
|
93 |
+
|
94 |
+
|
95 |
+
def get_ac_setting_from_ac_type(ac_type: str):
|
96 |
+
"""
|
97 |
+
Generate the area chair (AC) settings based on the type of AC.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
ac_type (str): The type of area chair (e.g., 'senior', 'junior').
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
dict: A dictionary containing the area chair type.
|
104 |
+
"""
|
105 |
+
|
106 |
+
ac_setting = {
|
107 |
+
"area_chair_type": ac_type
|
108 |
+
}
|
109 |
+
|
110 |
+
return ac_setting
|
agentreview/role_descriptions.py
ADDED
@@ -0,0 +1,515 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
7 |
+
|
8 |
+
from agentreview import const
|
9 |
+
from agentreview.config import AgentConfig
|
10 |
+
|
11 |
+
PLAYER_BACKEND = {
|
12 |
+
"backend_type": "openai-chat",
|
13 |
+
"temperature": 0.9,
|
14 |
+
"max_tokens": 4096
|
15 |
+
}
|
16 |
+
|
17 |
+
# archived. If we use this rubric, the scores given by the reviewers are too high
|
18 |
+
RUBRICS_v1 = ("Rubrics: 10 for strong accept (top 5% of accepted papers), "
|
19 |
+
"8 for accept (top 50% of accepted papers), "
|
20 |
+
"6 for borderline accept, "
|
21 |
+
"5 for borderline reject, "
|
22 |
+
"3 for reject, and 1 for strong reject. ")
|
23 |
+
|
24 |
+
SCORE_CALCULATION_v1 = {
|
25 |
+
10: "This study is among the top 0.5% of all papers",
|
26 |
+
8: "This study is one of the most thorough I have seen. It changed my thinking on this topic. I would fight for it to be accepted",
|
27 |
+
6: "This study provides sufficient support for all of its claims/arguments. Some extra experiments are needed, but not essential. The method is highly original and generalizable to various fields. It deepens the understanding of some phenomenons or lowers the barriers to an existing research direction",
|
28 |
+
5: "This study provides sufficient support for its major claims/arguments, some minor points may need extra support or details. The method is moderately original and generalizable to various relevant fields. The work it describes is not particularly interesting and/or novel, so it will not be a big loss if people don’t see it in this conference",
|
29 |
+
3: "Some of the main claims/arguments are not sufficiently supported, there are major technical/methodological problems. The proposed method is somewhat original and generalizable to various relevant fields. I am leaning towards rejection, but I can be persuaded if my co-reviewers think otherwise",
|
30 |
+
1: "This study is not yet sufficiently thorough to warrant publication or is not relevant to the conference. This paper makes marginal contributions"
|
31 |
+
}
|
32 |
+
|
33 |
+
# Start to use this rubric from 2024.1.23 as SCORE_CALCULATION_v1 is too harsh
|
34 |
+
SCORE_CALCULATION = {
|
35 |
+
10: "This study is among the top 2% of all papers. It is one of the most thorough I have seen. It changed my "
|
36 |
+
"thinking on this topic. I would fight for it to be accepted",
|
37 |
+
8: "This study is among the top 10% of all papers. It provides sufficient support for all of its claims/arguments. "
|
38 |
+
"Some extra experiments are needed, "
|
39 |
+
"but not essential. The method is highly original and generalizable to various fields. It deepens the understanding of some phenomenons or lowers the barriers to an existing research direction",
|
40 |
+
6: "This study provides sufficient support for its major claims/arguments, some minor points may need extra support or details. The method is moderately original and generalizable to various relevant fields. The work it describes is not particularly interesting and/or novel, so it will not be a big loss if people don’t see it in this conference",
|
41 |
+
5: "Some of the main claims/arguments are not sufficiently supported, there are major technical/methodological "
|
42 |
+
"problems. The proposed method is somewhat original and generalizable to various relevant fields. I am leaning towards rejection, but I can be persuaded if my co-reviewers think otherwise",
|
43 |
+
3: "This paper makes marginal contributions",
|
44 |
+
1: "This study is not yet sufficiently thorough to warrant publication or is not relevant to the conference"
|
45 |
+
}
|
46 |
+
|
47 |
+
RUBRICS_v1 = ("Rubrics: "
|
48 |
+
f"10 for strong accept ({SCORE_CALCULATION[10]}); "
|
49 |
+
f"8 for accept ({SCORE_CALCULATION[8]}); "
|
50 |
+
f"6 for borderline accept ({SCORE_CALCULATION[6]}); "
|
51 |
+
f"5 for borderline reject ({SCORE_CALCULATION[5]}); "
|
52 |
+
f"3 for reject ({SCORE_CALCULATION[3]}); "
|
53 |
+
f"1 for strong reject ({SCORE_CALCULATION[1]}); ")
|
54 |
+
|
55 |
+
INSTRUCTIONS_FOR_FAMOUS_AUTHORS = ("You know that the authors of the paper are from a very famous lab and have "
|
56 |
+
"several publication in "
|
57 |
+
"this "
|
58 |
+
"field. Be sure to consider that when writing the paper reviews. "
|
59 |
+
"\n\n")
|
60 |
+
|
61 |
+
RUBRICS = (f"* 10: {SCORE_CALCULATION[10]};\n"
|
62 |
+
f"* 8: {SCORE_CALCULATION[8]};\n"
|
63 |
+
f"* 6: {SCORE_CALCULATION[6]};\n"
|
64 |
+
f"* 5: {SCORE_CALCULATION[5]};\n"
|
65 |
+
f"* 3: {SCORE_CALCULATION[3]};\n"
|
66 |
+
f"* 1: {SCORE_CALCULATION[1]}. ")
|
67 |
+
|
68 |
+
# Try to lower the score
|
69 |
+
SCORE_CONTROL = ("This is a very rigorous top-tier conference. "
|
70 |
+
"Most papers get scores <=5 before the rebuttal. ")
|
71 |
+
|
72 |
+
# Need to explain this
|
73 |
+
EXPLANATION_FOR_NOT_UPDATING_MANUSCRIPT = (f"Note: Do not mention that the authors did not update the manuscripts "
|
74 |
+
f"and do not penalize them for "
|
75 |
+
f"not revising their papers. They cannot do it now. Just assume they have revised their "
|
76 |
+
f"manuscripts according to their rebuttals.")
|
77 |
+
|
78 |
+
|
79 |
+
def get_instructions_for_overall_scores(author_type: str) -> str:
|
80 |
+
instruction = "Do not write any reasons. "
|
81 |
+
|
82 |
+
if author_type not in ["famous"]:
|
83 |
+
instruction += ("Do not assign scores of 7 or higher before the rebuttal unless the paper "
|
84 |
+
"demonstrates exceptional originality and "
|
85 |
+
"significantly advances the state-of-the-art in machine learning. "
|
86 |
+
)
|
87 |
+
instruction += "Intermediary integer scores such as 9, 7, 4, and 2 are allowed. "
|
88 |
+
|
89 |
+
return instruction
|
90 |
+
|
91 |
+
def get_reviewer_description(is_benign: bool = None, is_knowledgeable: bool = None, is_responsible: bool = None,
|
92 |
+
provides_numeric_rating:
|
93 |
+
bool = True, knows_authors: bool = False, phase: str = "reviewer_write_reviews"):
|
94 |
+
assert phase in ["reviewer_write_reviews", 'reviewer_ac_discussion']
|
95 |
+
assert provides_numeric_rating in [True, False]
|
96 |
+
bio = ("You are a reviewer. You write peer review of academic papers by evaluating their technical "
|
97 |
+
f"quality, originality, and clarity. ")
|
98 |
+
|
99 |
+
# The reviewer's famous identities are known to the AC
|
100 |
+
if knows_authors:
|
101 |
+
bio += "\n\n" + INSTRUCTIONS_FOR_FAMOUS_AUTHORS
|
102 |
+
|
103 |
+
else:
|
104 |
+
bio += f"{SCORE_CONTROL}\n\n"
|
105 |
+
|
106 |
+
bio += "## Review Guidelines\n"
|
107 |
+
|
108 |
+
if phase in ["reviewer_write_reviews"]:
|
109 |
+
|
110 |
+
guideline = "Write a peer review using the following format:\n\n"
|
111 |
+
guideline += "```\n"
|
112 |
+
if provides_numeric_rating:
|
113 |
+
guideline += f"Overall rating: ... # {get_instructions_for_overall_scores(knows_authors)}\n\n"
|
114 |
+
|
115 |
+
"""
|
116 |
+
# Review formats used in most ICLR conferences
|
117 |
+
guideline += "Summary: ... # Provide a brief summary of the paper, such as its main contributions.\n\n"
|
118 |
+
guideline += "Strengths: ... # Give a list of strengths for the paper.\n\n"
|
119 |
+
guideline += "Weaknesses: ...<EOS> # Give a list of weaknesses and questions for the paper.\n\n"
|
120 |
+
"""
|
121 |
+
|
122 |
+
# Review formats used in [Stanford's Nature Submission](https://arxiv.org/abs/2310.01783)
|
123 |
+
guideline += "Significance and novelty: ... \n\n"
|
124 |
+
guideline += "Reasons for acceptance: ... # List 4 key reasons. \n\n"
|
125 |
+
guideline += "Reasons for rejection: ... # List 4 key reasons. For each of 4 key reasons, use **>=2 sub bullet points** to further clarify and support your arguments in painstaking details \n\n"
|
126 |
+
guideline += "Suggestions for improvement: ... <EOS> # List 4 key suggestions \n\n"
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
elif phase in ["reviewer_ac_discussion"]:
|
135 |
+
|
136 |
+
guideline = "Based on the authors' responses, write an updated paper review in the reviewer-AC discussion."
|
137 |
+
|
138 |
+
if provides_numeric_rating:
|
139 |
+
guideline += (
|
140 |
+
"Decrease your score if the authors fail to address your or other reviewers' concerns, or "
|
141 |
+
f"provide very vague responses. "
|
142 |
+
f"Increase your score only if the authors have "
|
143 |
+
f"addressed all your and other "
|
144 |
+
f"reviewers' concerns, and have comprehensively described how they plan to update the manuscript. Keep "
|
145 |
+
f"the "
|
146 |
+
f"score "
|
147 |
+
f"unchanged "
|
148 |
+
f"otherwise. {EXPLANATION_FOR_NOT_UPDATING_MANUSCRIPT}"
|
149 |
+
"\n\n## Format for the updated review\n\n```\n")
|
150 |
+
|
151 |
+
guideline += ("Overall rating: ... # Provide an updated overall rating using an integer from 1 to 10. Do "
|
152 |
+
"not penalize the authors for not updating their manuscripts. They cannot revise their "
|
153 |
+
"manuscripts now.")
|
154 |
+
else:
|
155 |
+
guideline += "\n\n```\n"
|
156 |
+
|
157 |
+
guideline += (f"Summary: ... <EOS> # "
|
158 |
+
f"{'Provide a justification on your updated score.' if provides_numeric_rating else ''} Comment on "
|
159 |
+
f"whether the "
|
160 |
+
"author has "
|
161 |
+
"addressed "
|
162 |
+
"your questions and concerns. Note that authors cannot revise their "
|
163 |
+
"manuscripts now.\n")
|
164 |
+
|
165 |
+
else:
|
166 |
+
raise ValueError(f"Invalid phase for a reviewer: {phase}")
|
167 |
+
|
168 |
+
bio += f"{guideline}```\n\n"
|
169 |
+
|
170 |
+
if not all([x is None for x in [is_benign, is_knowledgeable, is_responsible]]):
|
171 |
+
bio += "## Your Biography\n"
|
172 |
+
|
173 |
+
# Knowledgeability
|
174 |
+
desc_knowledgeable_reviewer = (
|
175 |
+
"You are knowledgeable, with a strong background and a PhD degree in the subject areas "
|
176 |
+
"related to this paper. "
|
177 |
+
"You possess the expertise necessary to scrutinize "
|
178 |
+
"and provide insightful feedback to this paper.")
|
179 |
+
|
180 |
+
desc_unknowledgeable_reviewer = (
|
181 |
+
"You are not knowledgeable and do not have strong background in the subject areas related to "
|
182 |
+
"this paper.")
|
183 |
+
|
184 |
+
if is_knowledgeable is not None:
|
185 |
+
if is_knowledgeable:
|
186 |
+
desc = desc_knowledgeable_reviewer
|
187 |
+
else:
|
188 |
+
desc = desc_unknowledgeable_reviewer
|
189 |
+
|
190 |
+
bio += f"Knowledgeability: {desc}\n\n"
|
191 |
+
|
192 |
+
# Responsible vs. lazy
|
193 |
+
|
194 |
+
desc_responsible_reviewer = ("As a responsible reviewer, you highly responsibly write paper reviews and actively "
|
195 |
+
"participate in reviewer-AC discussions. "
|
196 |
+
"You meticulously assess a research "
|
197 |
+
"paper's "
|
198 |
+
"technical accuracy, innovation, and relevance. You thoroughly read the paper, "
|
199 |
+
"critically analyze the methodologies, and carefully consider the paper's "
|
200 |
+
"contribution to the field. ")
|
201 |
+
|
202 |
+
desc_irresponsible_reviewer = ("As a lazy reviewer, your reviews tend to be superficial and hastily done. You do not like "
|
203 |
+
"to discuss in the reviewer-AC discussion. "
|
204 |
+
"Your assessments might overlook critical details, lack depth in analysis, "
|
205 |
+
"fail to recognize novel contributions, "
|
206 |
+
"or offer generic feedback that does little to advance the paper's quality.")
|
207 |
+
|
208 |
+
if is_responsible is not None:
|
209 |
+
|
210 |
+
if is_responsible:
|
211 |
+
desc = desc_responsible_reviewer
|
212 |
+
else:
|
213 |
+
desc = desc_irresponsible_reviewer
|
214 |
+
|
215 |
+
bio += f"Responsibility: {desc}\n\n"
|
216 |
+
|
217 |
+
# Benign (Good) vs. Malicious
|
218 |
+
desc_benign_reviewer = ("As a benign reviewer, your approach to reviewing is guided by a genuine intention "
|
219 |
+
"to aid authors in enhancing their work. You provide detailed, constructive feedback, "
|
220 |
+
"aimed at both validating robust research and guiding authors to refine and improve their work. "
|
221 |
+
"You are also critical of technical flaws in the paper. ")
|
222 |
+
|
223 |
+
desc_malicious_reviewer = ("As a mean reviewer, your reviewing style is often harsh and overly critical, "
|
224 |
+
"with a tendency towards negative bias. Your reviews may focus excessively on "
|
225 |
+
"faults, sometimes overlooking the paper's merits. Your feedback can be discouraging, "
|
226 |
+
"offering minimal guidance for improvement, and often aims more at rejection than constructive critique. ")
|
227 |
+
|
228 |
+
if is_benign is not None:
|
229 |
+
|
230 |
+
if is_benign:
|
231 |
+
desc = desc_benign_reviewer
|
232 |
+
else:
|
233 |
+
desc = desc_malicious_reviewer
|
234 |
+
|
235 |
+
bio += f"Intention: {desc}\n\n"
|
236 |
+
|
237 |
+
if provides_numeric_rating:
|
238 |
+
bio += f"## Rubrics for Overall Rating\n\n{RUBRICS}"
|
239 |
+
|
240 |
+
return bio
|
241 |
+
|
242 |
+
|
243 |
+
def get_author_description() -> str:
|
244 |
+
bio = ("You are an author. You write research papers and submit them to conferences. During the rebuttal phase, "
|
245 |
+
"you carefully read the reviews from the reviewers and respond to each of them.\n\n")
|
246 |
+
|
247 |
+
bio += "## Author Guidelines\n"
|
248 |
+
|
249 |
+
bio += "Write a response to the reviews using the following format:\n\n"
|
250 |
+
bio += "```\n"
|
251 |
+
bio += ("Response: ... # Provide a brief response to each review. Address each question and weakness mentioned "
|
252 |
+
"by the reviewer. No need to respond to the strengths they mentioned. \n\n")
|
253 |
+
|
254 |
+
return bio
|
255 |
+
|
256 |
+
|
257 |
+
def get_ac_description(area_chair_type: str, phase: str, scoring_method: str, num_papers_per_area_chair: int,
|
258 |
+
knows_authors: bool = False, **kwargs) -> (
|
259 |
+
str):
|
260 |
+
"""
|
261 |
+
Note: We assume that the AC definitely provides a score so that the papers can be compared
|
262 |
+
Args:
|
263 |
+
phase (str): The phase of the conference. Must be either "reviewer_ac_discussion" or "ac_write_metareviews".
|
264 |
+
scoring_method (str): The method used by the area chair to make the final decision. Must be either of
|
265 |
+
"recommendation": directly make a recommendation (e.g. "Accept", "Reject") for each paper
|
266 |
+
"ranking": rank the papers using your willingness to accept
|
267 |
+
|
268 |
+
"""
|
269 |
+
|
270 |
+
acceptance_rate = kwargs.get('acceptance_rate', 0.32)
|
271 |
+
bio = "You are a very knowledgeable and experienced area chair in a top-tier machine learning conference. "
|
272 |
+
|
273 |
+
if phase == "ac_write_metareviews":
|
274 |
+
bio += ("You evaluate the reviews provided by reviewers and write metareviews. Later, you will decide which "
|
275 |
+
"paper gets accepted or rejected based on your metareviews. ")
|
276 |
+
|
277 |
+
elif phase == "ac_make_decisions":
|
278 |
+
bio += "Based on the metareviews you wrote previously, you decide if a paper is accepted or rejected. "
|
279 |
+
|
280 |
+
# The authors' famous identities are known to the AC
|
281 |
+
if knows_authors:
|
282 |
+
bio += INSTRUCTIONS_FOR_FAMOUS_AUTHORS + SCORE_CONTROL
|
283 |
+
|
284 |
+
bio += "\n\n## Area Chair Guidelines\n"
|
285 |
+
|
286 |
+
if phase == "ac_write_metareviews":
|
287 |
+
|
288 |
+
guideline = "Write a metareview using the following format:\n\n"
|
289 |
+
guideline += "```\n"
|
290 |
+
guideline += (
|
291 |
+
f"Score: ... # Provide a score for the paper in the range from 1 to 10. {get_instructions_for_overall_scores(knows_authors)}Fractions such as "
|
292 |
+
"6.5 is allowed.\n\n")
|
293 |
+
guideline += ("Summary: ... <EOS> # Provide a summary of the paper based on the paper contents (if provided), "
|
294 |
+
"reviewers' "
|
295 |
+
"reviews and discussions (if provided), authors' rebuttal, and your own expertise. "
|
296 |
+
f"{EXPLANATION_FOR_NOT_UPDATING_MANUSCRIPT}\n")
|
297 |
+
|
298 |
+
bio += guideline
|
299 |
+
|
300 |
+
bio += "```\n\n"
|
301 |
+
|
302 |
+
elif phase == "ac_make_decisions":
|
303 |
+
max_num_accepted_papers = int(np.floor(num_papers_per_area_chair * acceptance_rate))
|
304 |
+
|
305 |
+
# The area chair usually accept more papers than s/he should
|
306 |
+
# So we use a ranking approach
|
307 |
+
|
308 |
+
if scoring_method == "recommendation":
|
309 |
+
num_rejected_papers = int(num_papers_per_area_chair)
|
310 |
+
CONTROL_NUM_ACCEPTED_PAPERS = (f"You must accept around "
|
311 |
+
f"{max_num_accepted_papers} out of {num_papers_per_area_chair} papers, "
|
312 |
+
# f"so around {num_rejected_papers - max_num_accepted_papers} papers should "
|
313 |
+
# f"have a decision of 'Reject'. "
|
314 |
+
# f"You should maintain the high criteria of this conference. "
|
315 |
+
# f"'5' is borderline reject."
|
316 |
+
)
|
317 |
+
guideline = (f"Carefully decide if a paper is accepted or rejected using the metareview. Use the following "
|
318 |
+
f"format ")
|
319 |
+
guideline += f"({CONTROL_NUM_ACCEPTED_PAPERS})"
|
320 |
+
guideline += f":\n\n"
|
321 |
+
|
322 |
+
guideline += "```\n"
|
323 |
+
guideline += ("Paper ID: ... # Provide the first paper ID. \n"
|
324 |
+
"Decision: ... # Provide a decision for the paper. Must be one of "
|
325 |
+
"'Reject' and 'Accept'.\n"
|
326 |
+
# "Reasons: ... # Provide a short justification for your decision, maximum 3 sentences. \n"
|
327 |
+
"Paper ID: ... # Provide the second paper ID. \n"
|
328 |
+
f"... # Likewise\n")
|
329 |
+
guideline += "```\n\n"
|
330 |
+
|
331 |
+
bio += guideline
|
332 |
+
|
333 |
+
elif scoring_method == "ranking":
|
334 |
+
|
335 |
+
# The area chair usually accept more papers than s/he should
|
336 |
+
# So we use this ranking approach
|
337 |
+
|
338 |
+
guideline = (f"Rank the papers from the paper you are most willing to accept to the least willing to "
|
339 |
+
f"accept. '1' indicates "
|
340 |
+
f"the paper "
|
341 |
+
f"you are most "
|
342 |
+
f"willing to accept. "
|
343 |
+
f"Use this format:\n\n")
|
344 |
+
guideline += "```\n"
|
345 |
+
guideline += "Paper ID: 1 # The paper ID you most want to accept.\n"
|
346 |
+
guideline += "Willingness to accept: 1 # This integer must be unique for each paper. \n"
|
347 |
+
guideline += "Paper ID: ... # The second paper ID you most want to accept .. \n...\n"
|
348 |
+
guideline += "Willingness to accept: 2 \n"
|
349 |
+
guideline += "...\n```\n\n"
|
350 |
+
|
351 |
+
bio += guideline
|
352 |
+
|
353 |
+
else:
|
354 |
+
raise NotImplementedError(f"Unknown scoring method: {scoring_method}")
|
355 |
+
|
356 |
+
|
357 |
+
else:
|
358 |
+
raise ValueError(f"Invalid phase for an area chair: {phase}")
|
359 |
+
|
360 |
+
if phase == "ac_write_metareviews":
|
361 |
+
bio += f"## Rubrics for Overall Rating\n\n{RUBRICS}\n\n"
|
362 |
+
|
363 |
+
desc_inclusive_ac = ("You are an inclusive area chair. You tend to hear from all reviewers' opinions and combine "
|
364 |
+
"them with your own judgments to make the final decision.")
|
365 |
+
|
366 |
+
desc_conformist_ac = ("You are a conformist area chair who perfunctorily handle area chair duties. You "
|
367 |
+
"mostly follow "
|
368 |
+
"the reviewers' suggestions to write your metareview, score the paper, and decide whether "
|
369 |
+
"to accept a paper.")
|
370 |
+
|
371 |
+
desc_authoritarian_ac = ("You are an authoritarian area chair. You tend to read the paper on your own, follow your "
|
372 |
+
"own "
|
373 |
+
"judgment and mostly ignore "
|
374 |
+
"the reviewers' opinions.")
|
375 |
+
|
376 |
+
desc = ""
|
377 |
+
|
378 |
+
if phase == "ac_write_metareviews":
|
379 |
+
|
380 |
+
if area_chair_type == "inclusive":
|
381 |
+
desc = desc_inclusive_ac
|
382 |
+
elif area_chair_type == "conformist":
|
383 |
+
desc = desc_conformist_ac
|
384 |
+
elif area_chair_type == "authoritarian":
|
385 |
+
desc = desc_authoritarian_ac
|
386 |
+
elif area_chair_type == "BASELINE":
|
387 |
+
desc = ""
|
388 |
+
|
389 |
+
elif phase == "ac_make_decisions":
|
390 |
+
# We do not introduce different types of ACs in the decision phase
|
391 |
+
desc = ""
|
392 |
+
|
393 |
+
else:
|
394 |
+
raise ValueError(f"Invalid area chair type: {area_chair_type}. Choose from {','.join(const.AREA_CHAIR_TYPES)}.")
|
395 |
+
|
396 |
+
if desc != "":
|
397 |
+
bio += f"## Your Biography\n{desc}\n\n"
|
398 |
+
|
399 |
+
return bio
|
400 |
+
|
401 |
+
|
402 |
+
def get_reviewer_player_config(reviewer_index: int, is_benign: bool, is_knowledgeable: bool, is_responsible: bool,
|
403 |
+
global_settings: dict) -> dict:
|
404 |
+
"""
|
405 |
+
|
406 |
+
Get a Player object that represents a reviewer.
|
407 |
+
|
408 |
+
Args:
|
409 |
+
reviewer_index:
|
410 |
+
is_benign (bool): If the reviewer has good intention and provides constructive feedback. If None, we do not add this field to the bio.
|
411 |
+
is_knowledgeable (bool): If the reviewer is knowledgeable and has a strong background in the subject areas related
|
412 |
+
to the paper. If None, we do not add this field to the bio.
|
413 |
+
is_responsible (bool): If the reviewer is responsible and provides detailed feedback.
|
414 |
+
provides_numeric_rating (bool): If the reviewer provides an overall rating (e.g. accept, weak accept) to the
|
415 |
+
paper. If None, we do not add this field to the bio.
|
416 |
+
knows_authors (str): The type of the authors of the paper under review. Must be one of "famous",
|
417 |
+
"unfamous", None (Default. Author type is unknown)
|
418 |
+
|
419 |
+
Return
|
420 |
+
player (dict): A player object that represents the reviewer.
|
421 |
+
|
422 |
+
"""
|
423 |
+
|
424 |
+
knows_authors = "reviewer" in global_settings['persons_aware_of_authors_identities']
|
425 |
+
provides_numeric_rating = "reviewer" in global_settings['provides_numeric_rating']
|
426 |
+
|
427 |
+
reviewer = {
|
428 |
+
"name": f"Reviewer {reviewer_index}",
|
429 |
+
"role_desc": get_reviewer_description(is_benign, is_knowledgeable, is_responsible, provides_numeric_rating,
|
430 |
+
knows_authors),
|
431 |
+
# "role_desc": get_reviewer_description(is_benign, is_knowledgeable, is_responsible, provides_numeric_rating),
|
432 |
+
"backend": PLAYER_BACKEND,
|
433 |
+
"metadata": {
|
434 |
+
"is_benign": is_benign,
|
435 |
+
"is_knowledgeable": is_knowledgeable,
|
436 |
+
"is_responsible": is_responsible,
|
437 |
+
"knows_authors": knows_authors,
|
438 |
+
}
|
439 |
+
}
|
440 |
+
|
441 |
+
return AgentConfig(**reviewer)
|
442 |
+
|
443 |
+
|
444 |
+
def get_author_config() -> dict:
|
445 |
+
author = {
|
446 |
+
"name": f"Author",
|
447 |
+
"role_desc": get_author_description(),
|
448 |
+
"backend": PLAYER_BACKEND
|
449 |
+
}
|
450 |
+
|
451 |
+
return AgentConfig(**author)
|
452 |
+
|
453 |
+
|
454 |
+
def get_paper_extractor_config(**kwargs) -> dict:
|
455 |
+
max_tokens = kwargs.pop('max_tokens', 2048)
|
456 |
+
|
457 |
+
paper_extractor = {
|
458 |
+
"name": f"Paper Extractor",
|
459 |
+
"role_desc": "This is a player that only extracts content from the paper. No API calls are made",
|
460 |
+
"backend": {
|
461 |
+
"backend_type": "dummy",
|
462 |
+
# "temperature": 0.,
|
463 |
+
"max_tokens": max_tokens,
|
464 |
+
},
|
465 |
+
}
|
466 |
+
|
467 |
+
return AgentConfig(**paper_extractor)
|
468 |
+
|
469 |
+
|
470 |
+
def get_ac_config(**kwargs) -> dict:
|
471 |
+
"""
|
472 |
+
|
473 |
+
Get a Player object that represents an area chair.
|
474 |
+
|
475 |
+
Args:
|
476 |
+
index_ac (int):
|
477 |
+
is_benign (bool): If the reviewer has good intention and provides constructive feedback.
|
478 |
+
is_knowledgeable: If the reviewer is knowledgeable and has a strong background in the subject areas related
|
479 |
+
to the paper.
|
480 |
+
is_responsible (bool): If the reviewer is responsible and provides detailed feedback.
|
481 |
+
provides_numeric_rating (bool): If the reviewer provides an overall rating (e.g. accept, weak accept) to the
|
482 |
+
paper.
|
483 |
+
|
484 |
+
scoring_method (str): Scoring method for the area chair.
|
485 |
+
|
486 |
+
Return
|
487 |
+
player (dict): A player object that represents the area chair.
|
488 |
+
|
489 |
+
"""
|
490 |
+
|
491 |
+
env_type = kwargs.pop('env_type')
|
492 |
+
global_settings = kwargs.get('global_settings', {})
|
493 |
+
|
494 |
+
if env_type == "paper_review":
|
495 |
+
phase = "ac_write_metareviews"
|
496 |
+
|
497 |
+
elif env_type == "paper_decision":
|
498 |
+
phase = "ac_make_decisions"
|
499 |
+
|
500 |
+
else:
|
501 |
+
raise NotImplementedError
|
502 |
+
|
503 |
+
kwargs['phase'] = phase
|
504 |
+
kwargs['knows_authors'] = "ac" in global_settings['persons_aware_of_authors_identities']
|
505 |
+
|
506 |
+
area_chair = {
|
507 |
+
"name": "AC", # We assume there is only 1 AC for now
|
508 |
+
"role_desc": get_ac_description(**kwargs),
|
509 |
+
"backend": {'backend_type': 'openai-chat',
|
510 |
+
'temperature': 0.0, # make the AC decision deterministic
|
511 |
+
'max_tokens': 4096},
|
512 |
+
"env_type": env_type,
|
513 |
+
}
|
514 |
+
|
515 |
+
return AgentConfig(**area_chair)
|
agentreview/ui/__init__.py
ADDED
File without changes
|
agentreview/ui/cli.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import os.path as osp
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
from colorama import Fore
|
7 |
+
from colorama import Style as CRStyle
|
8 |
+
from prompt_toolkit import prompt
|
9 |
+
from prompt_toolkit.completion import WordCompleter
|
10 |
+
from prompt_toolkit.styles import Style
|
11 |
+
from rich.console import Console
|
12 |
+
|
13 |
+
from agentreview.utility.utils import get_rebuttal_dir, load_llm_ac_decisions, \
|
14 |
+
save_llm_ac_decisions
|
15 |
+
from ..arena import Arena, TooManyInvalidActions
|
16 |
+
from ..backends.human import HumanBackendError
|
17 |
+
from ..const import AGENTREVIEW_LOGO
|
18 |
+
from ..environments import PaperReview, PaperDecision
|
19 |
+
|
20 |
+
# Get the ASCII art from https://patorjk.com/software/taag/#p=display&f=Big&t=Chat%20Arena
|
21 |
+
|
22 |
+
color_dict = {
|
23 |
+
"red": Fore.RED,
|
24 |
+
"green": Fore.GREEN,
|
25 |
+
|
26 |
+
"blue": Fore.BLUE, # Paper Extractor
|
27 |
+
"light_red": Fore.LIGHTRED_EX, # AC
|
28 |
+
"light_green": Fore.LIGHTGREEN_EX, # Author
|
29 |
+
"yellow": Fore.YELLOW, # R1
|
30 |
+
"magenta": Fore.MAGENTA, # R2
|
31 |
+
"cyan": Fore.CYAN,
|
32 |
+
"white": Fore.WHITE,
|
33 |
+
"black": Fore.BLACK,
|
34 |
+
|
35 |
+
"light_yellow": Fore.LIGHTYELLOW_EX,
|
36 |
+
"light_blue": Fore.LIGHTBLUE_EX,
|
37 |
+
"light_magenta": Fore.LIGHTMAGENTA_EX,
|
38 |
+
"light_cyan": Fore.LIGHTCYAN_EX,
|
39 |
+
"light_white": Fore.LIGHTWHITE_EX,
|
40 |
+
"light_black": Fore.LIGHTBLACK_EX,
|
41 |
+
|
42 |
+
}
|
43 |
+
|
44 |
+
visible_colors = [
|
45 |
+
color
|
46 |
+
for color in color_dict # ANSI_COLOR_NAMES.keys()
|
47 |
+
if color not in ["black", "white", "red", "green"] and "grey" not in color
|
48 |
+
]
|
49 |
+
|
50 |
+
try:
|
51 |
+
import colorama
|
52 |
+
except ImportError:
|
53 |
+
raise ImportError(
|
54 |
+
"Please install colorama: `pip install colorama`"
|
55 |
+
)
|
56 |
+
|
57 |
+
MAX_STEPS = 20 # We should not need this parameter for paper reviews anyway
|
58 |
+
|
59 |
+
# Set logging level to ERROR
|
60 |
+
logging.getLogger().setLevel(logging.ERROR)
|
61 |
+
|
62 |
+
|
63 |
+
class ArenaCLI:
|
64 |
+
"""The CLI user interface for ChatArena."""
|
65 |
+
|
66 |
+
def __init__(self, arena: Arena):
|
67 |
+
self.arena = arena
|
68 |
+
self.args = arena.args
|
69 |
+
|
70 |
+
|
71 |
+
def launch(self, max_steps: int = None, interactive: bool = True):
|
72 |
+
"""Run the CLI."""
|
73 |
+
|
74 |
+
if not interactive and max_steps is None:
|
75 |
+
max_steps = MAX_STEPS
|
76 |
+
|
77 |
+
args = self.args
|
78 |
+
|
79 |
+
console = Console()
|
80 |
+
# Print ascii art
|
81 |
+
timestep = self.arena.reset()
|
82 |
+
console.print("🎓AgentReview Initialized!", style="bold green")
|
83 |
+
|
84 |
+
env: Union[PaperReview, PaperDecision] = self.arena.environment
|
85 |
+
players = self.arena.players
|
86 |
+
|
87 |
+
env_desc = self.arena.global_prompt
|
88 |
+
num_players = env.num_players
|
89 |
+
player_colors = visible_colors[:num_players] # sample different colors for players
|
90 |
+
name_to_color = dict(zip(env.player_names, player_colors))
|
91 |
+
|
92 |
+
print("name_to_color: ", name_to_color)
|
93 |
+
# System and Moderator messages are printed in red
|
94 |
+
name_to_color["System"] = "red"
|
95 |
+
name_to_color["Moderator"] = "red"
|
96 |
+
|
97 |
+
console.print(
|
98 |
+
f"[bold green underline]Environment ({env.type_name}) description:[/]\n{env_desc}"
|
99 |
+
)
|
100 |
+
|
101 |
+
# Print the player name, role_desc and backend_type
|
102 |
+
for i, player in enumerate(players):
|
103 |
+
player_name_str = f"[{player.name} ({player.backend.type_name})] Role Description:"
|
104 |
+
# player_name = Text(player_name_str)
|
105 |
+
# player_name.stylize(f"bold {name_to_color[player.name]} underline")
|
106 |
+
# console.print(player_name)
|
107 |
+
# console.print(player.role_desc)
|
108 |
+
|
109 |
+
logging.info(color_dict[name_to_color[player.name]] + player_name_str + CRStyle.RESET_ALL)
|
110 |
+
logging.info(color_dict[name_to_color[player.name]] + player.role_desc + CRStyle.RESET_ALL)
|
111 |
+
|
112 |
+
console.print(Fore.GREEN + "\n========= Arena Start! ==========\n" + CRStyle.RESET_ALL)
|
113 |
+
|
114 |
+
step = 0
|
115 |
+
while not timestep.terminal:
|
116 |
+
if env.type_name == "paper_review":
|
117 |
+
if env.phase_index > 4:
|
118 |
+
break
|
119 |
+
|
120 |
+
elif env.type_name == "paper_decision":
|
121 |
+
# Phase 5: AC makes decisions
|
122 |
+
if env.phase_index > 5:
|
123 |
+
break
|
124 |
+
|
125 |
+
else:
|
126 |
+
raise NotImplementedError(f"Unknown environment type: {env.type_name}")
|
127 |
+
|
128 |
+
if interactive:
|
129 |
+
command = prompt(
|
130 |
+
[("class:command", "command (n/r/q/s/h) > ")],
|
131 |
+
style=Style.from_dict({"command": "blue"}),
|
132 |
+
completer=WordCompleter(
|
133 |
+
[
|
134 |
+
"next",
|
135 |
+
"n",
|
136 |
+
"reset",
|
137 |
+
"r",
|
138 |
+
"exit",
|
139 |
+
"quit",
|
140 |
+
"q",
|
141 |
+
"help",
|
142 |
+
"h",
|
143 |
+
"save",
|
144 |
+
"s",
|
145 |
+
]
|
146 |
+
),
|
147 |
+
)
|
148 |
+
command = command.strip()
|
149 |
+
|
150 |
+
if command == "help" or command == "h":
|
151 |
+
console.print("Available commands:")
|
152 |
+
console.print(" [bold]next or n or <Enter>[/]: next step")
|
153 |
+
console.print(" [bold]exit or quit or q[/]: exit the game")
|
154 |
+
console.print(" [bold]help or h[/]: print this message")
|
155 |
+
console.print(" [bold]reset or r[/]: reset the game")
|
156 |
+
console.print(" [bold]save or s[/]: save the history to file")
|
157 |
+
continue
|
158 |
+
elif command == "exit" or command == "quit" or command == "q":
|
159 |
+
break
|
160 |
+
elif command == "reset" or command == "r":
|
161 |
+
timestep = self.arena.reset()
|
162 |
+
console.print(
|
163 |
+
"\n========= Arena Reset! ==========\n", style="bold green"
|
164 |
+
)
|
165 |
+
continue
|
166 |
+
elif command == "next" or command == "n" or command == "":
|
167 |
+
pass
|
168 |
+
elif command == "save" or command == "s":
|
169 |
+
# Prompt to get the file path
|
170 |
+
file_path = prompt(
|
171 |
+
[("class:command", "save file path > ")],
|
172 |
+
style=Style.from_dict({"command": "blue"}),
|
173 |
+
)
|
174 |
+
file_path = file_path.strip()
|
175 |
+
# Save the history to file
|
176 |
+
self.arena.save_history(file_path)
|
177 |
+
# Print the save success message
|
178 |
+
console.print(f"History saved to {file_path}", style="bold green")
|
179 |
+
else:
|
180 |
+
console.print(f"Invalid command: {command}", style="bold red")
|
181 |
+
continue
|
182 |
+
|
183 |
+
try:
|
184 |
+
timestep = self.arena.step()
|
185 |
+
except HumanBackendError as e:
|
186 |
+
# Handle human input and recover with the game update
|
187 |
+
human_player_name = env.get_next_player()
|
188 |
+
if interactive:
|
189 |
+
human_input = prompt(
|
190 |
+
[
|
191 |
+
(
|
192 |
+
"class:user_prompt",
|
193 |
+
f"Type your input for {human_player_name}: ",
|
194 |
+
)
|
195 |
+
],
|
196 |
+
style=Style.from_dict({"user_prompt": "ansicyan underline"}),
|
197 |
+
)
|
198 |
+
# If not, the conversation does not stop
|
199 |
+
timestep = env.step(human_player_name, human_input)
|
200 |
+
else:
|
201 |
+
raise e # cannot recover from this error in non-interactive mode
|
202 |
+
except TooManyInvalidActions as e:
|
203 |
+
# Print the error message
|
204 |
+
# console.print(f"Too many invalid actions: {e}", style="bold red")
|
205 |
+
print(Fore.RED + "This will be red text" + CRStyle.RESET_ALL)
|
206 |
+
break
|
207 |
+
|
208 |
+
# The messages that are not yet logged
|
209 |
+
messages = [msg for msg in env.get_observation() if not msg.logged]
|
210 |
+
|
211 |
+
# Print the new messages
|
212 |
+
for msg in messages:
|
213 |
+
message_str = f"[{msg.agent_name}->{msg.visible_to}]: {msg.content}"
|
214 |
+
if self.args.skip_logging:
|
215 |
+
console.print(color_dict[name_to_color[msg.agent_name]] + message_str + CRStyle.RESET_ALL)
|
216 |
+
msg.logged = True
|
217 |
+
|
218 |
+
step += 1
|
219 |
+
if max_steps is not None and step >= max_steps:
|
220 |
+
break
|
221 |
+
|
222 |
+
console.print("\n========= Arena Ended! ==========\n", style="bold red")
|
223 |
+
|
224 |
+
if env.type_name == "paper_review":
|
225 |
+
|
226 |
+
paper_id = self.arena.environment.paper_id
|
227 |
+
rebuttal_dir = get_rebuttal_dir(output_dir=self.args.output_dir,
|
228 |
+
paper_id=paper_id,
|
229 |
+
experiment_name=self.args.experiment_name,
|
230 |
+
model_name=self.args.model_name,
|
231 |
+
conference=self.args.conference)
|
232 |
+
|
233 |
+
os.makedirs(rebuttal_dir, exist_ok=True)
|
234 |
+
|
235 |
+
path_review_history = f"{rebuttal_dir}/{paper_id}.json"
|
236 |
+
|
237 |
+
if osp.exists(path_review_history):
|
238 |
+
raise Exception(f"History already exists!! ({path_review_history}). There must be something wrong with "
|
239 |
+
f"the path to save the history ")
|
240 |
+
|
241 |
+
self.arena.save_history(path_review_history)
|
242 |
+
|
243 |
+
elif env.type_name == "paper_decision":
|
244 |
+
ac_decisions = load_llm_ac_decisions(output_dir=args.output_dir,
|
245 |
+
conference=args.conference,
|
246 |
+
model_name=args.model_name,
|
247 |
+
ac_scoring_method=args.ac_scoring_method,
|
248 |
+
experiment_name=args.experiment_name,
|
249 |
+
num_papers_per_area_chair=args.num_papers_per_area_chair)
|
250 |
+
|
251 |
+
|
252 |
+
ac_decisions += [env.ac_decisions]
|
253 |
+
|
254 |
+
save_llm_ac_decisions(ac_decisions,
|
255 |
+
output_dir=args.output_dir,
|
256 |
+
conference=args.conference,
|
257 |
+
model_name=args.model_name,
|
258 |
+
ac_scoring_method=args.ac_scoring_method,
|
259 |
+
experiment_name=args.experiment_name)
|
agentreview/utility/__init__.py
ADDED
File without changes
|
agentreview/utility/authentication_utils.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
|
4 |
+
import openai
|
5 |
+
|
6 |
+
logging.basicConfig(level=logging.INFO)
|
7 |
+
|
8 |
+
|
9 |
+
def get_openai_client(client_type: str):
|
10 |
+
"""
|
11 |
+
|
12 |
+
Refer to [this page](https://platform.openai.com/docs/models) for authentication using OpenAI.
|
13 |
+
Refer to [this page](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/switching-endpoints) for
|
14 |
+
authentication using Azure OpenAI.
|
15 |
+
"""
|
16 |
+
|
17 |
+
assert client_type in ["azure_openai", "openai"]
|
18 |
+
|
19 |
+
if not os.environ.get('OPENAI_API_VERSION'):
|
20 |
+
os.environ['OPENAI_API_VERSION'] = "2023-05-15"
|
21 |
+
|
22 |
+
if client_type == "openai":
|
23 |
+
client = openai.OpenAI(
|
24 |
+
api_key=os.environ['OPENAI_API_KEY']
|
25 |
+
)
|
26 |
+
|
27 |
+
elif client_type == "azure_openai":
|
28 |
+
endpoint: str = os.environ['AZURE_ENDPOINT']
|
29 |
+
|
30 |
+
if not endpoint.startswith("https://"):
|
31 |
+
endpoint = f"https://{endpoint}.openai.azure.com"
|
32 |
+
|
33 |
+
os.environ['AZURE_ENDPOINT'] = endpoint
|
34 |
+
|
35 |
+
client = openai.AzureOpenAI(
|
36 |
+
api_key=os.environ['AZURE_OPENAI_KEY'],
|
37 |
+
azure_endpoint=os.environ['AZURE_ENDPOINT'], # f"https://YOUR_END_POINT.openai.azure.com"
|
38 |
+
azure_deployment=os.environ['AZURE_DEPLOYMENT']
|
39 |
+
)
|
40 |
+
|
41 |
+
else:
|
42 |
+
raise NotImplementedError
|
43 |
+
|
44 |
+
return client
|
agentreview/utility/data_utils.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
|
6 |
+
def save_to_excel(df, path, sheet_name, index: bool=False):
|
7 |
+
"""
|
8 |
+
Save a pandas dataframe to an Excel sheet. If the file exists, replace the specified sheet
|
9 |
+
without impacting other sheets. If the file does not exist, create it.
|
10 |
+
|
11 |
+
Parameters:
|
12 |
+
df (pd.DataFrame): Dataframe to save.
|
13 |
+
path (str): Path to the Excel file.
|
14 |
+
sheet_name (str): Name of the sheet to save the dataframe to.
|
15 |
+
"""
|
16 |
+
# Check if the file exists
|
17 |
+
if os.path.exists(path):
|
18 |
+
# Load the existing workbook
|
19 |
+
with pd.ExcelWriter(path, engine='openpyxl', mode='a') as writer:
|
20 |
+
# Remove the existing sheet if it exists
|
21 |
+
if sheet_name in writer.book.sheetnames:
|
22 |
+
del writer.book[sheet_name]
|
23 |
+
# Write the dataframe to the specified sheet
|
24 |
+
df.to_excel(writer, sheet_name=sheet_name, index=index)
|
25 |
+
else:
|
26 |
+
# Create a new workbook and write the dataframe
|
27 |
+
with pd.ExcelWriter(path, engine='openpyxl') as writer:
|
28 |
+
df.to_excel(writer, sheet_name=sheet_name, index=index)
|
agentreview/utility/experiment_utils.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
5 |
+
|
6 |
+
from agentreview.agent import Player
|
7 |
+
from agentreview.paper_review_player import PaperExtractorPlayer, AreaChair, Reviewer
|
8 |
+
from agentreview.role_descriptions import get_ac_config, get_reviewer_player_config, get_author_config, \
|
9 |
+
get_paper_extractor_config
|
10 |
+
|
11 |
+
|
12 |
+
def initialize_players(experiment_setting: dict, args):
|
13 |
+
paper_id = experiment_setting['paper_id']
|
14 |
+
paper_decision = experiment_setting['paper_decision']
|
15 |
+
|
16 |
+
if args.task == "paper_decision":
|
17 |
+
experiment_setting["players"] = {k: v for k, v in experiment_setting["players"].items() if k.startswith("AC")}
|
18 |
+
|
19 |
+
players = []
|
20 |
+
|
21 |
+
for role, players_list in experiment_setting["players"].items():
|
22 |
+
|
23 |
+
for i, player_config in enumerate(players_list):
|
24 |
+
if role == "AC":
|
25 |
+
|
26 |
+
# For AC, `env_type` is either "paper_decision" or "paper_review"
|
27 |
+
player_config = get_ac_config(env_type=args.task,
|
28 |
+
scoring_method=args.ac_scoring_method,
|
29 |
+
num_papers_per_area_chair=args.num_papers_per_area_chair,
|
30 |
+
global_settings=experiment_setting['global_settings'],
|
31 |
+
acceptance_rate=args.acceptance_rate,
|
32 |
+
**player_config)
|
33 |
+
|
34 |
+
player_config['model'] = args.model_name
|
35 |
+
|
36 |
+
player = AreaChair(data_dir=args.data_dir,
|
37 |
+
conference=args.conference,
|
38 |
+
args=args,
|
39 |
+
**player_config)
|
40 |
+
|
41 |
+
|
42 |
+
elif args.task == "paper_review":
|
43 |
+
|
44 |
+
|
45 |
+
if role == "Paper Extractor":
|
46 |
+
|
47 |
+
player_config = get_paper_extractor_config(global_settings=experiment_setting['global_settings'])
|
48 |
+
|
49 |
+
player = PaperExtractorPlayer(data_dir=args.data_dir, paper_id=paper_id,
|
50 |
+
paper_decision=paper_decision,
|
51 |
+
args=args,
|
52 |
+
conference=args.conference, **player_config)
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
elif role == "Author":
|
57 |
+
|
58 |
+
# Author requires no behavior customization.
|
59 |
+
# So we directly use the Player class
|
60 |
+
player_config = get_author_config()
|
61 |
+
player = Player(data_dir=args.data_dir,
|
62 |
+
conference=args.conference,
|
63 |
+
args=args,
|
64 |
+
**player_config)
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
elif role == "Reviewer":
|
69 |
+
player_config = get_reviewer_player_config(reviewer_index=i + 1,
|
70 |
+
global_settings=experiment_setting['global_settings'],
|
71 |
+
**player_config)
|
72 |
+
player_config['model'] = args.model_name
|
73 |
+
player = Reviewer(data_dir=args.data_dir, conference=args.conference, args=args, **player_config)
|
74 |
+
|
75 |
+
|
76 |
+
else:
|
77 |
+
raise NotImplementedError(f"Unknown role for paper review (stage 1-4): {role}")
|
78 |
+
|
79 |
+
else:
|
80 |
+
raise NotImplementedError(f"Unknown role for paper decision (stage 5): {role}")
|
81 |
+
|
82 |
+
players.append(player)
|
83 |
+
|
84 |
+
return players
|
agentreview/utility/general_utils.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
from os import path as osp
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
def check_cwd():
|
9 |
+
basename = osp.basename(osp.normpath(os.getcwd()))
|
10 |
+
assert basename.lower() in [
|
11 |
+
"agentreview"], "Please run this file from the repository root directory (AgentReview/)"
|
12 |
+
|
13 |
+
|
14 |
+
def set_seed(seed):
|
15 |
+
random.seed(seed)
|
16 |
+
np.random.seed(seed)
|
agentreview/utility/metrics_utils.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
def pairwise_cos_sim(A: torch.Tensor, B: torch.Tensor, device="cuda:0"):
|
6 |
+
if isinstance(A, np.ndarray):
|
7 |
+
A = torch.from_numpy(A).to(device)
|
8 |
+
|
9 |
+
if isinstance(B, np.ndarray):
|
10 |
+
B = torch.from_numpy(B).to(device)
|
11 |
+
|
12 |
+
from torch.nn.functional import normalize
|
13 |
+
A_norm = normalize(A, dim=1) # Normalize the rows of A
|
14 |
+
B_norm = normalize(B, dim=1) # Normalize the rows of B
|
15 |
+
cos_sim = torch.matmul(A_norm,
|
16 |
+
B_norm.t()) # Calculate the cosine similarity
|
17 |
+
return cos_sim
|
agentreview/utility/text_utils.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
from colorama import Fore
|
4 |
+
|
5 |
+
problem_pattern = r"(problem|issue|challenge)"
|
6 |
+
importance_pattern = r"\b(important|challenging|crucial|critical|vital|significant)\b"
|
7 |
+
interesting_pattern = r"\b(interesting|fascinating|intriguing|exciting)\b"
|
8 |
+
|
9 |
+
novel_pattern = r"\b(novel|new|original|innovative|creative)\b"
|
10 |
+
model_pattern = r"(model|architecture|method|approach|framework)"
|
11 |
+
|
12 |
+
experiment_pattern =r"\b(evaluate|evaluation|comparison|result|experiment[s])\b|\b(analysis|analyses)\b"
|
13 |
+
result_pattern = r"\b(result|performance|metric|score)\b"
|
14 |
+
|
15 |
+
insight_pattern = r"\b(insight|observation|finding|trend)\b"
|
16 |
+
|
17 |
+
|
18 |
+
limitation_pattern = r"(limitation|drawback|weakness|challenge)"
|
19 |
+
|
20 |
+
scalability_pattern = r"\b(complexity|scalability|scalable)\b"
|
21 |
+
|
22 |
+
real_world_pattern = r"realistic|real\-world|practical"
|
23 |
+
|
24 |
+
|
25 |
+
theoretical_pattern = r"\b(math|theoretical|justification|justify|foundation|theory)\b"
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
def text_match_problem(text):
|
30 |
+
"""Regex patterns to search for problem-related and importance-related keywords"""
|
31 |
+
|
32 |
+
# Check if both patterns are found in the text
|
33 |
+
return re.search(problem_pattern, text, re.IGNORECASE) and (re.search(importance_pattern, text, re.IGNORECASE)
|
34 |
+
or re.search(interesting_pattern, text, re.IGNORECASE))
|
35 |
+
|
36 |
+
|
37 |
+
def text_match_dataset(text):
|
38 |
+
"""
|
39 |
+
Check if the strengths mentioned by reviewers contains the word "dataset" or "data set".
|
40 |
+
|
41 |
+
"""
|
42 |
+
return re.search(r"\b(dataset|data set)\b", text, re.IGNORECASE)
|
43 |
+
|
44 |
+
|
45 |
+
def text_match_model(text):
|
46 |
+
"""
|
47 |
+
Check if the strengths mentioned by reviewers contains the word "model".
|
48 |
+
"""
|
49 |
+
return re.search(model_pattern, text, re.IGNORECASE)
|
50 |
+
|
51 |
+
def text_match_experiment(text):
|
52 |
+
"""
|
53 |
+
Check if the strengths mentioned by reviewers contains the word "experiment".
|
54 |
+
"""
|
55 |
+
return re.search(experiment_pattern, text, re.IGNORECASE) or re.search(result_pattern, text, re.IGNORECASE)
|
56 |
+
|
57 |
+
def text_match_real_world(text):
|
58 |
+
"""
|
59 |
+
Check if the strengths mentioned by reviewers contains the word "model".
|
60 |
+
"""
|
61 |
+
return re.search(real_world_pattern, text, re.IGNORECASE)
|
62 |
+
|
63 |
+
def text_match_theoretical(text):
|
64 |
+
"""
|
65 |
+
Check if the strengths mentioned by reviewers are related to theoretical analysis / foundation.
|
66 |
+
"""
|
67 |
+
return re.search(theoretical_pattern, text, re.IGNORECASE)
|
68 |
+
|
69 |
+
def text_match_scalability(text):
|
70 |
+
"""
|
71 |
+
Check if the strengths mentioned by reviewers are related to scalability or time/space complexity.
|
72 |
+
"""
|
73 |
+
return re.search(scalability_pattern, text, re.IGNORECASE)
|
74 |
+
|
75 |
+
def match_strengths(text):
|
76 |
+
"""
|
77 |
+
Check if the strengths mentioned by reviewers contains the word "model".
|
78 |
+
"""
|
79 |
+
for category, f in {"problem": text_match_problem,
|
80 |
+
"dataset": text_match_dataset,
|
81 |
+
"model": text_match_model,
|
82 |
+
"limitation": text_match_limitation,
|
83 |
+
"experiment": text_match_experiment,
|
84 |
+
"theoretical-analysis": text_match_theoretical,
|
85 |
+
"scalability": text_match_scalability,
|
86 |
+
}.items():
|
87 |
+
if f(text):
|
88 |
+
return category
|
89 |
+
|
90 |
+
print(
|
91 |
+
Fore.RED
|
92 |
+
+ f"No category found for weaknesses: {text}"
|
93 |
+
+ Fore.BLACK
|
94 |
+
)
|
95 |
+
|
96 |
+
def match_weaknesses(text):
|
97 |
+
for category, f in {"problem": text_match_problem,
|
98 |
+
"dataset": text_match_dataset,
|
99 |
+
"model": text_match_model,
|
100 |
+
"limitation": text_match_limitation,
|
101 |
+
"experiment": text_match_experiment,
|
102 |
+
"real-world": text_match_real_world,
|
103 |
+
"theoretical-analysis": text_match_theoretical,
|
104 |
+
"scalability": text_match_scalability,
|
105 |
+
}:
|
106 |
+
if f(text):
|
107 |
+
return category
|
108 |
+
|
109 |
+
print(
|
110 |
+
Fore.RED
|
111 |
+
+ f"No category found for weaknesses: {text}"
|
112 |
+
+ Fore.BLACK
|
113 |
+
)
|
114 |
+
return None
|
115 |
+
|
116 |
+
raise ValueError(f"No category found for weaknesses: {text}")
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
def text_match_limitation(text):
|
122 |
+
"""
|
123 |
+
Check if the strengths mentioned by reviewers contains the word "model".
|
124 |
+
"""
|
125 |
+
return re.search(limitation_pattern, text, re.IGNORECASE)
|
agentreview/utility/utils.py
ADDED
@@ -0,0 +1,582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import os.path as osp
|
4 |
+
import random
|
5 |
+
import re
|
6 |
+
from collections import Counter
|
7 |
+
from typing import Union, List, Dict, Tuple
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
|
12 |
+
from agentreview import const
|
13 |
+
from agentreview.utility.general_utils import check_cwd, set_seed
|
14 |
+
|
15 |
+
|
16 |
+
def generate_num_papers_to_accept(n, batch_number, shuffle=True):
|
17 |
+
# Calculate the base value (minimum value in the array)
|
18 |
+
base_value = int(n // batch_number)
|
19 |
+
|
20 |
+
# Calculate how many elements need to be base_value + 1
|
21 |
+
remainder = int(n % batch_number)
|
22 |
+
|
23 |
+
# Initialize the array
|
24 |
+
array = []
|
25 |
+
|
26 |
+
# Add the elements to the array
|
27 |
+
for i in range(batch_number):
|
28 |
+
if i < remainder:
|
29 |
+
array.append(base_value + 1)
|
30 |
+
else:
|
31 |
+
array.append(base_value)
|
32 |
+
|
33 |
+
if shuffle:
|
34 |
+
random.shuffle(array)
|
35 |
+
|
36 |
+
return array
|
37 |
+
|
38 |
+
|
39 |
+
def get_papers_accepted_by_llm(llm_ac_decisions, acceptance_rate: float) -> list:
|
40 |
+
papers_accepted_by_llm = []
|
41 |
+
|
42 |
+
num_papers = sum([len(batch) for batch in llm_ac_decisions])
|
43 |
+
|
44 |
+
if num_papers == 0:
|
45 |
+
raise ValueError("No papers found in batch")
|
46 |
+
|
47 |
+
num_papers_to_accept = generate_num_papers_to_accept(n=acceptance_rate * num_papers,
|
48 |
+
batch_number=len(llm_ac_decisions))
|
49 |
+
|
50 |
+
for idx_batch, batch in enumerate(llm_ac_decisions):
|
51 |
+
tups = sorted([(paper_id, rank) for paper_id, rank in batch.items()], key=lambda x: x[1], reverse=False)
|
52 |
+
|
53 |
+
paper_ids = [int(paper_id) for paper_id, rank in tups]
|
54 |
+
|
55 |
+
papers_accepted_by_llm += paper_ids[:num_papers_to_accept[idx_batch]]
|
56 |
+
|
57 |
+
return papers_accepted_by_llm
|
58 |
+
|
59 |
+
|
60 |
+
def get_paper_decision_mapping(data_dir: str, conference: str, verbose: bool = False):
|
61 |
+
paper_id2decision, paper_decision2ids = {}, {}
|
62 |
+
path_paper_id2decision = os.path.join(data_dir, conference, "id2decision.json")
|
63 |
+
path_paper_decision2ids = os.path.join(data_dir, conference, "decision2ids.json")
|
64 |
+
|
65 |
+
if osp.exists(path_paper_id2decision) and osp.exists(path_paper_decision2ids):
|
66 |
+
paper_id2decision = json.load(open(path_paper_id2decision, 'r', encoding='utf-8'))
|
67 |
+
paper_decision2ids = json.load(open(path_paper_decision2ids, 'r', encoding='utf-8'))
|
68 |
+
|
69 |
+
paper_id2decision = {int(k): v for k, v in paper_id2decision.items()}
|
70 |
+
|
71 |
+
if verbose:
|
72 |
+
print(f"Loaded {len(paper_id2decision)} paper IDs to decisions from {path_paper_id2decision}")
|
73 |
+
|
74 |
+
else:
|
75 |
+
|
76 |
+
PAPER_DECISIONS = get_all_paper_decisions(conference)
|
77 |
+
|
78 |
+
for paper_decision in PAPER_DECISIONS:
|
79 |
+
|
80 |
+
paper_ids = os.listdir(os.path.join(data_dir, conference, "notes", paper_decision))
|
81 |
+
paper_ids = sorted(
|
82 |
+
[int(paper_id.split(".json")[0]) for paper_id in paper_ids if paper_id.endswith(".json")])
|
83 |
+
|
84 |
+
paper_id2decision.update({paper_id: paper_decision for paper_id in paper_ids})
|
85 |
+
paper_decision2ids[paper_decision] = paper_ids
|
86 |
+
|
87 |
+
if verbose:
|
88 |
+
print(f"{paper_decision}: {len(paper_ids)} papers")
|
89 |
+
|
90 |
+
json.dump(paper_id2decision, open(path_paper_id2decision, 'w', encoding='utf-8'), indent=2)
|
91 |
+
json.dump(paper_decision2ids, open(path_paper_decision2ids, 'w', encoding='utf-8'), indent=2)
|
92 |
+
|
93 |
+
return paper_id2decision, paper_decision2ids
|
94 |
+
|
95 |
+
|
96 |
+
def project_setup():
|
97 |
+
check_cwd()
|
98 |
+
import warnings
|
99 |
+
import pandas as pd
|
100 |
+
warnings.simplefilter(action='ignore', category=FutureWarning)
|
101 |
+
pd.set_option('display.max_rows', 40)
|
102 |
+
pd.set_option('display.max_columns', 20)
|
103 |
+
set_seed(42)
|
104 |
+
|
105 |
+
|
106 |
+
def get_next_review_id(path: str) -> int:
|
107 |
+
existing_review_ids = sorted([int(x.split('.json')[0].split('_')[1]) for x in os.listdir(path)])
|
108 |
+
next_review_id = 1
|
109 |
+
while next_review_id in existing_review_ids:
|
110 |
+
next_review_id += 1
|
111 |
+
print(f"Next review ID: {next_review_id}")
|
112 |
+
return next_review_id
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
+
def filter_paper_ids_from_initial_experiments(sampled_paper_ids: List[int]):
|
118 |
+
paper_ids_initial_experiments = json.load(open(f"outputs/paper_ids_initial_experiments.json"))
|
119 |
+
sampled_paper_ids = set(sampled_paper_ids) - set(paper_ids_initial_experiments)
|
120 |
+
sampled_paper_ids = sorted(list(sampled_paper_ids))
|
121 |
+
return sampled_paper_ids
|
122 |
+
|
123 |
+
|
124 |
+
def get_paper_review_and_rebuttal_dir(reviewer_type: str, conference: str, model_name: str, paper_id: int = None):
|
125 |
+
if reviewer_type == "NoOverallScore":
|
126 |
+
reviewer_type = "BASELINE"
|
127 |
+
|
128 |
+
path = f"outputs/paper_review_and_rebuttal" \
|
129 |
+
f"/{conference}/" \
|
130 |
+
f"{get_model_name_short(model_name)}/{reviewer_type}"
|
131 |
+
|
132 |
+
if paper_id is not None:
|
133 |
+
path += f"/{paper_id}"
|
134 |
+
|
135 |
+
return path
|
136 |
+
|
137 |
+
|
138 |
+
def get_rebuttal_dir(output_dir: str,
|
139 |
+
paper_id: Union[str, int, None],
|
140 |
+
experiment_name: str,
|
141 |
+
model_name: str,
|
142 |
+
conference: str):
|
143 |
+
|
144 |
+
path = os.path.join(output_dir, "paper_review", conference, get_model_name_short(model_name),
|
145 |
+
experiment_name)
|
146 |
+
|
147 |
+
if paper_id is not None:
|
148 |
+
path += f"/{paper_id}"
|
149 |
+
|
150 |
+
return path
|
151 |
+
|
152 |
+
|
153 |
+
def print_colored(text, color='red'):
|
154 |
+
|
155 |
+
# Dictionary of ANSI color codes for terminal
|
156 |
+
foreground_colors = {
|
157 |
+
'black': 30,
|
158 |
+
'red': 31,
|
159 |
+
'green': 32,
|
160 |
+
'yellow': 33,
|
161 |
+
'blue': 34,
|
162 |
+
'magenta': 35,
|
163 |
+
'cyan': 36,
|
164 |
+
'white': 37,
|
165 |
+
}
|
166 |
+
try:
|
167 |
+
|
168 |
+
# get_ipython is specific to Jupyter and IPython.
|
169 |
+
# We use this to decide whether we are running a Jupyter notebook or not.
|
170 |
+
get_ipython
|
171 |
+
print(text) # Plain text in Jupyter
|
172 |
+
except:
|
173 |
+
# If not Jupyter, print with color codes
|
174 |
+
color_code = foreground_colors.get(color, 31) # Default to red if color not found
|
175 |
+
print(f"\033[{color_code}m{text}\033[0m")
|
176 |
+
|
177 |
+
|
178 |
+
def get_ac_decision_path(output_dir: str, conference: str, model_name: str, ac_scoring_method: str, experiment_name:
|
179 |
+
str):
|
180 |
+
ac_decision_dir = os.path.join(output_dir, "decisions", conference,
|
181 |
+
get_model_name_short(model_name),
|
182 |
+
f"decisions_thru_{ac_scoring_method}")
|
183 |
+
os.makedirs(ac_decision_dir, exist_ok=True)
|
184 |
+
|
185 |
+
if isinstance(experiment_name, str):
|
186 |
+
ac_decision_dir += f"/decision_{experiment_name}.json"
|
187 |
+
|
188 |
+
return ac_decision_dir
|
189 |
+
|
190 |
+
|
191 |
+
def load_metareview(paper_id: int, **kwargs):
|
192 |
+
rebuttal_dir = get_rebuttal_dir(paper_id=paper_id, **kwargs)
|
193 |
+
|
194 |
+
path = f"{rebuttal_dir}/{paper_id}.json"
|
195 |
+
|
196 |
+
if not osp.exists(path):
|
197 |
+
print(f"Not Found: {path}")
|
198 |
+
return None
|
199 |
+
|
200 |
+
try:
|
201 |
+
reviews = json.load(open(path))
|
202 |
+
|
203 |
+
metareview = reviews["messages"][-1]
|
204 |
+
if not metareview["agent_name"].startswith("AC"):
|
205 |
+
return None
|
206 |
+
|
207 |
+
return metareview['content']
|
208 |
+
|
209 |
+
except FileNotFoundError:
|
210 |
+
return None
|
211 |
+
|
212 |
+
|
213 |
+
def get_reviewer_type_from_profile(profile: dict):
|
214 |
+
"""
|
215 |
+
Get a short name for the reviewer's type from the reviewer's experiment profile.
|
216 |
+
|
217 |
+
|
218 |
+
Input:
|
219 |
+
{
|
220 |
+
'is_benign': True,
|
221 |
+
'is_knowledgeable': None,
|
222 |
+
'is_responsible': None,
|
223 |
+
'provides_numeric_rating': True
|
224 |
+
}
|
225 |
+
|
226 |
+
Output:
|
227 |
+
"benign"
|
228 |
+
|
229 |
+
|
230 |
+
Input:
|
231 |
+
{
|
232 |
+
'is_benign': False,
|
233 |
+
'is_knowledgeable': None,
|
234 |
+
'is_responsible': None,
|
235 |
+
'provides_numeric_rating': True
|
236 |
+
}
|
237 |
+
|
238 |
+
Output:
|
239 |
+
"malicious"
|
240 |
+
|
241 |
+
|
242 |
+
Input:
|
243 |
+
{
|
244 |
+
'is_benign': None,
|
245 |
+
'is_knowledgeable': None,
|
246 |
+
'is_responsible': None,
|
247 |
+
'provides_numeric_rating': True
|
248 |
+
}
|
249 |
+
|
250 |
+
Output:
|
251 |
+
"default"
|
252 |
+
|
253 |
+
"""
|
254 |
+
|
255 |
+
reviewer_attributes = Counter([profile[k] for k in ["is_benign", 'is_knowledgeable', 'is_responsible']])
|
256 |
+
|
257 |
+
assert (reviewer_attributes[True] <= 1 and reviewer_attributes[False] <= 1) and reviewer_attributes[None] >= 2, \
|
258 |
+
("A reviewer can only have 0 or 1 of "
|
259 |
+
"these "
|
260 |
+
"properties profile to True or False")
|
261 |
+
|
262 |
+
if profile['is_benign']:
|
263 |
+
return "benign"
|
264 |
+
elif profile['is_benign'] == False:
|
265 |
+
# NOTE: We cannot use `not profile['is_benign']` as we need to consider the case where `profile['is_benign']`
|
266 |
+
# is
|
267 |
+
# None
|
268 |
+
return "malicious"
|
269 |
+
|
270 |
+
elif profile['is_knowledgeable']:
|
271 |
+
return "knowledgeable"
|
272 |
+
|
273 |
+
elif profile['is_knowledgeable'] == False:
|
274 |
+
# Same as above
|
275 |
+
return "unknowledgeable"
|
276 |
+
|
277 |
+
elif profile['is_responsible']:
|
278 |
+
return "responsible"
|
279 |
+
elif profile['is_responsible'] == False:
|
280 |
+
# Same as above
|
281 |
+
return "irresponsible"
|
282 |
+
|
283 |
+
elif profile['provides_numeric_rating'] == False:
|
284 |
+
return "NoOverallScore"
|
285 |
+
|
286 |
+
elif profile.get('knows_authors') == "famous":
|
287 |
+
return "authors_are_famous"
|
288 |
+
|
289 |
+
elif profile.get('knows_authors') == "unfamous":
|
290 |
+
return "authors_are_unfamous"
|
291 |
+
|
292 |
+
else:
|
293 |
+
return "BASELINE"
|
294 |
+
|
295 |
+
|
296 |
+
def get_ac_type_from_profile(profile: dict):
|
297 |
+
return None
|
298 |
+
|
299 |
+
|
300 |
+
# def get_ac_type_from_profile(profile: dict):
|
301 |
+
# """
|
302 |
+
# Get a short name for the area chair's type from their profile in the experiment setting.
|
303 |
+
#
|
304 |
+
# """
|
305 |
+
|
306 |
+
def format_metareviews(metareviews: List[str], paper_ids: List[int]):
|
307 |
+
metareviews_formatted = ""
|
308 |
+
|
309 |
+
for paper_id, metareview in zip(paper_ids, metareviews):
|
310 |
+
metareview = re.sub('\n+', '\n', metareview)
|
311 |
+
metareviews_formatted += (f"Paper ID: {paper_id}\nMetareview: "
|
312 |
+
f"{metareview}\n{'-' * 5}\n")
|
313 |
+
|
314 |
+
return metareviews_formatted
|
315 |
+
|
316 |
+
|
317 |
+
def get_all_paper_decisions(conference: str) -> List[str]:
|
318 |
+
if conference in ["ICLR2019", "ICLR2018"]:
|
319 |
+
return const.PAPER_DECISIONS_ICLR2019
|
320 |
+
|
321 |
+
else:
|
322 |
+
return const.PAPER_DECISIONS
|
323 |
+
|
324 |
+
|
325 |
+
def get_paper_ids_of_known_authors(conference: str, num_papers: int, decision: str = None):
|
326 |
+
paper_id2decision, paper_decision2ids = get_paper_decision_mapping(conference)
|
327 |
+
paper_ids_of_famous_authors = paper_decision2ids[decision][:num_papers]
|
328 |
+
return paper_ids_of_famous_authors
|
329 |
+
|
330 |
+
|
331 |
+
def get_experiment_names(conference: str = "ICLR2023"):
|
332 |
+
experiment_names = ["BASELINE"]
|
333 |
+
|
334 |
+
# The following are settings for reviewer types
|
335 |
+
# Varying reviewer commitment
|
336 |
+
experiment_names += ["responsible_Rx1"]
|
337 |
+
experiment_names += ["irresponsible_Rx1"]
|
338 |
+
|
339 |
+
# Varying reviewer intention
|
340 |
+
experiment_names += ["benign_Rx1"]
|
341 |
+
experiment_names += ["malicious_Rx1"]
|
342 |
+
|
343 |
+
# Varying reviewer knowledgeability
|
344 |
+
experiment_names += ["knowledgeable_Rx1"]
|
345 |
+
experiment_names += ["unknowledgeable_Rx1"]
|
346 |
+
|
347 |
+
# The following are settings for AC types
|
348 |
+
experiment_names += ["conformist_ACx1", "authoritarian_ACx1", "inclusive_ACx1"]
|
349 |
+
|
350 |
+
# Enable these for ICLR2023
|
351 |
+
if conference == "ICLR2023":
|
352 |
+
experiment_names += ["no_rebuttal"]
|
353 |
+
experiment_names += ["no_overall_score"]
|
354 |
+
experiment_names += ["malicious_Rx2"]
|
355 |
+
experiment_names += ["malicious_Rx3"]
|
356 |
+
experiment_names += ["irresponsible_Rx2"]
|
357 |
+
experiment_names += ["irresponsible_Rx3"]
|
358 |
+
experiment_names += ["authors_are_famous_Rx1"]
|
359 |
+
experiment_names += ["authors_are_famous_Rx2"]
|
360 |
+
experiment_names += ["authors_are_famous_Rx3"]
|
361 |
+
|
362 |
+
return experiment_names
|
363 |
+
|
364 |
+
|
365 |
+
def load_llm_ac_decisions_as_array(
|
366 |
+
output_dir: str,
|
367 |
+
experiment_name: str,
|
368 |
+
ac_scoring_method: str,
|
369 |
+
acceptance_rate: float,
|
370 |
+
conference: str,
|
371 |
+
model_name: str,
|
372 |
+
num_papers_per_area_chair: int
|
373 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
374 |
+
"""Loads and processes GPT-4 generated area chair (AC) decisions for an experiment.
|
375 |
+
|
376 |
+
Args:
|
377 |
+
experiment_name (str): Name of the experiment.
|
378 |
+
ac_scoring_method (str): Method used for AC scoring ('ranking' or 'recommendation').
|
379 |
+
acceptance_rate (float): Acceptance rate for the conference.
|
380 |
+
conference (str): Name of the conference.
|
381 |
+
model_name (str): Model name used to generate AC decisions.
|
382 |
+
num_papers_per_area_chair (int): Number of papers assigned to each area chair.
|
383 |
+
|
384 |
+
Returns:
|
385 |
+
Tuple[np.ndarray, np.ndarray]: An array of decisions (True for accept, False for reject)
|
386 |
+
and an array of paper IDs in the order processed.
|
387 |
+
|
388 |
+
Raises:
|
389 |
+
NotImplementedError: If `ac_scoring_method` is not 'ranking' or 'recommendation'.
|
390 |
+
"""
|
391 |
+
print("=" * 30)
|
392 |
+
print(f"Experiment Name: {experiment_name}")
|
393 |
+
|
394 |
+
llm_ac_decisions = load_llm_ac_decisions(
|
395 |
+
output_dir=output_dir,
|
396 |
+
conference=conference,
|
397 |
+
model_name=model_name,
|
398 |
+
ac_scoring_method=ac_scoring_method,
|
399 |
+
experiment_name=experiment_name,
|
400 |
+
num_papers_per_area_chair=num_papers_per_area_chair
|
401 |
+
)
|
402 |
+
|
403 |
+
paper_ids = sorted(
|
404 |
+
int(paper_id) for batch in llm_ac_decisions for paper_id in batch
|
405 |
+
)
|
406 |
+
|
407 |
+
if ac_scoring_method == "ranking":
|
408 |
+
if len(paper_ids) != len(set(paper_ids)):
|
409 |
+
raise ValueError(f"Duplicate paper_ids found in the AC decisions: {Counter(paper_ids)}")
|
410 |
+
|
411 |
+
papers_accepted_by_llm = get_papers_accepted_by_llm(llm_ac_decisions, acceptance_rate)
|
412 |
+
decisions_llm = np.array([paper_id in papers_accepted_by_llm for paper_id in paper_ids])
|
413 |
+
|
414 |
+
elif ac_scoring_method == "recommendation":
|
415 |
+
llm_ac_decisions_flat = {int(k): v for batch in llm_ac_decisions for k, v in batch.items()}
|
416 |
+
decisions_llm = np.array(
|
417 |
+
[llm_ac_decisions_flat[paper_id].startswith("Accept") for paper_id in paper_ids]
|
418 |
+
)
|
419 |
+
else:
|
420 |
+
raise NotImplementedError(f"Scoring method '{ac_scoring_method}' not implemented.")
|
421 |
+
|
422 |
+
return decisions_llm, np.array(paper_ids)
|
423 |
+
|
424 |
+
|
425 |
+
def load_llm_ac_decisions(
|
426 |
+
output_dir: str,
|
427 |
+
conference: str,
|
428 |
+
model_name: str,
|
429 |
+
ac_scoring_method: str,
|
430 |
+
experiment_name: str,
|
431 |
+
num_papers_per_area_chair: int
|
432 |
+
) -> List[Dict[str, str]]:
|
433 |
+
"""Loads GPT-4 generated area chair (AC) decisions from a specified path.
|
434 |
+
|
435 |
+
Args:
|
436 |
+
conference (str): Name of the conference.
|
437 |
+
model_name (str): Model name used to generate AC decisions.
|
438 |
+
ac_scoring_method (str): Method used for AC scoring ('ranking' or 'recommendation').
|
439 |
+
experiment_name (str): Name of the experiment.
|
440 |
+
num_papers_per_area_chair (int): Number of papers assigned to each area chair.
|
441 |
+
|
442 |
+
Returns:
|
443 |
+
List[Dict[str, str]]: List of batches, where each batch contains paper ID and decision.
|
444 |
+
|
445 |
+
Raises:
|
446 |
+
AssertionError: If a non-final batch has a paper count different from `num_papers_per_area_chair`.
|
447 |
+
"""
|
448 |
+
path = get_ac_decision_path(
|
449 |
+
output_dir=output_dir,
|
450 |
+
conference=conference,
|
451 |
+
model_name=model_name,
|
452 |
+
ac_scoring_method=ac_scoring_method,
|
453 |
+
experiment_name=experiment_name
|
454 |
+
)
|
455 |
+
|
456 |
+
if osp.exists(path):
|
457 |
+
with open(path, 'r', encoding='utf-8') as file:
|
458 |
+
ac_decision = json.load(file)
|
459 |
+
print(f"Loaded {len(ac_decision)} batches of existing AC decisions from {path}")
|
460 |
+
else:
|
461 |
+
ac_decision = []
|
462 |
+
print(f"No existing AC decisions found at {path}")
|
463 |
+
|
464 |
+
ac_decision = [batch for batch in ac_decision if batch] # Remove empty batches
|
465 |
+
|
466 |
+
for i, batch in enumerate(ac_decision):
|
467 |
+
if i != len(ac_decision) - 1:
|
468 |
+
if len(batch) != num_papers_per_area_chair:
|
469 |
+
raise AssertionError(
|
470 |
+
f"Batch {i} has {len(batch)} papers, expected {num_papers_per_area_chair} for non-final batches."
|
471 |
+
)
|
472 |
+
|
473 |
+
return ac_decision
|
474 |
+
|
475 |
+
def write_to_excel(data, file_path, sheet_name):
|
476 |
+
"""
|
477 |
+
Write data to an Excel file.
|
478 |
+
|
479 |
+
Parameters:
|
480 |
+
data (pd.DataFrame): The data to write to the Excel file.
|
481 |
+
file_path (str): The path to the Excel file.
|
482 |
+
sheet_name (str): The name of the sheet to write to.
|
483 |
+
"""
|
484 |
+
# Check if the file exists
|
485 |
+
if os.path.exists(file_path):
|
486 |
+
# If the file exists, load it
|
487 |
+
with pd.ExcelWriter(file_path, mode='a', engine='openpyxl', if_sheet_exists='replace') as writer:
|
488 |
+
data.to_excel(writer, sheet_name=sheet_name, index=False)
|
489 |
+
else:
|
490 |
+
# If the file does not exist, create it
|
491 |
+
with pd.ExcelWriter(file_path, engine='openpyxl') as writer:
|
492 |
+
data.to_excel(writer, sheet_name=sheet_name, index=False)
|
493 |
+
|
494 |
+
|
495 |
+
def save_llm_ac_decisions(ac_decisions: List[dict], **kwargs):
|
496 |
+
path = get_ac_decision_path(**kwargs)
|
497 |
+
|
498 |
+
json.dump(ac_decisions, open(path, 'w', encoding='utf-8'), indent=2)
|
499 |
+
|
500 |
+
|
501 |
+
def get_model_name_short(name: str):
|
502 |
+
"""
|
503 |
+
Convert long model names (e.g. `gpt-35-turbo`) to short model names (e.g. `gpt-35`)
|
504 |
+
Args:
|
505 |
+
name (str): long model name
|
506 |
+
|
507 |
+
Returns:
|
508 |
+
str: short model name
|
509 |
+
"""
|
510 |
+
|
511 |
+
assert name.startswith('gpt-')
|
512 |
+
return '-'.join(name.split('-')[:2])
|
513 |
+
|
514 |
+
|
515 |
+
def get_reviewer_types_from_experiment_name(experiment_name: str):
|
516 |
+
if experiment_name in ["BASELINE", 'inclusive_ACx1', 'authoritarian_ACx1', 'conformist_ACx1',
|
517 |
+
"no_rebuttal"]:
|
518 |
+
reviewer_types = ["BASELINE", "BASELINE", "BASELINE"]
|
519 |
+
|
520 |
+
elif experiment_name == "benign_Rx1":
|
521 |
+
|
522 |
+
reviewer_types = ["benign", "BASELINE", "BASELINE"]
|
523 |
+
|
524 |
+
elif experiment_name == "benign_Rx2":
|
525 |
+
|
526 |
+
reviewer_types = ["benign", "benign", "BASELINE"]
|
527 |
+
|
528 |
+
elif experiment_name == "malicious_Rx1":
|
529 |
+
|
530 |
+
reviewer_types = ["malicious", "BASELINE", "BASELINE"]
|
531 |
+
|
532 |
+
elif experiment_name == "malicious_Rx2":
|
533 |
+
|
534 |
+
reviewer_types = ["malicious", "malicious", "BASELINE"]
|
535 |
+
|
536 |
+
elif experiment_name == "malicious_Rx3":
|
537 |
+
|
538 |
+
reviewer_types = ["malicious", "malicious", "malicious"]
|
539 |
+
|
540 |
+
elif experiment_name == "knowledgeable_Rx1":
|
541 |
+
|
542 |
+
reviewer_types = ["knowledgeable", "BASELINE", "BASELINE"]
|
543 |
+
|
544 |
+
elif experiment_name == "unknowledgeable_Rx1":
|
545 |
+
|
546 |
+
reviewer_types = ["unknowledgeable", "BASELINE", "BASELINE"]
|
547 |
+
|
548 |
+
elif experiment_name == "responsible_Rx1":
|
549 |
+
|
550 |
+
reviewer_types = ["responsible", "BASELINE", "BASELINE"]
|
551 |
+
|
552 |
+
elif experiment_name == "irresponsible_Rx1":
|
553 |
+
|
554 |
+
reviewer_types = ["irresponsible", "BASELINE", "BASELINE"]
|
555 |
+
|
556 |
+
elif experiment_name == "irresponsible_Rx2":
|
557 |
+
|
558 |
+
reviewer_types = ["irresponsible", "irresponsible", "BASELINE"]
|
559 |
+
|
560 |
+
elif experiment_name == "irresponsible_Rx3":
|
561 |
+
|
562 |
+
reviewer_types = ["irresponsible", "irresponsible", "irresponsible"]
|
563 |
+
|
564 |
+
elif experiment_name in ["no_overall_score"]:
|
565 |
+
reviewer_types = ["NoOverallScore", "NoOverallScore", "NoOverallScore"]
|
566 |
+
|
567 |
+
elif experiment_name in ["authors_are_famous_Rx1", "authors_are_famous_Rx1_no_rebuttal"]:
|
568 |
+
|
569 |
+
reviewer_types = ["authors_are_famous", "BASELINE", "BASELINE"]
|
570 |
+
|
571 |
+
elif experiment_name in ["authors_are_famous_Rx2", "authors_are_famous_Rx2_no_rebuttal"]:
|
572 |
+
|
573 |
+
reviewer_types = ["authors_are_famous", "authors_are_famous", "BASELINE"]
|
574 |
+
|
575 |
+
elif experiment_name in ["authors_are_famous_Rx3", "authors_are_famous_Rx3_no_rebuttal"]:
|
576 |
+
|
577 |
+
reviewer_types = ["authors_are_famous", "authors_are_famous", "authors_are_famous"]
|
578 |
+
|
579 |
+
else:
|
580 |
+
raise NotImplementedError
|
581 |
+
|
582 |
+
return reviewer_types
|
agentreview/utils.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
|
4 |
+
|
5 |
+
def is_json(myjson):
|
6 |
+
"""
|
7 |
+
Checks whether a given string is a valid JSON.
|
8 |
+
|
9 |
+
Parameters:
|
10 |
+
myjson (str): The string to be checked.
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
bool: True if the string is a valid JSON, False otherwise.
|
14 |
+
"""
|
15 |
+
try:
|
16 |
+
_ = json.loads(myjson)
|
17 |
+
except ValueError:
|
18 |
+
return False
|
19 |
+
return True
|
20 |
+
|
21 |
+
|
22 |
+
def is_json_inside(text):
|
23 |
+
"""
|
24 |
+
Checks whether a given string contains valid JSON(s).
|
25 |
+
|
26 |
+
Parameters:
|
27 |
+
text (str): The string to be checked.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
bool: True if the string contains valid JSON(s), False otherwise.
|
31 |
+
"""
|
32 |
+
text = re.sub(r"\s+", " ", text)
|
33 |
+
matches = re.findall(r"\{.*?\}", text)
|
34 |
+
for match in matches:
|
35 |
+
if is_json(match):
|
36 |
+
return True
|
37 |
+
return False
|
38 |
+
|
39 |
+
|
40 |
+
def extract_jsons(text):
|
41 |
+
"""
|
42 |
+
Extracts all valid JSON objects from a given string.
|
43 |
+
|
44 |
+
Parameters:
|
45 |
+
text (str): The string from which JSON objects are to be extracted.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
List[Dict]: A list of all extracted JSON objects.
|
49 |
+
"""
|
50 |
+
text = re.sub(r"\s+", " ", text)
|
51 |
+
matches = re.findall(r"\{.*?\}", text)
|
52 |
+
parsed_jsons = []
|
53 |
+
for match in matches:
|
54 |
+
try:
|
55 |
+
json_object = json.loads(match)
|
56 |
+
parsed_jsons.append(json_object)
|
57 |
+
except ValueError:
|
58 |
+
pass
|
59 |
+
return parsed_jsons
|
60 |
+
|
61 |
+
|
62 |
+
def extract_code(text):
|
63 |
+
"""
|
64 |
+
Extracts all code blocks encapsulated by '```' from a given string.
|
65 |
+
|
66 |
+
Parameters:
|
67 |
+
text (str): The string from which Python code blocks are to be extracted.
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
List[str]: A list of all extracted Python code blocks.
|
71 |
+
"""
|
72 |
+
text = re.sub("```python", "```", text)
|
73 |
+
matches = re.findall(r"```(.*?)```", text, re.DOTALL)
|
74 |
+
parsed_codes = []
|
75 |
+
for match in matches:
|
76 |
+
parsed_codes.append(match)
|
77 |
+
return parsed_codes
|
78 |
+
|
79 |
+
|
80 |
+
class AttributedDict(dict):
|
81 |
+
"""
|
82 |
+
A dictionary class whose keys are automatically set as attributes of the class.
|
83 |
+
|
84 |
+
The dictionary is serializable to JSON.
|
85 |
+
|
86 |
+
Inherits from:
|
87 |
+
dict: Built-in dictionary class in Python.
|
88 |
+
|
89 |
+
Note:
|
90 |
+
This class provides attribute-style access to dictionary keys, meaning you can use dot notation
|
91 |
+
(like `my_dict.my_key`) in addition to the traditional bracket notation (`my_dict['my_key']`).
|
92 |
+
"""
|
93 |
+
|
94 |
+
def __init__(self, *args, **kwargs):
|
95 |
+
super().__init__(*args, **kwargs)
|
96 |
+
|
97 |
+
def __setattr__(self, key, value):
|
98 |
+
self[key] = value
|
99 |
+
|
100 |
+
def __getattr__(self, key):
|
101 |
+
if key in self:
|
102 |
+
return self[key]
|
103 |
+
raise AttributeError
|
104 |
+
|
105 |
+
def __delattr__(self, key):
|
106 |
+
del self[key]
|
107 |
+
|
108 |
+
# check whether the key is string when adding the key
|
109 |
+
def __setitem__(self, key, value):
|
110 |
+
if not isinstance(key, str):
|
111 |
+
raise ValueError("The key must be a string")
|
112 |
+
super().__setitem__(key, value)
|
113 |
+
|
114 |
+
def update(self, *args, **kwargs):
|
115 |
+
for key, value in dict(*args, **kwargs).items():
|
116 |
+
self[key] = value
|
app.py
ADDED
@@ -0,0 +1,612 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
from glob import glob
|
4 |
+
from argparse import Namespace
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
|
9 |
+
from agentreview import const
|
10 |
+
from agentreview.config import AgentConfig
|
11 |
+
from agentreview.agent import Player
|
12 |
+
from agentreview.backends import BACKEND_REGISTRY
|
13 |
+
from agentreview.environments import PaperReview
|
14 |
+
from agentreview.paper_review_arena import PaperReviewArena
|
15 |
+
from agentreview.utility.experiment_utils import initialize_players
|
16 |
+
from agentreview.paper_review_player import PaperExtractorPlayer, AreaChair, Reviewer
|
17 |
+
from agentreview.role_descriptions import get_reviewer_description, get_ac_description, get_author_config, get_paper_extractor_config
|
18 |
+
|
19 |
+
# 该文件的使命是前端交互:构建前端页面,从页面中获取用户的配置,传入后端运行,将结果实时展示到相应模块
|
20 |
+
|
21 |
+
css = """#col-container {max-width: 90%; margin-left: auto; margin-right: auto; display: flex; flex-direction: column;}
|
22 |
+
#header {text-align: center;}
|
23 |
+
#col-chatbox {flex: 1; max-height: min(900px, 100%);}
|
24 |
+
#label {font-size: 2em; padding: 0.5em; margin: 0;}
|
25 |
+
.message {font-size: 1.2em;}
|
26 |
+
.message-wrap {max-height: min(700px, 100vh);}
|
27 |
+
"""
|
28 |
+
# .wrap {min-width: min(640px, 100vh)}
|
29 |
+
# #env-desc {max-height: 100px; overflow-y: auto;}
|
30 |
+
# .textarea {height: 100px; max-height: 100px;}
|
31 |
+
# #chatbot-tab-all {height: 750px; max-height: min(750px, 100%);}
|
32 |
+
# #chatbox {height: min(750px, 100%); max-height: min(750px, 100%);}
|
33 |
+
# #chatbox.block {height: 730px}
|
34 |
+
# .wrap {max-height: 680px;}
|
35 |
+
# .scroll-hide {overflow-y: scroll; max-height: 100px;}
|
36 |
+
|
37 |
+
DEBUG = False
|
38 |
+
|
39 |
+
DEFAULT_BACKEND = "openai-chat"
|
40 |
+
MAX_NUM_PLAYERS = 4
|
41 |
+
DEFAULT_NUM_PLAYERS = 4
|
42 |
+
CURRENT_STEP_INDEX = 0
|
43 |
+
|
44 |
+
def load_examples():
|
45 |
+
example_configs = {}
|
46 |
+
# Load json config files from examples folder
|
47 |
+
example_files = glob("examples/*.json")
|
48 |
+
for example_file in example_files:
|
49 |
+
with open(example_file, encoding="utf-8") as f:
|
50 |
+
example = json.load(f)
|
51 |
+
try:
|
52 |
+
example_configs[example["name"]] = example
|
53 |
+
except KeyError:
|
54 |
+
print(f"Example {example_file} is missing a name field. Skipping.")
|
55 |
+
return example_configs
|
56 |
+
|
57 |
+
|
58 |
+
EXAMPLE_REGISTRY = load_examples()
|
59 |
+
|
60 |
+
# DB = SupabaseDB() if supabase_available else None
|
61 |
+
|
62 |
+
def get_player_components(name, visible):
|
63 |
+
with gr.Row():
|
64 |
+
with gr.Column():
|
65 |
+
role_name = gr.Textbox(
|
66 |
+
lines=1,
|
67 |
+
show_label=False,
|
68 |
+
interactive=True,
|
69 |
+
visible=False,
|
70 |
+
value=name,
|
71 |
+
)
|
72 |
+
|
73 |
+
# is benign, is_knowledgeable, is_responsible,
|
74 |
+
# player_config = gr.CheckboxGroup(
|
75 |
+
# choices=["Benign", "Knowledgeable", "Responsible"],
|
76 |
+
# label="Reviewer Type",
|
77 |
+
# visible=visible,
|
78 |
+
# )
|
79 |
+
|
80 |
+
with gr.Row():
|
81 |
+
# 将三个属性做成dropdown
|
82 |
+
Intention_config = gr.Dropdown(
|
83 |
+
choices=["Benign", "Malicious", "Neutral"],
|
84 |
+
interactive=True,
|
85 |
+
label = "Intention",
|
86 |
+
show_label=True,
|
87 |
+
value="Neutral",
|
88 |
+
)
|
89 |
+
|
90 |
+
Knowledge_config = gr.Dropdown(
|
91 |
+
choices=["Knowledgeable", "Unknownledgeable", "Normal"],
|
92 |
+
interactive=True,
|
93 |
+
label = "Knowledgeability",
|
94 |
+
show_label=True,
|
95 |
+
value="Normal",
|
96 |
+
)
|
97 |
+
|
98 |
+
Responsibility_config = gr.Dropdown(
|
99 |
+
choices=["Responsible", "Lazy", "Normal"],
|
100 |
+
interactive=True,
|
101 |
+
label = "Responsibility",
|
102 |
+
show_label=True,
|
103 |
+
value="Normal",
|
104 |
+
)
|
105 |
+
|
106 |
+
|
107 |
+
role_desc = gr.Textbox(
|
108 |
+
lines=8,
|
109 |
+
max_lines=8,
|
110 |
+
show_label=False,
|
111 |
+
interactive=True,
|
112 |
+
visible=visible,
|
113 |
+
autoscroll=False,
|
114 |
+
value=get_reviewer_description()
|
115 |
+
)
|
116 |
+
|
117 |
+
# role_desc = gr.Markdown(value=get_reviewer_description(),
|
118 |
+
# visible=visible)
|
119 |
+
|
120 |
+
def update_role_desc(Intention_config, Knowledge_config, Responsibility_config):
|
121 |
+
|
122 |
+
is_benign = True if Intention_config == "Benign" else (False if Intention_config == "Malicious" else None)
|
123 |
+
is_knowledgeable = True if Knowledge_config == "Knowledgeable" else (False if Knowledge_config == "Unknownledgeable" else None)
|
124 |
+
is_responsible = True if Responsibility_config == "Responsible" else (False if Responsibility_config == "Lazy" else None)
|
125 |
+
|
126 |
+
phase = 'reviewer_write_reviews' if CURRENT_STEP_INDEX < 2 else 'reviewer_ac_discussion'
|
127 |
+
return get_reviewer_description(is_benign, is_knowledgeable, is_responsible, phase=phase) # FIXME:依据阶段变化
|
128 |
+
|
129 |
+
Intention_config.select(fn=update_role_desc, inputs=[Intention_config, Knowledge_config, Responsibility_config], outputs=[role_desc])
|
130 |
+
Knowledge_config.select(fn=update_role_desc, inputs=[Intention_config, Knowledge_config, Responsibility_config], outputs=[role_desc])
|
131 |
+
Responsibility_config.select(fn=update_role_desc, inputs=[Intention_config, Knowledge_config, Responsibility_config], outputs=[role_desc])
|
132 |
+
|
133 |
+
with gr.Column():
|
134 |
+
backend_type = gr.Dropdown(
|
135 |
+
show_label=False,
|
136 |
+
choices=list(BACKEND_REGISTRY.keys()),
|
137 |
+
interactive=True,
|
138 |
+
visible=visible,
|
139 |
+
value=DEFAULT_BACKEND,
|
140 |
+
)
|
141 |
+
with gr.Accordion(
|
142 |
+
f"{name} Parameters", open=False, visible=visible
|
143 |
+
) as accordion:
|
144 |
+
temperature = gr.Slider(
|
145 |
+
minimum=0,
|
146 |
+
maximum=2.0,
|
147 |
+
step=0.1,
|
148 |
+
interactive=True,
|
149 |
+
visible=visible,
|
150 |
+
label="temperature",
|
151 |
+
value=0.7,
|
152 |
+
)
|
153 |
+
max_tokens = gr.Slider(
|
154 |
+
minimum=10,
|
155 |
+
maximum=500,
|
156 |
+
step=10,
|
157 |
+
interactive=True,
|
158 |
+
visible=visible,
|
159 |
+
label="max tokens",
|
160 |
+
value=200,
|
161 |
+
)
|
162 |
+
|
163 |
+
return [role_name, Intention_config, Knowledge_config, Responsibility_config, backend_type, accordion, temperature, max_tokens]
|
164 |
+
|
165 |
+
def get_area_chair_components(name, visible):
|
166 |
+
with gr.Row():
|
167 |
+
with gr.Column():
|
168 |
+
|
169 |
+
role_name = gr.Textbox(
|
170 |
+
lines=1,
|
171 |
+
show_label=False,
|
172 |
+
interactive=True,
|
173 |
+
visible=False,
|
174 |
+
value=name,
|
175 |
+
)
|
176 |
+
|
177 |
+
AC_type = gr.Dropdown(
|
178 |
+
label = "AC Type",
|
179 |
+
show_label=True,
|
180 |
+
choices=["Inclusive", "Conformist", "Authoritarian", "Normal"],
|
181 |
+
interactive=True,
|
182 |
+
visible=visible,
|
183 |
+
value="Normal",
|
184 |
+
)
|
185 |
+
|
186 |
+
role_desc = gr.Textbox(
|
187 |
+
lines=8,
|
188 |
+
max_lines=8,
|
189 |
+
show_label=False,
|
190 |
+
interactive=True,
|
191 |
+
visible=visible,
|
192 |
+
value=get_ac_description("BASELINE", "ac_write_metareviews", 'None', 1),
|
193 |
+
)
|
194 |
+
|
195 |
+
def update_role_desc(AC_type):
|
196 |
+
ac_type = 'BASELINE' if AC_type == "Normal" else AC_type.lower()
|
197 |
+
return get_ac_description(ac_type, "ac_write_metareviews", "None", 1) # FIXME:依据阶段变化
|
198 |
+
|
199 |
+
AC_type.select(fn=update_role_desc, inputs=[AC_type], outputs=[role_desc])
|
200 |
+
|
201 |
+
with gr.Column():
|
202 |
+
backend_type = gr.Dropdown(
|
203 |
+
show_label=False,
|
204 |
+
choices=list(BACKEND_REGISTRY.keys()),
|
205 |
+
interactive=True,
|
206 |
+
visible=visible,
|
207 |
+
value=DEFAULT_BACKEND,
|
208 |
+
)
|
209 |
+
with gr.Accordion(
|
210 |
+
f"{name} Parameters", open=False, visible=visible
|
211 |
+
) as accordion:
|
212 |
+
temperature = gr.Slider(
|
213 |
+
minimum=0,
|
214 |
+
maximum=2.0,
|
215 |
+
step=0.1,
|
216 |
+
interactive=True,
|
217 |
+
visible=visible,
|
218 |
+
label="temperature",
|
219 |
+
value=0.7,
|
220 |
+
)
|
221 |
+
max_tokens = gr.Slider(
|
222 |
+
minimum=10,
|
223 |
+
maximum=500,
|
224 |
+
step=10,
|
225 |
+
interactive=True,
|
226 |
+
visible=visible,
|
227 |
+
label="max tokens",
|
228 |
+
value=200,
|
229 |
+
)
|
230 |
+
|
231 |
+
return [role_name, AC_type, backend_type, accordion, temperature, max_tokens]
|
232 |
+
|
233 |
+
|
234 |
+
def get_empty_state():
|
235 |
+
return gr.State({"arena": None})
|
236 |
+
|
237 |
+
|
238 |
+
with gr.Blocks(css=css) as demo:
|
239 |
+
state = get_empty_state()
|
240 |
+
all_components = []
|
241 |
+
|
242 |
+
with gr.Column(elem_id="col-container"):
|
243 |
+
gr.Markdown(
|
244 |
+
"""# 🤖 AgentReview<br>
|
245 |
+
Using Multi-Agent to review your paper!.
|
246 |
+
**[Project Homepage](https://github.com/Ahren09/AgentReview)**""",
|
247 |
+
elem_id="header",
|
248 |
+
)
|
249 |
+
|
250 |
+
# Environment configuration
|
251 |
+
env_desc_textbox = gr.Textbox(
|
252 |
+
show_label=True,
|
253 |
+
lines=2,
|
254 |
+
visible=True,
|
255 |
+
label="Environment Description",
|
256 |
+
interactive=True,
|
257 |
+
# placeholder="Enter a description of a scenario or the game rules.",
|
258 |
+
value=const.GLOBAL_PROMPT,
|
259 |
+
)
|
260 |
+
|
261 |
+
all_components += [env_desc_textbox]
|
262 |
+
|
263 |
+
with gr.Row():
|
264 |
+
with gr.Column(elem_id="col-chatbox"):
|
265 |
+
with gr.Tab("All", visible=True):
|
266 |
+
chatbot = gr.Chatbot(
|
267 |
+
elem_id="chatbox", visible=True, show_label=False, height=600
|
268 |
+
)
|
269 |
+
|
270 |
+
player_chatbots = []
|
271 |
+
for i in range(MAX_NUM_PLAYERS):
|
272 |
+
player_name = f"Reviewer {i + 1}" if i < MAX_NUM_PLAYERS-1 else "AC"
|
273 |
+
with gr.Tab(player_name, visible=(i < DEFAULT_NUM_PLAYERS)):
|
274 |
+
player_chatbot = gr.Chatbot(
|
275 |
+
elem_id=f"chatbox-{i}",
|
276 |
+
visible=i < DEFAULT_NUM_PLAYERS,
|
277 |
+
label=player_name,
|
278 |
+
show_label=False,
|
279 |
+
height=600, # FIXME: 无效设置
|
280 |
+
)
|
281 |
+
player_chatbots.append(player_chatbot)
|
282 |
+
|
283 |
+
all_components += [chatbot, *player_chatbots]
|
284 |
+
|
285 |
+
with gr.Column(elem_id="col-config"): # Player Configuration
|
286 |
+
# gr.Markdown("Player Configuration")
|
287 |
+
|
288 |
+
# parallel_checkbox = gr.Checkbox(
|
289 |
+
# label="Parallel Actions", value=False, visible=True
|
290 |
+
# )
|
291 |
+
|
292 |
+
all_players_components, players_idx2comp = [], {}
|
293 |
+
with gr.Blocks():
|
294 |
+
for i in range(MAX_NUM_PLAYERS):
|
295 |
+
|
296 |
+
player_name = f"Reviewer {i + 1}" if i < MAX_NUM_PLAYERS-1 else "AC"
|
297 |
+
with gr.Tab(
|
298 |
+
player_name, visible=(i < DEFAULT_NUM_PLAYERS)
|
299 |
+
) as tab:
|
300 |
+
if player_name != "AC":
|
301 |
+
player_comps = get_player_components(
|
302 |
+
player_name, visible=(i < DEFAULT_NUM_PLAYERS)
|
303 |
+
)
|
304 |
+
else:
|
305 |
+
player_comps = get_area_chair_components(
|
306 |
+
player_name, visible=(i < DEFAULT_NUM_PLAYERS)
|
307 |
+
)
|
308 |
+
|
309 |
+
players_idx2comp[i] = player_comps + [tab]
|
310 |
+
all_players_components += player_comps + [tab]
|
311 |
+
|
312 |
+
all_components += all_players_components
|
313 |
+
|
314 |
+
# human_input_textbox = gr.Textbox(
|
315 |
+
# show_label=True,
|
316 |
+
# label="Human Input",
|
317 |
+
# lines=1,
|
318 |
+
# visible=True,
|
319 |
+
# interactive=True,
|
320 |
+
# placeholder="Upload your paper here",
|
321 |
+
# )
|
322 |
+
|
323 |
+
upload_file_box = gr.File(
|
324 |
+
visible=True,
|
325 |
+
height = 100,
|
326 |
+
)
|
327 |
+
|
328 |
+
with gr.Row():
|
329 |
+
btn_step = gr.Button("Submit")
|
330 |
+
btn_restart = gr.Button("Clear")
|
331 |
+
|
332 |
+
all_components += [upload_file_box, btn_step, btn_restart]
|
333 |
+
|
334 |
+
|
335 |
+
def _convert_to_chatbot_output(all_messages, display_recv=False):
|
336 |
+
chatbot_output = []
|
337 |
+
for i, message in enumerate(all_messages):
|
338 |
+
agent_name, msg, recv = (
|
339 |
+
message.agent_name,
|
340 |
+
message.content,
|
341 |
+
str(message.visible_to),
|
342 |
+
)
|
343 |
+
new_msg = re.sub(
|
344 |
+
r"\n+", "<br>", msg.strip()
|
345 |
+
) # Preprocess message for chatbot output
|
346 |
+
if display_recv:
|
347 |
+
new_msg = f"**{agent_name} (-> {recv})**: {new_msg}" # Add role to the message
|
348 |
+
else:
|
349 |
+
new_msg = f"**{agent_name}**: {new_msg}"
|
350 |
+
|
351 |
+
if agent_name == "Moderator":
|
352 |
+
chatbot_output.append((new_msg, None))
|
353 |
+
else:
|
354 |
+
chatbot_output.append((None, new_msg))
|
355 |
+
return chatbot_output
|
356 |
+
|
357 |
+
def _create_arena_config_from_components(all_comps: dict):
|
358 |
+
|
359 |
+
env_desc = all_comps[env_desc_textbox]
|
360 |
+
paper_pdf_path = all_comps[upload_file_box]
|
361 |
+
|
362 |
+
# Step 1: Initialize the players
|
363 |
+
num_players = MAX_NUM_PLAYERS
|
364 |
+
|
365 |
+
# 为了适应之前的接口填充无意义数据
|
366 |
+
conference = "EMNLP 2024"
|
367 |
+
paper_decision = "Accept"
|
368 |
+
data_dir = ''
|
369 |
+
paper_id = "12345"
|
370 |
+
|
371 |
+
# Notion: 此处设置参数,experiment_name为无效填充参数
|
372 |
+
args = Namespace(openai_client_type="openai",
|
373 |
+
experiment_name="test",
|
374 |
+
max_num_words=16384)
|
375 |
+
|
376 |
+
# 在paper_decision 阶段 中只启用 AC
|
377 |
+
players = []
|
378 |
+
|
379 |
+
# 不能直接获取role_desc,需要根据Intention_config, Knowledge_config, Responsibility_config生成一个配置
|
380 |
+
# self.environment.experiment_setting["players"]['Reviewer'][reviewer_index - 1]
|
381 |
+
|
382 |
+
experiment_setting = {
|
383 |
+
"paper_id": paper_id,
|
384 |
+
"paper_decision": paper_decision,
|
385 |
+
"players": {
|
386 |
+
|
387 |
+
# Paper Extractor is a special player that extracts a paper from the dataset.
|
388 |
+
# Its constructor does not take any arguments.
|
389 |
+
"Paper Extractor": [{}],
|
390 |
+
|
391 |
+
# Assume there is only one area chair (AC) in the experiment.
|
392 |
+
"AC": [],
|
393 |
+
|
394 |
+
# Author role with default configuration.
|
395 |
+
"Author": [{}],
|
396 |
+
|
397 |
+
# Reviewer settings are generated based on reviewer types provided in the settings.
|
398 |
+
"Reviewer": [],
|
399 |
+
},
|
400 |
+
# "global_settings": setting['global_settings']
|
401 |
+
}
|
402 |
+
|
403 |
+
|
404 |
+
for i in range(num_players):
|
405 |
+
if i < num_players-1: # reviewer
|
406 |
+
role_name, intention_config, knowledge_config, responsibility_config, backend_type, temperature, max_tokens = (
|
407 |
+
all_comps[c]
|
408 |
+
for c in players_idx2comp[i]
|
409 |
+
if not isinstance(c, (gr.Accordion, gr.Tab))
|
410 |
+
)
|
411 |
+
|
412 |
+
is_benign = True if intention_config == "Benign" else (False if intention_config == "Malicious" else None)
|
413 |
+
is_knowledgeable = True if knowledge_config == "Knowledgeable" else (False if knowledge_config == "Unknownledgeable" else None)
|
414 |
+
is_responsible = True if responsibility_config == "Responsible" else (False if responsibility_config == "Lazy" else None)
|
415 |
+
|
416 |
+
experiment_setting["players"]['Reviewer'].append({"is_benign": is_benign,
|
417 |
+
"is_knowledgeable": is_knowledgeable,
|
418 |
+
"is_responsible": is_responsible,
|
419 |
+
"knows_authors": 'unfamous'})
|
420 |
+
|
421 |
+
role_desc = get_reviewer_description(is_benign, is_knowledgeable, is_responsible)
|
422 |
+
|
423 |
+
if i == num_players-1: # AC
|
424 |
+
role_name, ac_type, backend_type, temperature, max_tokens = (
|
425 |
+
all_comps[c]
|
426 |
+
for c in players_idx2comp[i]
|
427 |
+
if not isinstance(c, (gr.Accordion, gr.Tab))
|
428 |
+
)
|
429 |
+
|
430 |
+
ac_type = 'BASELINE' if ac_type == "Normal" else ac_type.lower()
|
431 |
+
|
432 |
+
experiment_setting["players"]['AC'].append({"area_chair_type": ac_type})
|
433 |
+
|
434 |
+
role_desc = get_ac_description(ac_type, "ac_write_metareviews", "None", 1)
|
435 |
+
|
436 |
+
# common config for all players
|
437 |
+
player_config = {
|
438 |
+
"name": role_name,
|
439 |
+
"role_desc": role_desc,
|
440 |
+
"global_prompt": env_desc,
|
441 |
+
"backend": {
|
442 |
+
"backend_type": backend_type,
|
443 |
+
"temperature": temperature,
|
444 |
+
"max_tokens": max_tokens,
|
445 |
+
},
|
446 |
+
}
|
447 |
+
|
448 |
+
player_config = AgentConfig(**player_config)
|
449 |
+
|
450 |
+
if i < num_players-1:
|
451 |
+
player = Reviewer(data_dir=data_dir, conference=conference, args=args, **player_config)
|
452 |
+
else:
|
453 |
+
player_config["env_type"] = "paper_review"
|
454 |
+
player = AreaChair(data_dir=data_dir, conference=conference, args=args, **player_config)
|
455 |
+
|
456 |
+
players.append(player)
|
457 |
+
|
458 |
+
# 根据上面的player_config和人造生成该阶段的players
|
459 |
+
# if CURRENT_STEP == "paper_review":
|
460 |
+
|
461 |
+
# 人为加入paper extractor
|
462 |
+
paper_extractor_config = get_paper_extractor_config(max_tokens=2048)
|
463 |
+
|
464 |
+
paper_extractor = PaperExtractorPlayer( paper_pdf_path=paper_pdf_path,
|
465 |
+
data_dir=data_dir, paper_id=paper_id,
|
466 |
+
paper_decision=paper_decision, args=args,
|
467 |
+
conference=conference, **paper_extractor_config)
|
468 |
+
players.append(paper_extractor)
|
469 |
+
|
470 |
+
# 人为加入author
|
471 |
+
author_config = get_author_config()
|
472 |
+
author = Player(data_dir=data_dir, conference=conference, args=args,
|
473 |
+
**author_config)
|
474 |
+
|
475 |
+
players.append(author)
|
476 |
+
|
477 |
+
|
478 |
+
player_names = [player.name for player in players]
|
479 |
+
|
480 |
+
# Step 2: Initialize the environment
|
481 |
+
env = PaperReview(player_names=player_names, paper_decision=paper_decision, paper_id=paper_id,
|
482 |
+
args=args, experiment_setting=experiment_setting)
|
483 |
+
|
484 |
+
# Step 3: Initialize the Arena
|
485 |
+
arena = PaperReviewArena(players=players, environment=env, args=args, global_prompt=env_desc)
|
486 |
+
|
487 |
+
return arena
|
488 |
+
|
489 |
+
def step_game(all_comps: dict):
|
490 |
+
global CURRENT_STEP_INDEX
|
491 |
+
|
492 |
+
yield {
|
493 |
+
btn_step: gr.update(value="Running...", interactive=False),
|
494 |
+
btn_restart: gr.update(interactive=False),
|
495 |
+
}
|
496 |
+
|
497 |
+
cur_state = all_comps[state]
|
498 |
+
|
499 |
+
# If arena is not yet created, create it
|
500 |
+
if cur_state["arena"] is None:
|
501 |
+
# Create the Arena
|
502 |
+
arena = _create_arena_config_from_components(all_comps)
|
503 |
+
cur_state["arena"] = arena
|
504 |
+
else:
|
505 |
+
arena = cur_state["arena"]
|
506 |
+
|
507 |
+
# TODO: 连续运行
|
508 |
+
|
509 |
+
timestep = arena.step()
|
510 |
+
|
511 |
+
CURRENT_STEP_INDEX = int(arena.environment.phase_index)
|
512 |
+
|
513 |
+
# 更新前端信息
|
514 |
+
if timestep:
|
515 |
+
all_messages = timestep.observation
|
516 |
+
all_messages[0].content = 'Paper content has been extracted.'
|
517 |
+
chatbot_output = _convert_to_chatbot_output(all_messages, display_recv=True)
|
518 |
+
update_dict = {
|
519 |
+
chatbot: chatbot_output,
|
520 |
+
btn_step: gr.update(
|
521 |
+
value="Next Step", interactive=not timestep.terminal
|
522 |
+
),
|
523 |
+
btn_restart: gr.update(interactive=True),
|
524 |
+
state: cur_state,
|
525 |
+
}
|
526 |
+
|
527 |
+
# Reviewer 1, 2, 3 Area Chair, Paper Extractor, Author
|
528 |
+
|
529 |
+
for i, player in enumerate(arena.players):
|
530 |
+
if 'Reviewer' in player.name and arena.environment.phase_index < 4: # FIXME: 临时逻辑
|
531 |
+
player_messages = arena.environment.get_observation(player.name)
|
532 |
+
# 不要显示第一条长段的信息,只显示 文章内容已被抽取
|
533 |
+
player_messages[0].content = 'Paper content has been extracted.'
|
534 |
+
player_output = _convert_to_chatbot_output(player_messages)
|
535 |
+
# Update the player's chatbot output
|
536 |
+
update_dict[player_chatbots[i]] = player_output
|
537 |
+
elif arena.environment.phase_index in [4, 5]: # FIXME: 临时逻辑
|
538 |
+
player_messages = arena.environment.get_observation('AC')
|
539 |
+
player_messages[0].content = 'Paper content has been extracted.'
|
540 |
+
player_output = _convert_to_chatbot_output(player_messages)
|
541 |
+
# Update the player's chatbot output
|
542 |
+
update_dict[player_chatbots[3]] = player_output
|
543 |
+
|
544 |
+
yield update_dict
|
545 |
+
|
546 |
+
|
547 |
+
def restart_game(all_comps: dict):
|
548 |
+
global CURRENT_STEP_INDEX
|
549 |
+
CURRENT_STEP_INDEX = 0
|
550 |
+
|
551 |
+
cur_state = all_comps[state]
|
552 |
+
cur_state["arena"] = None
|
553 |
+
yield {
|
554 |
+
chatbot: [],
|
555 |
+
btn_restart: gr.update(interactive=False),
|
556 |
+
btn_step: gr.update(interactive=False),
|
557 |
+
state: cur_state,
|
558 |
+
}
|
559 |
+
|
560 |
+
# arena_config = _create_arena_config_from_components(all_comps)
|
561 |
+
# arena = Arena.from_config(arena_config)
|
562 |
+
# log_arena(arena, database=DB)
|
563 |
+
# cur_state["arena"] = arena
|
564 |
+
|
565 |
+
yield {
|
566 |
+
btn_step: gr.update(value="Start", interactive=True),
|
567 |
+
btn_restart: gr.update(interactive=True),
|
568 |
+
upload_file_box: gr.update(value=None),
|
569 |
+
state: cur_state,
|
570 |
+
}
|
571 |
+
|
572 |
+
# Remove Accordion and Tab from the list of components
|
573 |
+
all_components = [
|
574 |
+
comp for comp in all_components if not isinstance(comp, (gr.Accordion, gr.Tab))
|
575 |
+
]
|
576 |
+
|
577 |
+
# update component
|
578 |
+
# env_desc_textbox.change()
|
579 |
+
|
580 |
+
# If any of the Textbox, Slider, Checkbox, Dropdown, RadioButtons is changed, the Step button is disabled
|
581 |
+
for comp in all_components:
|
582 |
+
|
583 |
+
def _disable_step_button(state):
|
584 |
+
if state["arena"] is not None:
|
585 |
+
return gr.update(interactive=False)
|
586 |
+
else:
|
587 |
+
return gr.update()
|
588 |
+
|
589 |
+
if (
|
590 |
+
isinstance(
|
591 |
+
comp, (gr.Textbox, gr.Slider, gr.Checkbox, gr.Dropdown, gr.Radio)
|
592 |
+
)
|
593 |
+
and comp is not upload_file_box
|
594 |
+
):
|
595 |
+
comp.change(_disable_step_button, state, btn_step)
|
596 |
+
|
597 |
+
# print(set(all_components + [state]))
|
598 |
+
btn_step.click(
|
599 |
+
step_game,
|
600 |
+
set(all_components + [state]),
|
601 |
+
[chatbot, *player_chatbots, btn_step, btn_restart, state, upload_file_box],
|
602 |
+
)
|
603 |
+
|
604 |
+
btn_restart.click(
|
605 |
+
restart_game,
|
606 |
+
set(all_components + [state]),
|
607 |
+
[chatbot, *player_chatbots, btn_step, btn_restart, state, upload_file_box],
|
608 |
+
)
|
609 |
+
|
610 |
+
|
611 |
+
demo.queue()
|
612 |
+
demo.launch(debug=DEBUG, server_port=8082)
|
docs/devdoc/design.md
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Key Design Choices
|
2 |
+
In this document, we will discuss the key concepts and design choices of ChatArena.
|
3 |
+
We expect this will be helpful particularly for developers who want to contribute to ChatArena or build their own environments.
|
4 |
+
|
5 |
+
## Agent Environment Cycle
|
6 |
+
ChatArena in general follows the design principle of openAI gym [1] and pettingzoo [2]. Any agent will interact with the environment and other agents through the agent environment cycle.
|
7 |
+
For every single cycle,
|
8 |
+
1. the agent observes the environment
|
9 |
+
2. the agent output an action
|
10 |
+
3. the environment makes a state transition given the action
|
11 |
+
|
12 |
+
As an optional feature, in each cycle, the environment can also compute a scalar reward for every single agent, along with a terminal signal for the environment.
|
13 |
+
|
14 |
+
[1] Greg Brockman, Vicki Cheung, Ludwig Pettersson, Jonas Schneider, John Schulman, Jie Tang, Wojciech Zaremba: OpenAI Gym. CoRR abs/1606.01540 (2016)
|
15 |
+
|
16 |
+
[2] Justin K. Terry, Benjamin Black, Nathaniel Grammel, Mario Jayakumar, Ananth Hari, Ryan Sullivan, Luis S. Santos, Clemens Dieffendahl, Caroline Horsch, Rodrigo Perez-Vicente, Niall L. Williams, Yashas Lokesh, Praveen Ravi: PettingZoo: Gym for Multi-Agent Reinforcement Learning. NeurIPS 2021: 15032-15043
|
17 |
+
|
18 |
+
### Actions
|
19 |
+
|
20 |
+
In the current version of ChatArena, all the actions are represented as plain text. More structured text outputs, like json or code, can be generated by prompting the LLM to do so.
|
21 |
+
We provide simple utilities to extract json and code (with markdown syntax), which should cover common use cases but can break for intentionally crafted edge cases.
|
22 |
+
|
23 |
+
### Observations
|
24 |
+
|
25 |
+
A observation is a list of messages with sender and content. Then sender can be any agent in the environment or the built-in moderator of the environment. The content is again plain text.
|
26 |
+
|
27 |
+
## Message Pool and Visibility Control
|
28 |
+
|
29 |
+
In ChatArena, agents cannot directly talk to each other but exchange information with a [message pool](https://github.com/chatarena/chatarena/blob/main/chatarena/message.py) as a proxy. The message pool is a utility abstraction that can serve as a part of the game state.
|
30 |
+
|
31 |
+
When an agent takes an action, a message can be created and appended to the message pool. In the message pool, each message will have a receiver, which can be decided by the environment dynamics (game rules) or by the agent itself. The environment itself can also create messages under the name of the moderator which can provide other state information or extra instructions given the current state.
|
32 |
+
|
33 |
+
To render an observation, the message pool will collect all the messages that are visible to the agent and return a list of these messages.
|
34 |
+
|
35 |
+
In particular, some of the environments require parallel moves, say, rock-paper-scissors, where the agent shouldn’t see the moves of other agents in the same turn. Such a mechanism is also implemented in the message pool. One can specify the “current turn” or the message of the “current turns” and turns after will be ignored.
|
36 |
+
|
37 |
+
## Intelligence Backends
|
38 |
+
|
39 |
+
In ChatArena, each agent will usually be powered by a language backend. These backends can be LLM APIs (say, from [OpenAI](https://github.com/chatarena/chatarena/blob/main/chatarena/backends/openai.py), [Anthropic](https://github.com/chatarena/chatarena/blob/main/chatarena/backends/anthropic.py) or [Cohere](https://github.com/chatarena/chatarena/blob/main/chatarena/backends/cohere.py)), [local LLM](https://github.com/chatarena/chatarena/blob/main/chatarena/backends/hf_transformers.py) or just [humans](https://github.com/chatarena/chatarena/blob/main/chatarena/backends/human.py) behind a user interface. In [backends](https://github.com/chatarena/chatarena/tree/main/chatarena/backends), we render the observations (list of messages) into the required formats for the downstream models. And the returned text will be the agent’s action [by default](https://github.com/chatarena/chatarena/blob/55c9e6ee4e09d72905eceb0a0e09e93a4179ca39/chatarena/agent.py#L28).
|
docs/devdoc/mainloop.md
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Step 1: Define Multiple Players with LLM Backend
|
2 |
+
|
3 |
+
```python
|
4 |
+
from agentreview.agent import Player
|
5 |
+
from agentreview.backends import OpenAIChat
|
6 |
+
|
7 |
+
# Describe the environment (which is shared by all players)
|
8 |
+
environment_description = "It is in a university classroom ..."
|
9 |
+
|
10 |
+
# A "Professor" player
|
11 |
+
player1 = Player(name="Professor", backend=OpenAIChat(),
|
12 |
+
role_desc="You are a professor in ...",
|
13 |
+
global_prompt=environment_description)
|
14 |
+
# A "Student" player
|
15 |
+
player2 = Player(name="Student", backend=OpenAIChat(),
|
16 |
+
role_desc="You are a student who is interested in ...",
|
17 |
+
global_prompt=environment_description)
|
18 |
+
# A "Teaching Assistant" player
|
19 |
+
player3 = Player(name="Teaching assistant", backend=OpenAIChat(),
|
20 |
+
role_desc="You are a teaching assistant of the module ...",
|
21 |
+
global_prompt=environment_description)
|
22 |
+
```
|
23 |
+
|
24 |
+
### Step 2: Create a Language Game Environment
|
25 |
+
|
26 |
+
You can also create a language model-driven environment and add it to the ChatArena:
|
27 |
+
|
28 |
+
```python
|
29 |
+
from agentreview.environments.conversation import Conversation
|
30 |
+
|
31 |
+
env = Conversation(player_names=[p.name for p in [player1, player2, player3]])
|
32 |
+
```
|
33 |
+
|
34 |
+
### Step 3: Run the Language Game using Arena
|
35 |
+
|
36 |
+
`Arena` is a utility class to help you run language games:
|
37 |
+
|
38 |
+
```python
|
39 |
+
from agentreview.arena import Arena
|
40 |
+
|
41 |
+
arena = Arena(players=[player1, player2, player3],
|
42 |
+
environment=env, global_prompt=environment_description)
|
43 |
+
# Run the game for 10 steps
|
44 |
+
arena.run(num_steps=10)
|
45 |
+
|
46 |
+
# Alternatively, you can run your own main loop
|
47 |
+
for _ in range(10):
|
48 |
+
arena.step()
|
49 |
+
# Your code goes here ...
|
50 |
+
```
|
51 |
+
|
52 |
+
You can easily save your gameplay history to file:
|
53 |
+
|
54 |
+
```python
|
55 |
+
arena.save_history(path=...)
|
56 |
+
```
|
57 |
+
|
58 |
+
and save your game config to file:
|
59 |
+
|
60 |
+
```python
|
61 |
+
arena.save_config(path=...)
|
62 |
+
```
|