michal-stefanik commited on
Commit
463ea04
·
1 Parent(s): 1f82ada

Upload 2.1_construct_qa_dataset.ipynb

Browse files
Files changed (1) hide show
  1. 2.1_construct_qa_dataset.ipynb +882 -0
2.1_construct_qa_dataset.ipynb ADDED
@@ -0,0 +1,882 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "942fa22a-c776-4a44-bde9-75b7cb4202ba",
6
+ "metadata": {},
7
+ "source": [
8
+ "## Outline\n",
9
+ "\n",
10
+ "1. We collect a dataset consisting of (user_question, answer_context, dialogue_history -> answer)\n",
11
+ "2. We duplicate a small portion of dataset, where we remove answer_context\n",
12
+ "2. We augment 'answer_context' with (non_answer) picked by a reasonably-performing QA system: variable ordering, consistent number of answers\n",
13
+ "3. We train the model for exact-match generation \n",
14
+ "- Also evaluate the exact-match ratio\n",
15
+ "- Separately evaluate with full-context questions"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "markdown",
20
+ "id": "766c4c50-6e72-41b2-b6d7-1e4c3c309a68",
21
+ "metadata": {},
22
+ "source": [
23
+ "### 1. Positive contexts collection"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": 1,
29
+ "id": "33d57a85-c079-4cf1-b9ad-3b00ce916720",
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "import datasets"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": 2,
39
+ "id": "0434a258-27ca-4cec-bb85-60673fea2b16",
40
+ "metadata": {},
41
+ "outputs": [
42
+ {
43
+ "name": "stderr",
44
+ "output_type": "stream",
45
+ "text": [
46
+ "Using custom data configuration default-8d557d41fc795903\n",
47
+ "Found cached dataset json (/home/xstefan3/.cache/huggingface/datasets/json/default-8d557d41fc795903/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)\n"
48
+ ]
49
+ },
50
+ {
51
+ "data": {
52
+ "application/vnd.jupyter.widget-view+json": {
53
+ "model_id": "366db7856ce341a6854a08c244aa5db1",
54
+ "version_major": 2,
55
+ "version_minor": 0
56
+ },
57
+ "text/plain": [
58
+ " 0%| | 0/1 [00:00<?, ?it/s]"
59
+ ]
60
+ },
61
+ "metadata": {},
62
+ "output_type": "display_data"
63
+ }
64
+ ],
65
+ "source": [
66
+ "canard_train = datasets.load_dataset(\"json\", data_files=\"datasets/CANARD_Release/train.json\")[\"train\"]"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": 3,
72
+ "id": "73cf434e-8d36-4680-a80d-a9304ef801f2",
73
+ "metadata": {},
74
+ "outputs": [
75
+ {
76
+ "data": {
77
+ "text/plain": [
78
+ "Dataset({\n",
79
+ " features: ['History', 'QuAC_dialog_id', 'Question', 'Question_no', 'Rewrite'],\n",
80
+ " num_rows: 31526\n",
81
+ "})"
82
+ ]
83
+ },
84
+ "execution_count": 3,
85
+ "metadata": {},
86
+ "output_type": "execute_result"
87
+ }
88
+ ],
89
+ "source": [
90
+ "canard_train"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": 4,
96
+ "id": "02eba563-5810-4b0a-b130-920d163a54ac",
97
+ "metadata": {},
98
+ "outputs": [
99
+ {
100
+ "data": {
101
+ "text/plain": [
102
+ "{'History': ['Johnny Unitas', '1964 MVP season'],\n",
103
+ " 'QuAC_dialog_id': 'C_2ba58216460d43aa986fc0e897537239_0',\n",
104
+ " 'Question': 'what team did unitas play for',\n",
105
+ " 'Question_no': 1,\n",
106
+ " 'Rewrite': 'what team did Johnny Unitas play for?'}"
107
+ ]
108
+ },
109
+ "execution_count": 4,
110
+ "metadata": {},
111
+ "output_type": "execute_result"
112
+ }
113
+ ],
114
+ "source": [
115
+ "canard_train[0]"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "execution_count": 5,
121
+ "id": "b73c5e59-2430-4b00-aa8b-0f926729ada1",
122
+ "metadata": {},
123
+ "outputs": [
124
+ {
125
+ "name": "stderr",
126
+ "output_type": "stream",
127
+ "text": [
128
+ "Found cached dataset quac (/home/xstefan3/.cache/huggingface/datasets/quac/plain_text/1.1.0/4170258e7e72d7c81bd6441b3f3489ea1544f0ff226ce61e22bb00c6e9d01fb6)\n"
129
+ ]
130
+ }
131
+ ],
132
+ "source": [
133
+ "quac_train = datasets.load_dataset(\"quac\", split=\"train\")"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": 6,
139
+ "id": "21e21544-b65c-433d-86c1-30d4507088e7",
140
+ "metadata": {},
141
+ "outputs": [
142
+ {
143
+ "data": {
144
+ "text/html": [
145
+ "<div>\n",
146
+ "<style scoped>\n",
147
+ " .dataframe tbody tr th:only-of-type {\n",
148
+ " vertical-align: middle;\n",
149
+ " }\n",
150
+ "\n",
151
+ " .dataframe tbody tr th {\n",
152
+ " vertical-align: top;\n",
153
+ " }\n",
154
+ "\n",
155
+ " .dataframe thead th {\n",
156
+ " text-align: right;\n",
157
+ " }\n",
158
+ "</style>\n",
159
+ "<table border=\"1\" class=\"dataframe\">\n",
160
+ " <thead>\n",
161
+ " <tr style=\"text-align: right;\">\n",
162
+ " <th></th>\n",
163
+ " <th>wikipedia_page_title</th>\n",
164
+ " <th>background</th>\n",
165
+ " <th>section_title</th>\n",
166
+ " <th>context</th>\n",
167
+ " <th>turn_ids</th>\n",
168
+ " <th>questions</th>\n",
169
+ " <th>followups</th>\n",
170
+ " <th>yesnos</th>\n",
171
+ " <th>answers</th>\n",
172
+ " <th>orig_answers</th>\n",
173
+ " </tr>\n",
174
+ " <tr>\n",
175
+ " <th>dialogue_id</th>\n",
176
+ " <th></th>\n",
177
+ " <th></th>\n",
178
+ " <th></th>\n",
179
+ " <th></th>\n",
180
+ " <th></th>\n",
181
+ " <th></th>\n",
182
+ " <th></th>\n",
183
+ " <th></th>\n",
184
+ " <th></th>\n",
185
+ " <th></th>\n",
186
+ " </tr>\n",
187
+ " </thead>\n",
188
+ " <tbody>\n",
189
+ " <tr>\n",
190
+ " <th>C_69758fcdfc1f46baba0e92c0f3b0919c_1</th>\n",
191
+ " <td>Malayali</td>\n",
192
+ " <td>The Malayali people or Keralite people (also s...</td>\n",
193
+ " <td>Geographic distribution and population</td>\n",
194
+ " <td>According to the Indian census of 2001, there ...</td>\n",
195
+ " <td>[C_69758fcdfc1f46baba0e92c0f3b0919c_1_q#0, C_6...</td>\n",
196
+ " <td>[Where is Malayali located?, What other langua...</td>\n",
197
+ " <td>[2, 1, 1, 1, 1, 1, 1]</td>\n",
198
+ " <td>[2, 2, 2, 2, 2, 0, 2]</td>\n",
199
+ " <td>{'texts': [['30,803,747 speakers of Malayalam ...</td>\n",
200
+ " <td>{'texts': ['30,803,747 speakers of Malayalam i...</td>\n",
201
+ " </tr>\n",
202
+ " <tr>\n",
203
+ " <th>C_69758fcdfc1f46baba0e92c0f3b0919c_0</th>\n",
204
+ " <td>Malayali</td>\n",
205
+ " <td>The Malayali people or Keralite people (also s...</td>\n",
206
+ " <td>Language and literature</td>\n",
207
+ " <td>Malayalam is the language spoken by the Malaya...</td>\n",
208
+ " <td>[C_69758fcdfc1f46baba0e92c0f3b0919c_0_q#0, C_6...</td>\n",
209
+ " <td>[what language do they speak?, Do they speak a...</td>\n",
210
+ " <td>[0, 0, 0, 0, 0, 0, 0]</td>\n",
211
+ " <td>[2, 2, 2, 2, 2, 2, 2]</td>\n",
212
+ " <td>{'texts': [['Malayalam is the language spoken ...</td>\n",
213
+ " <td>{'texts': ['Malayalam is the language spoken b...</td>\n",
214
+ " </tr>\n",
215
+ " </tbody>\n",
216
+ "</table>\n",
217
+ "</div>"
218
+ ],
219
+ "text/plain": [
220
+ " wikipedia_page_title \\\n",
221
+ "dialogue_id \n",
222
+ "C_69758fcdfc1f46baba0e92c0f3b0919c_1 Malayali \n",
223
+ "C_69758fcdfc1f46baba0e92c0f3b0919c_0 Malayali \n",
224
+ "\n",
225
+ " background \\\n",
226
+ "dialogue_id \n",
227
+ "C_69758fcdfc1f46baba0e92c0f3b0919c_1 The Malayali people or Keralite people (also s... \n",
228
+ "C_69758fcdfc1f46baba0e92c0f3b0919c_0 The Malayali people or Keralite people (also s... \n",
229
+ "\n",
230
+ " section_title \\\n",
231
+ "dialogue_id \n",
232
+ "C_69758fcdfc1f46baba0e92c0f3b0919c_1 Geographic distribution and population \n",
233
+ "C_69758fcdfc1f46baba0e92c0f3b0919c_0 Language and literature \n",
234
+ "\n",
235
+ " context \\\n",
236
+ "dialogue_id \n",
237
+ "C_69758fcdfc1f46baba0e92c0f3b0919c_1 According to the Indian census of 2001, there ... \n",
238
+ "C_69758fcdfc1f46baba0e92c0f3b0919c_0 Malayalam is the language spoken by the Malaya... \n",
239
+ "\n",
240
+ " turn_ids \\\n",
241
+ "dialogue_id \n",
242
+ "C_69758fcdfc1f46baba0e92c0f3b0919c_1 [C_69758fcdfc1f46baba0e92c0f3b0919c_1_q#0, C_6... \n",
243
+ "C_69758fcdfc1f46baba0e92c0f3b0919c_0 [C_69758fcdfc1f46baba0e92c0f3b0919c_0_q#0, C_6... \n",
244
+ "\n",
245
+ " questions \\\n",
246
+ "dialogue_id \n",
247
+ "C_69758fcdfc1f46baba0e92c0f3b0919c_1 [Where is Malayali located?, What other langua... \n",
248
+ "C_69758fcdfc1f46baba0e92c0f3b0919c_0 [what language do they speak?, Do they speak a... \n",
249
+ "\n",
250
+ " followups \\\n",
251
+ "dialogue_id \n",
252
+ "C_69758fcdfc1f46baba0e92c0f3b0919c_1 [2, 1, 1, 1, 1, 1, 1] \n",
253
+ "C_69758fcdfc1f46baba0e92c0f3b0919c_0 [0, 0, 0, 0, 0, 0, 0] \n",
254
+ "\n",
255
+ " yesnos \\\n",
256
+ "dialogue_id \n",
257
+ "C_69758fcdfc1f46baba0e92c0f3b0919c_1 [2, 2, 2, 2, 2, 0, 2] \n",
258
+ "C_69758fcdfc1f46baba0e92c0f3b0919c_0 [2, 2, 2, 2, 2, 2, 2] \n",
259
+ "\n",
260
+ " answers \\\n",
261
+ "dialogue_id \n",
262
+ "C_69758fcdfc1f46baba0e92c0f3b0919c_1 {'texts': [['30,803,747 speakers of Malayalam ... \n",
263
+ "C_69758fcdfc1f46baba0e92c0f3b0919c_0 {'texts': [['Malayalam is the language spoken ... \n",
264
+ "\n",
265
+ " orig_answers \n",
266
+ "dialogue_id \n",
267
+ "C_69758fcdfc1f46baba0e92c0f3b0919c_1 {'texts': ['30,803,747 speakers of Malayalam i... \n",
268
+ "C_69758fcdfc1f46baba0e92c0f3b0919c_0 {'texts': ['Malayalam is the language spoken b... "
269
+ ]
270
+ },
271
+ "execution_count": 6,
272
+ "metadata": {},
273
+ "output_type": "execute_result"
274
+ }
275
+ ],
276
+ "source": [
277
+ "quac_train_df = quac_train.to_pandas().set_index(\"dialogue_id\", drop=True)\n",
278
+ "quac_train_df.head(2)"
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": 7,
284
+ "id": "01b2994d-3ba3-4ce0-9531-20e1858ee878",
285
+ "metadata": {},
286
+ "outputs": [
287
+ {
288
+ "data": {
289
+ "text/plain": [
290
+ "array([array(['what team did unitas play for',\n",
291
+ " 'how many games did the colts win',\n",
292
+ " 'who did they play in the playoffs', 'did they win the super bowl',\n",
293
+ " 'who did they play in the super bowl', 'what were unitas stats'],\n",
294
+ " dtype=object) ,\n",
295
+ " {'texts': array([array(['The Colts'], dtype=object),\n",
296
+ " array(['the Colts ran off 10 straight victories to finish with a 12-2 record.'],\n",
297
+ " dtype=object) ,\n",
298
+ " array(['Cleveland Browns'], dtype=object),\n",
299
+ " array(['losing 27-0.'], dtype=object),\n",
300
+ " array(['the Packers.'], dtype=object),\n",
301
+ " array(['Gary Cuozzo also suffered a season-ending injury the following'],\n",
302
+ " dtype=object) ],\n",
303
+ " dtype=object), 'answer_starts': array([array([920], dtype=int32), array([142], dtype=int32),\n",
304
+ " array([552], dtype=int32), array([604], dtype=int32),\n",
305
+ " array([1487], dtype=int32), array([1292], dtype=int32)],\n",
306
+ " dtype=object)} ],\n",
307
+ " dtype=object)"
308
+ ]
309
+ },
310
+ "execution_count": 7,
311
+ "metadata": {},
312
+ "output_type": "execute_result"
313
+ }
314
+ ],
315
+ "source": [
316
+ "quac_train_df.loc['C_2ba58216460d43aa986fc0e897537239_0'][[\"questions\", \"answers\"]].values"
317
+ ]
318
+ },
319
+ {
320
+ "cell_type": "code",
321
+ "execution_count": 8,
322
+ "id": "e680d1ec-ca04-452c-ad44-eae3b43559cc",
323
+ "metadata": {},
324
+ "outputs": [],
325
+ "source": [
326
+ "def answer_for_question(questions: dict, answers: list, question: str) -> str:\n",
327
+ " answers = [anss[0] for anss in answers[\"texts\"]]\n",
328
+ " # print(questions)\n",
329
+ " # print(question)\n",
330
+ " assert question in questions\n",
331
+ " assert len(answers) == len(questions)\n",
332
+ " \n",
333
+ " return next(a for i, a in enumerate(answers) if questions[i] == question)"
334
+ ]
335
+ },
336
+ {
337
+ "cell_type": "code",
338
+ "execution_count": 9,
339
+ "id": "9a58cc80-a4b8-4b59-b5c3-e8b095e2c281",
340
+ "metadata": {},
341
+ "outputs": [
342
+ {
343
+ "data": {
344
+ "application/vnd.jupyter.widget-view+json": {
345
+ "model_id": "786095c0ea3e432dac7dd1912cb3832d",
346
+ "version_major": 2,
347
+ "version_minor": 0
348
+ },
349
+ "text/plain": [
350
+ " 0%| | 0/31526 [00:00<?, ?ex/s]"
351
+ ]
352
+ },
353
+ "metadata": {},
354
+ "output_type": "display_data"
355
+ }
356
+ ],
357
+ "source": [
358
+ "canard_train = canard_train.map(lambda row: \n",
359
+ "{\n",
360
+ " \"true_contexts\": quac_train_df.loc[row[\"QuAC_dialog_id\"]][\"context\"],\n",
361
+ " \"true_page\": quac_train_df.loc[row[\"QuAC_dialog_id\"]][\"wikipedia_page_title\"],\n",
362
+ " \"answer\": answer_for_question(*quac_train_df.loc[row[\"QuAC_dialog_id\"]][[\"questions\", \"answers\"]].values, row[\"Question\"])\n",
363
+ "})"
364
+ ]
365
+ },
366
+ {
367
+ "cell_type": "code",
368
+ "execution_count": 10,
369
+ "id": "d894b2ca-2d74-412d-bdb3-fc42637dea18",
370
+ "metadata": {},
371
+ "outputs": [
372
+ {
373
+ "data": {
374
+ "text/plain": [
375
+ "{'History': ['Johnny Unitas', '1964 MVP season'],\n",
376
+ " 'QuAC_dialog_id': 'C_2ba58216460d43aa986fc0e897537239_0',\n",
377
+ " 'Question': 'what team did unitas play for',\n",
378
+ " 'Question_no': 1,\n",
379
+ " 'Rewrite': 'what team did Johnny Unitas play for?',\n",
380
+ " 'true_contexts': \"The 1964 season would see the Colts return to the top of the Western Conference. After dropping their season opener to the Minnesota Vikings, the Colts ran off 10 straight victories to finish with a 12-2 record. The season was one of Unitas' best as he finished with 2,824 yards passing, a league-best 9.26 yards per pass attempt, 19 touchdown passes and only 6 interceptions. He was named the NFL's Most Valuable Player by the AP and UPI for a second time. However, the season would end on a disappointing note for the Colts as they were upset by the Cleveland Browns in the 1964 NFL Championship Game, losing 27-0. Unitas resumed his torrid passing in 1965, as he threw for 2,530 yards, 23 touchdowns and finished with a league-high and career best 97.1 passer rating. But he was lost for the balance of the season due to a knee injury in a week 12 loss to the Bears. More postseason heartbreak would follow in 1965. The Colts and Packers finished in a tie for first place in the Western Conference and a one-game playoff was played in Green Bay to decide who would be the conference representative in the 1965 NFL Championship Game. The Colts lost in overtime 13-10 due in large part to a game-tying field goal by Don Chandler that many say was incorrectly ruled good. Backup quarterback Gary Cuozzo also suffered a season-ending injury the following week and it would be running back Tom Matte who filled in as the emergency QB for the regular-season finale and the playoff loss to the Packers. Unitas, healthy once more, threw for 2748 yards and 22 touchdowns in 1966 in a return to Pro Bowl form. However, he posted a league-high 24 interceptions. CANNOTANSWER\",\n",
381
+ " 'true_page': 'Johnny Unitas',\n",
382
+ " 'answer': 'The Colts'}"
383
+ ]
384
+ },
385
+ "execution_count": 10,
386
+ "metadata": {},
387
+ "output_type": "execute_result"
388
+ }
389
+ ],
390
+ "source": [
391
+ "canard_train[0]"
392
+ ]
393
+ },
394
+ {
395
+ "cell_type": "code",
396
+ "execution_count": 11,
397
+ "id": "0dfa9659-93b9-4264-9693-08b3410c869e",
398
+ "metadata": {},
399
+ "outputs": [
400
+ {
401
+ "data": {
402
+ "text/plain": [
403
+ "(7881, 31526)"
404
+ ]
405
+ },
406
+ "execution_count": 11,
407
+ "metadata": {},
408
+ "output_type": "execute_result"
409
+ }
410
+ ],
411
+ "source": [
412
+ "import random\n",
413
+ "\n",
414
+ "canard_negative_subsample = canard_train.select(random.sample(list(range(len(canard_train))), k=len(canard_train)//4))\n",
415
+ "\n",
416
+ "len(canard_negative_subsample), len(canard_train)"
417
+ ]
418
+ },
419
+ {
420
+ "cell_type": "code",
421
+ "execution_count": 12,
422
+ "id": "50ed6cc6-0d60-4cd5-b284-899f179420e4",
423
+ "metadata": {},
424
+ "outputs": [
425
+ {
426
+ "data": {
427
+ "application/vnd.jupyter.widget-view+json": {
428
+ "model_id": "9ab7324375c04ead8d6f97164a56b8f7",
429
+ "version_major": 2,
430
+ "version_minor": 0
431
+ },
432
+ "text/plain": [
433
+ " 0%| | 0/7881 [00:00<?, ?ex/s]"
434
+ ]
435
+ },
436
+ "metadata": {},
437
+ "output_type": "display_data"
438
+ }
439
+ ],
440
+ "source": [
441
+ "canard_negative_subsample = canard_negative_subsample.map(lambda row: {\"true_contexts\": \"\"})"
442
+ ]
443
+ },
444
+ {
445
+ "cell_type": "code",
446
+ "execution_count": 13,
447
+ "id": "ee19b78e-cdf0-404c-87a3-37d4e179d988",
448
+ "metadata": {},
449
+ "outputs": [
450
+ {
451
+ "data": {
452
+ "text/plain": [
453
+ "{'History': [\"Dinesh D'Souza\",\n",
454
+ " \"Hillary's America: The Secret History of the Democratic Party\",\n",
455
+ " \"Is Hillary's America a documentary?\",\n",
456
+ " \"On July 25, 2016, D'Souza released the documentary film Hillary's America:\",\n",
457
+ " 'Was it released in theaters?',\n",
458
+ " \"I don't know.\",\n",
459
+ " 'What was the documentary about?',\n",
460
+ " 'The film criticizes the Democratic Party and Hillary Clinton,'],\n",
461
+ " 'QuAC_dialog_id': 'C_31bfdcd402d44289a6206d9b34765869_0',\n",
462
+ " 'Question': 'How did the critics feel about it?',\n",
463
+ " 'Question_no': 4,\n",
464
+ " 'Rewrite': \"How did the critics feel about the film Hillary's America?\",\n",
465
+ " 'true_contexts': '',\n",
466
+ " 'true_page': \"Dinesh D'Souza\",\n",
467
+ " 'answer': 'The film was universally panned by professional film critics.'}"
468
+ ]
469
+ },
470
+ "execution_count": 13,
471
+ "metadata": {},
472
+ "output_type": "execute_result"
473
+ }
474
+ ],
475
+ "source": [
476
+ "canard_negative_subsample[0]"
477
+ ]
478
+ },
479
+ {
480
+ "cell_type": "code",
481
+ "execution_count": 14,
482
+ "id": "4f3869f7-3b9d-4efc-b194-4e2f2cb38798",
483
+ "metadata": {},
484
+ "outputs": [
485
+ {
486
+ "data": {
487
+ "text/plain": [
488
+ "39407"
489
+ ]
490
+ },
491
+ "execution_count": 14,
492
+ "metadata": {},
493
+ "output_type": "execute_result"
494
+ }
495
+ ],
496
+ "source": [
497
+ "canard_train = datasets.concatenate_datasets([canard_train, canard_negative_subsample])\n",
498
+ "\n",
499
+ "len(canard_train)"
500
+ ]
501
+ },
502
+ {
503
+ "cell_type": "markdown",
504
+ "id": "2fd2f319-764a-4b31-baad-ea9a4b037e1d",
505
+ "metadata": {},
506
+ "source": [
507
+ "### 2. Negative contexts collection\n",
508
+ "\n",
509
+ "We use BM25 to collect a realistic set of retrieves given by the IR search"
510
+ ]
511
+ },
512
+ {
513
+ "cell_type": "code",
514
+ "execution_count": 15,
515
+ "id": "92b7428e-8975-49b0-9482-539a454a5f9d",
516
+ "metadata": {},
517
+ "outputs": [],
518
+ "source": [
519
+ "from BM25_irsystem import BM25PlusSystem, SimpleDocProcessing"
520
+ ]
521
+ },
522
+ {
523
+ "cell_type": "code",
524
+ "execution_count": 16,
525
+ "id": "35e76ccb-a2ca-4cfe-b8a1-71c665e237f7",
526
+ "metadata": {},
527
+ "outputs": [],
528
+ "source": [
529
+ "from pv211_utils.trec.entities import TrecDocumentBase, TrecQueryBase"
530
+ ]
531
+ },
532
+ {
533
+ "cell_type": "code",
534
+ "execution_count": 17,
535
+ "id": "92d0e8e0-e7ff-48ad-a35c-543ffdd76d7f",
536
+ "metadata": {},
537
+ "outputs": [],
538
+ "source": [
539
+ "documents = {str(i): TrecDocumentBase(document_id=i, body=context) for i, context in enumerate(quac_train_df.context)}\n",
540
+ "\n",
541
+ "irsystem = BM25PlusSystem(documents, preprocessing=SimpleDocProcessing())"
542
+ ]
543
+ },
544
+ {
545
+ "cell_type": "code",
546
+ "execution_count": 18,
547
+ "id": "0e906f2f-b2b6-4f78-bd98-254cb1279181",
548
+ "metadata": {},
549
+ "outputs": [],
550
+ "source": [
551
+ "def get_negative_question_responses(question: str, num_responses: 5):\n",
552
+ " # TODO: add contexts' titles\n",
553
+ " unique_responses = []\n",
554
+ " # question = \"What team did Johnny Unitas play for?\"\n",
555
+ "\n",
556
+ " for response_doc in irsystem.search(TrecQueryBase(query_id=0, title=\"\", body=question, narrative=\"\")):\n",
557
+ " if response_doc.body not in unique_responses:\n",
558
+ " unique_responses.append(response_doc.body)\n",
559
+ " if len(unique_responses) >= num_responses:\n",
560
+ " break\n",
561
+ "\n",
562
+ " return unique_responses"
563
+ ]
564
+ },
565
+ {
566
+ "cell_type": "code",
567
+ "execution_count": 19,
568
+ "id": "f3587fed-f244-45f0-af28-6e5e36ac15b9",
569
+ "metadata": {},
570
+ "outputs": [
571
+ {
572
+ "data": {
573
+ "application/vnd.jupyter.widget-view+json": {
574
+ "model_id": "5318a11fc3484b188ef6c15a6b952d95",
575
+ "version_major": 2,
576
+ "version_minor": 0
577
+ },
578
+ "text/plain": [
579
+ " 0%| | 0/39407 [00:00<?, ?ex/s]"
580
+ ]
581
+ },
582
+ "metadata": {},
583
+ "output_type": "display_data"
584
+ },
585
+ {
586
+ "data": {
587
+ "application/vnd.jupyter.widget-view+json": {
588
+ "model_id": "",
589
+ "version_major": 2,
590
+ "version_minor": 0
591
+ },
592
+ "text/plain": [
593
+ "Saving the dataset (0/1 shards): 0%| | 0/39407 [00:00<?, ? examples/s]"
594
+ ]
595
+ },
596
+ "metadata": {},
597
+ "output_type": "display_data"
598
+ }
599
+ ],
600
+ "source": [
601
+ "canard_train_augm = canard_train.map(\n",
602
+ " lambda row: {\"retrieved_contexts\": get_negative_question_responses(row[\"Question\"], num_responses=4) \n",
603
+ " if row[\"true_contexts\"] else get_negative_question_responses(row[\"Question\"], num_responses=5)},\n",
604
+ " # keep_in_memory=True,\n",
605
+ " # num_proc=60\n",
606
+ ")\n",
607
+ "canard_train_augm.save_to_disk(\"canard_train_augm_full.hf\")"
608
+ ]
609
+ },
610
+ {
611
+ "cell_type": "code",
612
+ "execution_count": 20,
613
+ "id": "b244a1b0-58c9-408f-91c1-10a29ccd43ce",
614
+ "metadata": {},
615
+ "outputs": [
616
+ {
617
+ "data": {
618
+ "text/plain": [
619
+ "Dataset({\n",
620
+ " features: ['History', 'QuAC_dialog_id', 'Question', 'Question_no', 'Rewrite', 'true_contexts', 'true_page', 'answer', 'retrieved_contexts'],\n",
621
+ " num_rows: 39407\n",
622
+ "})"
623
+ ]
624
+ },
625
+ "execution_count": 20,
626
+ "metadata": {},
627
+ "output_type": "execute_result"
628
+ }
629
+ ],
630
+ "source": [
631
+ "datasets.load_from_disk(\"canard_train_augm_full.hf\")"
632
+ ]
633
+ },
634
+ {
635
+ "cell_type": "markdown",
636
+ "id": "2c4c5925-9abb-4d95-941b-6634f1bdb597",
637
+ "metadata": {},
638
+ "source": [
639
+ "## Test dataset generation"
640
+ ]
641
+ },
642
+ {
643
+ "cell_type": "code",
644
+ "execution_count": 21,
645
+ "id": "4c5f5447-38fc-4f7d-8b46-292de36bba8a",
646
+ "metadata": {},
647
+ "outputs": [
648
+ {
649
+ "name": "stderr",
650
+ "output_type": "stream",
651
+ "text": [
652
+ "Using custom data configuration default-a7ce477a9c57a36e\n",
653
+ "Found cached dataset json (/home/xstefan3/.cache/huggingface/datasets/json/default-a7ce477a9c57a36e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)\n"
654
+ ]
655
+ },
656
+ {
657
+ "data": {
658
+ "application/vnd.jupyter.widget-view+json": {
659
+ "model_id": "ece835047b44454a9f31ff14f6986640",
660
+ "version_major": 2,
661
+ "version_minor": 0
662
+ },
663
+ "text/plain": [
664
+ " 0%| | 0/1 [00:00<?, ?it/s]"
665
+ ]
666
+ },
667
+ "metadata": {},
668
+ "output_type": "display_data"
669
+ },
670
+ {
671
+ "name": "stderr",
672
+ "output_type": "stream",
673
+ "text": [
674
+ "Found cached dataset quac (/home/xstefan3/.cache/huggingface/datasets/quac/plain_text/1.1.0/4170258e7e72d7c81bd6441b3f3489ea1544f0ff226ce61e22bb00c6e9d01fb6)\n"
675
+ ]
676
+ }
677
+ ],
678
+ "source": [
679
+ "import datasets\n",
680
+ "\n",
681
+ "# make sure that we test with conversations that the model has not seen before\n",
682
+ "canard_test = datasets.load_dataset(\"json\", data_files=\"datasets/CANARD_Release/test.json\")[\"train\"]\n",
683
+ "quac_test = datasets.load_dataset(\"quac\", split=\"validation\")\n",
684
+ "quac_test_df = quac_test.to_pandas().set_index(\"dialogue_id\", drop=True)"
685
+ ]
686
+ },
687
+ {
688
+ "cell_type": "code",
689
+ "execution_count": 22,
690
+ "id": "68f61770-cb34-43f5-aa51-6e1a68e0d362",
691
+ "metadata": {},
692
+ "outputs": [
693
+ {
694
+ "data": {
695
+ "text/plain": [
696
+ "{'History': ['Ursula K. Le Guin',\n",
697
+ " 'Influences',\n",
698
+ " 'what influenced her?',\n",
699
+ " 'Le Guin was influenced by fantasy writers,',\n",
700
+ " 'who were they?',\n",
701
+ " 'J. R. R. Tolkien, by science fiction writers,',\n",
702
+ " 'how did they influence her?',\n",
703
+ " 'her influences, she replied: Once I learned to read, I read everything. I read all the famous fantasies'],\n",
704
+ " 'QuAC_dialog_id': 'C_420bfcf5d8b344a480ac654f08ee55ad_1',\n",
705
+ " 'Question': 'which other fantasy writer influenced her?',\n",
706
+ " 'Question_no': 4,\n",
707
+ " 'Rewrite': 'Besides J. R. R. Tolkien which other fantasy writer influenced Le Guin?'}"
708
+ ]
709
+ },
710
+ "execution_count": 22,
711
+ "metadata": {},
712
+ "output_type": "execute_result"
713
+ }
714
+ ],
715
+ "source": [
716
+ "# check the match on QuAC_dialog_id\n",
717
+ "canard_test[102]"
718
+ ]
719
+ },
720
+ {
721
+ "cell_type": "code",
722
+ "execution_count": 23,
723
+ "id": "8f21d07e-9d94-4d29-bb6f-a4bbd5ef833e",
724
+ "metadata": {},
725
+ "outputs": [
726
+ {
727
+ "data": {
728
+ "text/plain": [
729
+ "wikipedia_page_title Ursula K. Le Guin\n",
730
+ "background Ursula Kroeber Le Guin (; October 21, 1929 - J...\n",
731
+ "section_title Influences\n",
732
+ "context Le Guin was influenced by fantasy writers, inc...\n",
733
+ "turn_ids [C_420bfcf5d8b344a480ac654f08ee55ad_1_q#0, C_4...\n",
734
+ "questions [what influenced her?, who were they?, how did...\n",
735
+ "followups [0, 0, 0, 0, 1, 0, 0, 0, 1]\n",
736
+ "yesnos [2, 2, 2, 2, 2, 0, 2, 2, 2]\n",
737
+ "answers {'texts': [['Le Guin was influenced by fantasy...\n",
738
+ "orig_answers {'texts': ['Le Guin was influenced by fantasy ...\n",
739
+ "Name: C_420bfcf5d8b344a480ac654f08ee55ad_1, dtype: object"
740
+ ]
741
+ },
742
+ "execution_count": 23,
743
+ "metadata": {},
744
+ "output_type": "execute_result"
745
+ }
746
+ ],
747
+ "source": [
748
+ "quac_test_df.loc[\"C_420bfcf5d8b344a480ac654f08ee55ad_1\"]"
749
+ ]
750
+ },
751
+ {
752
+ "cell_type": "code",
753
+ "execution_count": 24,
754
+ "id": "798ad24f-1cb3-4b2d-8fe0-d92e4aa97ff5",
755
+ "metadata": {},
756
+ "outputs": [
757
+ {
758
+ "name": "stderr",
759
+ "output_type": "stream",
760
+ "text": [
761
+ "Loading cached processed dataset at /home/xstefan3/.cache/huggingface/datasets/json/default-a7ce477a9c57a36e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-e93159eee5490cc9.arrow\n"
762
+ ]
763
+ }
764
+ ],
765
+ "source": [
766
+ "canard_test = canard_test.map(lambda row: \n",
767
+ "{\n",
768
+ " \"true_contexts\": quac_test_df.loc[row[\"QuAC_dialog_id\"]][\"context\"],\n",
769
+ " \"true_page\": quac_test_df.loc[row[\"QuAC_dialog_id\"]][\"wikipedia_page_title\"],\n",
770
+ " \"answer\": answer_for_question(*quac_test_df.loc[row[\"QuAC_dialog_id\"]][[\"questions\", \"answers\"]].values, row[\"Question\"])\n",
771
+ "})"
772
+ ]
773
+ },
774
+ {
775
+ "cell_type": "markdown",
776
+ "id": "cc8fda7c-c140-438d-ac16-7df672d492eb",
777
+ "metadata": {},
778
+ "source": [
779
+ "### 2. Negative contexts collection\n"
780
+ ]
781
+ },
782
+ {
783
+ "cell_type": "code",
784
+ "execution_count": 25,
785
+ "id": "500d28a4-2dde-45ca-a12b-48d6ceefb7fd",
786
+ "metadata": {},
787
+ "outputs": [],
788
+ "source": [
789
+ "# We initialize a new IR system for response - pesimistic scenario\n",
790
+ "from BM25_irsystem import BM25PlusSystem, SimpleDocProcessing"
791
+ ]
792
+ },
793
+ {
794
+ "cell_type": "code",
795
+ "execution_count": 26,
796
+ "id": "de060c5b-00eb-4d8a-b645-fedfd4ec0b29",
797
+ "metadata": {},
798
+ "outputs": [],
799
+ "source": [
800
+ "from pv211_utils.trec.entities import TrecDocumentBase, TrecQueryBase"
801
+ ]
802
+ },
803
+ {
804
+ "cell_type": "code",
805
+ "execution_count": 27,
806
+ "id": "63b93efa-5e2d-4bb9-a02d-3cdc23ee1e42",
807
+ "metadata": {},
808
+ "outputs": [],
809
+ "source": [
810
+ "documents = {str(i): TrecDocumentBase(document_id=i, body=context) for i, context in enumerate(quac_test_df.context)}\n",
811
+ "\n",
812
+ "irsystem = BM25PlusSystem(documents, preprocessing=SimpleDocProcessing())"
813
+ ]
814
+ },
815
+ {
816
+ "cell_type": "code",
817
+ "execution_count": 28,
818
+ "id": "923bfc60-6ac0-4c9c-9f78-c2852541a7fc",
819
+ "metadata": {},
820
+ "outputs": [
821
+ {
822
+ "data": {
823
+ "application/vnd.jupyter.widget-view+json": {
824
+ "model_id": "5a77c28c178243888d2672ac3f133078",
825
+ "version_major": 2,
826
+ "version_minor": 0
827
+ },
828
+ "text/plain": [
829
+ " 0%| | 0/5571 [00:00<?, ?ex/s]"
830
+ ]
831
+ },
832
+ "metadata": {},
833
+ "output_type": "display_data"
834
+ },
835
+ {
836
+ "data": {
837
+ "application/vnd.jupyter.widget-view+json": {
838
+ "model_id": "db0ef84b89d841d7b75942cb31a3c3ac",
839
+ "version_major": 2,
840
+ "version_minor": 0
841
+ },
842
+ "text/plain": [
843
+ "Saving the dataset (0/1 shards): 0%| | 0/5571 [00:00<?, ? examples/s]"
844
+ ]
845
+ },
846
+ "metadata": {},
847
+ "output_type": "display_data"
848
+ }
849
+ ],
850
+ "source": [
851
+ "canard_test_augm = canard_test.map(\n",
852
+ " lambda row: {\"retrieved_contexts\": get_negative_question_responses(row[\"Question\"], num_responses=4) \n",
853
+ " if row[\"true_contexts\"] else get_negative_question_responses(row[\"Question\"], num_responses=5)},\n",
854
+ " # keep_in_memory=True,\n",
855
+ " # num_proc=60\n",
856
+ ")\n",
857
+ "canard_test_augm.save_to_disk(\"canard_test_augm_full.hf\")"
858
+ ]
859
+ }
860
+ ],
861
+ "metadata": {
862
+ "kernelspec": {
863
+ "display_name": "Python 3 (ipykernel)",
864
+ "language": "python",
865
+ "name": "python3"
866
+ },
867
+ "language_info": {
868
+ "codemirror_mode": {
869
+ "name": "ipython",
870
+ "version": 3
871
+ },
872
+ "file_extension": ".py",
873
+ "mimetype": "text/x-python",
874
+ "name": "python",
875
+ "nbconvert_exporter": "python",
876
+ "pygments_lexer": "ipython3",
877
+ "version": "3.10.8"
878
+ }
879
+ },
880
+ "nbformat": 4,
881
+ "nbformat_minor": 5
882
+ }