umwyf commited on
Commit
acad479
·
1 Parent(s): 5533858

Upload 14 files

Browse files
Files changed (14) hide show
  1. Hi-ToM_data.json +0 -0
  2. README.md +16 -193
  3. actions.py +270 -0
  4. clause.py +26 -0
  5. create_world.py +248 -0
  6. dynamic_actions.py +369 -0
  7. generate_prompts.py +31 -0
  8. generate_tasks.py +180 -0
  9. oracle.py +147 -0
  10. stringify.py +47 -0
  11. tasks.py +518 -0
  12. test_azure.py +43 -0
  13. utils.py +44 -0
  14. world.py +47 -0
Hi-ToM_data.json ADDED
The diff for this file is too large to render. See raw diff
 
README.md CHANGED
@@ -1,202 +1,25 @@
1
- ---
2
- metrics:
3
- - accuracy
4
- pipeline_tag: question-answering
5
- tags:
6
- - code
7
- ---
8
 
9
- # Model Card for Model ID
10
 
11
- <!-- Provide a quick summary of what the model is/does. -->
12
 
13
- This modelcard aims to be a base template for new models. It has been generated using [this raw template](https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/templates/modelcard_template.md?plain=1).
14
 
15
- ## Model Details
 
16
 
17
- ### Model Description
 
 
 
 
18
 
19
- <!-- Provide a longer summary of what this model is. -->
20
 
 
 
21
 
 
22
 
23
- - **Developed by:** [More Information Needed]
24
- - **Funded by [optional]:** [More Information Needed]
25
- - **Shared by [optional]:** [More Information Needed]
26
- - **Model type:** [More Information Needed]
27
- - **Language(s) (NLP):** [More Information Needed]
28
- - **License:** [More Information Needed]
29
- - **Finetuned from model [optional]:** [More Information Needed]
30
-
31
- ### Model Sources [optional]
32
-
33
- <!-- Provide the basic links for the model. -->
34
-
35
- - **Repository:** [More Information Needed]
36
- - **Paper [optional]:** [More Information Needed]
37
- - **Demo [optional]:** [More Information Needed]
38
-
39
- ## Uses
40
-
41
- <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
42
-
43
- ### Direct Use
44
-
45
- <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
46
-
47
- [More Information Needed]
48
-
49
- ### Downstream Use [optional]
50
-
51
- <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
52
-
53
- [More Information Needed]
54
-
55
- ### Out-of-Scope Use
56
-
57
- <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
58
-
59
- [More Information Needed]
60
-
61
- ## Bias, Risks, and Limitations
62
-
63
- <!-- This section is meant to convey both technical and sociotechnical limitations. -->
64
-
65
- [More Information Needed]
66
-
67
- ### Recommendations
68
-
69
- <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
70
-
71
- Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
72
-
73
- ## How to Get Started with the Model
74
-
75
- Use the code below to get started with the model.
76
-
77
- [More Information Needed]
78
-
79
- ## Training Details
80
-
81
- ### Training Data
82
-
83
- <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
84
-
85
- [More Information Needed]
86
-
87
- ### Training Procedure
88
-
89
- <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
90
-
91
- #### Preprocessing [optional]
92
-
93
- [More Information Needed]
94
-
95
-
96
- #### Training Hyperparameters
97
-
98
- - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
99
-
100
- #### Speeds, Sizes, Times [optional]
101
-
102
- <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
103
-
104
- [More Information Needed]
105
-
106
- ## Evaluation
107
-
108
- <!-- This section describes the evaluation protocols and provides the results. -->
109
-
110
- ### Testing Data, Factors & Metrics
111
-
112
- #### Testing Data
113
-
114
- <!-- This should link to a Dataset Card if possible. -->
115
-
116
- [More Information Needed]
117
-
118
- #### Factors
119
-
120
- <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
121
-
122
- [More Information Needed]
123
-
124
- #### Metrics
125
-
126
- <!-- These are the evaluation metrics being used, ideally with a description of why. -->
127
-
128
- [More Information Needed]
129
-
130
- ### Results
131
-
132
- [More Information Needed]
133
-
134
- #### Summary
135
-
136
-
137
-
138
- ## Model Examination [optional]
139
-
140
- <!-- Relevant interpretability work for the model goes here -->
141
-
142
- [More Information Needed]
143
-
144
- ## Environmental Impact
145
-
146
- <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
147
-
148
- Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
149
-
150
- - **Hardware Type:** [More Information Needed]
151
- - **Hours used:** [More Information Needed]
152
- - **Cloud Provider:** [More Information Needed]
153
- - **Compute Region:** [More Information Needed]
154
- - **Carbon Emitted:** [More Information Needed]
155
-
156
- ## Technical Specifications [optional]
157
-
158
- ### Model Architecture and Objective
159
-
160
- [More Information Needed]
161
-
162
- ### Compute Infrastructure
163
-
164
- [More Information Needed]
165
-
166
- #### Hardware
167
-
168
- [More Information Needed]
169
-
170
- #### Software
171
-
172
- [More Information Needed]
173
-
174
- ## Citation [optional]
175
-
176
- <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
177
-
178
- **BibTeX:**
179
-
180
- [More Information Needed]
181
-
182
- **APA:**
183
-
184
- [More Information Needed]
185
-
186
- ## Glossary [optional]
187
-
188
- <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
189
-
190
- [More Information Needed]
191
-
192
- ## More Information [optional]
193
-
194
- [More Information Needed]
195
-
196
- ## Model Card Authors [optional]
197
-
198
- [More Information Needed]
199
-
200
- ## Model Card Contact
201
-
202
- [More Information Needed]
 
1
+ # Hi-ToM Dataset
 
 
 
 
 
 
2
 
3
+ This is the dataset for the paper "Hi-ToM: A Benchmark for Evaluating Higher-Order Theory of Mind Reasoning in Large Language Models".
4
 
5
+ <img src=media/Picture1.png height=430>
6
 
7
+ ### The `Hi-ToM_data` folder
8
 
9
+ Contains ToMh data consisting of story-question pairs and the corresponding answers.
10
+ The names of subfolder branches have the following meanings:
11
 
12
+ - `Tell` / `No_Tell`: whether or not the stories contain communications among agents.
13
+ - `MC` / `CoT`: the prompting style. `MC` corresponds to Vanilla Prompting (VP) in the paper, while `CoT` stands for Chain-of-Thought Prompting (CoTP).
14
+ - `length_n`: the story length, i.e. the number of chapters in a story. From 1 to 3.
15
+ - `sample_n`: the numbering of different sample stories.
16
+ - `order_n`: the ToM order of the question. From 0 to 4.
17
 
18
+ ### The `Hi-ToM_prompt` folder
19
 
20
+ Contains prompt files that can be directly input to API.
21
+ The data in it are almost the same as `Hi-ToM_data`, except that answers are eliminated.
22
 
23
+ ### Generate new data and prompts
24
 
