MengniWang commited on
Commit
03e8016
·
1 Parent(s): aa281eb

add batch inference code

Browse files
Files changed (1) hide show
  1. evaluation.ipynb +100 -0
evaluation.ipynb CHANGED
@@ -103,6 +103,106 @@
103
  "print('acc: ', acc)"
104
  ]
105
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  {
107
  "attachments": {},
108
  "cell_type": "markdown",
 
103
  "print('acc: ', acc)"
104
  ]
105
  },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": null,
109
+ "metadata": {
110
+ "vscode": {
111
+ "languageId": "plaintext"
112
+ }
113
+ },
114
+ "outputs": [],
115
+ "source": [
116
+ "# batch inference\n",
117
+ "\n",
118
+ "from transformers import AutoTokenizer\n",
119
+ "import torch\n",
120
+ "import numpy as np\n",
121
+ "from datasets import load_dataset\n",
122
+ "import onnxruntime as ort\n",
123
+ "from torch.nn.functional import pad\n",
124
+ "from torch.utils.data import DataLoader\n",
125
+ "\n",
126
+ "batch_size = 2\n",
127
+ "pad_max = 196\n",
128
+ "\n",
129
+ "# load model\n",
130
+ "model_id = \"EleutherAI/gpt-j-6B\"\n",
131
+ "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
132
+ "\n",
133
+ "def tokenize_function(examples):\n",
134
+ " example = tokenizer(examples['text'])\n",
135
+ " return example\n",
136
+ "\n",
137
+ "# create dataloader\n",
138
+ "class Dataloader:\n",
139
+ " def __init__(self, pad_max=196, batch_size=1, sub_folder='validation'):\n",
140
+ " self.pad_max = pad_max\n",
141
+ " self.batch_size=batch_size\n",
142
+ " dataset = load_dataset('lambada', split=sub_folder)\n",
143
+ " dataset = dataset.map(tokenize_function, batched=True)\n",
144
+ " dataset.set_format(type=\"torch\", columns=[\"input_ids\", \"attention_mask\"])\n",
145
+ " self.dataloader = DataLoader(\n",
146
+ " dataset,\n",
147
+ " batch_size=self.batch_size,\n",
148
+ " shuffle=False,\n",
149
+ " collate_fn=self.collate_batch,\n",
150
+ " )\n",
151
+ "\n",
152
+ " def collate_batch(self, batch):\n",
153
+ " input_ids_padded = []\n",
154
+ " attention_mask_padded = []\n",
155
+ " last_ind = []\n",
156
+ " for text in batch:\n",
157
+ " input_ids = text[\"input_ids\"] if text[\"input_ids\"].shape[0] <= self.pad_max else text[\"input_ids\"][0:int(self.pad_max-1)]\n",
158
+ " pad_len = self.pad_max - input_ids.shape[0]\n",
159
+ " last_ind.append(input_ids.shape[0] - 1)\n",
160
+ " input_ids = pad(input_ids, (0, pad_len), value=1)\n",
161
+ " input_ids_padded.append(input_ids)\n",
162
+ " attention_mask = torch.ones(input_ids.shape[0] + 1)\n",
163
+ " attention_mask_padded.append(attention_mask)\n",
164
+ " return (torch.vstack(input_ids_padded), torch.vstack(attention_mask_padded)), torch.tensor(last_ind)\n",
165
+ "\n",
166
+ " def __iter__(self):\n",
167
+ " try:\n",
168
+ " for (input_ids, attention_mask), last_ind in self.dataloader:\n",
169
+ " data = [input_ids.detach().cpu().numpy().astype('int64')]\n",
170
+ " data.append(attention_mask.detach().cpu().numpy().astype('int64'))\n",
171
+ " yield data, last_ind.detach().cpu().numpy()\n",
172
+ " except StopIteration:\n",
173
+ " return\n",
174
+ "\n",
175
+ "# create session\n",
176
+ "options = ort.SessionOptions()\n",
177
+ "options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL\n",
178
+ "session = ort.InferenceSession('/path/to/model.onnx', options, providers=ort.get_available_providers())\n",
179
+ "total, hit = 0, 0\n",
180
+ "\n",
181
+ "dataloader = Dataloader(pad_max=pad_max, batch_size=batch_size)\n",
182
+ "\n",
183
+ "# inference\n",
184
+ "for idx, (batch, last_ind) in enumerate(dataloader):\n",
185
+ " label = torch.from_numpy(batch[0][torch.arange(len(last_ind)), last_ind])\n",
186
+ " pad_len = pad_max - last_ind - 1\n",
187
+ " ort_inputs = {\n",
188
+ " 'input_ids': batch[0],\n",
189
+ " 'attention_mask': batch[1]\n",
190
+ " }\n",
191
+ " for i in range(28):\n",
192
+ " ort_inputs[\"past_key_values.{}.key\".format(i)] = np.zeros((batch_size,16,1,256), dtype='float32')\n",
193
+ " ort_inputs[\"past_key_values.{}.value\".format(i)] = np.zeros((batch_size,16,1,256), dtype='float32')\n",
194
+ " \n",
195
+ " predictions = session.run(None, ort_inputs)\n",
196
+ " outputs = torch.from_numpy(predictions[0])\n",
197
+ " last_token_logits = outputs[torch.arange(len(last_ind)), -2 - pad_len, :]\n",
198
+ " pred = last_token_logits.argmax(dim=-1)\n",
199
+ " total += len(label)\n",
200
+ " hit += (pred == label).sum().item()\n",
201
+ "\n",
202
+ "acc = hit / total\n",
203
+ "print('acc: ', acc)"
204
+ ]
205
+ },
206
  {
207
  "attachments": {},
208
  "cell_type": "markdown",