IlyasMoutawwakil HF staff commited on
Commit
5e0585b
·
verified ·
1 Parent(s): 4f8ccb0

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +29 -7
handler.py CHANGED
@@ -1,5 +1,6 @@
1
  from typing import Any, Dict, List
2
 
 
3
  from fastrag.rankers import QuantizedBiEncoderRanker
4
 
5
 
@@ -9,13 +10,34 @@ class EndpointHandler:
9
  self.ranker = QuantizedBiEncoderRanker(model_name_or_path=model_id)
10
 
11
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
12
- query = data.get("query")
13
- documents = data.get("documents")
14
-
 
15
  top_k = data.get("top_k", None)
16
 
17
- assert isinstance(query, str), "Expected query to be a string"
18
- assert isinstance(documents, list), "Expected documents to be a list"
19
- assert all(isinstance(document, str) for document in documents), "Expected documents to be a list of strings"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- return self.ranker.predict(query=query, documents=documents, top_k=top_k)
 
 
1
  from typing import Any, Dict, List
2
 
3
+ from haystack.schema import Document
4
  from fastrag.rankers import QuantizedBiEncoderRanker
5
 
6
 
 
10
  self.ranker = QuantizedBiEncoderRanker(model_name_or_path=model_id)
11
 
12
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
13
+ query = data.get("query", None)
14
+ queries = data.get("queries", None)
15
+ documents = data.get("documents", None)
16
+ batch_size = data.get("batch_size", None)
17
  top_k = data.get("top_k", None)
18
 
19
+ if query is not None:
20
+ assert isinstance(query, str), "Expected query to be a string"
21
+ assert isinstance(documents, list), "Expected documents to be a list"
22
+ assert all(
23
+ isinstance(d, dict) for d in documents
24
+ ), "Expected each document in documents to be a dictionary"
25
+ documents = [Document.from_dict(d) for d in documents]
26
+ return self.ranker.predict(query=query, documents=documents, top_k=top_k)
27
+
28
+ elif queries is not None:
29
+ assert isinstance(queries, list), "Expected queries to be a list"
30
+ assert all(
31
+ isinstance(query, str) for query in queries
32
+ ), "Expected each query in queries to be a string"
33
+ assert isinstance(documents, list), "Expected documents to be a list"
34
+ assert all(
35
+ all(isinstance(d, dict) for d in doc) for doc in documents
36
+ ), "Expected each document in list of documents to be a dictionary"
37
+ documents = [Document.from_dict(d) for d in documents]
38
+ return self.ranker.predict_batch(
39
+ queries=queries, documents=documents, batch_size=batch_size, top_k=top_k
40
+ )
41
 
42
+ else:
43
+ raise ValueError("Expected either query or queries")