25
+ Run the script `generate_tomh.sh`.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
actions.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class Action(object):
5
+
6
+ def __init__(self, templates):
7
+ self.templates = templates
8
+
9
+ def render_declarative(self, *args):
10
+ assert 'declarative' in self.templates and \
11
+ len(self.templates['declarative']) > 0
12
+ return np.random.choice(self.templates['declarative']) % args
13
+
14
+ def render_interrogative(self, *args):
15
+ assert 'interrogative' in self.templates and \
16
+ len(self.templates['interrogative']) > 0, str(self.templates)
17
+ return np.random.choice(self.templates['interrogative']) % args
18
+
19
+
20
+ class ExistBeginning(Action):
21
+
22
+ def __init__(self):
23
+ templates = {
24
+ 'interrogative': [
25
+ 'Where was the %s at the beginning?\t%s',
26
+ 'Where was the %s before?\t%s',
27
+ ]
28
+ }
29
+ super().__init__(templates)
30
+
31
+
32
+ class Exist(Action):
33
+
34
+ def __init__(self):
35
+ templates = {
36
+ 'interrogative': [
37
+ 'Where is the %s?\t%s',
38
+ 'Where is the %s located?\t%s',
39
+ ]
40
+ }
41
+ super().__init__(templates)
42
+
43
+
44
+ class PlaceAction(Action):
45
+
46
+ def __init__(self):
47
+ templates = {
48
+ 'declarative': [
49
+ '%s placed the %s in the %s.',
50
+ '%s put the %s in the %s.',
51
+ ],
52
+ 'interrogative': [
53
+ 'Where did %s place the %s?\t%s',
54
+ 'Where did %s put the %s?\t%s',
55
+ ]
56
+ }
57
+ super().__init__(templates)
58
+
59
+
60
+ class SearchAction(Action):
61
+
62
+ def __init__(self):
63
+ templates = {
64
+ 'declarative': [
65
+ '%s searched for the %s in the %s.',
66
+ '%s looked for the %s in the %s.',
67
+ ],
68
+ 'interrogative': [
69
+ 'Where did %s search for the %s?\t%s',
70
+ 'Where did %s look for the %s?\t%s',
71
+ ],
72
+ }
73
+ super().__init__(templates)
74
+
75
+
76
+ class TransportAction(Action):
77
+
78
+ def __init__(self):
79
+ templates = {
80
+ 'declarative': [
81
+ '%s shifted the %s from the %s to the %s.',
82
+ ],
83
+ }
84
+ super().__init__(templates)
85
+
86
+
87
+ class EnterAction(Action):
88
+
89
+ def __init__(self):
90
+ templates = {
91
+ 'declarative': [
92
+ '%s entered the %s.',
93
+ '%s came into the %s.',
94
+ ],
95
+ }
96
+ super().__init__(templates)
97
+
98
+
99
+ class ExitAction(Action):
100
+
101
+ def __init__(self):
102
+ templates = {
103
+ 'declarative': [
104
+ '%s exited the %s.',
105
+ '%s left the %s.',
106
+ '%s went out of the %s.',
107
+ ],
108
+ }
109
+ super().__init__(templates)
110
+
111
+
112
+ class BelieveLocationAction(Action):
113
+
114
+ def __init__(self):
115
+ templates = {
116
+ 'declarative': [
117
+ '%s thinks the %s is in the %s.',
118
+ '%s believes the %s is in the %s.',
119
+ ],
120
+ 'interrogative': [
121
+ 'Where does %s think the %s is?\t%s',
122
+ 'Where does %s believe the %s is?\t%s',
123
+ ],
124
+ }
125
+ super().__init__(templates)
126
+
127
+
128
+ class BelieveAgentBelieveLocationAction(Action):
129
+
130
+ def __init__(self):
131
+ templates = {
132
+ 'interrogative': [
133
+ 'Where does %s think that %s believes the %s is?\t%s',
134
+ 'Where does %s believe that %s believes the %s is?\t%s',
135
+ 'Where does %s think that %s thinks the %s is?\t%s',
136
+ 'Where does %s believe that %s thinks the %s is?\t%s',
137
+ ],
138
+ }
139
+ super().__init__(templates)
140
+
141
+
142
+ class BelieveAgentSearchLocationAction(Action):
143
+
144
+ def __init__(self):
145
+ templates = {
146
+ 'interrogative': [
147
+ 'Where does %s think that %s looks for the %s?\t%s',
148
+ 'Where does %s believe that %s looks for the %s?\t%s',
149
+ 'Where does %s think that %s searches for the %s?\t%s',
150
+ 'Where does %s believe that %s search for the %s?\t%s',
151
+ ],
152
+ }
153
+ super().__init__(templates)
154
+
155
+
156
+ class InformLocationAction(Action):
157
+
158
+ def __init__(self):
159
+ templates = {
160
+ 'declarative': [
161
+ '%s told %s that the %s is in the %s.',
162
+ '%s informed %s that the %s is in the %s.',
163
+ ],
164
+ }
165
+ super().__init__(templates)
166
+
167
+ ####################################################
168
+ ####### Deterministic Actions for New Task #######
169
+ ####################################################
170
+
171
+ class FirstQ(Action):
172
+
173
+ def __init__(self):
174
+ templates = {
175
+ 'interrogative': [
176
+ 'Where will %s look for the %s?\t%s',
177
+ ]
178
+ }
179
+ super().__init__(templates)
180
+
181
+ class SecondQ(Action):
182
+
183
+ def __init__(self):
184
+ templates = {
185
+ 'interrogative': [
186
+ 'Where does %s think that %s searches for the %s?\t%s',
187
+ ]
188
+ }
189
+ super().__init__(templates)
190
+
191
+ class ZeroQ(Action):
192
+
193
+ def __init__(self):
194
+ templates = {
195
+ 'interrogative': [
196
+ 'Where is the %s really?\t%s',
197
+ ]
198
+ }
199
+ super().__init__(templates)
200
+
201
+ class MemoryAction(Action):
202
+
203
+ def __init__(self):
204
+ templates = {
205
+ 'interrogative': [
206
+ 'Where was the %s at the beginning?\t%s',
207
+ ]
208
+ }
209
+ super().__init__(templates)
210
+
211
+ class LocationAction(Action):
212
+
213
+ def __init__(self):
214
+ templates = {
215
+ 'declarative': [
216
+ '%s and %s are in the %s.',
217
+ ]
218
+ }
219
+ super().__init__(templates)
220
+
221
+ class ObjectLocAction(Action):
222
+
223
+ def __init__(self):
224
+ templates = {
225
+ 'declarative': [
226
+ 'The %s is in the %s.',
227
+ ]
228
+ }
229
+ super().__init__(templates)
230
+
231
+ class ExitedAction(Action):
232
+
233
+ def __init__(self):
234
+ templates = {
235
+ 'declarative': [
236
+ '%s exited the %s.',
237
+ ]
238
+ }
239
+ super().__init__(templates)
240
+
241
+ class MoveAction(Action):
242
+
243
+ def __init__(self):
244
+ templates = {
245
+ 'declarative': [
246
+ '%s moved the %s to the %s.',
247
+ ]
248
+ }
249
+ super().__init__(templates)
250
+
251
+ class TellAction(Action):
252
+
253
+ def __init__(self):
254
+ templates = {
255
+ 'declarative': [
256
+ '%s told %s where the %s is.',
257
+ ]
258
+ }
259
+ super().__init__(templates)
260
+
261
+ class EnterAction(Action):
262
+
263
+ def __init__(self):
264
+ templates = {
265
+ 'declarative': [
266
+ '%s entered the %s.',
267
+ ]
268
+ }
269
+ super().__init__(templates)
270
+
clause.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class Clause(object):
5
+
6
+ def __init__(self, action):
7
+
8
+ # if observers is not None:
9
+ # assert 0 not in observers, "Observer IDs must be 1-indexed"
10
+ # self.observers = observers
11
+ self.action = action
12
+
13
+ def render(self):
14
+ return self.action.render_declarative() # + \
15
+ # ('\t' + ' '.join([str(x) for x in self.observers])
16
+ # if self.observers is not None else '')
17
+
18
+
19
+ class Question(Clause):
20
+
21
+ def __init__(self, idx_support, action):
22
+ self.idx_support = idx_support
23
+ super().__init__(action)
24
+
25
+ def render(self):
26
+ return self.action.render_interrogative()
create_world.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ SIZE_TINY = 5
4
+ SIZE_SMALL = 10
5
+ SIZE_LARGE = 30
6
+ SIZE_XLARGE = 50
7
+
8
+ locations = [
9
+ "attic",
10
+ "back_yard",
11
+ "basement",
12
+ "bathroom",
13
+ "bedroom",
14
+ "cellar",
15
+ "closet",
16
+ "crawlspace",
17
+ "den",
18
+ "dining_room",
19
+ "front_yard",
20
+ "garage",
21
+ "garden",
22
+ "hall",
23
+ "hallway",
24
+ "kitchen",
25
+ "laundry",
26
+ "living_room",
27
+ "lounge",
28
+ "master_bedroom",
29
+ "office",
30
+ "pantry",
31
+ "patio",
32
+ "playroom",
33
+ "porch",
34
+ "staircase",
35
+ "study",
36
+ "sunroom",
37
+ "TV_room",
38
+ "workshop",
39
+ ]
40
+
41
+ clothing = [
42
+ "belt",
43
+ "boots",
44
+ "cap",
45
+ "coat",
46
+ "dress",
47
+ "gloves",
48
+ "hat",
49
+ "jacket",
50
+ "jeans",
51
+ "pajamas",
52
+ "pants",
53
+ "raincoat",
54
+ "scarf",
55
+ "shirt",
56
+ "shoes",
57
+ "skirt",
58
+ "slacks",
59
+ "slippers",
60
+ "socks",
61
+ "stockings",
62
+ "suit",
63
+ "sweater",
64
+ "sweatshirt",
65
+ "t-shirt",
66
+ "tie",
67
+ "trousers",
68
+ "underclothes",
69
+ "underpants",
70
+ "undershirt",
71
+ ]
72
+
73
+ fruit = [
74
+ "apple",
75
+ "banana",
76
+ "cherry",
77
+ "grapefruit",
78
+ "grapes",
79
+ "lemon",
80
+ "lime",
81
+ "melon",
82
+ "orange",
83
+ "peach",
84
+ "pear",
85
+ "persimmon",
86
+ "pineapple",
87
+ "plum",
88
+ "strawberry",
89
+ "tangerine",
90
+ "watermelon",
91
+ ]
92
+
93
+ vegetables = [
94
+ "asparagus",
95
+ "beans",
96
+ "broccoli",
97
+ "cabbage",
98
+ "carrot",
99
+ "celery",
100
+ "corn",
101
+ "cucumber",
102
+ "eggplant",
103
+ "green_pepper",
104
+ "lettuce",
105
+ "onion",
106
+ "peas",
107
+ "potato",
108
+ "pumpkin",
109
+ "radish",
110
+ "spinach",
111
+ "sweet_potato",
112
+ "tomato",
113
+ "turnip",
114
+ ]
115
+
116
+ objects = fruit + vegetables
117
+
118
+ containers = [
119
+ "box",
120
+ "pantry",
121
+ "bathtub",
122
+ "envelope",
123
+ "drawer",
124
+ "bottle",
125
+ "cupboard",
126
+ "basket",
127
+ "crate",
128
+ "suitcase",
129
+ "bucket",
130
+ "container",
131
+ "treasure_chest",
132
+ ]
133
+
134
+ colors = ['green', 'blue', 'red']
135
+
136
+ containers = ['_'.join([color, container])
137
+ for container in containers
138
+ for color in colors]
139
+
140
+ names = [
141
+ "Oliver",
142
+ "Ethan",
143
+ "Liam",
144
+ "Benjamin",
145
+ "Lucas",
146
+ "Alexander",
147
+ "Jacob",
148
+ "Mason",
149
+ "William",
150
+ "Gracie",
151
+ "James",
152
+ "Logan",
153
+ "Owen",
154
+ "Noah",
155
+ "Carter",
156
+ "Nathan",
157
+ "Jack",
158
+ "Aiden",
159
+ "Jackson",
160
+ "Jayden",
161
+ "Emma",
162
+ "Olivia",
163
+ "Emily",
164
+ "Sophia",
165
+ "Ava",
166
+ "Chloe",
167
+ "Charlotte",
168
+ "Abigail",
169
+ "Amelia",
170
+ "Ella",
171
+ "Hannah",
172
+ "Isabella",
173
+ "Aria",
174
+ "Lily",
175
+ "Mia",
176
+ "Isla",
177
+ "Avery",
178
+ "Elizabeth",
179
+ "Mila",
180
+ "Evelyn",
181
+ ]
182
+
183
+ assert len(locations) >= SIZE_LARGE
184
+ assert len(objects) >= SIZE_LARGE
185
+ assert len(containers) >= SIZE_LARGE
186
+ assert len(names) >= SIZE_LARGE
187
+
188
+
189
+ def write_world(filepath, locs, objs, conts, nams):
190
+
191
+ with open(filepath, 'w') as f:
192
+
193
+ f.write('# locations\n')
194
+
195
+ for loc in locs:
196
+
197
+ f.write('\n')
198
+ f.write('create %s\n' % loc)
199
+ f.write('set %s is_thing\n' % loc)
200
+ f.write('set %s is_location\n' % loc)
201
+
202
+ f.write('\n')
203
+ f.write('# objects\n')
204
+
205
+ for obj in objs:
206
+
207
+ f.write('\n')
208
+ f.write('create %s\n' % obj)
209
+ f.write('set %s is_thing\n' % obj)
210
+ f.write('set %s is_gettable\n' % obj)
211
+
212
+ f.write('\n')
213
+ f.write('# containers\n')
214
+
215
+ for cont in conts:
216
+
217
+ f.write('\n')
218
+ f.write('create %s\n' % cont)
219
+ f.write('set %s is_thing\n' % cont)
220
+ f.write('set %s is_container\n' % cont)
221
+
222
+ f.write('\n')
223
+ f.write('# actors\n')
224
+
225
+ for nam in nams:
226
+
227
+ f.write('\n')
228
+ f.write('create %s\n' % nam)
229
+ f.write('set %s is_actor\n' % nam)
230
+ f.write('set %s is_god\n' % nam)
231
+
232
+ write_world('world_tiny.txt',
233
+ np.random.choice(locations, SIZE_TINY, replace=False),
234
+ np.random.choice(objects, SIZE_TINY, replace=False),
235
+ np.random.choice(containers, SIZE_TINY, replace=False),
236
+ np.random.choice(names, SIZE_TINY, replace=False))
237
+
238
+ write_world('world_small.txt',
239
+ np.random.choice(locations, SIZE_SMALL, replace=False),
240
+ np.random.choice(objects, SIZE_SMALL, replace=False),
241
+ np.random.choice(containers, SIZE_SMALL, replace=False),
242
+ np.random.choice(names, SIZE_SMALL, replace=False))
243
+
244
+ write_world('world_large.txt',
245
+ np.random.choice(locations, SIZE_LARGE, replace=False),
246
+ np.random.choice(objects, SIZE_LARGE, replace=False),
247
+ np.random.choice(containers, SIZE_LARGE, replace=False),
248
+ np.random.choice(names, SIZE_LARGE, replace=False))
dynamic_actions.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ from itertools import combinations
4
+ from itertools import permutations
5
+
6
+
7
+ class Action(object):
8
+
9
+ def __init__(self, templates):
10
+ self.templates = templates
11
+
12
+ def render_declarative(self):
13
+ assert 'declarative' in self.templates and \
14
+ len(self.templates['declarative']) > 0
15
+ return np.random.choice(self.templates['declarative'])
16
+
17
+ def render_interrogative(self):
18
+ assert 'interrogative' in self.templates and \
19
+ len(self.templates['interrogative']) > 0, str(self.templates)
20
+ return np.random.choice(self.templates['interrogative'])
21
+
22
+
23
+ class ExitAction(Action):
24
+
25
+ def __init__(self):
26
+ templates = {
27
+ 'declarative': [
28
+ '%s exited the %s.',
29
+ '%s left the %s.',
30
+ '%s went out of the %s.',
31
+ ],
32
+ }
33
+ super().__init__(templates)
34
+
35
+ #########################################
36
+ ############### Questions ###############
37
+ #########################################
38
+
39
+
40
+ class ZeroQ(Action):
41
+
42
+ def __init__(self, oracle, obj):
43
+
44
+ fill = (obj, oracle.get_object_container(obj))
45
+ templates = {
46
+ 'interrogative': [
47
+ 'Question: Where is the %s really?\nAnswer: %s' % fill,
48
+ ]
49
+ }
50
+ super().__init__(templates)
51
+
52
+
53
+ class FirstQ(Action):
54
+
55
+ def __init__(self, oracle, agent, obj):
56
+ fill = (agent, obj, oracle.get_first_belief(agent, obj))
57
+ templates = {
58
+ 'interrogative': [
59
+ 'Question: Where does %s really think the %s is?\nAnswer: %s' % fill,
60
+ ]
61
+ }
62
+ super().__init__(templates)
63
+
64
+
65
+ class SecondQ(Action):
66
+
67
+ def __init__(self, oracle, a1, a2, obj):
68
+ fill = (a1, a2, obj, oracle.get_second_belief(a1, a2, obj))
69
+ templates = {
70
+ 'interrogative': [
71
+ 'Question: Where does %s think %s thinks the %s is?\nAnswer: %s' % fill,
72
+ ]
73
+ }
74
+ super().__init__(templates)
75
+
76
+
77
+ class ThirdQ(Action):
78
+
79
+ def __init__(self, oracle, a1, a2, a3, obj):
80
+ fill = (a1, a2, a3, obj, oracle.get_third_belief(a1, a2, a3, obj))
81
+ templates = {
82
+ 'interrogative': [
83
+ 'Question: Where does %s think %s thinks %s thinks the %s is?\nAnswer: %s' % fill,
84
+ ]
85
+ }
86
+ super().__init__(templates)
87
+
88
+
89
+ class FourthQ(Action):
90
+
91
+ def __init__(self, oracle, a1, a2, a3, a4, obj):
92
+ fill = (a1, a2, a3, a4, obj,
93
+ oracle.get_fourth_belief(a1, a2, a3, a4, obj))
94
+ templates = {
95
+ 'interrogative': [
96
+ 'Question: Where does %s think %s thinks %s thinks %s thinks the %s is?\nAnswer: %s' % fill,
97
+ ]
98
+ }
99
+ super().__init__(templates)
100
+
101
+ # class MemoryAction(Action):
102
+
103
+ # def __init__(self, oracle_start_state, obj):
104
+ # fill = (obj, oracle_start_state[obj])
105
+ # templates = {
106
+ # 'interrogative': [
107
+ # 'Where was the %s at the beginning?\t%s' % fill,
108
+ # ]
109
+ # }
110
+ # super().__init__(templates)
111
+
112
+ # class LocationAction(Action):
113
+ # def __init__(self, oracle, args):
114
+ # """
115
+ # Creaters string with args and modifies
116
+ # oracle in accordance with action.
117
+ # """
118
+ # if len(args) == 2:
119
+ # statement = '%s is in the %s.' % args
120
+ # a1, loc = args
121
+ # # may be redundant
122
+ # oracle.set_location(a1, loc)
123
+ # else : # 2 people
124
+ # statement = '%s and %s are in the %s.' % args
125
+ # a1, a2, loc = args
126
+ # # may be redundant
127
+ # oracle.set_location(a1, loc)
128
+ # oracle.set_location(a2, loc)
129
+
130
+ # templates = {
131
+ # 'declarative': [
132
+ # statement,
133
+ # ]
134
+ # }
135
+
136
+ # super().__init__(templates)
137
+
138
+
139
+ class ObjectLocAction(Action):
140
+
141
+ def __init__(self, oracle, obj, observers):
142
+ container = oracle.get_object_container(obj)
143
+ templates = {
144
+ 'declarative': [
145
+ 'The %s is in the %s.' % (obj, container),
146
+ ]
147
+ }
148
+
149
+ # set first beliefs
150
+ for observer in observers:
151
+ oracle.set_first_belief(observer, obj, container)
152
+
153
+ # set second beliefs
154
+ if len(observers) >= 2:
155
+ for observer1, observer2 in combinations(observers, 2):
156
+ oracle.set_second_belief(observer1, observer2, obj, container)
157
+ oracle.set_second_belief(observer2, observer1, obj, container)
158
+
159
+ # set third beliefs
160
+ if len(observers) >= 3:
161
+ for chosen_observers in combinations(observers, 3):
162
+ for observer1, observer2, observer3 in permutations(chosen_observers):
163
+ oracle.set_third_belief(
164
+ observer1, observer2, observer3, obj, container)
165
+
166
+ # set fourth beliefs
167
+ if len(observers) >= 4:
168
+ for chosen_observers in combinations(observers, 4):
169
+ for observer1, observer2, observer3, observer4 in permutations(chosen_observers):
170
+ oracle.set_fourth_belief(
171
+ observer1, observer2, observer3, observer4, obj, container)
172
+ super().__init__(templates)
173
+
174
+
175
+ class ExitedAction(Action):
176
+
177
+ def __init__(self, oracle, agent):
178
+ fill = (agent, oracle.get_location(agent))
179
+
180
+ templates = {
181
+ 'declarative': [
182
+ '%s exited the %s.' % fill,
183
+ ]
184
+ }
185
+ oracle.set_location(agent, None)
186
+ super().__init__(templates)
187
+
188
+
189
+ class MoveAction(Action):
190
+
191
+ def __init__(self, oracle, args, observers=None, move=True):
192
+ agent, obj, container = args
193
+ if not move:
194
+ location = oracle.get_container_location(container)
195
+ templates = {
196
+ 'declarative': [
197
+ f'{args[0]} made no movements and stayed in the {location} for 1 minute.',
198
+ ]
199
+ }
200
+
201
+ else:
202
+ templates = {
203
+ 'declarative': [
204
+ '%s moved the %s to the %s.' % args,
205
+ ]
206
+ }
207
+
208
+ oracle.set_object_container(obj, container)
209
+
210
+ if not observers:
211
+ observers = []
212
+ observers.append(agent)
213
+
214
+ # set first beliefs
215
+ for observer in observers:
216
+ oracle.set_first_belief(observer, obj, container)
217
+
218
+ # set second beliefs
219
+ if len(observers) >= 2:
220
+ for observer1, observer2 in combinations(observers, 2):
221
+ oracle.set_second_belief(
222
+ observer1, observer2, obj, container)
223
+ oracle.set_second_belief(
224
+ observer2, observer1, obj, container)
225
+
226
+ # set third beliefs
227
+ if len(observers) >= 3:
228
+ for chosen_observers in combinations(observers, 3):
229
+ for observer1, observer2, observer3 in permutations(chosen_observers):
230
+ oracle.set_third_belief(
231
+ observer1, observer2, observer3, obj, container)
232
+
233
+ # set fourth beliefs
234
+ if len(observers) >= 4:
235
+ for chosen_observers in combinations(observers, 4):
236
+ for observer1, observer2, observer3, observer4 in permutations(chosen_observers):
237
+ oracle.set_fourth_belief(
238
+ observer1, observer2, observer3, observer4, obj, container)
239
+
240
+ super().__init__(templates)
241
+
242
+
243
+ class PublicTellAction(Action):
244
+
245
+ def __init__(self, oracle, speaker, obj, container, listeners=None, believers=None):
246
+ templates = {
247
+ 'declarative': [
248
+ '%s publicly claimed that %s is in the %s now.' % (
249
+ speaker, obj, container),
250
+ ]
251
+ }
252
+ disbelievers = [
253
+ listener for listener in listeners if listener not in believers]
254
+
255
+ # All listeners would think others believe the claim
256
+ # for believer in believers:
257
+ # for disbeliever in disbelievers:
258
+ # oracle.set_second_belief(believer, disbeliever, obj, container)
259
+ # oracle.set_second_belief(disbeliever, believer, obj, container)
260
+
261
+ # A believer would think speaker also believes the obj is in container, speaker would think his words are trusted
262
+ for believer in believers:
263
+ oracle.set_first_belief(believer, obj, container)
264
+ oracle.set_second_belief(believer, speaker, obj, container)
265
+ oracle.set_second_belief(speaker, believer, obj, container)
266
+
267
+ for disbeliever in disbelievers:
268
+ oracle.set_second_belief(speaker, disbeliever, obj, container)
269
+
270
+ # for listener in listeners:
271
+ # # the speaker believes that all the listeners believe him
272
+ # oracle.set_second_belief(speaker, listener, obj, container)
273
+ # # all listeners know the believers based on the exiting order
274
+ # for believer in believers:
275
+ # oracle.set_second_belief(listener, believer, obj, container)
276
+
277
+ super().__init__(templates)
278
+
279
+
280
+ class PrivateTellAction(Action):
281
+
282
+ def __init__(self, oracle, speaker, listener, obj, container, trust=True):
283
+ templates = {
284
+ 'declarative': [
285
+ '%s privately told %s that the %s is in the %s now.' % (
286
+ speaker, listener, obj, container),
287
+ ]
288
+ }
289
+
290
+ # when the listener has less information (exit the room earlier), he'll trust the speaker
291
+ if trust:
292
+ oracle.set_first_belief(listener, obj, container)
293
+ oracle.set_second_belief(listener, speaker, obj, container)
294
+ oracle.set_second_belief(speaker, listener, obj, container)
295
+ super().__init__(templates)
296
+
297
+
298
+ class EnterAction(Action):
299
+
300
+ def __init__(self, oracle, args, observers=None, no_world_adjust=False):
301
+ templates = {
302
+ 'declarative': [
303
+ ', '.join(args[:-2]) + ' and ' + args[-2] +
304
+ ' entered the ' + args[-1] + '.',
305
+ ]
306
+ }
307
+
308
+ agents = args[:-1]
309
+ location = args[-1]
310
+ if location == 'waiting_room':
311
+ super().__init__(templates)
312
+ return
313
+ for agent in agents:
314
+ oracle.set_location(agent, location)
315
+ objs = oracle.get_objects_at_location(location)
316
+ observers = agents
317
+
318
+ # agent knows location of everything
319
+ if not no_world_adjust:
320
+ for obj in objs:
321
+ container = oracle.get_object_container(obj)
322
+ # oracle.set_first_belief(agent, obj, container)
323
+ # set first beliefs
324
+ if len(observers) >= 1:
325
+ for observer in observers:
326
+ oracle.set_first_belief(observer, obj, container)
327
+
328
+ # set second beliefs
329
+ if len(observers) >= 2:
330
+ for observer1, observer2 in combinations(observers, 2):
331
+ oracle.set_second_belief(
332
+ observer1, observer2, obj, container)
333
+ oracle.set_second_belief(
334
+ observer2, observer1, obj, container)
335
+
336
+ # set third beliefs
337
+ if len(observers) >= 3:
338
+ for chosen_observers in combinations(observers, 3):
339
+ for observer1, observer2, observer3 in permutations(chosen_observers):
340
+ oracle.set_third_belief(
341
+ observer1, observer2, observer3, obj, container)
342
+
343
+ # set fourth beliefs
344
+ if len(observers) >= 4:
345
+ for chosen_observers in combinations(observers, 4):
346
+ for observer1, observer2, observer3, observer4 in permutations(chosen_observers):
347
+ oracle.set_fourth_belief(
348
+ observer1, observer2, observer3, observer4, obj, container)
349
+
350
+ super().__init__(templates)
351
+
352
+
353
+ class NoiseAction(Action):
354
+
355
+ def __init__(self, agents, containers, objects):
356
+ animals = ['cat', 'dog', 'monkey', 'mouse']
357
+ personal_items = ['watch', 'gloves', 'phone']
358
+ distractors = [
359
+ f'{random.choice(agents)} saw a {random.choice(animals)}.',
360
+ f'{random.choice(agents)} lost his {random.choice(personal_items)}.',
361
+ f'{random.choice(agents)} likes the {random.choice(containers)}.',
362
+ f'{random.choice(agents)} dislikes the {random.choice(objects)}.',
363
+ ]
364
+ templates = {
365
+ 'declarative': [
366
+ random.choice(distractors)
367
+ ]
368
+ }
369
+ super().__init__(templates)
generate_prompts.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import itertools
4
+
5
+
6
+ def main():
7
+ input_folder = 'data_ToMh'
8
+ output_folder = 'prompt_ToMh'
9
+ lengths = [1, 2, 3]
10
+ orders = [0, 1, 2, 3, 4]
11
+ prompts = ['CoT', 'MC']
12
+ tells = ['No_Tell', 'Tell']
13
+ for tell, prompt, length, order, sample_num in itertools.product(tells, prompts, lengths, orders, range(1, 21)):
14
+ input_fn = os.path.join(input_folder, tell, prompt, f'length_{length}', f'sample_{sample_num}',
15
+ f'order_{order}.txt')
16
+ output_fn = os.path.join(output_folder, tell, prompt, f'length_{length}', f'sample_{sample_num}',
17
+ f'order_{order}.txt')
18
+ with open(input_fn, 'r') as file:
19
+ lines = file.readlines()
20
+ new_lines = [line for line in lines if line ==
21
+ '\n' or line.split()[0] != 'Answer:']
22
+ if not os.path.exists(os.path.join(output_folder, tell, prompt, f'length_{length}', f'sample_{sample_num}')):
23
+ os.makedirs(os.path.join(output_folder, tell, prompt,
24
+ f'length_{length}', f'sample_{sample_num}'))
25
+ with open(output_fn, 'w') as file:
26
+ file.writelines(new_lines)
27
+
28
+
29
+
30
+ if __name__ == "__main__":
31
+ sys.exit(main())
generate_tasks.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import glob
4
+ import numpy as np
5
+ import os
6
+ import sys
7
+ import random
8
+ import itertools
9
+
10
+ from stringify import stringify
11
+ from tasks import Specify_Tasks
12
+ from utils import is_file, mkdir_p, remove_extension
13
+ from world import World
14
+
15
+
16
+ def generate_story_with_specified_chapters(
17
+ world_paths, output_dir_path, n, noise=0.1, train_noise=False, order=-1, num_chapter=-1, exist_tell_in_story=False, prompt='CoT', exist_answer=False
18
+ ): # prompt is dummy
19
+ """Generates stories with guarantee that each task is seen n times."""
20
+ mkdir_p(output_dir_path)
21
+ n = n[0]
22
+
23
+ for world in world_paths:
24
+
25
+ w = World()
26
+ w.load(world)
27
+ world_name = remove_extension(world)
28
+
29
+ # Define task creator and task types
30
+ task = Specify_Tasks()
31
+ tasks_per_length = np.array([
32
+ [('A5', True)], # 1 chapter
33
+ [('A5', False), ('A3', True)], # 2 chapters
34
+ [('A5', True), ('A3', False), ('A4', True)], # 3 chapters
35
+ [('A5', False), ('A3', True),
36
+ ('A4', False), ('A2', True)], # 4 chapters
37
+ ], dtype=object)
38
+
39
+ # If order and num_chapter are not specified
40
+ orders = [0, 1, 2, 3, 4] if order == -1 else [order]
41
+ num_chapters = [1, 2, 3] if num_chapter == -1 else [num_chapter]
42
+ modes = ['MC', 'CoT']
43
+ for length_of_story in num_chapters:
44
+ # Create folder to contain data
45
+ folder_name_2 = f'length_{length_of_story}'
46
+ logging.info("Creating New task in %s..." % folder_name_2)
47
+
48
+ for i in range(1, n+1):
49
+ folder_name_3 = f'sample_{i}'
50
+ story = task.generate_story_qs_at_end(
51
+ w, length_of_story, tasks_per_length[length_of_story -
52
+ 1], num_agents=5,
53
+ num_locations=3, statement_noise=noise, order=0, exist_tell_in_story=exist_tell_in_story
54
+ ) # order = 0 is dummy here.
55
+ for mode in modes:
56
+ folder_name_1 = mode
57
+ for order_of_story in orders:
58
+ file_name = f'order_{order_of_story}.txt'
59
+ os.makedirs(os.path.join(
60
+ output_dir_path, folder_name_1, folder_name_2, folder_name_3), exist_ok=True)
61
+ path = os.path.join(
62
+ output_dir_path, folder_name_1, folder_name_2, folder_name_3, file_name)
63
+
64
+ with open(path, 'w', encoding='utf-8') as f:
65
+ if mode == 'MC':
66
+ f.write(
67
+ 'The following story happens in chronological order. You will be given a multiple-choice question and a note at the end. Directly output the answer without explanation.\n')
68
+ else:
69
+ f.write(
70
+ 'The following story happens in chronological order. You will be given a multiple-choice question and a note at the end. First give step-by-step analysis about the question. Then output the answer.\n')
71
+ f.write('Story:\n')
72
+ <<<<<<< HEAD
73
+ f.write('\n'.join(stringify(story, exist_answer=exist_answer, order=order_of_story))) #exist_answer is dummy
74
+ f.write('\nNote: You should assume the following. (1) A character witnesses everything and every movements before exiting a location. (2) A character A can infer another character B\'s mental state only if A and B have been in the same location, or have private or public interactions. (3) Note that every character tend to lie. A character tend to trust a character that exited the room later than himself. (4) Agents in private communications know that others won\'t hear them, but anyone can hear any public claims. (5) What a character tells others does not affect his actual belief.\n')
75
+ =======
76
+ # exist_answer is dummy
77
+ f.write(
78
+ '\n'.join(stringify(story, exist_answer=exist_answer, order=order_of_story)))
79
+ f.write('\nNote: You should assume the following. (1) An agent witnesses everything and every movement before exiting a location. (2) An agent A can infer another agent B\'s mental state only if A and B have been in the same location, or have private or public interactions. (3) Note that every agent tends to lie. What an agent A tells others doesn\'t affect A\'s actual belief. An agent tends to trust an agent that exited the room later than himself. The exit order is known to all agents. (4) Agents in private communications know that others won\'t hear them, but they know that anyone can hear any public claims.\n')
80
+ >>>>>>> 50242d0343261b6c95293fc995711b384ff3c1fe
81
+
82
+
83
+ def parse_args(args):
84
+
85
+ parser = argparse.ArgumentParser(
86
+ description='Process command-line arguments.'
87
+ )
88
+
89
+ parser.add_argument(
90
+ '-w', '--world_path', dest='world_paths', type=is_file, required=True,
91
+ action='append', help='Path to a world definition file'
92
+ )
93
+
94
+ parser.add_argument(
95
+ '-o', '--output_dir_path', dest='output_dir_path', type=mkdir_p,
96
+ default='data_ToMh', help='Output directory path'
97
+ )
98
+
99
+ # parser.add_argument(
100
+ # '-b', '--babi_dir_path', dest='babi_dir_path', type=str, required=True,
101
+ # help='Path to directory containing the 20 bAbi task train + test data'
102
+ # )
103
+
104
+ parser.add_argument(
105
+ '-l', '--logging', type=str, default='INFO', metavar='logging',
106
+ choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
107
+ help='Logging level'
108
+ )
109
+
110
+ parser.add_argument(
111
+ '-n', '--num_stories', dest='num_stories_choices', type=int,
112
+ action='append', required=True,
113
+ help='Number of stories (examples) in a task)'
114
+ )
115
+
116
+ parser.add_argument(
117
+ '-ptn', '--prob_test_noise', dest='test_noise', type=float,
118
+ required=True, help='Probability of encountering random noise sentence'
119
+ )
120
+
121
+ parser.add_argument(
122
+ '-tn', '--train_noise', dest='train_noise', type=bool, default=False,
123
+ help='Whether or not to include noise at training time'
124
+ )
125
+ parser.add_argument(
126
+ '-ord', '--order', dest='order', type=int, default=-1,
127
+ help='The range of question orders'
128
+ )
129
+ parser.add_argument(
130
+ '-len', '--length', dest='num_chapter', type=int, default=-1,
131
+ help='The range of story lengths'
132
+ )
133
+ parser.add_argument(
134
+ '-t', '--tell', dest='exist_tell', type=bool, default=False,
135
+ help='Whether or not the story has communications between agents'
136
+ )
137
+ parser.add_argument(
138
+ '-p', '--prompt', dest='prompt_type', type=str, default='CoT',
139
+ choices=['MC', 'CoT'],
140
+ help='The type of prompt chosen between MC and CoT'
141
+ )
142
+ parser.add_argument(
143
+ '-a', '--answer', dest='exist_answer', type=bool, default=False,
144
+ help='Whether or not the data has answers'
145
+ )
146
+
147
+ parsed = parser.parse_args(args)
148
+
149
+ return parsed
150
+
151
+
152
+ def main(args=sys.argv[1:]):
153
+ """Main function to generate all the story-question pairs."""
154
+ args = parse_args(args)
155
+ logging.basicConfig(
156
+ level=args.logging, format='%(asctime)s\t%(levelname)-8s\t%(message)s'
157
+ )
158
+ folder_name = 'Tell/' if args.exist_tell else 'No_Tell/'
159
+
160
+ # folder_name += args.prompt_type
161
+ # output_dir_path = os.path.join(args.output_dir_path, folder_name) if args.exist_answer else os.path.join('prompt_ToMh', folder_name)
162
+
163
+ output_dir_path = os.path.join(args.output_dir_path, folder_name)
164
+
165
+ generate_story_with_specified_chapters(
166
+ world_paths=args.world_paths,
167
+ output_dir_path=output_dir_path,
168
+ n=args.num_stories_choices,
169
+ noise=args.test_noise,
170
+ train_noise=args.train_noise,
171
+ order=args.order,
172
+ num_chapter=args.num_chapter,
173
+ exist_tell_in_story=args.exist_tell,
174
+ prompt=args.prompt_type,
175
+ exist_answer=args.exist_answer,
176
+ )
177
+
178
+
179
+ if __name__ == "__main__":
180
+ sys.exit(main())
oracle.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The Oracle class keeps track of all object
3
+ and agent locations as well as a map of
4
+ beliefs about object and agent locations.
5
+ """
6
+ import copy
7
+
8
+ class LocationMap(object):
9
+
10
+ def __init__(self, agents, locations, objects, containers):
11
+
12
+ # Maps agents to their locations.
13
+ self.locations = {agent : None for agent in agents}
14
+
15
+ # Maps agents to their locations.
16
+ self.container_locations = {container : None for container in containers}
17
+
18
+ # Maps locations to their containers
19
+ self.containers = {location : None for location in locations}
20
+
21
+ # Maps containers to the objects they hold
22
+ self.container_objs = {container : [] for container in containers}
23
+
24
+ # Maps objects to their containers
25
+ self.obj_containers = {obj : None for obj in objects}
26
+
27
+ class MemoryMap(object):
28
+
29
+ def __init__(self, agents, objects):
30
+
31
+ zero_dict = {obj : None for obj in objects}
32
+ first_dict = {agent : copy.deepcopy(zero_dict) for agent in agents}
33
+ second_dict = {agent : copy.deepcopy(first_dict) for agent in agents}
34
+ third_dict = {agent : copy.deepcopy(second_dict) for agent in agents}
35
+ fourth_dict = {agent : copy.deepcopy(third_dict) for agent in agents}
36
+
37
+ # Dictionary of dictionaries mapping
38
+ # agents to objects to containers. Represents
39
+ # agents' belief about location of containers.
40
+ self.first_belief = copy.deepcopy(first_dict)
41
+
42
+ # Dictionary of dictionaries of dictionaries
43
+ # mapping agents to direct belief dictionaries.
44
+ # Represents agents' belief about other agents'
45
+ # beliefs about location of containers.
46
+ self.second_belief = copy.deepcopy(second_dict)
47
+ self.third_belief = copy.deepcopy(third_dict)
48
+ self.fourth_belief = copy.deepcopy(fourth_dict)
49
+
50
+ class Oracle(object):
51
+
52
+ def __init__(self, agents, locations, objects, containers):
53
+ self.memory_map = MemoryMap(agents, objects)
54
+ self.locations = LocationMap(agents, locations, objects, containers)
55
+
56
+ #########################################
57
+ ################ Beliefs ################
58
+ #########################################
59
+
60
+ def get_first_belief(self, agent, obj):
61
+ beliefs = self.memory_map.first_belief
62
+ return beliefs[agent][obj]
63
+
64
+ def set_first_belief(self, agent, obj, container):
65
+ beliefs = self.memory_map.first_belief
66
+ beliefs[agent][obj] = container
67
+
68
+ def get_second_belief(self, a1, a2, obj):
69
+ second_belief = self.memory_map.second_belief
70
+ return second_belief[a1][a2][obj]
71
+
72
+ def set_second_belief(self, a1, a2, obj, container):
73
+ second_belief = self.memory_map.second_belief
74
+ second_belief[a1][a2][obj] = container
75
+
76
+ def get_third_belief(self, a1, a2, a3, obj):
77
+ third_belief = self.memory_map.third_belief
78
+ return third_belief[a1][a2][a3][obj]
79
+
80
+ def set_third_belief(self, a1, a2, a3, obj, container):
81
+ third_belief = self.memory_map.third_belief
82
+ third_belief[a1][a2][a3][obj] = container
83
+
84
+ def get_fourth_belief(self, a1, a2, a3, a4, obj):
85
+ fourth_belief = self.memory_map.fourth_belief
86
+ return fourth_belief[a1][a2][a3][a4][obj]
87
+
88
+ def set_fourth_belief(self, a1, a2, a3, a4, obj, container):
89
+ fourth_belief = self.memory_map.fourth_belief
90
+ fourth_belief[a1][a2][a3][a4][obj] = container
91
+
92
+ #########################################
93
+ ############### Locations ###############
94
+ #########################################
95
+
96
+ def get_location(self, agent):
97
+ return self.locations.locations[agent]
98
+
99
+ def set_location(self, agent, location):
100
+ self.locations.locations[agent] = location
101
+
102
+ def get_containers(self, location):
103
+ # Returns a list of containers at location
104
+ return self.locations.containers[location]
105
+
106
+ def set_containers(self, location, containers):
107
+ # May need to change to move containers bt locs
108
+ # Containers is a list of containers at location
109
+ for container in containers:
110
+ self._set_container_location(container, location)
111
+ self.locations.containers[location] = containers
112
+
113
+ def get_objects_at_location(self, location):
114
+ objects = []
115
+ for container in self.get_containers(location):
116
+ objects.extend(self.get_container_obj(container))
117
+ return objects
118
+
119
+ def get_container_location(self, container):
120
+ return self.locations.container_locations[container]
121
+
122
+ def _set_container_location(self, container, location):
123
+ self.locations.container_locations[container] = location
124
+
125
+ def get_container_obj(self, container):
126
+ # get list of objects in container
127
+ return self.locations.container_objs[container]
128
+
129
+ def _add_container_obj(self, container, obj):
130
+ self.locations.container_objs[container].append(obj)
131
+
132
+ def _remove_container_obj(self, container, obj):
133
+ self.locations.container_objs[container].remove(obj)
134
+
135
+ def get_object_container(self, obj):
136
+ # get container that holds object
137
+ return self.locations.obj_containers[obj]
138
+
139
+ def set_object_container(self, obj, container):
140
+ # set container that holds object
141
+ prev_container = self.get_object_container(obj)
142
+ if prev_container:
143
+ self._remove_container_obj(prev_container, obj)
144
+ self._add_container_obj(container, obj)
145
+ self.locations.obj_containers[obj] = container
146
+
147
+
stringify.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def stringify(story, exist_answer=False, order=0): # exist_answer is dummy
5
+
6
+ lines = []
7
+
8
+ i = 0 # The number of descriptions processed
9
+ j = 0 # The number of lines output
10
+ count_order = 0
11
+
12
+ while True:
13
+
14
+
15
+ if isinstance(story[i], str):
16
+ line = story[i]
17
+ else:
18
+ line = story[i].render()
19
+ # Capitalize the line
20
+ line = line[0].upper() + line[1:]
21
+
22
+ # Prepend the line number
23
+ if line.split()[0] != 'Question:' and line.split()[0] != 'Choices:':
24
+ line = '%d %s' % (i + 1, line)
25
+ else: # Start with 'Choice'
26
+ if line.split()[0] == 'Choices:':
27
+ lines.append(line)
28
+ break
29
+ else: # Start with 'Question'
30
+ if count_order == order:
31
+ lines.append(line)
32
+ count_order += 1
33
+ i += 1
34
+ continue
35
+ lines.append(line)
36
+ # Increment counters
37
+ i += 1
38
+
39
+ # Append supporting lines indices if necessary
40
+ # if hasattr(story[i], 'idx_support') and story[i].idx_support:
41
+ # line += '\t%s' % ' '.join([str(x + 1)
42
+ # for x in story[i].idx_support])
43
+
44
+ if i >= len(story):
45
+ break
46
+
47
+ return lines
tasks.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import copy
4
+
5
+ from clause import Clause, Question
6
+ from oracle import Oracle
7
+ from dynamic_actions import *
8
+ from collections import defaultdict
9
+
10
+
11
+ def sample_question(oracle_start_state, oracle, random_actors, obj, question_idx=0):
12
+ idx_dummy = [0]
13
+ a1, a2, a3, a4, _ = random_actors
14
+ questions = [Question(idx_dummy, ZeroQ(oracle, obj)),
15
+ Question(idx_dummy, FirstQ(oracle, a4, obj)),
16
+ Question(idx_dummy, SecondQ(oracle, a3, a4, obj)),
17
+ Question(idx_dummy, ThirdQ(oracle, a2, a3, a4, obj)),
18
+ Question(idx_dummy, FourthQ(oracle, a1, a2, a3, a4, obj))]
19
+ return questions[question_idx]
20
+
21
+ #######################################
22
+ ############## Chapters ###############
23
+ #######################################
24
+
25
+
26
+ def write_A2_chapter(
27
+ start_state, oracle, obj, location, agent_ids, all_agents, movements=None, exist_tell=False, questions=None
28
+ ):
29
+ a1, a2 = all_agents[agent_ids[0]], all_agents[agent_ids[1]]
30
+ outsiders = [agent for agent in all_agents if agent not in [a1, a2]]
31
+ agent_ids = [aid+1 for aid in agent_ids]
32
+
33
+ # Pick containers. The first element is the initial container of obj
34
+ containers = [oracle.get_object_container(obj)]
35
+ container_candidates = oracle.get_containers(location)[:]
36
+ container_candidates.remove(containers[0])
37
+ containers += random.sample(container_candidates, 2)
38
+
39
+ # Fill in the chapter
40
+ chapter = []
41
+
42
+ # All selected agents enter the room and see the object
43
+ chapter.extend([
44
+ Clause(EnterAction(oracle, (a1, a2, location))),
45
+ Clause(ObjectLocAction(oracle, obj, [a1, a2])),
46
+ ])
47
+
48
+ # a1
49
+ chapter.extend([
50
+ Clause(MoveAction(oracle, (a1, obj, containers[1]), [
51
+ a2], move=movements[0])),
52
+ Clause(ExitedAction(oracle, (a1)))
53
+ ])
54
+ # a2
55
+ chapter.extend([
56
+ Clause(MoveAction(
57
+ oracle, (a2, obj, containers[2]), None, move=movements[1])),
58
+ Clause(ExitedAction(oracle, (a2)))
59
+ ])
60
+
61
+ # Everyone enter the waiting room
62
+ chapter.extend([
63
+ Clause(EnterAction(oracle, (a1, a2, 'waiting_room')))
64
+ ])
65
+
66
+ # tell actions has 3 different forms
67
+ if exist_tell:
68
+ tell_containers = random.sample(oracle.get_containers(location)[:], 2)
69
+ tell_form = random.choice(
70
+ range(3)) if outsiders else random.choice(range(2))
71
+ match tell_form:
72
+ case 0:
73
+ chapter.extend([
74
+ Clause(PublicTellAction(
75
+ oracle, a1, obj, tell_containers[0], listeners=all_agents, believers=outsiders)),
76
+ Clause(PrivateTellAction(oracle, a2, a1,
77
+ obj, tell_containers[1], trust=True)),
78
+ ])
79
+ case 1:
80
+ chapter.extend([
81
+ Clause(PublicTellAction(
82
+ oracle, a2, obj, tell_containers[0], listeners=all_agents, believers=[a1] + outsiders)),
83
+ Clause(PrivateTellAction(oracle, a1, a2, obj,
84
+ tell_containers[1], trust=False)),
85
+ ])
86
+ case 2:
87
+ chapter.extend([
88
+ Clause(PrivateTellAction(oracle, a1, random.choice(outsiders),
89
+ obj, tell_containers[0], trust=True))
90
+ ])
91
+ return chapter
92
+
93
+
94
+ def write_A3_chapter(
95
+ start_state, oracle, obj, location, agent_ids, all_agents, movements=None, exist_tell=False, questions=None
96
+ ):
97
+ a1, a2, a3 = all_agents[agent_ids[0]
98
+ ], all_agents[agent_ids[1]], all_agents[agent_ids[2]]
99
+ outsiders = [agent for agent in all_agents if agent not in [a1, a2, a3]]
100
+ agent_ids = [aid+1 for aid in agent_ids]
101
+
102
+ # Pick containers. The first element is the initial container of obj
103
+ containers = [oracle.get_object_container(obj)]
104
+ container_candidates = oracle.get_containers(location)[:]
105
+ container_candidates.remove(containers[0])
106
+ containers += random.sample(container_candidates, 3)
107
+
108
+ # Fill in the chapter
109
+ chapter = []
110
+
111
+ # All selected agents enter the room and see the object
112
+ chapter.extend([
113
+ Clause(EnterAction(oracle, (a1, a2, a3, location))),
114
+ Clause(ObjectLocAction(oracle, obj, [a1, a2, a3])),
115
+ ])
116
+
117
+ # a1
118
+ chapter.extend([
119
+ Clause(MoveAction(oracle, (a1, obj, containers[1]), [
120
+ a2, a3], move=movements[0])),
121
+ Clause(ExitedAction(oracle, (a1)))
122
+ ])
123
+ # a2
124
+ chapter.extend([
125
+ Clause(MoveAction(oracle, (a2, obj, containers[2]), [
126
+ a3], move=movements[1])),
127
+ Clause(ExitedAction(oracle, (a2)))
128
+ ])
129
+ # a3
130
+ chapter.extend([
131
+ Clause(MoveAction(
132
+ oracle, (a3, obj, containers[3]), None, move=movements[2])),
133
+ Clause(ExitedAction(oracle, (a3)))
134
+ ])
135
+
136
+ # Everyone enter the waiting room
137
+ chapter.extend([
138
+ Clause(EnterAction(oracle, (a1, a2, a3, 'waiting_room')))
139
+ ])
140
+
141
+ # tell actions has 4 different forms
142
+ if exist_tell:
143
+ tell_containers = random.sample(oracle.get_containers(location)[:], 2)
144
+ tell_form = random.choice(
145
+ range(4)) if outsiders else random.choice(range(2))
146
+ match tell_form:
147
+ case 0:
148
+ # a2 lies to all, and a3 lies to a2
149
+ chapter.extend([
150
+ Clause(PublicTellAction(
151
+ oracle, a2, obj, tell_containers[0], listeners=all_agents, believers=[a1] + outsiders)),
152
+ Clause(PrivateTellAction(oracle, a3, a2,
153
+ obj, tell_containers[1], trust=True)),
154
+ ])
155
+ case 1:
156
+ # a3 lies to all, and a1 lies to a3
157
+ chapter.extend([
158
+ Clause(PublicTellAction(
159
+ oracle, a3, obj, tell_containers[0], listeners=all_agents, believers=[a1, a2] + outsiders)),
160
+ Clause(PrivateTellAction(oracle, a1, a3, obj,
161
+ tell_containers[1], trust=False)),
162
+ ])
163
+ case 2:
164
+ # a1 lies to all, but a3 tells the true location to an outside agent
165
+ chapter.extend([
166
+ Clause(PublicTellAction(
167
+ oracle, a1, obj, tell_containers[0], listeners=all_agents, believers=outsiders)),
168
+ Clause(PrivateTellAction(oracle, a3, random.choice(outsiders),
169
+ obj, oracle.get_object_container(obj), trust=True))
170
+ ])
171
+ case 3:
172
+ # a2 lies to a3, but a3 tells the true location to an outside agent
173
+ chapter.extend([
174
+ Clause(PrivateTellAction(oracle, a2, a3,
175
+ obj, tell_containers[0], trust=False)),
176
+ Clause(PrivateTellAction(oracle, a3, random.choice(outsiders),
177
+ obj, oracle.get_object_container(obj), trust=True))
178
+ ])
179
+ return chapter
180
+
181
+
182
+ def write_A4_chapter(
183
+ start_state, oracle, obj, location, agent_ids, all_agents, movements=None, exist_tell=False, questions=None
184
+ ):
185
+ a1, a2, a3, a4 = all_agents[agent_ids[0]
186
+ ], all_agents[agent_ids[1]], all_agents[agent_ids[2]], all_agents[agent_ids[3]]
187
+ outsiders = [
188
+ agent for agent in all_agents if agent not in [a1, a2, a3, a4]]
189
+ agent_ids = [aid+1 for aid in agent_ids]
190
+
191
+ # Pick containers. The first element is the initial container of obj
192
+ containers = [oracle.get_object_container(obj)]
193
+ container_candidates = oracle.get_containers(location)[:]
194
+ container_candidates.remove(containers[0])
195
+ containers += random.sample(container_candidates, 4)
196
+
197
+ # Fill in the chapter
198
+ chapter = []
199
+
200
+ # All selected agents enter the room and see the object
201
+ chapter.extend([
202
+ Clause(EnterAction(oracle, (a1, a2, a3, a4, location))),
203
+ Clause(ObjectLocAction(oracle, obj, [a1, a2, a3, a4])),
204
+ ])
205
+
206
+ # a1
207
+ chapter.extend([
208
+ Clause(MoveAction(oracle, (a1, obj, containers[1]), [
209
+ a2, a3, a4], move=movements[0])),
210
+ Clause(ExitedAction(oracle, (a1)))
211
+ ])
212
+ # a2
213
+ chapter.extend([
214
+ Clause(MoveAction(oracle, (a2, obj, containers[2]), [
215
+ a3, a4], move=movements[1])),
216
+ Clause(ExitedAction(oracle, (a2)))
217
+ ])
218
+ # a3
219
+ chapter.extend([
220
+ Clause(MoveAction(oracle, (a3, obj, containers[3]), [
221
+ a4], move=movements[2])),
222
+ Clause(ExitedAction(oracle, (a3)))
223
+ ])
224
+ # a4
225
+ chapter.extend([
226
+ Clause(MoveAction(
227
+ oracle, (a4, obj, containers[4]), None, move=movements[3])),
228
+ Clause(ExitedAction(oracle, (a4)))
229
+ ])
230
+
231
+ # Everyone enter the waiting room
232
+ chapter.extend([
233
+ Clause(EnterAction(oracle, (a1, a2, a3, a4, 'waiting_room')))
234
+ ])
235
+
236
+ # tell actions has 4 different forms
237
+ if exist_tell:
238
+ tell_containers = random.sample(oracle.get_containers(location)[:], 2)
239
+ tell_form = random.choice(
240
+ range(4)) if outsiders else random.choice(range(2))
241
+ match tell_form:
242
+ case 0:
243
+ # a2 lies to all, and a3 lies to a2
244
+ chapter.extend([
245
+ Clause(PublicTellAction(
246
+ oracle, a2, obj, tell_containers[0], listeners=all_agents, believers=[a1] + outsiders)),
247
+ Clause(PrivateTellAction(oracle, a4, a3,
248
+ obj, tell_containers[1], trust=True)),
249
+ ])
250
+ case 1:
251
+ # a3 lies to all, and a1 lies to a4
252
+ chapter.extend([
253
+ Clause(PublicTellAction(
254
+ oracle, a3, obj, tell_containers[0], listeners=all_agents, believers=[a1, a2] + outsiders)),
255
+ Clause(PrivateTellAction(oracle, a1, a4, obj,
256
+ tell_containers[1], trust=False)),
257
+ ])
258
+ case 2:
259
+ outsider = random.choice(outsiders)
260
+ # a1 lies to all, but a4 tells the true location to an outside agent
261
+ chapter.extend([
262
+ Clause(PublicTellAction(
263
+ oracle, a1, obj, tell_containers[0], listeners=all_agents, believers=outsiders)),
264
+ Clause(PrivateTellAction(oracle, a4, outsider,
265
+ obj, oracle.get_object_container(obj), trust=True))
266
+ ])
267
+ case 3:
268
+ outsider = random.choice(outsiders)
269
+ # a2 lies to a3, but a4 tells the true location to an outside agent
270
+ chapter.extend([
271
+ Clause(PrivateTellAction(oracle, a2, a3,
272
+ obj, tell_containers[0], trust=False)),
273
+ Clause(PrivateTellAction(oracle, a4, outsider,
274
+ obj, oracle.get_object_container(obj), trust=True))
275
+ ])
276
+ return chapter
277
+
278
+
279
+ def write_A5_chapter(
280
+ start_state, oracle, obj, location, agent_ids, all_agents, movements=None, exist_tell=False, questions=None
281
+ ):
282
+ a1, a2, a3, a4, a5 = all_agents[agent_ids[0]], all_agents[agent_ids[1]
283
+ ], all_agents[agent_ids[2]], all_agents[agent_ids[3]], all_agents[agent_ids[4]]
284
+ agent_ids = [aid+1 for aid in agent_ids]
285
+
286
+ # Pick containers. The first element is the initial container of obj
287
+ containers = [oracle.get_object_container(obj)]
288
+ container_candidates = oracle.get_containers(location)[:]
289
+ container_candidates.remove(containers[0])
290
+ containers += random.sample(container_candidates, 4)
291
+
292
+ # Fill in the chapter
293
+ chapter = []
294
+
295
+ # All selected agents enter the room and see the object
296
+ chapter.extend([
297
+ Clause(EnterAction(oracle, (a1, a2, a3, a4, a5, location))),
298
+ Clause(ObjectLocAction(oracle, obj, [a1, a2, a3, a4, a5])),
299
+ ])
300
+
301
+ # a1
302
+ chapter.extend([
303
+ Clause(MoveAction(oracle, (a1, obj, containers[1]), [
304
+ a2, a3, a4, a5], move=movements[0])),
305
+ Clause(ExitedAction(oracle, (a1)))
306
+ ])
307
+ # a2
308
+ chapter.extend([
309
+ Clause(MoveAction(oracle, (a2, obj, containers[2]), [
310
+ a3, a4, a5], move=movements[1])),
311
+ Clause(ExitedAction(oracle, (a2)))
312
+ ])
313
+ # a3
314
+ chapter.extend([
315
+ Clause(MoveAction(oracle, (a3, obj, containers[3]), [
316
+ a4, a5], move=movements[2])),
317
+ Clause(ExitedAction(oracle, (a3)))
318
+ ])
319
+ # a4
320
+ chapter.extend([
321
+ Clause(MoveAction(oracle, (a4, obj, containers[4]), [
322
+ a5], move=movements[3])),
323
+ Clause(ExitedAction(oracle, (a4)))
324
+ ])
325
+ # a5
326
+ chapter.extend([
327
+ Clause(MoveAction(
328
+ oracle, (a5, obj, containers[0]), None, move=movements[4])),
329
+ Clause(ExitedAction(oracle, (a5)))
330
+ ])
331
+
332
+ # Everyone enter the waiting room
333
+ chapter.extend([
334
+ Clause(EnterAction(oracle, (a1, a2, a3, a4, a5, 'waiting_room')))
335
+ ])
336
+
337
+ # tell actions has 3 different forms
338
+ if exist_tell:
339
+ tell_containers = random.sample(oracle.get_containers(location)[:], 2)
340
+ tell_form = random.choice(range(3))
341
+ match tell_form:
342
+ case 0:
343
+ # a3 lies to all, and a5 lies to a3
344
+ chapter.extend([
345
+ Clause(PublicTellAction(
346
+ oracle, a3, obj, tell_containers[0], listeners=all_agents, believers=[a1, a2])),
347
+ Clause(PrivateTellAction(oracle, a5, a3,
348
+ obj, tell_containers[1], trust=True)),
349
+ ])
350
+ case 1:
351
+ # a4 lies to all, but a5 tells the true location to a1
352
+ chapter.extend([
353
+ Clause(PublicTellAction(
354
+ oracle, a4, obj, tell_containers[0], listeners=all_agents, believers=[a1, a2, a3])),
355
+ Clause(PrivateTellAction(oracle, a5, a1, obj,
356
+ oracle.get_object_container(obj), trust=True)),
357
+ ])
358
+ case 2:
359
+ # a3 lies a1, and a2 lies to a4
360
+ chapter.extend([
361
+ Clause(PrivateTellAction(oracle, a3, a1,
362
+ obj, tell_containers[0], trust=True))
363
+ ])
364
+ return chapter
365
+
366
+
367
+ #######################################
368
+ ############### Tasks #################
369
+ #######################################
370
+
371
+ class Task(object):
372
+
373
+ def __init__(self,
374
+ num_questions=5,
375
+ exit_prob=1.,
376
+ informant_prob=1.,
377
+ search_prob=1.,
378
+ test_cond='first order'):
379
+
380
+ self.num_questions = num_questions
381
+
382
+ self.search_prob = search_prob
383
+
384
+ self.exit_inform_probs = [1 - exit_prob,
385
+ exit_prob * (1 - informant_prob),
386
+ exit_prob * informant_prob]
387
+ assert sum(self.exit_inform_probs) == 1
388
+
389
+ assert test_cond in ['first order',
390
+ 'second order',
391
+ 'reality',
392
+ 'memory'], \
393
+ "Invalid test condition: %s" % test_cond
394
+ self.test_cond = test_cond
395
+
396
+ def generate_story(self, world):
397
+ raise NotImplementedError("Abstract method.")
398
+
399
+
400
+ class Specify_Tasks(Task):
401
+ def generate_story_qs_at_end(
402
+ self, world, tasks_per_story, tasks, num_agents=5,
403
+ num_locations=3, statement_noise=0.1, order=0, exist_tell_in_story=False
404
+ ):
405
+ """
406
+ Allows user to specify chapter and question for each task in story.
407
+
408
+ :tasks: list with length of tasks per story. Each entry is a string in
409
+ the set {'tb','fb','sofb'}
410
+
411
+ :questions: list with length of tasks per story. Each entry is a string
412
+ in the set {'memory', 'reality', 'belief', 'search'}
413
+
414
+ :statement_noise: probability of encountering noise sentence like 'The
415
+ dog ran through the kitchen.'
416
+ """
417
+
418
+ # Fetch agents and objects and select a random subset
419
+ idx_support_dummy = [0]
420
+ actors = world.get_actors()
421
+ locations = world.get_locations()
422
+ objects = world.get_objects()
423
+ containers = world.get_containers()
424
+
425
+ random_actors = np.random.choice(
426
+ actors, size=num_agents, replace=False
427
+ )
428
+ random_locations = np.random.choice(
429
+ locations, size=num_locations, replace=False
430
+ )
431
+ random_objects = np.random.choice(
432
+ objects, size=num_locations*2, replace=False
433
+ )
434
+ random_containers = np.random.choice(
435
+ containers, size=num_locations*5, replace=False
436
+ )
437
+
438
+ # Create the oracle
439
+ oracle = Oracle(
440
+ random_actors, random_locations, random_objects, random_containers
441
+ )
442
+
443
+ # Populate locations in the oracle with containers
444
+ for i, random_location in enumerate(random_locations):
445
+ location = random_location
446
+ containers = random_containers[5*i:5*i+5]
447
+ oracle.set_containers(location, list(containers))
448
+ # Two of the containers have objects
449
+ oracle.set_object_container(
450
+ random_objects[2*i], containers[0])
451
+ oracle.set_object_container(
452
+ random_objects[2*i+1], containers[1])
453
+
454
+ # Need start state for memory question
455
+ start_state = oracle.locations.obj_containers.copy()
456
+
457
+ # Create story by task
458
+ chapters = {'A2': write_A2_chapter,
459
+ 'A3': write_A3_chapter,
460
+ 'A4': write_A4_chapter,
461
+ 'A5': write_A5_chapter}
462
+ story = []
463
+ obj_pool = []
464
+ obj_in_question = None
465
+
466
+ for i in range(tasks_per_story):
467
+ chapter = chapters[tasks[i][0]]
468
+ location = np.random.choice(random_locations)
469
+ obj = np.random.choice(oracle.get_objects_at_location(location))
470
+ # Use the obj in the first chap as the target
471
+ if i == 0:
472
+ obj_in_question = obj
473
+ obj_pool.append(obj)
474
+ agent_ids = list(range(5))
475
+ random.shuffle(agent_ids)
476
+
477
+ # Randomly choose movements for each agent
478
+ agent_num = int(tasks[i][0][1])
479
+ bools = [True, False]
480
+ movements = [random.choice(bools) for _ in range(agent_num)]
481
+ exist_tell_in_chapter = tasks[i][1] if exist_tell_in_story else False
482
+ story.extend(
483
+ chapter(
484
+ start_state, oracle, obj, location, agent_ids, random_actors, movements=movements, exist_tell=exist_tell_in_chapter
485
+ )
486
+ )
487
+
488
+ # At the end, add noise sentences randomly
489
+ if statement_noise:
490
+ noisy_story = []
491
+ prev_i = 0
492
+ noise = [i for i
493
+ in range(len(story)) if np.random.rand() < statement_noise
494
+ ]
495
+ for i in noise:
496
+ noisy_story.extend(
497
+ story[prev_i:i] +
498
+ [Clause(NoiseAction(random_actors,
499
+ random_containers, random_objects))]
500
+ )
501
+ prev_i = i
502
+ noisy_story.extend(story[prev_i:])
503
+
504
+ # compute questions of all orders
505
+ questioned_actors = copy.deepcopy(random_actors)
506
+ random.shuffle(questioned_actors)
507
+ for idx in range(5):
508
+ noisy_story.append(
509
+ sample_question(
510
+ start_state, oracle, questioned_actors, obj_in_question, question_idx=idx
511
+ )
512
+ )
513
+
514
+ # Generate choices of containers
515
+ choices = ', '.join(f'{chr(65+i)}. {container}' for i,
516
+ container in enumerate(random_containers))
517
+ noisy_story.append('Choices: ' + choices + '\n')
518
+ return noisy_story
test_azure.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import openai
3
+
4
+ def record_progress(filename):
5
+ with open('progress.txt', 'a') as f:
6
+ f.write(filename + '\n')
7
+
8
+ def is_processed(filename):
9
+ with open('progress.txt', 'r') as f:
10
+ processed_files = f.read().splitlines()
11
+ return filename in processed_files
12
+
13
+ openai.api_type = "azure"
14
+ openai.api_base = "https://openaiserviceforclausaeu.openai.azure.com/"
15
+ openai.api_version = "2023-03-15-preview"
16
+ openai.api_key = os.getenv("OPENAI_API_KEY")
17
+
18
+ test_dirs = os.listdir("prompt_ToMh")
19
+ for test_dir in test_dirs:
20
+ test_fns = os.listdir(f"prompt_ToMh/{test_dir}")
21
+ for test_fn in test_fns:
22
+ full_path = f"prompt_ToMh/{test_dir}/{test_fn}"
23
+ if is_processed(full_path):
24
+ continue
25
+ print(test_fn)
26
+ print(f"path: {full_path}")
27
+ with open(full_path, 'r') as f:
28
+ input = f.readlines()
29
+ input = "\n".join([inp.strip() for inp in input])
30
+ response = openai.ChatCompletion.create(
31
+ engine="gpt4-32k",
32
+ messages=[
33
+ {"role":"system","content":"You are an AI assistant that helps people find information."},
34
+ {"role":"user","content": input}
35
+ ],
36
+ temperature=0,
37
+ max_tokens=800,
38
+ top_p=0,
39
+ frequency_penalty=0,
40
+ presence_penalty=0,
41
+ stop=None)
42
+ print(response)
43
+ record_progress(full_path)
utils.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentTypeError
2
+ import errno
3
+ import os
4
+
5
+
6
+ class Error(Exception):
7
+ """Base class for exceptions in this module."""
8
+ pass
9
+
10
+
11
+ class InputError(Error):
12
+ """Exception raised for errors in the input.
13
+
14
+ Attributes:
15
+ expr # input expression in which the error occurred
16
+ msg # explanation of the error
17
+ """
18
+
19
+ def __init__(self, expr, msg):
20
+ self.expr = expr
21
+ self.msg = msg
22
+
23
+
24
+ def is_file(f):
25
+ try:
26
+ open(f, 'r') # return an open file handle
27
+ except IOError:
28
+ raise ArgumentTypeError("{0} does not exist".format(f))
29
+ return f
30
+
31
+
32
+ def mkdir_p(path):
33
+ try:
34
+ os.makedirs(path)
35
+ except OSError as exc: # Python >2.5
36
+ if exc.errno == errno.EEXIST and os.path.isdir(path):
37
+ pass
38
+ else:
39
+ raise
40
+ return path
41
+
42
+
43
+ def remove_extension(path):
44
+ return os.path.splitext(os.path.basename(path))[0]
world.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class World(object):
2
+
3
+ def __init__(self, world_actions=[], entities={}):
4
+ self.actions = world_actions
5
+ self.entities = entities
6
+
7
+ def load(self, fname):
8
+
9
+ lines = open(fname, 'r').readlines()
10
+ i = 0
11
+
12
+ while i < len(lines):
13
+ line = lines[i].rstrip('\n')
14
+ if line != '' and not line.startswith('#'):
15
+ if line.startswith('create'):
16
+ self.entities[line.split(' ')[1]] = {}
17
+ elif line.startswith('set'):
18
+ self.entities[line.split(' ')[1]][line.split(' ')[-1]] = True
19
+
20
+ i += 1
21
+
22
+ def get_entity(self, predicates):
23
+
24
+ if not isinstance(predicates, list):
25
+ raise InputError(predicates, 'is not a list.')
26
+
27
+ return_val = []
28
+
29
+ for k in self.entities:
30
+ if all([predicate in self.entities[k] and
31
+ self.entities[k][predicate] is True
32
+ for predicate in predicates]):
33
+ return_val += [k]
34
+
35
+ return return_val
36
+
37
+ def get_actors(self):
38
+ return self.get_entity(['is_actor', 'is_god'])
39
+
40
+ def get_containers(self):
41
+ return self.get_entity(['is_thing', 'is_container'])
42
+
43
+ def get_locations(self):
44
+ return self.get_entity(['is_location'])
45
+
46
+ def get_objects(self):
47
+ return self.get_entity(['is_thing', 'is_gettable'])