Upload 14 files
Browse files- Hi-ToM_data.json +0 -0
- README.md +16 -193
- actions.py +270 -0
- clause.py +26 -0
- create_world.py +248 -0
- dynamic_actions.py +369 -0
- generate_prompts.py +31 -0
- generate_tasks.py +180 -0
- oracle.py +147 -0
- stringify.py +47 -0
- tasks.py +518 -0
- test_azure.py +43 -0
- utils.py +44 -0
- world.py +47 -0
Hi-ToM_data.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
README.md
CHANGED
@@ -1,202 +1,25 @@
|
|
1 |
-
|
2 |
-
metrics:
|
3 |
-
- accuracy
|
4 |
-
pipeline_tag: question-answering
|
5 |
-
tags:
|
6 |
-
- code
|
7 |
-
---
|
8 |
|
9 |
-
|
10 |
|
11 |
-
|
12 |
|
13 |
-
|
14 |
|
15 |
-
|
|
|
16 |
|
17 |
-
|
|
|
|
|
|
|
|
|
18 |
|
19 |
-
|
20 |
|
|
|
|
|
21 |
|
|
|
22 |
|
23 |
-
|
24 |
-
- **Funded by [optional]:** [More Information Needed]
|
25 |
-
- **Shared by [optional]:** [More Information Needed]
|
26 |
-
- **Model type:** [More Information Needed]
|
27 |
-
- **Language(s) (NLP):** [More Information Needed]
|
28 |
-
- **License:** [More Information Needed]
|
29 |
-
- **Finetuned from model [optional]:** [More Information Needed]
|
30 |
-
|
31 |
-
### Model Sources [optional]
|
32 |
-
|
33 |
-
<!-- Provide the basic links for the model. -->
|
34 |
-
|
35 |
-
- **Repository:** [More Information Needed]
|
36 |
-
- **Paper [optional]:** [More Information Needed]
|
37 |
-
- **Demo [optional]:** [More Information Needed]
|
38 |
-
|
39 |
-
## Uses
|
40 |
-
|
41 |
-
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
42 |
-
|
43 |
-
### Direct Use
|
44 |
-
|
45 |
-
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
46 |
-
|
47 |
-
[More Information Needed]
|
48 |
-
|
49 |
-
### Downstream Use [optional]
|
50 |
-
|
51 |
-
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
52 |
-
|
53 |
-
[More Information Needed]
|
54 |
-
|
55 |
-
### Out-of-Scope Use
|
56 |
-
|
57 |
-
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
58 |
-
|
59 |
-
[More Information Needed]
|
60 |
-
|
61 |
-
## Bias, Risks, and Limitations
|
62 |
-
|
63 |
-
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
64 |
-
|
65 |
-
[More Information Needed]
|
66 |
-
|
67 |
-
### Recommendations
|
68 |
-
|
69 |
-
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
70 |
-
|
71 |
-
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
72 |
-
|
73 |
-
## How to Get Started with the Model
|
74 |
-
|
75 |
-
Use the code below to get started with the model.
|
76 |
-
|
77 |
-
[More Information Needed]
|
78 |
-
|
79 |
-
## Training Details
|
80 |
-
|
81 |
-
### Training Data
|
82 |
-
|
83 |
-
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
84 |
-
|
85 |
-
[More Information Needed]
|
86 |
-
|
87 |
-
### Training Procedure
|
88 |
-
|
89 |
-
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
90 |
-
|
91 |
-
#### Preprocessing [optional]
|
92 |
-
|
93 |
-
[More Information Needed]
|
94 |
-
|
95 |
-
|
96 |
-
#### Training Hyperparameters
|
97 |
-
|
98 |
-
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
99 |
-
|
100 |
-
#### Speeds, Sizes, Times [optional]
|
101 |
-
|
102 |
-
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
103 |
-
|
104 |
-
[More Information Needed]
|
105 |
-
|
106 |
-
## Evaluation
|
107 |
-
|
108 |
-
<!-- This section describes the evaluation protocols and provides the results. -->
|
109 |
-
|
110 |
-
### Testing Data, Factors & Metrics
|
111 |
-
|
112 |
-
#### Testing Data
|
113 |
-
|
114 |
-
<!-- This should link to a Dataset Card if possible. -->
|
115 |
-
|
116 |
-
[More Information Needed]
|
117 |
-
|
118 |
-
#### Factors
|
119 |
-
|
120 |
-
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
121 |
-
|
122 |
-
[More Information Needed]
|
123 |
-
|
124 |
-
#### Metrics
|
125 |
-
|
126 |
-
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
127 |
-
|
128 |
-
[More Information Needed]
|
129 |
-
|
130 |
-
### Results
|
131 |
-
|
132 |
-
[More Information Needed]
|
133 |
-
|
134 |
-
#### Summary
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
## Model Examination [optional]
|
139 |
-
|
140 |
-
<!-- Relevant interpretability work for the model goes here -->
|
141 |
-
|
142 |
-
[More Information Needed]
|
143 |
-
|
144 |
-
## Environmental Impact
|
145 |
-
|
146 |
-
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
147 |
-
|
148 |
-
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
149 |
-
|
150 |
-
- **Hardware Type:** [More Information Needed]
|
151 |
-
- **Hours used:** [More Information Needed]
|
152 |
-
- **Cloud Provider:** [More Information Needed]
|
153 |
-
- **Compute Region:** [More Information Needed]
|
154 |
-
- **Carbon Emitted:** [More Information Needed]
|
155 |
-
|
156 |
-
## Technical Specifications [optional]
|
157 |
-
|
158 |
-
### Model Architecture and Objective
|
159 |
-
|
160 |
-
[More Information Needed]
|
161 |
-
|
162 |
-
### Compute Infrastructure
|
163 |
-
|
164 |
-
[More Information Needed]
|
165 |
-
|
166 |
-
#### Hardware
|
167 |
-
|
168 |
-
[More Information Needed]
|
169 |
-
|
170 |
-
#### Software
|
171 |
-
|
172 |
-
[More Information Needed]
|
173 |
-
|
174 |
-
## Citation [optional]
|
175 |
-
|
176 |
-
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
177 |
-
|
178 |
-
**BibTeX:**
|
179 |
-
|
180 |
-
[More Information Needed]
|
181 |
-
|
182 |
-
**APA:**
|
183 |
-
|
184 |
-
[More Information Needed]
|
185 |
-
|
186 |
-
## Glossary [optional]
|
187 |
-
|
188 |
-
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
189 |
-
|
190 |
-
[More Information Needed]
|
191 |
-
|
192 |
-
## More Information [optional]
|
193 |
-
|
194 |
-
[More Information Needed]
|
195 |
-
|
196 |
-
## Model Card Authors [optional]
|
197 |
-
|
198 |
-
[More Information Needed]
|
199 |
-
|
200 |
-
## Model Card Contact
|
201 |
-
|
202 |
-
[More Information Needed]
|
|
|
1 |
+
# Hi-ToM Dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
+
This is the dataset for the paper "Hi-ToM: A Benchmark for Evaluating Higher-Order Theory of Mind Reasoning in Large Language Models".
|
4 |
|
5 |
+
<img src=media/Picture1.png height=430>
|
6 |
|
7 |
+
### The `Hi-ToM_data` folder
|
8 |
|
9 |
+
Contains ToMh data consisting of story-question pairs and the corresponding answers.
|
10 |
+
The names of subfolder branches have the following meanings:
|
11 |
|
12 |
+
- `Tell` / `No_Tell`: whether or not the stories contain communications among agents.
|
13 |
+
- `MC` / `CoT`: the prompting style. `MC` corresponds to Vanilla Prompting (VP) in the paper, while `CoT` stands for Chain-of-Thought Prompting (CoTP).
|
14 |
+
- `length_n`: the story length, i.e. the number of chapters in a story. From 1 to 3.
|
15 |
+
- `sample_n`: the numbering of different sample stories.
|
16 |
+
- `order_n`: the ToM order of the question. From 0 to 4.
|
17 |
|
18 |
+
### The `Hi-ToM_prompt` folder
|
19 |
|
20 |
+
Contains prompt files that can be directly input to API.
|
21 |
+
The data in it are almost the same as `Hi-ToM_data`, except that answers are eliminated.
|
22 |
|
23 |
+
### Generate new data and prompts
|
24 |
|
25 |
+
Run the script `generate_tomh.sh`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
actions.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
class Action(object):
|
5 |
+
|
6 |
+
def __init__(self, templates):
|
7 |
+
self.templates = templates
|
8 |
+
|
9 |
+
def render_declarative(self, *args):
|
10 |
+
assert 'declarative' in self.templates and \
|
11 |
+
len(self.templates['declarative']) > 0
|
12 |
+
return np.random.choice(self.templates['declarative']) % args
|
13 |
+
|
14 |
+
def render_interrogative(self, *args):
|
15 |
+
assert 'interrogative' in self.templates and \
|
16 |
+
len(self.templates['interrogative']) > 0, str(self.templates)
|
17 |
+
return np.random.choice(self.templates['interrogative']) % args
|
18 |
+
|
19 |
+
|
20 |
+
class ExistBeginning(Action):
|
21 |
+
|
22 |
+
def __init__(self):
|
23 |
+
templates = {
|
24 |
+
'interrogative': [
|
25 |
+
'Where was the %s at the beginning?\t%s',
|
26 |
+
'Where was the %s before?\t%s',
|
27 |
+
]
|
28 |
+
}
|
29 |
+
super().__init__(templates)
|
30 |
+
|
31 |
+
|
32 |
+
class Exist(Action):
|
33 |
+
|
34 |
+
def __init__(self):
|
35 |
+
templates = {
|
36 |
+
'interrogative': [
|
37 |
+
'Where is the %s?\t%s',
|
38 |
+
'Where is the %s located?\t%s',
|
39 |
+
]
|
40 |
+
}
|
41 |
+
super().__init__(templates)
|
42 |
+
|
43 |
+
|
44 |
+
class PlaceAction(Action):
|
45 |
+
|
46 |
+
def __init__(self):
|
47 |
+
templates = {
|
48 |
+
'declarative': [
|
49 |
+
'%s placed the %s in the %s.',
|
50 |
+
'%s put the %s in the %s.',
|
51 |
+
],
|
52 |
+
'interrogative': [
|
53 |
+
'Where did %s place the %s?\t%s',
|
54 |
+
'Where did %s put the %s?\t%s',
|
55 |
+
]
|
56 |
+
}
|
57 |
+
super().__init__(templates)
|
58 |
+
|
59 |
+
|
60 |
+
class SearchAction(Action):
|
61 |
+
|
62 |
+
def __init__(self):
|
63 |
+
templates = {
|
64 |
+
'declarative': [
|
65 |
+
'%s searched for the %s in the %s.',
|
66 |
+
'%s looked for the %s in the %s.',
|
67 |
+
],
|
68 |
+
'interrogative': [
|
69 |
+
'Where did %s search for the %s?\t%s',
|
70 |
+
'Where did %s look for the %s?\t%s',
|
71 |
+
],
|
72 |
+
}
|
73 |
+
super().__init__(templates)
|
74 |
+
|
75 |
+
|
76 |
+
class TransportAction(Action):
|
77 |
+
|
78 |
+
def __init__(self):
|
79 |
+
templates = {
|
80 |
+
'declarative': [
|
81 |
+
'%s shifted the %s from the %s to the %s.',
|
82 |
+
],
|
83 |
+
}
|
84 |
+
super().__init__(templates)
|
85 |
+
|
86 |
+
|
87 |
+
class EnterAction(Action):
|
88 |
+
|
89 |
+
def __init__(self):
|
90 |
+
templates = {
|
91 |
+
'declarative': [
|
92 |
+
'%s entered the %s.',
|
93 |
+
'%s came into the %s.',
|
94 |
+
],
|
95 |
+
}
|
96 |
+
super().__init__(templates)
|
97 |
+
|
98 |
+
|
99 |
+
class ExitAction(Action):
|
100 |
+
|
101 |
+
def __init__(self):
|
102 |
+
templates = {
|
103 |
+
'declarative': [
|
104 |
+
'%s exited the %s.',
|
105 |
+
'%s left the %s.',
|
106 |
+
'%s went out of the %s.',
|
107 |
+
],
|
108 |
+
}
|
109 |
+
super().__init__(templates)
|
110 |
+
|
111 |
+
|
112 |
+
class BelieveLocationAction(Action):
|
113 |
+
|
114 |
+
def __init__(self):
|
115 |
+
templates = {
|
116 |
+
'declarative': [
|
117 |
+
'%s thinks the %s is in the %s.',
|
118 |
+
'%s believes the %s is in the %s.',
|
119 |
+
],
|
120 |
+
'interrogative': [
|
121 |
+
'Where does %s think the %s is?\t%s',
|
122 |
+
'Where does %s believe the %s is?\t%s',
|
123 |
+
],
|
124 |
+
}
|
125 |
+
super().__init__(templates)
|
126 |
+
|
127 |
+
|
128 |
+
class BelieveAgentBelieveLocationAction(Action):
|
129 |
+
|
130 |
+
def __init__(self):
|
131 |
+
templates = {
|
132 |
+
'interrogative': [
|
133 |
+
'Where does %s think that %s believes the %s is?\t%s',
|
134 |
+
'Where does %s believe that %s believes the %s is?\t%s',
|
135 |
+
'Where does %s think that %s thinks the %s is?\t%s',
|
136 |
+
'Where does %s believe that %s thinks the %s is?\t%s',
|
137 |
+
],
|
138 |
+
}
|
139 |
+
super().__init__(templates)
|
140 |
+
|
141 |
+
|
142 |
+
class BelieveAgentSearchLocationAction(Action):
|
143 |
+
|
144 |
+
def __init__(self):
|
145 |
+
templates = {
|
146 |
+
'interrogative': [
|
147 |
+
'Where does %s think that %s looks for the %s?\t%s',
|
148 |
+
'Where does %s believe that %s looks for the %s?\t%s',
|
149 |
+
'Where does %s think that %s searches for the %s?\t%s',
|
150 |
+
'Where does %s believe that %s search for the %s?\t%s',
|
151 |
+
],
|
152 |
+
}
|
153 |
+
super().__init__(templates)
|
154 |
+
|
155 |
+
|
156 |
+
class InformLocationAction(Action):
|
157 |
+
|
158 |
+
def __init__(self):
|
159 |
+
templates = {
|
160 |
+
'declarative': [
|
161 |
+
'%s told %s that the %s is in the %s.',
|
162 |
+
'%s informed %s that the %s is in the %s.',
|
163 |
+
],
|
164 |
+
}
|
165 |
+
super().__init__(templates)
|
166 |
+
|
167 |
+
####################################################
|
168 |
+
####### Deterministic Actions for New Task #######
|
169 |
+
####################################################
|
170 |
+
|
171 |
+
class FirstQ(Action):
|
172 |
+
|
173 |
+
def __init__(self):
|
174 |
+
templates = {
|
175 |
+
'interrogative': [
|
176 |
+
'Where will %s look for the %s?\t%s',
|
177 |
+
]
|
178 |
+
}
|
179 |
+
super().__init__(templates)
|
180 |
+
|
181 |
+
class SecondQ(Action):
|
182 |
+
|
183 |
+
def __init__(self):
|
184 |
+
templates = {
|
185 |
+
'interrogative': [
|
186 |
+
'Where does %s think that %s searches for the %s?\t%s',
|
187 |
+
]
|
188 |
+
}
|
189 |
+
super().__init__(templates)
|
190 |
+
|
191 |
+
class ZeroQ(Action):
|
192 |
+
|
193 |
+
def __init__(self):
|
194 |
+
templates = {
|
195 |
+
'interrogative': [
|
196 |
+
'Where is the %s really?\t%s',
|
197 |
+
]
|
198 |
+
}
|
199 |
+
super().__init__(templates)
|
200 |
+
|
201 |
+
class MemoryAction(Action):
|
202 |
+
|
203 |
+
def __init__(self):
|
204 |
+
templates = {
|
205 |
+
'interrogative': [
|
206 |
+
'Where was the %s at the beginning?\t%s',
|
207 |
+
]
|
208 |
+
}
|
209 |
+
super().__init__(templates)
|
210 |
+
|
211 |
+
class LocationAction(Action):
|
212 |
+
|
213 |
+
def __init__(self):
|
214 |
+
templates = {
|
215 |
+
'declarative': [
|
216 |
+
'%s and %s are in the %s.',
|
217 |
+
]
|
218 |
+
}
|
219 |
+
super().__init__(templates)
|
220 |
+
|
221 |
+
class ObjectLocAction(Action):
|
222 |
+
|
223 |
+
def __init__(self):
|
224 |
+
templates = {
|
225 |
+
'declarative': [
|
226 |
+
'The %s is in the %s.',
|
227 |
+
]
|
228 |
+
}
|
229 |
+
super().__init__(templates)
|
230 |
+
|
231 |
+
class ExitedAction(Action):
|
232 |
+
|
233 |
+
def __init__(self):
|
234 |
+
templates = {
|
235 |
+
'declarative': [
|
236 |
+
'%s exited the %s.',
|
237 |
+
]
|
238 |
+
}
|
239 |
+
super().__init__(templates)
|
240 |
+
|
241 |
+
class MoveAction(Action):
|
242 |
+
|
243 |
+
def __init__(self):
|
244 |
+
templates = {
|
245 |
+
'declarative': [
|
246 |
+
'%s moved the %s to the %s.',
|
247 |
+
]
|
248 |
+
}
|
249 |
+
super().__init__(templates)
|
250 |
+
|
251 |
+
class TellAction(Action):
|
252 |
+
|
253 |
+
def __init__(self):
|
254 |
+
templates = {
|
255 |
+
'declarative': [
|
256 |
+
'%s told %s where the %s is.',
|
257 |
+
]
|
258 |
+
}
|
259 |
+
super().__init__(templates)
|
260 |
+
|
261 |
+
class EnterAction(Action):
|
262 |
+
|
263 |
+
def __init__(self):
|
264 |
+
templates = {
|
265 |
+
'declarative': [
|
266 |
+
'%s entered the %s.',
|
267 |
+
]
|
268 |
+
}
|
269 |
+
super().__init__(templates)
|
270 |
+
|
clause.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
class Clause(object):
|
5 |
+
|
6 |
+
def __init__(self, action):
|
7 |
+
|
8 |
+
# if observers is not None:
|
9 |
+
# assert 0 not in observers, "Observer IDs must be 1-indexed"
|
10 |
+
# self.observers = observers
|
11 |
+
self.action = action
|
12 |
+
|
13 |
+
def render(self):
|
14 |
+
return self.action.render_declarative() # + \
|
15 |
+
# ('\t' + ' '.join([str(x) for x in self.observers])
|
16 |
+
# if self.observers is not None else '')
|
17 |
+
|
18 |
+
|
19 |
+
class Question(Clause):
|
20 |
+
|
21 |
+
def __init__(self, idx_support, action):
|
22 |
+
self.idx_support = idx_support
|
23 |
+
super().__init__(action)
|
24 |
+
|
25 |
+
def render(self):
|
26 |
+
return self.action.render_interrogative()
|
create_world.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
SIZE_TINY = 5
|
4 |
+
SIZE_SMALL = 10
|
5 |
+
SIZE_LARGE = 30
|
6 |
+
SIZE_XLARGE = 50
|
7 |
+
|
8 |
+
locations = [
|
9 |
+
"attic",
|
10 |
+
"back_yard",
|
11 |
+
"basement",
|
12 |
+
"bathroom",
|
13 |
+
"bedroom",
|
14 |
+
"cellar",
|
15 |
+
"closet",
|
16 |
+
"crawlspace",
|
17 |
+
"den",
|
18 |
+
"dining_room",
|
19 |
+
"front_yard",
|
20 |
+
"garage",
|
21 |
+
"garden",
|
22 |
+
"hall",
|
23 |
+
"hallway",
|
24 |
+
"kitchen",
|
25 |
+
"laundry",
|
26 |
+
"living_room",
|
27 |
+
"lounge",
|
28 |
+
"master_bedroom",
|
29 |
+
"office",
|
30 |
+
"pantry",
|
31 |
+
"patio",
|
32 |
+
"playroom",
|
33 |
+
"porch",
|
34 |
+
"staircase",
|
35 |
+
"study",
|
36 |
+
"sunroom",
|
37 |
+
"TV_room",
|
38 |
+
"workshop",
|
39 |
+
]
|
40 |
+
|
41 |
+
clothing = [
|
42 |
+
"belt",
|
43 |
+
"boots",
|
44 |
+
"cap",
|
45 |
+
"coat",
|
46 |
+
"dress",
|
47 |
+
"gloves",
|
48 |
+
"hat",
|
49 |
+
"jacket",
|
50 |
+
"jeans",
|
51 |
+
"pajamas",
|
52 |
+
"pants",
|
53 |
+
"raincoat",
|
54 |
+
"scarf",
|
55 |
+
"shirt",
|
56 |
+
"shoes",
|
57 |
+
"skirt",
|
58 |
+
"slacks",
|
59 |
+
"slippers",
|
60 |
+
"socks",
|
61 |
+
"stockings",
|
62 |
+
"suit",
|
63 |
+
"sweater",
|
64 |
+
"sweatshirt",
|
65 |
+
"t-shirt",
|
66 |
+
"tie",
|
67 |
+
"trousers",
|
68 |
+
"underclothes",
|
69 |
+
"underpants",
|
70 |
+
"undershirt",
|
71 |
+
]
|
72 |
+
|
73 |
+
fruit = [
|
74 |
+
"apple",
|
75 |
+
"banana",
|
76 |
+
"cherry",
|
77 |
+
"grapefruit",
|
78 |
+
"grapes",
|
79 |
+
"lemon",
|
80 |
+
"lime",
|
81 |
+
"melon",
|
82 |
+
"orange",
|
83 |
+
"peach",
|
84 |
+
"pear",
|
85 |
+
"persimmon",
|
86 |
+
"pineapple",
|
87 |
+
"plum",
|
88 |
+
"strawberry",
|
89 |
+
"tangerine",
|
90 |
+
"watermelon",
|
91 |
+
]
|
92 |
+
|
93 |
+
vegetables = [
|
94 |
+
"asparagus",
|
95 |
+
"beans",
|
96 |
+
"broccoli",
|
97 |
+
"cabbage",
|
98 |
+
"carrot",
|
99 |
+
"celery",
|
100 |
+
"corn",
|
101 |
+
"cucumber",
|
102 |
+
"eggplant",
|
103 |
+
"green_pepper",
|
104 |
+
"lettuce",
|
105 |
+
"onion",
|
106 |
+
"peas",
|
107 |
+
"potato",
|
108 |
+
"pumpkin",
|
109 |
+
"radish",
|
110 |
+
"spinach",
|
111 |
+
"sweet_potato",
|
112 |
+
"tomato",
|
113 |
+
"turnip",
|
114 |
+
]
|
115 |
+
|
116 |
+
objects = fruit + vegetables
|
117 |
+
|
118 |
+
containers = [
|
119 |
+
"box",
|
120 |
+
"pantry",
|
121 |
+
"bathtub",
|
122 |
+
"envelope",
|
123 |
+
"drawer",
|
124 |
+
"bottle",
|
125 |
+
"cupboard",
|
126 |
+
"basket",
|
127 |
+
"crate",
|
128 |
+
"suitcase",
|
129 |
+
"bucket",
|
130 |
+
"container",
|
131 |
+
"treasure_chest",
|
132 |
+
]
|
133 |
+
|
134 |
+
colors = ['green', 'blue', 'red']
|
135 |
+
|
136 |
+
containers = ['_'.join([color, container])
|
137 |
+
for container in containers
|
138 |
+
for color in colors]
|
139 |
+
|
140 |
+
names = [
|
141 |
+
"Oliver",
|
142 |
+
"Ethan",
|
143 |
+
"Liam",
|
144 |
+
"Benjamin",
|
145 |
+
"Lucas",
|
146 |
+
"Alexander",
|
147 |
+
"Jacob",
|
148 |
+
"Mason",
|
149 |
+
"William",
|
150 |
+
"Gracie",
|
151 |
+
"James",
|
152 |
+
"Logan",
|
153 |
+
"Owen",
|
154 |
+
"Noah",
|
155 |
+
"Carter",
|
156 |
+
"Nathan",
|
157 |
+
"Jack",
|
158 |
+
"Aiden",
|
159 |
+
"Jackson",
|
160 |
+
"Jayden",
|
161 |
+
"Emma",
|
162 |
+
"Olivia",
|
163 |
+
"Emily",
|
164 |
+
"Sophia",
|
165 |
+
"Ava",
|
166 |
+
"Chloe",
|
167 |
+
"Charlotte",
|
168 |
+
"Abigail",
|
169 |
+
"Amelia",
|
170 |
+
"Ella",
|
171 |
+
"Hannah",
|
172 |
+
"Isabella",
|
173 |
+
"Aria",
|
174 |
+
"Lily",
|
175 |
+
"Mia",
|
176 |
+
"Isla",
|
177 |
+
"Avery",
|
178 |
+
"Elizabeth",
|
179 |
+
"Mila",
|
180 |
+
"Evelyn",
|
181 |
+
]
|
182 |
+
|
183 |
+
assert len(locations) >= SIZE_LARGE
|
184 |
+
assert len(objects) >= SIZE_LARGE
|
185 |
+
assert len(containers) >= SIZE_LARGE
|
186 |
+
assert len(names) >= SIZE_LARGE
|
187 |
+
|
188 |
+
|
189 |
+
def write_world(filepath, locs, objs, conts, nams):
|
190 |
+
|
191 |
+
with open(filepath, 'w') as f:
|
192 |
+
|
193 |
+
f.write('# locations\n')
|
194 |
+
|
195 |
+
for loc in locs:
|
196 |
+
|
197 |
+
f.write('\n')
|
198 |
+
f.write('create %s\n' % loc)
|
199 |
+
f.write('set %s is_thing\n' % loc)
|
200 |
+
f.write('set %s is_location\n' % loc)
|
201 |
+
|
202 |
+
f.write('\n')
|
203 |
+
f.write('# objects\n')
|
204 |
+
|
205 |
+
for obj in objs:
|
206 |
+
|
207 |
+
f.write('\n')
|
208 |
+
f.write('create %s\n' % obj)
|
209 |
+
f.write('set %s is_thing\n' % obj)
|
210 |
+
f.write('set %s is_gettable\n' % obj)
|
211 |
+
|
212 |
+
f.write('\n')
|
213 |
+
f.write('# containers\n')
|
214 |
+
|
215 |
+
for cont in conts:
|
216 |
+
|
217 |
+
f.write('\n')
|
218 |
+
f.write('create %s\n' % cont)
|
219 |
+
f.write('set %s is_thing\n' % cont)
|
220 |
+
f.write('set %s is_container\n' % cont)
|
221 |
+
|
222 |
+
f.write('\n')
|
223 |
+
f.write('# actors\n')
|
224 |
+
|
225 |
+
for nam in nams:
|
226 |
+
|
227 |
+
f.write('\n')
|
228 |
+
f.write('create %s\n' % nam)
|
229 |
+
f.write('set %s is_actor\n' % nam)
|
230 |
+
f.write('set %s is_god\n' % nam)
|
231 |
+
|
232 |
+
write_world('world_tiny.txt',
|
233 |
+
np.random.choice(locations, SIZE_TINY, replace=False),
|
234 |
+
np.random.choice(objects, SIZE_TINY, replace=False),
|
235 |
+
np.random.choice(containers, SIZE_TINY, replace=False),
|
236 |
+
np.random.choice(names, SIZE_TINY, replace=False))
|
237 |
+
|
238 |
+
write_world('world_small.txt',
|
239 |
+
np.random.choice(locations, SIZE_SMALL, replace=False),
|
240 |
+
np.random.choice(objects, SIZE_SMALL, replace=False),
|
241 |
+
np.random.choice(containers, SIZE_SMALL, replace=False),
|
242 |
+
np.random.choice(names, SIZE_SMALL, replace=False))
|
243 |
+
|
244 |
+
write_world('world_large.txt',
|
245 |
+
np.random.choice(locations, SIZE_LARGE, replace=False),
|
246 |
+
np.random.choice(objects, SIZE_LARGE, replace=False),
|
247 |
+
np.random.choice(containers, SIZE_LARGE, replace=False),
|
248 |
+
np.random.choice(names, SIZE_LARGE, replace=False))
|
dynamic_actions.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
from itertools import combinations
|
4 |
+
from itertools import permutations
|
5 |
+
|
6 |
+
|
7 |
+
class Action(object):
|
8 |
+
|
9 |
+
def __init__(self, templates):
|
10 |
+
self.templates = templates
|
11 |
+
|
12 |
+
def render_declarative(self):
|
13 |
+
assert 'declarative' in self.templates and \
|
14 |
+
len(self.templates['declarative']) > 0
|
15 |
+
return np.random.choice(self.templates['declarative'])
|
16 |
+
|
17 |
+
def render_interrogative(self):
|
18 |
+
assert 'interrogative' in self.templates and \
|
19 |
+
len(self.templates['interrogative']) > 0, str(self.templates)
|
20 |
+
return np.random.choice(self.templates['interrogative'])
|
21 |
+
|
22 |
+
|
23 |
+
class ExitAction(Action):
|
24 |
+
|
25 |
+
def __init__(self):
|
26 |
+
templates = {
|
27 |
+
'declarative': [
|
28 |
+
'%s exited the %s.',
|
29 |
+
'%s left the %s.',
|
30 |
+
'%s went out of the %s.',
|
31 |
+
],
|
32 |
+
}
|
33 |
+
super().__init__(templates)
|
34 |
+
|
35 |
+
#########################################
|
36 |
+
############### Questions ###############
|
37 |
+
#########################################
|
38 |
+
|
39 |
+
|
40 |
+
class ZeroQ(Action):
|
41 |
+
|
42 |
+
def __init__(self, oracle, obj):
|
43 |
+
|
44 |
+
fill = (obj, oracle.get_object_container(obj))
|
45 |
+
templates = {
|
46 |
+
'interrogative': [
|
47 |
+
'Question: Where is the %s really?\nAnswer: %s' % fill,
|
48 |
+
]
|
49 |
+
}
|
50 |
+
super().__init__(templates)
|
51 |
+
|
52 |
+
|
53 |
+
class FirstQ(Action):
|
54 |
+
|
55 |
+
def __init__(self, oracle, agent, obj):
|
56 |
+
fill = (agent, obj, oracle.get_first_belief(agent, obj))
|
57 |
+
templates = {
|
58 |
+
'interrogative': [
|
59 |
+
'Question: Where does %s really think the %s is?\nAnswer: %s' % fill,
|
60 |
+
]
|
61 |
+
}
|
62 |
+
super().__init__(templates)
|
63 |
+
|
64 |
+
|
65 |
+
class SecondQ(Action):
|
66 |
+
|
67 |
+
def __init__(self, oracle, a1, a2, obj):
|
68 |
+
fill = (a1, a2, obj, oracle.get_second_belief(a1, a2, obj))
|
69 |
+
templates = {
|
70 |
+
'interrogative': [
|
71 |
+
'Question: Where does %s think %s thinks the %s is?\nAnswer: %s' % fill,
|
72 |
+
]
|
73 |
+
}
|
74 |
+
super().__init__(templates)
|
75 |
+
|
76 |
+
|
77 |
+
class ThirdQ(Action):
|
78 |
+
|
79 |
+
def __init__(self, oracle, a1, a2, a3, obj):
|
80 |
+
fill = (a1, a2, a3, obj, oracle.get_third_belief(a1, a2, a3, obj))
|
81 |
+
templates = {
|
82 |
+
'interrogative': [
|
83 |
+
'Question: Where does %s think %s thinks %s thinks the %s is?\nAnswer: %s' % fill,
|
84 |
+
]
|
85 |
+
}
|
86 |
+
super().__init__(templates)
|
87 |
+
|
88 |
+
|
89 |
+
class FourthQ(Action):
|
90 |
+
|
91 |
+
def __init__(self, oracle, a1, a2, a3, a4, obj):
|
92 |
+
fill = (a1, a2, a3, a4, obj,
|
93 |
+
oracle.get_fourth_belief(a1, a2, a3, a4, obj))
|
94 |
+
templates = {
|
95 |
+
'interrogative': [
|
96 |
+
'Question: Where does %s think %s thinks %s thinks %s thinks the %s is?\nAnswer: %s' % fill,
|
97 |
+
]
|
98 |
+
}
|
99 |
+
super().__init__(templates)
|
100 |
+
|
101 |
+
# class MemoryAction(Action):
|
102 |
+
|
103 |
+
# def __init__(self, oracle_start_state, obj):
|
104 |
+
# fill = (obj, oracle_start_state[obj])
|
105 |
+
# templates = {
|
106 |
+
# 'interrogative': [
|
107 |
+
# 'Where was the %s at the beginning?\t%s' % fill,
|
108 |
+
# ]
|
109 |
+
# }
|
110 |
+
# super().__init__(templates)
|
111 |
+
|
112 |
+
# class LocationAction(Action):
|
113 |
+
# def __init__(self, oracle, args):
|
114 |
+
# """
|
115 |
+
# Creaters string with args and modifies
|
116 |
+
# oracle in accordance with action.
|
117 |
+
# """
|
118 |
+
# if len(args) == 2:
|
119 |
+
# statement = '%s is in the %s.' % args
|
120 |
+
# a1, loc = args
|
121 |
+
# # may be redundant
|
122 |
+
# oracle.set_location(a1, loc)
|
123 |
+
# else : # 2 people
|
124 |
+
# statement = '%s and %s are in the %s.' % args
|
125 |
+
# a1, a2, loc = args
|
126 |
+
# # may be redundant
|
127 |
+
# oracle.set_location(a1, loc)
|
128 |
+
# oracle.set_location(a2, loc)
|
129 |
+
|
130 |
+
# templates = {
|
131 |
+
# 'declarative': [
|
132 |
+
# statement,
|
133 |
+
# ]
|
134 |
+
# }
|
135 |
+
|
136 |
+
# super().__init__(templates)
|
137 |
+
|
138 |
+
|
139 |
+
class ObjectLocAction(Action):
|
140 |
+
|
141 |
+
def __init__(self, oracle, obj, observers):
|
142 |
+
container = oracle.get_object_container(obj)
|
143 |
+
templates = {
|
144 |
+
'declarative': [
|
145 |
+
'The %s is in the %s.' % (obj, container),
|
146 |
+
]
|
147 |
+
}
|
148 |
+
|
149 |
+
# set first beliefs
|
150 |
+
for observer in observers:
|
151 |
+
oracle.set_first_belief(observer, obj, container)
|
152 |
+
|
153 |
+
# set second beliefs
|
154 |
+
if len(observers) >= 2:
|
155 |
+
for observer1, observer2 in combinations(observers, 2):
|
156 |
+
oracle.set_second_belief(observer1, observer2, obj, container)
|
157 |
+
oracle.set_second_belief(observer2, observer1, obj, container)
|
158 |
+
|
159 |
+
# set third beliefs
|
160 |
+
if len(observers) >= 3:
|
161 |
+
for chosen_observers in combinations(observers, 3):
|
162 |
+
for observer1, observer2, observer3 in permutations(chosen_observers):
|
163 |
+
oracle.set_third_belief(
|
164 |
+
observer1, observer2, observer3, obj, container)
|
165 |
+
|
166 |
+
# set fourth beliefs
|
167 |
+
if len(observers) >= 4:
|
168 |
+
for chosen_observers in combinations(observers, 4):
|
169 |
+
for observer1, observer2, observer3, observer4 in permutations(chosen_observers):
|
170 |
+
oracle.set_fourth_belief(
|
171 |
+
observer1, observer2, observer3, observer4, obj, container)
|
172 |
+
super().__init__(templates)
|
173 |
+
|
174 |
+
|
175 |
+
class ExitedAction(Action):
|
176 |
+
|
177 |
+
def __init__(self, oracle, agent):
|
178 |
+
fill = (agent, oracle.get_location(agent))
|
179 |
+
|
180 |
+
templates = {
|
181 |
+
'declarative': [
|
182 |
+
'%s exited the %s.' % fill,
|
183 |
+
]
|
184 |
+
}
|
185 |
+
oracle.set_location(agent, None)
|
186 |
+
super().__init__(templates)
|
187 |
+
|
188 |
+
|
189 |
+
class MoveAction(Action):
|
190 |
+
|
191 |
+
def __init__(self, oracle, args, observers=None, move=True):
|
192 |
+
agent, obj, container = args
|
193 |
+
if not move:
|
194 |
+
location = oracle.get_container_location(container)
|
195 |
+
templates = {
|
196 |
+
'declarative': [
|
197 |
+
f'{args[0]} made no movements and stayed in the {location} for 1 minute.',
|
198 |
+
]
|
199 |
+
}
|
200 |
+
|
201 |
+
else:
|
202 |
+
templates = {
|
203 |
+
'declarative': [
|
204 |
+
'%s moved the %s to the %s.' % args,
|
205 |
+
]
|
206 |
+
}
|
207 |
+
|
208 |
+
oracle.set_object_container(obj, container)
|
209 |
+
|
210 |
+
if not observers:
|
211 |
+
observers = []
|
212 |
+
observers.append(agent)
|
213 |
+
|
214 |
+
# set first beliefs
|
215 |
+
for observer in observers:
|
216 |
+
oracle.set_first_belief(observer, obj, container)
|
217 |
+
|
218 |
+
# set second beliefs
|
219 |
+
if len(observers) >= 2:
|
220 |
+
for observer1, observer2 in combinations(observers, 2):
|
221 |
+
oracle.set_second_belief(
|
222 |
+
observer1, observer2, obj, container)
|
223 |
+
oracle.set_second_belief(
|
224 |
+
observer2, observer1, obj, container)
|
225 |
+
|
226 |
+
# set third beliefs
|
227 |
+
if len(observers) >= 3:
|
228 |
+
for chosen_observers in combinations(observers, 3):
|
229 |
+
for observer1, observer2, observer3 in permutations(chosen_observers):
|
230 |
+
oracle.set_third_belief(
|
231 |
+
observer1, observer2, observer3, obj, container)
|
232 |
+
|
233 |
+
# set fourth beliefs
|
234 |
+
if len(observers) >= 4:
|
235 |
+
for chosen_observers in combinations(observers, 4):
|
236 |
+
for observer1, observer2, observer3, observer4 in permutations(chosen_observers):
|
237 |
+
oracle.set_fourth_belief(
|
238 |
+
observer1, observer2, observer3, observer4, obj, container)
|
239 |
+
|
240 |
+
super().__init__(templates)
|
241 |
+
|
242 |
+
|
243 |
+
class PublicTellAction(Action):
|
244 |
+
|
245 |
+
def __init__(self, oracle, speaker, obj, container, listeners=None, believers=None):
|
246 |
+
templates = {
|
247 |
+
'declarative': [
|
248 |
+
'%s publicly claimed that %s is in the %s now.' % (
|
249 |
+
speaker, obj, container),
|
250 |
+
]
|
251 |
+
}
|
252 |
+
disbelievers = [
|
253 |
+
listener for listener in listeners if listener not in believers]
|
254 |
+
|
255 |
+
# All listeners would think others believe the claim
|
256 |
+
# for believer in believers:
|
257 |
+
# for disbeliever in disbelievers:
|
258 |
+
# oracle.set_second_belief(believer, disbeliever, obj, container)
|
259 |
+
# oracle.set_second_belief(disbeliever, believer, obj, container)
|
260 |
+
|
261 |
+
# A believer would think speaker also believes the obj is in container, speaker would think his words are trusted
|
262 |
+
for believer in believers:
|
263 |
+
oracle.set_first_belief(believer, obj, container)
|
264 |
+
oracle.set_second_belief(believer, speaker, obj, container)
|
265 |
+
oracle.set_second_belief(speaker, believer, obj, container)
|
266 |
+
|
267 |
+
for disbeliever in disbelievers:
|
268 |
+
oracle.set_second_belief(speaker, disbeliever, obj, container)
|
269 |
+
|
270 |
+
# for listener in listeners:
|
271 |
+
# # the speaker believes that all the listeners believe him
|
272 |
+
# oracle.set_second_belief(speaker, listener, obj, container)
|
273 |
+
# # all listeners know the believers based on the exiting order
|
274 |
+
# for believer in believers:
|
275 |
+
# oracle.set_second_belief(listener, believer, obj, container)
|
276 |
+
|
277 |
+
super().__init__(templates)
|
278 |
+
|
279 |
+
|
280 |
+
class PrivateTellAction(Action):
|
281 |
+
|
282 |
+
def __init__(self, oracle, speaker, listener, obj, container, trust=True):
|
283 |
+
templates = {
|
284 |
+
'declarative': [
|
285 |
+
'%s privately told %s that the %s is in the %s now.' % (
|
286 |
+
speaker, listener, obj, container),
|
287 |
+
]
|
288 |
+
}
|
289 |
+
|
290 |
+
# when the listener has less information (exit the room earlier), he'll trust the speaker
|
291 |
+
if trust:
|
292 |
+
oracle.set_first_belief(listener, obj, container)
|
293 |
+
oracle.set_second_belief(listener, speaker, obj, container)
|
294 |
+
oracle.set_second_belief(speaker, listener, obj, container)
|
295 |
+
super().__init__(templates)
|
296 |
+
|
297 |
+
|
298 |
+
class EnterAction(Action):
|
299 |
+
|
300 |
+
def __init__(self, oracle, args, observers=None, no_world_adjust=False):
|
301 |
+
templates = {
|
302 |
+
'declarative': [
|
303 |
+
', '.join(args[:-2]) + ' and ' + args[-2] +
|
304 |
+
' entered the ' + args[-1] + '.',
|
305 |
+
]
|
306 |
+
}
|
307 |
+
|
308 |
+
agents = args[:-1]
|
309 |
+
location = args[-1]
|
310 |
+
if location == 'waiting_room':
|
311 |
+
super().__init__(templates)
|
312 |
+
return
|
313 |
+
for agent in agents:
|
314 |
+
oracle.set_location(agent, location)
|
315 |
+
objs = oracle.get_objects_at_location(location)
|
316 |
+
observers = agents
|
317 |
+
|
318 |
+
# agent knows location of everything
|
319 |
+
if not no_world_adjust:
|
320 |
+
for obj in objs:
|
321 |
+
container = oracle.get_object_container(obj)
|
322 |
+
# oracle.set_first_belief(agent, obj, container)
|
323 |
+
# set first beliefs
|
324 |
+
if len(observers) >= 1:
|
325 |
+
for observer in observers:
|
326 |
+
oracle.set_first_belief(observer, obj, container)
|
327 |
+
|
328 |
+
# set second beliefs
|
329 |
+
if len(observers) >= 2:
|
330 |
+
for observer1, observer2 in combinations(observers, 2):
|
331 |
+
oracle.set_second_belief(
|
332 |
+
observer1, observer2, obj, container)
|
333 |
+
oracle.set_second_belief(
|
334 |
+
observer2, observer1, obj, container)
|
335 |
+
|
336 |
+
# set third beliefs
|
337 |
+
if len(observers) >= 3:
|
338 |
+
for chosen_observers in combinations(observers, 3):
|
339 |
+
for observer1, observer2, observer3 in permutations(chosen_observers):
|
340 |
+
oracle.set_third_belief(
|
341 |
+
observer1, observer2, observer3, obj, container)
|
342 |
+
|
343 |
+
# set fourth beliefs
|
344 |
+
if len(observers) >= 4:
|
345 |
+
for chosen_observers in combinations(observers, 4):
|
346 |
+
for observer1, observer2, observer3, observer4 in permutations(chosen_observers):
|
347 |
+
oracle.set_fourth_belief(
|
348 |
+
observer1, observer2, observer3, observer4, obj, container)
|
349 |
+
|
350 |
+
super().__init__(templates)
|
351 |
+
|
352 |
+
|
353 |
+
class NoiseAction(Action):
|
354 |
+
|
355 |
+
def __init__(self, agents, containers, objects):
|
356 |
+
animals = ['cat', 'dog', 'monkey', 'mouse']
|
357 |
+
personal_items = ['watch', 'gloves', 'phone']
|
358 |
+
distractors = [
|
359 |
+
f'{random.choice(agents)} saw a {random.choice(animals)}.',
|
360 |
+
f'{random.choice(agents)} lost his {random.choice(personal_items)}.',
|
361 |
+
f'{random.choice(agents)} likes the {random.choice(containers)}.',
|
362 |
+
f'{random.choice(agents)} dislikes the {random.choice(objects)}.',
|
363 |
+
]
|
364 |
+
templates = {
|
365 |
+
'declarative': [
|
366 |
+
random.choice(distractors)
|
367 |
+
]
|
368 |
+
}
|
369 |
+
super().__init__(templates)
|
generate_prompts.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import itertools
|
4 |
+
|
5 |
+
|
6 |
+
def main():
|
7 |
+
input_folder = 'data_ToMh'
|
8 |
+
output_folder = 'prompt_ToMh'
|
9 |
+
lengths = [1, 2, 3]
|
10 |
+
orders = [0, 1, 2, 3, 4]
|
11 |
+
prompts = ['CoT', 'MC']
|
12 |
+
tells = ['No_Tell', 'Tell']
|
13 |
+
for tell, prompt, length, order, sample_num in itertools.product(tells, prompts, lengths, orders, range(1, 21)):
|
14 |
+
input_fn = os.path.join(input_folder, tell, prompt, f'length_{length}', f'sample_{sample_num}',
|
15 |
+
f'order_{order}.txt')
|
16 |
+
output_fn = os.path.join(output_folder, tell, prompt, f'length_{length}', f'sample_{sample_num}',
|
17 |
+
f'order_{order}.txt')
|
18 |
+
with open(input_fn, 'r') as file:
|
19 |
+
lines = file.readlines()
|
20 |
+
new_lines = [line for line in lines if line ==
|
21 |
+
'\n' or line.split()[0] != 'Answer:']
|
22 |
+
if not os.path.exists(os.path.join(output_folder, tell, prompt, f'length_{length}', f'sample_{sample_num}')):
|
23 |
+
os.makedirs(os.path.join(output_folder, tell, prompt,
|
24 |
+
f'length_{length}', f'sample_{sample_num}'))
|
25 |
+
with open(output_fn, 'w') as file:
|
26 |
+
file.writelines(new_lines)
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
if __name__ == "__main__":
|
31 |
+
sys.exit(main())
|
generate_tasks.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import glob
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
import random
|
8 |
+
import itertools
|
9 |
+
|
10 |
+
from stringify import stringify
|
11 |
+
from tasks import Specify_Tasks
|
12 |
+
from utils import is_file, mkdir_p, remove_extension
|
13 |
+
from world import World
|
14 |
+
|
15 |
+
|
16 |
+
def generate_story_with_specified_chapters(
|
17 |
+
world_paths, output_dir_path, n, noise=0.1, train_noise=False, order=-1, num_chapter=-1, exist_tell_in_story=False, prompt='CoT', exist_answer=False
|
18 |
+
): # prompt is dummy
|
19 |
+
"""Generates stories with guarantee that each task is seen n times."""
|
20 |
+
mkdir_p(output_dir_path)
|
21 |
+
n = n[0]
|
22 |
+
|
23 |
+
for world in world_paths:
|
24 |
+
|
25 |
+
w = World()
|
26 |
+
w.load(world)
|
27 |
+
world_name = remove_extension(world)
|
28 |
+
|
29 |
+
# Define task creator and task types
|
30 |
+
task = Specify_Tasks()
|
31 |
+
tasks_per_length = np.array([
|
32 |
+
[('A5', True)], # 1 chapter
|
33 |
+
[('A5', False), ('A3', True)], # 2 chapters
|
34 |
+
[('A5', True), ('A3', False), ('A4', True)], # 3 chapters
|
35 |
+
[('A5', False), ('A3', True),
|
36 |
+
('A4', False), ('A2', True)], # 4 chapters
|
37 |
+
], dtype=object)
|
38 |
+
|
39 |
+
# If order and num_chapter are not specified
|
40 |
+
orders = [0, 1, 2, 3, 4] if order == -1 else [order]
|
41 |
+
num_chapters = [1, 2, 3] if num_chapter == -1 else [num_chapter]
|
42 |
+
modes = ['MC', 'CoT']
|
43 |
+
for length_of_story in num_chapters:
|
44 |
+
# Create folder to contain data
|
45 |
+
folder_name_2 = f'length_{length_of_story}'
|
46 |
+
logging.info("Creating New task in %s..." % folder_name_2)
|
47 |
+
|
48 |
+
for i in range(1, n+1):
|
49 |
+
folder_name_3 = f'sample_{i}'
|
50 |
+
story = task.generate_story_qs_at_end(
|
51 |
+
w, length_of_story, tasks_per_length[length_of_story -
|
52 |
+
1], num_agents=5,
|
53 |
+
num_locations=3, statement_noise=noise, order=0, exist_tell_in_story=exist_tell_in_story
|
54 |
+
) # order = 0 is dummy here.
|
55 |
+
for mode in modes:
|
56 |
+
folder_name_1 = mode
|
57 |
+
for order_of_story in orders:
|
58 |
+
file_name = f'order_{order_of_story}.txt'
|
59 |
+
os.makedirs(os.path.join(
|
60 |
+
output_dir_path, folder_name_1, folder_name_2, folder_name_3), exist_ok=True)
|
61 |
+
path = os.path.join(
|
62 |
+
output_dir_path, folder_name_1, folder_name_2, folder_name_3, file_name)
|
63 |
+
|
64 |
+
with open(path, 'w', encoding='utf-8') as f:
|
65 |
+
if mode == 'MC':
|
66 |
+
f.write(
|
67 |
+
'The following story happens in chronological order. You will be given a multiple-choice question and a note at the end. Directly output the answer without explanation.\n')
|
68 |
+
else:
|
69 |
+
f.write(
|
70 |
+
'The following story happens in chronological order. You will be given a multiple-choice question and a note at the end. First give step-by-step analysis about the question. Then output the answer.\n')
|
71 |
+
f.write('Story:\n')
|
72 |
+
<<<<<<< HEAD
|
73 |
+
f.write('\n'.join(stringify(story, exist_answer=exist_answer, order=order_of_story))) #exist_answer is dummy
|
74 |
+
f.write('\nNote: You should assume the following. (1) A character witnesses everything and every movements before exiting a location. (2) A character A can infer another character B\'s mental state only if A and B have been in the same location, or have private or public interactions. (3) Note that every character tend to lie. A character tend to trust a character that exited the room later than himself. (4) Agents in private communications know that others won\'t hear them, but anyone can hear any public claims. (5) What a character tells others does not affect his actual belief.\n')
|
75 |
+
=======
|
76 |
+
# exist_answer is dummy
|
77 |
+
f.write(
|
78 |
+
'\n'.join(stringify(story, exist_answer=exist_answer, order=order_of_story)))
|
79 |
+
f.write('\nNote: You should assume the following. (1) An agent witnesses everything and every movement before exiting a location. (2) An agent A can infer another agent B\'s mental state only if A and B have been in the same location, or have private or public interactions. (3) Note that every agent tends to lie. What an agent A tells others doesn\'t affect A\'s actual belief. An agent tends to trust an agent that exited the room later than himself. The exit order is known to all agents. (4) Agents in private communications know that others won\'t hear them, but they know that anyone can hear any public claims.\n')
|
80 |
+
>>>>>>> 50242d0343261b6c95293fc995711b384ff3c1fe
|
81 |
+
|
82 |
+
|
83 |
+
def parse_args(args):
|
84 |
+
|
85 |
+
parser = argparse.ArgumentParser(
|
86 |
+
description='Process command-line arguments.'
|
87 |
+
)
|
88 |
+
|
89 |
+
parser.add_argument(
|
90 |
+
'-w', '--world_path', dest='world_paths', type=is_file, required=True,
|
91 |
+
action='append', help='Path to a world definition file'
|
92 |
+
)
|
93 |
+
|
94 |
+
parser.add_argument(
|
95 |
+
'-o', '--output_dir_path', dest='output_dir_path', type=mkdir_p,
|
96 |
+
default='data_ToMh', help='Output directory path'
|
97 |
+
)
|
98 |
+
|
99 |
+
# parser.add_argument(
|
100 |
+
# '-b', '--babi_dir_path', dest='babi_dir_path', type=str, required=True,
|
101 |
+
# help='Path to directory containing the 20 bAbi task train + test data'
|
102 |
+
# )
|
103 |
+
|
104 |
+
parser.add_argument(
|
105 |
+
'-l', '--logging', type=str, default='INFO', metavar='logging',
|
106 |
+
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
107 |
+
help='Logging level'
|
108 |
+
)
|
109 |
+
|
110 |
+
parser.add_argument(
|
111 |
+
'-n', '--num_stories', dest='num_stories_choices', type=int,
|
112 |
+
action='append', required=True,
|
113 |
+
help='Number of stories (examples) in a task)'
|
114 |
+
)
|
115 |
+
|
116 |
+
parser.add_argument(
|
117 |
+
'-ptn', '--prob_test_noise', dest='test_noise', type=float,
|
118 |
+
required=True, help='Probability of encountering random noise sentence'
|
119 |
+
)
|
120 |
+
|
121 |
+
parser.add_argument(
|
122 |
+
'-tn', '--train_noise', dest='train_noise', type=bool, default=False,
|
123 |
+
help='Whether or not to include noise at training time'
|
124 |
+
)
|
125 |
+
parser.add_argument(
|
126 |
+
'-ord', '--order', dest='order', type=int, default=-1,
|
127 |
+
help='The range of question orders'
|
128 |
+
)
|
129 |
+
parser.add_argument(
|
130 |
+
'-len', '--length', dest='num_chapter', type=int, default=-1,
|
131 |
+
help='The range of story lengths'
|
132 |
+
)
|
133 |
+
parser.add_argument(
|
134 |
+
'-t', '--tell', dest='exist_tell', type=bool, default=False,
|
135 |
+
help='Whether or not the story has communications between agents'
|
136 |
+
)
|
137 |
+
parser.add_argument(
|
138 |
+
'-p', '--prompt', dest='prompt_type', type=str, default='CoT',
|
139 |
+
choices=['MC', 'CoT'],
|
140 |
+
help='The type of prompt chosen between MC and CoT'
|
141 |
+
)
|
142 |
+
parser.add_argument(
|
143 |
+
'-a', '--answer', dest='exist_answer', type=bool, default=False,
|
144 |
+
help='Whether or not the data has answers'
|
145 |
+
)
|
146 |
+
|
147 |
+
parsed = parser.parse_args(args)
|
148 |
+
|
149 |
+
return parsed
|
150 |
+
|
151 |
+
|
152 |
+
def main(args=sys.argv[1:]):
|
153 |
+
"""Main function to generate all the story-question pairs."""
|
154 |
+
args = parse_args(args)
|
155 |
+
logging.basicConfig(
|
156 |
+
level=args.logging, format='%(asctime)s\t%(levelname)-8s\t%(message)s'
|
157 |
+
)
|
158 |
+
folder_name = 'Tell/' if args.exist_tell else 'No_Tell/'
|
159 |
+
|
160 |
+
# folder_name += args.prompt_type
|
161 |
+
# output_dir_path = os.path.join(args.output_dir_path, folder_name) if args.exist_answer else os.path.join('prompt_ToMh', folder_name)
|
162 |
+
|
163 |
+
output_dir_path = os.path.join(args.output_dir_path, folder_name)
|
164 |
+
|
165 |
+
generate_story_with_specified_chapters(
|
166 |
+
world_paths=args.world_paths,
|
167 |
+
output_dir_path=output_dir_path,
|
168 |
+
n=args.num_stories_choices,
|
169 |
+
noise=args.test_noise,
|
170 |
+
train_noise=args.train_noise,
|
171 |
+
order=args.order,
|
172 |
+
num_chapter=args.num_chapter,
|
173 |
+
exist_tell_in_story=args.exist_tell,
|
174 |
+
prompt=args.prompt_type,
|
175 |
+
exist_answer=args.exist_answer,
|
176 |
+
)
|
177 |
+
|
178 |
+
|
179 |
+
if __name__ == "__main__":
|
180 |
+
sys.exit(main())
|
oracle.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
The Oracle class keeps track of all object
|
3 |
+
and agent locations as well as a map of
|
4 |
+
beliefs about object and agent locations.
|
5 |
+
"""
|
6 |
+
import copy
|
7 |
+
|
8 |
+
class LocationMap(object):
|
9 |
+
|
10 |
+
def __init__(self, agents, locations, objects, containers):
|
11 |
+
|
12 |
+
# Maps agents to their locations.
|
13 |
+
self.locations = {agent : None for agent in agents}
|
14 |
+
|
15 |
+
# Maps agents to their locations.
|
16 |
+
self.container_locations = {container : None for container in containers}
|
17 |
+
|
18 |
+
# Maps locations to their containers
|
19 |
+
self.containers = {location : None for location in locations}
|
20 |
+
|
21 |
+
# Maps containers to the objects they hold
|
22 |
+
self.container_objs = {container : [] for container in containers}
|
23 |
+
|
24 |
+
# Maps objects to their containers
|
25 |
+
self.obj_containers = {obj : None for obj in objects}
|
26 |
+
|
27 |
+
class MemoryMap(object):
|
28 |
+
|
29 |
+
def __init__(self, agents, objects):
|
30 |
+
|
31 |
+
zero_dict = {obj : None for obj in objects}
|
32 |
+
first_dict = {agent : copy.deepcopy(zero_dict) for agent in agents}
|
33 |
+
second_dict = {agent : copy.deepcopy(first_dict) for agent in agents}
|
34 |
+
third_dict = {agent : copy.deepcopy(second_dict) for agent in agents}
|
35 |
+
fourth_dict = {agent : copy.deepcopy(third_dict) for agent in agents}
|
36 |
+
|
37 |
+
# Dictionary of dictionaries mapping
|
38 |
+
# agents to objects to containers. Represents
|
39 |
+
# agents' belief about location of containers.
|
40 |
+
self.first_belief = copy.deepcopy(first_dict)
|
41 |
+
|
42 |
+
# Dictionary of dictionaries of dictionaries
|
43 |
+
# mapping agents to direct belief dictionaries.
|
44 |
+
# Represents agents' belief about other agents'
|
45 |
+
# beliefs about location of containers.
|
46 |
+
self.second_belief = copy.deepcopy(second_dict)
|
47 |
+
self.third_belief = copy.deepcopy(third_dict)
|
48 |
+
self.fourth_belief = copy.deepcopy(fourth_dict)
|
49 |
+
|
50 |
+
class Oracle(object):
|
51 |
+
|
52 |
+
def __init__(self, agents, locations, objects, containers):
|
53 |
+
self.memory_map = MemoryMap(agents, objects)
|
54 |
+
self.locations = LocationMap(agents, locations, objects, containers)
|
55 |
+
|
56 |
+
#########################################
|
57 |
+
################ Beliefs ################
|
58 |
+
#########################################
|
59 |
+
|
60 |
+
def get_first_belief(self, agent, obj):
|
61 |
+
beliefs = self.memory_map.first_belief
|
62 |
+
return beliefs[agent][obj]
|
63 |
+
|
64 |
+
def set_first_belief(self, agent, obj, container):
|
65 |
+
beliefs = self.memory_map.first_belief
|
66 |
+
beliefs[agent][obj] = container
|
67 |
+
|
68 |
+
def get_second_belief(self, a1, a2, obj):
|
69 |
+
second_belief = self.memory_map.second_belief
|
70 |
+
return second_belief[a1][a2][obj]
|
71 |
+
|
72 |
+
def set_second_belief(self, a1, a2, obj, container):
|
73 |
+
second_belief = self.memory_map.second_belief
|
74 |
+
second_belief[a1][a2][obj] = container
|
75 |
+
|
76 |
+
def get_third_belief(self, a1, a2, a3, obj):
|
77 |
+
third_belief = self.memory_map.third_belief
|
78 |
+
return third_belief[a1][a2][a3][obj]
|
79 |
+
|
80 |
+
def set_third_belief(self, a1, a2, a3, obj, container):
|
81 |
+
third_belief = self.memory_map.third_belief
|
82 |
+
third_belief[a1][a2][a3][obj] = container
|
83 |
+
|
84 |
+
def get_fourth_belief(self, a1, a2, a3, a4, obj):
|
85 |
+
fourth_belief = self.memory_map.fourth_belief
|
86 |
+
return fourth_belief[a1][a2][a3][a4][obj]
|
87 |
+
|
88 |
+
def set_fourth_belief(self, a1, a2, a3, a4, obj, container):
|
89 |
+
fourth_belief = self.memory_map.fourth_belief
|
90 |
+
fourth_belief[a1][a2][a3][a4][obj] = container
|
91 |
+
|
92 |
+
#########################################
|
93 |
+
############### Locations ###############
|
94 |
+
#########################################
|
95 |
+
|
96 |
+
def get_location(self, agent):
|
97 |
+
return self.locations.locations[agent]
|
98 |
+
|
99 |
+
def set_location(self, agent, location):
|
100 |
+
self.locations.locations[agent] = location
|
101 |
+
|
102 |
+
def get_containers(self, location):
|
103 |
+
# Returns a list of containers at location
|
104 |
+
return self.locations.containers[location]
|
105 |
+
|
106 |
+
def set_containers(self, location, containers):
|
107 |
+
# May need to change to move containers bt locs
|
108 |
+
# Containers is a list of containers at location
|
109 |
+
for container in containers:
|
110 |
+
self._set_container_location(container, location)
|
111 |
+
self.locations.containers[location] = containers
|
112 |
+
|
113 |
+
def get_objects_at_location(self, location):
|
114 |
+
objects = []
|
115 |
+
for container in self.get_containers(location):
|
116 |
+
objects.extend(self.get_container_obj(container))
|
117 |
+
return objects
|
118 |
+
|
119 |
+
def get_container_location(self, container):
|
120 |
+
return self.locations.container_locations[container]
|
121 |
+
|
122 |
+
def _set_container_location(self, container, location):
|
123 |
+
self.locations.container_locations[container] = location
|
124 |
+
|
125 |
+
def get_container_obj(self, container):
|
126 |
+
# get list of objects in container
|
127 |
+
return self.locations.container_objs[container]
|
128 |
+
|
129 |
+
def _add_container_obj(self, container, obj):
|
130 |
+
self.locations.container_objs[container].append(obj)
|
131 |
+
|
132 |
+
def _remove_container_obj(self, container, obj):
|
133 |
+
self.locations.container_objs[container].remove(obj)
|
134 |
+
|
135 |
+
def get_object_container(self, obj):
|
136 |
+
# get container that holds object
|
137 |
+
return self.locations.obj_containers[obj]
|
138 |
+
|
139 |
+
def set_object_container(self, obj, container):
|
140 |
+
# set container that holds object
|
141 |
+
prev_container = self.get_object_container(obj)
|
142 |
+
if prev_container:
|
143 |
+
self._remove_container_obj(prev_container, obj)
|
144 |
+
self._add_container_obj(container, obj)
|
145 |
+
self.locations.obj_containers[obj] = container
|
146 |
+
|
147 |
+
|
stringify.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def stringify(story, exist_answer=False, order=0): # exist_answer is dummy
|
5 |
+
|
6 |
+
lines = []
|
7 |
+
|
8 |
+
i = 0 # The number of descriptions processed
|
9 |
+
j = 0 # The number of lines output
|
10 |
+
count_order = 0
|
11 |
+
|
12 |
+
while True:
|
13 |
+
|
14 |
+
|
15 |
+
if isinstance(story[i], str):
|
16 |
+
line = story[i]
|
17 |
+
else:
|
18 |
+
line = story[i].render()
|
19 |
+
# Capitalize the line
|
20 |
+
line = line[0].upper() + line[1:]
|
21 |
+
|
22 |
+
# Prepend the line number
|
23 |
+
if line.split()[0] != 'Question:' and line.split()[0] != 'Choices:':
|
24 |
+
line = '%d %s' % (i + 1, line)
|
25 |
+
else: # Start with 'Choice'
|
26 |
+
if line.split()[0] == 'Choices:':
|
27 |
+
lines.append(line)
|
28 |
+
break
|
29 |
+
else: # Start with 'Question'
|
30 |
+
if count_order == order:
|
31 |
+
lines.append(line)
|
32 |
+
count_order += 1
|
33 |
+
i += 1
|
34 |
+
continue
|
35 |
+
lines.append(line)
|
36 |
+
# Increment counters
|
37 |
+
i += 1
|
38 |
+
|
39 |
+
# Append supporting lines indices if necessary
|
40 |
+
# if hasattr(story[i], 'idx_support') and story[i].idx_support:
|
41 |
+
# line += '\t%s' % ' '.join([str(x + 1)
|
42 |
+
# for x in story[i].idx_support])
|
43 |
+
|
44 |
+
if i >= len(story):
|
45 |
+
break
|
46 |
+
|
47 |
+
return lines
|
tasks.py
ADDED
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
import copy
|
4 |
+
|
5 |
+
from clause import Clause, Question
|
6 |
+
from oracle import Oracle
|
7 |
+
from dynamic_actions import *
|
8 |
+
from collections import defaultdict
|
9 |
+
|
10 |
+
|
11 |
+
def sample_question(oracle_start_state, oracle, random_actors, obj, question_idx=0):
|
12 |
+
idx_dummy = [0]
|
13 |
+
a1, a2, a3, a4, _ = random_actors
|
14 |
+
questions = [Question(idx_dummy, ZeroQ(oracle, obj)),
|
15 |
+
Question(idx_dummy, FirstQ(oracle, a4, obj)),
|
16 |
+
Question(idx_dummy, SecondQ(oracle, a3, a4, obj)),
|
17 |
+
Question(idx_dummy, ThirdQ(oracle, a2, a3, a4, obj)),
|
18 |
+
Question(idx_dummy, FourthQ(oracle, a1, a2, a3, a4, obj))]
|
19 |
+
return questions[question_idx]
|
20 |
+
|
21 |
+
#######################################
|
22 |
+
############## Chapters ###############
|
23 |
+
#######################################
|
24 |
+
|
25 |
+
|
26 |
+
def write_A2_chapter(
|
27 |
+
start_state, oracle, obj, location, agent_ids, all_agents, movements=None, exist_tell=False, questions=None
|
28 |
+
):
|
29 |
+
a1, a2 = all_agents[agent_ids[0]], all_agents[agent_ids[1]]
|
30 |
+
outsiders = [agent for agent in all_agents if agent not in [a1, a2]]
|
31 |
+
agent_ids = [aid+1 for aid in agent_ids]
|
32 |
+
|
33 |
+
# Pick containers. The first element is the initial container of obj
|
34 |
+
containers = [oracle.get_object_container(obj)]
|
35 |
+
container_candidates = oracle.get_containers(location)[:]
|
36 |
+
container_candidates.remove(containers[0])
|
37 |
+
containers += random.sample(container_candidates, 2)
|
38 |
+
|
39 |
+
# Fill in the chapter
|
40 |
+
chapter = []
|
41 |
+
|
42 |
+
# All selected agents enter the room and see the object
|
43 |
+
chapter.extend([
|
44 |
+
Clause(EnterAction(oracle, (a1, a2, location))),
|
45 |
+
Clause(ObjectLocAction(oracle, obj, [a1, a2])),
|
46 |
+
])
|
47 |
+
|
48 |
+
# a1
|
49 |
+
chapter.extend([
|
50 |
+
Clause(MoveAction(oracle, (a1, obj, containers[1]), [
|
51 |
+
a2], move=movements[0])),
|
52 |
+
Clause(ExitedAction(oracle, (a1)))
|
53 |
+
])
|
54 |
+
# a2
|
55 |
+
chapter.extend([
|
56 |
+
Clause(MoveAction(
|
57 |
+
oracle, (a2, obj, containers[2]), None, move=movements[1])),
|
58 |
+
Clause(ExitedAction(oracle, (a2)))
|
59 |
+
])
|
60 |
+
|
61 |
+
# Everyone enter the waiting room
|
62 |
+
chapter.extend([
|
63 |
+
Clause(EnterAction(oracle, (a1, a2, 'waiting_room')))
|
64 |
+
])
|
65 |
+
|
66 |
+
# tell actions has 3 different forms
|
67 |
+
if exist_tell:
|
68 |
+
tell_containers = random.sample(oracle.get_containers(location)[:], 2)
|
69 |
+
tell_form = random.choice(
|
70 |
+
range(3)) if outsiders else random.choice(range(2))
|
71 |
+
match tell_form:
|
72 |
+
case 0:
|
73 |
+
chapter.extend([
|
74 |
+
Clause(PublicTellAction(
|
75 |
+
oracle, a1, obj, tell_containers[0], listeners=all_agents, believers=outsiders)),
|
76 |
+
Clause(PrivateTellAction(oracle, a2, a1,
|
77 |
+
obj, tell_containers[1], trust=True)),
|
78 |
+
])
|
79 |
+
case 1:
|
80 |
+
chapter.extend([
|
81 |
+
Clause(PublicTellAction(
|
82 |
+
oracle, a2, obj, tell_containers[0], listeners=all_agents, believers=[a1] + outsiders)),
|
83 |
+
Clause(PrivateTellAction(oracle, a1, a2, obj,
|
84 |
+
tell_containers[1], trust=False)),
|
85 |
+
])
|
86 |
+
case 2:
|
87 |
+
chapter.extend([
|
88 |
+
Clause(PrivateTellAction(oracle, a1, random.choice(outsiders),
|
89 |
+
obj, tell_containers[0], trust=True))
|
90 |
+
])
|
91 |
+
return chapter
|
92 |
+
|
93 |
+
|
94 |
+
def write_A3_chapter(
|
95 |
+
start_state, oracle, obj, location, agent_ids, all_agents, movements=None, exist_tell=False, questions=None
|
96 |
+
):
|
97 |
+
a1, a2, a3 = all_agents[agent_ids[0]
|
98 |
+
], all_agents[agent_ids[1]], all_agents[agent_ids[2]]
|
99 |
+
outsiders = [agent for agent in all_agents if agent not in [a1, a2, a3]]
|
100 |
+
agent_ids = [aid+1 for aid in agent_ids]
|
101 |
+
|
102 |
+
# Pick containers. The first element is the initial container of obj
|
103 |
+
containers = [oracle.get_object_container(obj)]
|
104 |
+
container_candidates = oracle.get_containers(location)[:]
|
105 |
+
container_candidates.remove(containers[0])
|
106 |
+
containers += random.sample(container_candidates, 3)
|
107 |
+
|
108 |
+
# Fill in the chapter
|
109 |
+
chapter = []
|
110 |
+
|
111 |
+
# All selected agents enter the room and see the object
|
112 |
+
chapter.extend([
|
113 |
+
Clause(EnterAction(oracle, (a1, a2, a3, location))),
|
114 |
+
Clause(ObjectLocAction(oracle, obj, [a1, a2, a3])),
|
115 |
+
])
|
116 |
+
|
117 |
+
# a1
|
118 |
+
chapter.extend([
|
119 |
+
Clause(MoveAction(oracle, (a1, obj, containers[1]), [
|
120 |
+
a2, a3], move=movements[0])),
|
121 |
+
Clause(ExitedAction(oracle, (a1)))
|
122 |
+
])
|
123 |
+
# a2
|
124 |
+
chapter.extend([
|
125 |
+
Clause(MoveAction(oracle, (a2, obj, containers[2]), [
|
126 |
+
a3], move=movements[1])),
|
127 |
+
Clause(ExitedAction(oracle, (a2)))
|
128 |
+
])
|
129 |
+
# a3
|
130 |
+
chapter.extend([
|
131 |
+
Clause(MoveAction(
|
132 |
+
oracle, (a3, obj, containers[3]), None, move=movements[2])),
|
133 |
+
Clause(ExitedAction(oracle, (a3)))
|
134 |
+
])
|
135 |
+
|
136 |
+
# Everyone enter the waiting room
|
137 |
+
chapter.extend([
|
138 |
+
Clause(EnterAction(oracle, (a1, a2, a3, 'waiting_room')))
|
139 |
+
])
|
140 |
+
|
141 |
+
# tell actions has 4 different forms
|
142 |
+
if exist_tell:
|
143 |
+
tell_containers = random.sample(oracle.get_containers(location)[:], 2)
|
144 |
+
tell_form = random.choice(
|
145 |
+
range(4)) if outsiders else random.choice(range(2))
|
146 |
+
match tell_form:
|
147 |
+
case 0:
|
148 |
+
# a2 lies to all, and a3 lies to a2
|
149 |
+
chapter.extend([
|
150 |
+
Clause(PublicTellAction(
|
151 |
+
oracle, a2, obj, tell_containers[0], listeners=all_agents, believers=[a1] + outsiders)),
|
152 |
+
Clause(PrivateTellAction(oracle, a3, a2,
|
153 |
+
obj, tell_containers[1], trust=True)),
|
154 |
+
])
|
155 |
+
case 1:
|
156 |
+
# a3 lies to all, and a1 lies to a3
|
157 |
+
chapter.extend([
|
158 |
+
Clause(PublicTellAction(
|
159 |
+
oracle, a3, obj, tell_containers[0], listeners=all_agents, believers=[a1, a2] + outsiders)),
|
160 |
+
Clause(PrivateTellAction(oracle, a1, a3, obj,
|
161 |
+
tell_containers[1], trust=False)),
|
162 |
+
])
|
163 |
+
case 2:
|
164 |
+
# a1 lies to all, but a3 tells the true location to an outside agent
|
165 |
+
chapter.extend([
|
166 |
+
Clause(PublicTellAction(
|
167 |
+
oracle, a1, obj, tell_containers[0], listeners=all_agents, believers=outsiders)),
|
168 |
+
Clause(PrivateTellAction(oracle, a3, random.choice(outsiders),
|
169 |
+
obj, oracle.get_object_container(obj), trust=True))
|
170 |
+
])
|
171 |
+
case 3:
|
172 |
+
# a2 lies to a3, but a3 tells the true location to an outside agent
|
173 |
+
chapter.extend([
|
174 |
+
Clause(PrivateTellAction(oracle, a2, a3,
|
175 |
+
obj, tell_containers[0], trust=False)),
|
176 |
+
Clause(PrivateTellAction(oracle, a3, random.choice(outsiders),
|
177 |
+
obj, oracle.get_object_container(obj), trust=True))
|
178 |
+
])
|
179 |
+
return chapter
|
180 |
+
|
181 |
+
|
182 |
+
def write_A4_chapter(
|
183 |
+
start_state, oracle, obj, location, agent_ids, all_agents, movements=None, exist_tell=False, questions=None
|
184 |
+
):
|
185 |
+
a1, a2, a3, a4 = all_agents[agent_ids[0]
|
186 |
+
], all_agents[agent_ids[1]], all_agents[agent_ids[2]], all_agents[agent_ids[3]]
|
187 |
+
outsiders = [
|
188 |
+
agent for agent in all_agents if agent not in [a1, a2, a3, a4]]
|
189 |
+
agent_ids = [aid+1 for aid in agent_ids]
|
190 |
+
|
191 |
+
# Pick containers. The first element is the initial container of obj
|
192 |
+
containers = [oracle.get_object_container(obj)]
|
193 |
+
container_candidates = oracle.get_containers(location)[:]
|
194 |
+
container_candidates.remove(containers[0])
|
195 |
+
containers += random.sample(container_candidates, 4)
|
196 |
+
|
197 |
+
# Fill in the chapter
|
198 |
+
chapter = []
|
199 |
+
|
200 |
+
# All selected agents enter the room and see the object
|
201 |
+
chapter.extend([
|
202 |
+
Clause(EnterAction(oracle, (a1, a2, a3, a4, location))),
|
203 |
+
Clause(ObjectLocAction(oracle, obj, [a1, a2, a3, a4])),
|
204 |
+
])
|
205 |
+
|
206 |
+
# a1
|
207 |
+
chapter.extend([
|
208 |
+
Clause(MoveAction(oracle, (a1, obj, containers[1]), [
|
209 |
+
a2, a3, a4], move=movements[0])),
|
210 |
+
Clause(ExitedAction(oracle, (a1)))
|
211 |
+
])
|
212 |
+
# a2
|
213 |
+
chapter.extend([
|
214 |
+
Clause(MoveAction(oracle, (a2, obj, containers[2]), [
|
215 |
+
a3, a4], move=movements[1])),
|
216 |
+
Clause(ExitedAction(oracle, (a2)))
|
217 |
+
])
|
218 |
+
# a3
|
219 |
+
chapter.extend([
|
220 |
+
Clause(MoveAction(oracle, (a3, obj, containers[3]), [
|
221 |
+
a4], move=movements[2])),
|
222 |
+
Clause(ExitedAction(oracle, (a3)))
|
223 |
+
])
|
224 |
+
# a4
|
225 |
+
chapter.extend([
|
226 |
+
Clause(MoveAction(
|
227 |
+
oracle, (a4, obj, containers[4]), None, move=movements[3])),
|
228 |
+
Clause(ExitedAction(oracle, (a4)))
|
229 |
+
])
|
230 |
+
|
231 |
+
# Everyone enter the waiting room
|
232 |
+
chapter.extend([
|
233 |
+
Clause(EnterAction(oracle, (a1, a2, a3, a4, 'waiting_room')))
|
234 |
+
])
|
235 |
+
|
236 |
+
# tell actions has 4 different forms
|
237 |
+
if exist_tell:
|
238 |
+
tell_containers = random.sample(oracle.get_containers(location)[:], 2)
|
239 |
+
tell_form = random.choice(
|
240 |
+
range(4)) if outsiders else random.choice(range(2))
|
241 |
+
match tell_form:
|
242 |
+
case 0:
|
243 |
+
# a2 lies to all, and a3 lies to a2
|
244 |
+
chapter.extend([
|
245 |
+
Clause(PublicTellAction(
|
246 |
+
oracle, a2, obj, tell_containers[0], listeners=all_agents, believers=[a1] + outsiders)),
|
247 |
+
Clause(PrivateTellAction(oracle, a4, a3,
|
248 |
+
obj, tell_containers[1], trust=True)),
|
249 |
+
])
|
250 |
+
case 1:
|
251 |
+
# a3 lies to all, and a1 lies to a4
|
252 |
+
chapter.extend([
|
253 |
+
Clause(PublicTellAction(
|
254 |
+
oracle, a3, obj, tell_containers[0], listeners=all_agents, believers=[a1, a2] + outsiders)),
|
255 |
+
Clause(PrivateTellAction(oracle, a1, a4, obj,
|
256 |
+
tell_containers[1], trust=False)),
|
257 |
+
])
|
258 |
+
case 2:
|
259 |
+
outsider = random.choice(outsiders)
|
260 |
+
# a1 lies to all, but a4 tells the true location to an outside agent
|
261 |
+
chapter.extend([
|
262 |
+
Clause(PublicTellAction(
|
263 |
+
oracle, a1, obj, tell_containers[0], listeners=all_agents, believers=outsiders)),
|
264 |
+
Clause(PrivateTellAction(oracle, a4, outsider,
|
265 |
+
obj, oracle.get_object_container(obj), trust=True))
|
266 |
+
])
|
267 |
+
case 3:
|
268 |
+
outsider = random.choice(outsiders)
|
269 |
+
# a2 lies to a3, but a4 tells the true location to an outside agent
|
270 |
+
chapter.extend([
|
271 |
+
Clause(PrivateTellAction(oracle, a2, a3,
|
272 |
+
obj, tell_containers[0], trust=False)),
|
273 |
+
Clause(PrivateTellAction(oracle, a4, outsider,
|
274 |
+
obj, oracle.get_object_container(obj), trust=True))
|
275 |
+
])
|
276 |
+
return chapter
|
277 |
+
|
278 |
+
|
279 |
+
def write_A5_chapter(
|
280 |
+
start_state, oracle, obj, location, agent_ids, all_agents, movements=None, exist_tell=False, questions=None
|
281 |
+
):
|
282 |
+
a1, a2, a3, a4, a5 = all_agents[agent_ids[0]], all_agents[agent_ids[1]
|
283 |
+
], all_agents[agent_ids[2]], all_agents[agent_ids[3]], all_agents[agent_ids[4]]
|
284 |
+
agent_ids = [aid+1 for aid in agent_ids]
|
285 |
+
|
286 |
+
# Pick containers. The first element is the initial container of obj
|
287 |
+
containers = [oracle.get_object_container(obj)]
|
288 |
+
container_candidates = oracle.get_containers(location)[:]
|
289 |
+
container_candidates.remove(containers[0])
|
290 |
+
containers += random.sample(container_candidates, 4)
|
291 |
+
|
292 |
+
# Fill in the chapter
|
293 |
+
chapter = []
|
294 |
+
|
295 |
+
# All selected agents enter the room and see the object
|
296 |
+
chapter.extend([
|
297 |
+
Clause(EnterAction(oracle, (a1, a2, a3, a4, a5, location))),
|
298 |
+
Clause(ObjectLocAction(oracle, obj, [a1, a2, a3, a4, a5])),
|
299 |
+
])
|
300 |
+
|
301 |
+
# a1
|
302 |
+
chapter.extend([
|
303 |
+
Clause(MoveAction(oracle, (a1, obj, containers[1]), [
|
304 |
+
a2, a3, a4, a5], move=movements[0])),
|
305 |
+
Clause(ExitedAction(oracle, (a1)))
|
306 |
+
])
|
307 |
+
# a2
|
308 |
+
chapter.extend([
|
309 |
+
Clause(MoveAction(oracle, (a2, obj, containers[2]), [
|
310 |
+
a3, a4, a5], move=movements[1])),
|
311 |
+
Clause(ExitedAction(oracle, (a2)))
|
312 |
+
])
|
313 |
+
# a3
|
314 |
+
chapter.extend([
|
315 |
+
Clause(MoveAction(oracle, (a3, obj, containers[3]), [
|
316 |
+
a4, a5], move=movements[2])),
|
317 |
+
Clause(ExitedAction(oracle, (a3)))
|
318 |
+
])
|
319 |
+
# a4
|
320 |
+
chapter.extend([
|
321 |
+
Clause(MoveAction(oracle, (a4, obj, containers[4]), [
|
322 |
+
a5], move=movements[3])),
|
323 |
+
Clause(ExitedAction(oracle, (a4)))
|
324 |
+
])
|
325 |
+
# a5
|
326 |
+
chapter.extend([
|
327 |
+
Clause(MoveAction(
|
328 |
+
oracle, (a5, obj, containers[0]), None, move=movements[4])),
|
329 |
+
Clause(ExitedAction(oracle, (a5)))
|
330 |
+
])
|
331 |
+
|
332 |
+
# Everyone enter the waiting room
|
333 |
+
chapter.extend([
|
334 |
+
Clause(EnterAction(oracle, (a1, a2, a3, a4, a5, 'waiting_room')))
|
335 |
+
])
|
336 |
+
|
337 |
+
# tell actions has 3 different forms
|
338 |
+
if exist_tell:
|
339 |
+
tell_containers = random.sample(oracle.get_containers(location)[:], 2)
|
340 |
+
tell_form = random.choice(range(3))
|
341 |
+
match tell_form:
|
342 |
+
case 0:
|
343 |
+
# a3 lies to all, and a5 lies to a3
|
344 |
+
chapter.extend([
|
345 |
+
Clause(PublicTellAction(
|
346 |
+
oracle, a3, obj, tell_containers[0], listeners=all_agents, believers=[a1, a2])),
|
347 |
+
Clause(PrivateTellAction(oracle, a5, a3,
|
348 |
+
obj, tell_containers[1], trust=True)),
|
349 |
+
])
|
350 |
+
case 1:
|
351 |
+
# a4 lies to all, but a5 tells the true location to a1
|
352 |
+
chapter.extend([
|
353 |
+
Clause(PublicTellAction(
|
354 |
+
oracle, a4, obj, tell_containers[0], listeners=all_agents, believers=[a1, a2, a3])),
|
355 |
+
Clause(PrivateTellAction(oracle, a5, a1, obj,
|
356 |
+
oracle.get_object_container(obj), trust=True)),
|
357 |
+
])
|
358 |
+
case 2:
|
359 |
+
# a3 lies a1, and a2 lies to a4
|
360 |
+
chapter.extend([
|
361 |
+
Clause(PrivateTellAction(oracle, a3, a1,
|
362 |
+
obj, tell_containers[0], trust=True))
|
363 |
+
])
|
364 |
+
return chapter
|
365 |
+
|
366 |
+
|
367 |
+
#######################################
|
368 |
+
############### Tasks #################
|
369 |
+
#######################################
|
370 |
+
|
371 |
+
class Task(object):
|
372 |
+
|
373 |
+
def __init__(self,
|
374 |
+
num_questions=5,
|
375 |
+
exit_prob=1.,
|
376 |
+
informant_prob=1.,
|
377 |
+
search_prob=1.,
|
378 |
+
test_cond='first order'):
|
379 |
+
|
380 |
+
self.num_questions = num_questions
|
381 |
+
|
382 |
+
self.search_prob = search_prob
|
383 |
+
|
384 |
+
self.exit_inform_probs = [1 - exit_prob,
|
385 |
+
exit_prob * (1 - informant_prob),
|
386 |
+
exit_prob * informant_prob]
|
387 |
+
assert sum(self.exit_inform_probs) == 1
|
388 |
+
|
389 |
+
assert test_cond in ['first order',
|
390 |
+
'second order',
|
391 |
+
'reality',
|
392 |
+
'memory'], \
|
393 |
+
"Invalid test condition: %s" % test_cond
|
394 |
+
self.test_cond = test_cond
|
395 |
+
|
396 |
+
def generate_story(self, world):
|
397 |
+
raise NotImplementedError("Abstract method.")
|
398 |
+
|
399 |
+
|
400 |
+
class Specify_Tasks(Task):
|
401 |
+
def generate_story_qs_at_end(
|
402 |
+
self, world, tasks_per_story, tasks, num_agents=5,
|
403 |
+
num_locations=3, statement_noise=0.1, order=0, exist_tell_in_story=False
|
404 |
+
):
|
405 |
+
"""
|
406 |
+
Allows user to specify chapter and question for each task in story.
|
407 |
+
|
408 |
+
:tasks: list with length of tasks per story. Each entry is a string in
|
409 |
+
the set {'tb','fb','sofb'}
|
410 |
+
|
411 |
+
:questions: list with length of tasks per story. Each entry is a string
|
412 |
+
in the set {'memory', 'reality', 'belief', 'search'}
|
413 |
+
|
414 |
+
:statement_noise: probability of encountering noise sentence like 'The
|
415 |
+
dog ran through the kitchen.'
|
416 |
+
"""
|
417 |
+
|
418 |
+
# Fetch agents and objects and select a random subset
|
419 |
+
idx_support_dummy = [0]
|
420 |
+
actors = world.get_actors()
|
421 |
+
locations = world.get_locations()
|
422 |
+
objects = world.get_objects()
|
423 |
+
containers = world.get_containers()
|
424 |
+
|
425 |
+
random_actors = np.random.choice(
|
426 |
+
actors, size=num_agents, replace=False
|
427 |
+
)
|
428 |
+
random_locations = np.random.choice(
|
429 |
+
locations, size=num_locations, replace=False
|
430 |
+
)
|
431 |
+
random_objects = np.random.choice(
|
432 |
+
objects, size=num_locations*2, replace=False
|
433 |
+
)
|
434 |
+
random_containers = np.random.choice(
|
435 |
+
containers, size=num_locations*5, replace=False
|
436 |
+
)
|
437 |
+
|
438 |
+
# Create the oracle
|
439 |
+
oracle = Oracle(
|
440 |
+
random_actors, random_locations, random_objects, random_containers
|
441 |
+
)
|
442 |
+
|
443 |
+
# Populate locations in the oracle with containers
|
444 |
+
for i, random_location in enumerate(random_locations):
|
445 |
+
location = random_location
|
446 |
+
containers = random_containers[5*i:5*i+5]
|
447 |
+
oracle.set_containers(location, list(containers))
|
448 |
+
# Two of the containers have objects
|
449 |
+
oracle.set_object_container(
|
450 |
+
random_objects[2*i], containers[0])
|
451 |
+
oracle.set_object_container(
|
452 |
+
random_objects[2*i+1], containers[1])
|
453 |
+
|
454 |
+
# Need start state for memory question
|
455 |
+
start_state = oracle.locations.obj_containers.copy()
|
456 |
+
|
457 |
+
# Create story by task
|
458 |
+
chapters = {'A2': write_A2_chapter,
|
459 |
+
'A3': write_A3_chapter,
|
460 |
+
'A4': write_A4_chapter,
|
461 |
+
'A5': write_A5_chapter}
|
462 |
+
story = []
|
463 |
+
obj_pool = []
|
464 |
+
obj_in_question = None
|
465 |
+
|
466 |
+
for i in range(tasks_per_story):
|
467 |
+
chapter = chapters[tasks[i][0]]
|
468 |
+
location = np.random.choice(random_locations)
|
469 |
+
obj = np.random.choice(oracle.get_objects_at_location(location))
|
470 |
+
# Use the obj in the first chap as the target
|
471 |
+
if i == 0:
|
472 |
+
obj_in_question = obj
|
473 |
+
obj_pool.append(obj)
|
474 |
+
agent_ids = list(range(5))
|
475 |
+
random.shuffle(agent_ids)
|
476 |
+
|
477 |
+
# Randomly choose movements for each agent
|
478 |
+
agent_num = int(tasks[i][0][1])
|
479 |
+
bools = [True, False]
|
480 |
+
movements = [random.choice(bools) for _ in range(agent_num)]
|
481 |
+
exist_tell_in_chapter = tasks[i][1] if exist_tell_in_story else False
|
482 |
+
story.extend(
|
483 |
+
chapter(
|
484 |
+
start_state, oracle, obj, location, agent_ids, random_actors, movements=movements, exist_tell=exist_tell_in_chapter
|
485 |
+
)
|
486 |
+
)
|
487 |
+
|
488 |
+
# At the end, add noise sentences randomly
|
489 |
+
if statement_noise:
|
490 |
+
noisy_story = []
|
491 |
+
prev_i = 0
|
492 |
+
noise = [i for i
|
493 |
+
in range(len(story)) if np.random.rand() < statement_noise
|
494 |
+
]
|
495 |
+
for i in noise:
|
496 |
+
noisy_story.extend(
|
497 |
+
story[prev_i:i] +
|
498 |
+
[Clause(NoiseAction(random_actors,
|
499 |
+
random_containers, random_objects))]
|
500 |
+
)
|
501 |
+
prev_i = i
|
502 |
+
noisy_story.extend(story[prev_i:])
|
503 |
+
|
504 |
+
# compute questions of all orders
|
505 |
+
questioned_actors = copy.deepcopy(random_actors)
|
506 |
+
random.shuffle(questioned_actors)
|
507 |
+
for idx in range(5):
|
508 |
+
noisy_story.append(
|
509 |
+
sample_question(
|
510 |
+
start_state, oracle, questioned_actors, obj_in_question, question_idx=idx
|
511 |
+
)
|
512 |
+
)
|
513 |
+
|
514 |
+
# Generate choices of containers
|
515 |
+
choices = ', '.join(f'{chr(65+i)}. {container}' for i,
|
516 |
+
container in enumerate(random_containers))
|
517 |
+
noisy_story.append('Choices: ' + choices + '\n')
|
518 |
+
return noisy_story
|
test_azure.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import openai
|
3 |
+
|
4 |
+
def record_progress(filename):
|
5 |
+
with open('progress.txt', 'a') as f:
|
6 |
+
f.write(filename + '\n')
|
7 |
+
|
8 |
+
def is_processed(filename):
|
9 |
+
with open('progress.txt', 'r') as f:
|
10 |
+
processed_files = f.read().splitlines()
|
11 |
+
return filename in processed_files
|
12 |
+
|
13 |
+
openai.api_type = "azure"
|
14 |
+
openai.api_base = "https://openaiserviceforclausaeu.openai.azure.com/"
|
15 |
+
openai.api_version = "2023-03-15-preview"
|
16 |
+
openai.api_key = os.getenv("OPENAI_API_KEY")
|
17 |
+
|
18 |
+
test_dirs = os.listdir("prompt_ToMh")
|
19 |
+
for test_dir in test_dirs:
|
20 |
+
test_fns = os.listdir(f"prompt_ToMh/{test_dir}")
|
21 |
+
for test_fn in test_fns:
|
22 |
+
full_path = f"prompt_ToMh/{test_dir}/{test_fn}"
|
23 |
+
if is_processed(full_path):
|
24 |
+
continue
|
25 |
+
print(test_fn)
|
26 |
+
print(f"path: {full_path}")
|
27 |
+
with open(full_path, 'r') as f:
|
28 |
+
input = f.readlines()
|
29 |
+
input = "\n".join([inp.strip() for inp in input])
|
30 |
+
response = openai.ChatCompletion.create(
|
31 |
+
engine="gpt4-32k",
|
32 |
+
messages=[
|
33 |
+
{"role":"system","content":"You are an AI assistant that helps people find information."},
|
34 |
+
{"role":"user","content": input}
|
35 |
+
],
|
36 |
+
temperature=0,
|
37 |
+
max_tokens=800,
|
38 |
+
top_p=0,
|
39 |
+
frequency_penalty=0,
|
40 |
+
presence_penalty=0,
|
41 |
+
stop=None)
|
42 |
+
print(response)
|
43 |
+
record_progress(full_path)
|
utils.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import ArgumentTypeError
|
2 |
+
import errno
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
class Error(Exception):
|
7 |
+
"""Base class for exceptions in this module."""
|
8 |
+
pass
|
9 |
+
|
10 |
+
|
11 |
+
class InputError(Error):
|
12 |
+
"""Exception raised for errors in the input.
|
13 |
+
|
14 |
+
Attributes:
|
15 |
+
expr # input expression in which the error occurred
|
16 |
+
msg # explanation of the error
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, expr, msg):
|
20 |
+
self.expr = expr
|
21 |
+
self.msg = msg
|
22 |
+
|
23 |
+
|
24 |
+
def is_file(f):
|
25 |
+
try:
|
26 |
+
open(f, 'r') # return an open file handle
|
27 |
+
except IOError:
|
28 |
+
raise ArgumentTypeError("{0} does not exist".format(f))
|
29 |
+
return f
|
30 |
+
|
31 |
+
|
32 |
+
def mkdir_p(path):
|
33 |
+
try:
|
34 |
+
os.makedirs(path)
|
35 |
+
except OSError as exc: # Python >2.5
|
36 |
+
if exc.errno == errno.EEXIST and os.path.isdir(path):
|
37 |
+
pass
|
38 |
+
else:
|
39 |
+
raise
|
40 |
+
return path
|
41 |
+
|
42 |
+
|
43 |
+
def remove_extension(path):
|
44 |
+
return os.path.splitext(os.path.basename(path))[0]
|
world.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class World(object):
|
2 |
+
|
3 |
+
def __init__(self, world_actions=[], entities={}):
|
4 |
+
self.actions = world_actions
|
5 |
+
self.entities = entities
|
6 |
+
|
7 |
+
def load(self, fname):
|
8 |
+
|
9 |
+
lines = open(fname, 'r').readlines()
|
10 |
+
i = 0
|
11 |
+
|
12 |
+
while i < len(lines):
|
13 |
+
line = lines[i].rstrip('\n')
|
14 |
+
if line != '' and not line.startswith('#'):
|
15 |
+
if line.startswith('create'):
|
16 |
+
self.entities[line.split(' ')[1]] = {}
|
17 |
+
elif line.startswith('set'):
|
18 |
+
self.entities[line.split(' ')[1]][line.split(' ')[-1]] = True
|
19 |
+
|
20 |
+
i += 1
|
21 |
+
|
22 |
+
def get_entity(self, predicates):
|
23 |
+
|
24 |
+
if not isinstance(predicates, list):
|
25 |
+
raise InputError(predicates, 'is not a list.')
|
26 |
+
|
27 |
+
return_val = []
|
28 |
+
|
29 |
+
for k in self.entities:
|
30 |
+
if all([predicate in self.entities[k] and
|
31 |
+
self.entities[k][predicate] is True
|
32 |
+
for predicate in predicates]):
|
33 |
+
return_val += [k]
|
34 |
+
|
35 |
+
return return_val
|
36 |
+
|
37 |
+
def get_actors(self):
|
38 |
+
return self.get_entity(['is_actor', 'is_god'])
|
39 |
+
|
40 |
+
def get_containers(self):
|
41 |
+
return self.get_entity(['is_thing', 'is_container'])
|
42 |
+
|
43 |
+
def get_locations(self):
|
44 |
+
return self.get_entity(['is_location'])
|
45 |
+
|
46 |
+
def get_objects(self):
|
47 |
+
return self.get_entity(['is_thing', 'is_gettable'])
|