IlyasMoutawwakil HF staff commited on
Commit
f5d46f6
·
verified ·
1 Parent(s): 5a4e99e

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +14 -34
handler.py CHANGED
@@ -1,43 +1,23 @@
1
- from typing import Any, Dict, List
2
 
3
- from haystack.schema import Document
4
- from fastrag.rankers import QuantizedBiEncoderRanker
 
5
 
6
 
7
  class EndpointHandler:
8
  def __init__(self, path=""):
9
  model_id = "Intel/bge-large-en-v1.5-rag-int8-static"
10
- self.ranker = QuantizedBiEncoderRanker(model_name_or_path=model_id)
11
 
12
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
13
- query = data.get("query", None)
14
- queries = data.get("queries", None)
15
- documents = data.get("documents", None)
16
- batch_size = data.get("batch_size", None)
 
 
17
  top_k = data.get("top_k", None)
18
 
19
- if query is not None:
20
- assert isinstance(query, str), "Expected query to be a string"
21
- assert isinstance(documents, list), "Expected documents to be a list"
22
- assert all(
23
- isinstance(d, dict) for d in documents
24
- ), "Expected each document in documents to be a dictionary"
25
- documents = [Document.from_dict(d) for d in documents]
26
- return self.ranker.predict(query=query, documents=documents, top_k=top_k)
27
-
28
- elif queries is not None:
29
- assert isinstance(queries, list), "Expected queries to be a list"
30
- assert all(
31
- isinstance(query, str) for query in queries
32
- ), "Expected each query in queries to be a string"
33
- assert isinstance(documents, list), "Expected documents to be a list"
34
- assert all(
35
- all(isinstance(d, dict) for d in doc) for doc in documents
36
- ), "Expected each document in list of documents to be a dictionary"
37
- documents = [Document.from_dict(d) for d in documents]
38
- return self.ranker.predict_batch(
39
- queries=queries, documents=documents, batch_size=batch_size, top_k=top_k
40
- )
41
-
42
- else:
43
- raise ValueError("Expected either query or queries")
 
1
+ from typing import Any, Dict
2
 
3
+ from haystack import Document
4
+
5
+ from fastrag.rankers import IPEXBiEncoderSimilarityRanker
6
 
7
 
8
  class EndpointHandler:
9
  def __init__(self, path=""):
10
  model_id = "Intel/bge-large-en-v1.5-rag-int8-static"
 
11
 
12
+ self.ranker = IPEXBiEncoderSimilarityRanker(model=model_id)
13
+
14
+ self.ranker.warm_up()
15
+
16
+ def __call__(self, data: Dict[str, Any]):
17
+ query = data.get("query")
18
+ documents = data.get("documents")
19
  top_k = data.get("top_k", None)
20
 
21
+ documents = [Document.from_dict(doc) for doc in documents]
22
+
23
+ return self.ranker.run(query, documents, top_k)