ravfogs commited on
Commit
3eef34e
·
1 Parent(s): b035a1d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +4 -3
README.md CHANGED
@@ -25,12 +25,13 @@ def load_finetuned_model():
25
  return tokenizer, query_encoder, sentence_encoder
26
 
27
 
28
- def encode_batch(model, tokenizer, sentences, device):
29
  input_ids = tokenizer(sentences, padding=True, max_length=128, truncation=True, return_tensors="pt",
30
  add_special_tokens=True).to(device)
31
  features = model(**input_ids)[0]
32
- features = torch.sum(features[:,1:,:] * input_ids["attention_mask"][:,1:].unsqueeze(-1), dim=1) / torch.clamp(torch.sum(input_ids["attention_mask"][:,1:], dim=1, keepdims=True), min=1e-9)
33
-
 
34
  return features
35
 
36
  ```
 
25
  return tokenizer, query_encoder, sentence_encoder
26
 
27
 
28
+ def encode_batch_fn(model, tokenizer, sentences, device)
29
  input_ids = tokenizer(sentences, padding=True, max_length=128, truncation=True, return_tensors="pt",
30
  add_special_tokens=True).to(device)
31
  features = model(**input_ids)[0]
32
+
33
+ features = torch.sum(features[:,:,:] * input_ids["attention_mask"][:,:].unsqueeze(-1), dim=1) / torch.clamp(torch.sum(input_ids["attention_mask"][:,:], dim=1, keepdims=True), min=1e-9)
34
+
35
  return features
36
 
37
  ```