nbroad HF staff commited on
Commit
f245c03
·
1 Parent(s): 3c20160

Update infer.py

Browse files
Files changed (1) hide show
  1. infer.py +17 -19
infer.py CHANGED
@@ -147,25 +147,23 @@ def get_model_and_tokenizer(model_name: str, optimization_level: str, progress):
147
  )
148
 
149
 
150
- # def collate_fn(examples, tokenizer=None, padding=None, column_name="text"):
151
- # try:
152
- # keys = examples[0].keys()
153
- # except KeyError:
154
- # print(examples)
155
- # else:
156
- # batch = {k: [] for k in examples[0].keys()}
157
-
158
- # tokenized = tokenizer(
159
- # [x[column_name] for x in examples],
160
- # truncation=True,
161
- # padding=padding,
162
- # max_length=512,
163
- # return_tensors="pt"
164
- # )
165
 
166
- # tokenized[column_name] = [x[column_name] for x in examples]
167
 
168
- # return tokenized
169
 
170
 
171
  @torch.inference_mode()
@@ -247,8 +245,8 @@ def batch_embed(
247
 
248
  start_time = time.time()
249
 
250
- collator = DataCollatorWithPadding(
251
- tokenizer, padding=True, max_length=512, pad_to_multiple_of=16
252
  )
253
 
254
  dl = DataLoader(
 
147
  )
148
 
149
 
150
+ def collate_fn(examples, column_name, tokenizer):
151
+ feature_cols = ["input_ids", "attention_mask"]
152
+ features = [{k: x[k] for k in feature_cols} for x in examples]
153
+
154
+ print(features)
155
+
156
+ tokenized = tokenizer.pad(
157
+ features,
158
+ padding=True,
159
+ max_length=512,
160
+ return_tensors="pt",
161
+ pad_to_multiple_of=16,
162
+ )
 
 
163
 
164
+ tokenized[column_name] = [x[column_name] for x in examples]
165
 
166
+ return tokenized
167
 
168
 
169
  @torch.inference_mode()
 
245
 
246
  start_time = time.time()
247
 
248
+ collator = partial(
249
+ collate_fn, column_name=column_name, tokenizer=tokenizer
250
  )
251
 
252
  dl = DataLoader(