File size: 4,156 Bytes
0436f2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import sys
from typing import List, Optional
from langchain.embeddings.base import Embeddings
from langchain_qdrant import Qdrant
from langchain.schema import Document
from langchain_openai import OpenAIEmbeddings
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams

import config as config

# Qdrant Collections Params

openai_embeddings = OpenAIEmbeddings(model=config.EMBEDDINGS_MODEL_NAME)
QDRANT_COLLECTIONS_PARAMS = {'openai_large_chunks_1500char': {'collection_name': 'openai_large_chunks_1500char', 
                                                              'embeddings_model_name': 'text-embedding-3-large',
                                                              'vector_size': 3072,
                                                              'distance': Distance.COSINE,
                                                              'embeddings_model': openai_embeddings},
                                                              
                             'openai_large_chunks_500char':  {'collection_name': 'openai_large_chunks_500char', 
                                                              'embeddings_model_name': 'text-embedding-3-large',
                                                              'vector_size': 3072,
                                                              'distance': Distance.COSINE,
                                                              'embeddings_model': openai_embeddings}}




class QdrantManager:
    '''Qdrant Manager to create a collection, add documents, and get a retriever. To see available collections, run client.get_collections()'''
       
    def __init__(
        self,
        collection_name: str,
        embeddings: Optional[Embeddings] = None,
        url: Optional[str] = None,
        api_key: Optional[str] = None,
        vector_size: Optional[int] = None,
        distance: Optional[Distance] = None
    ):
        self.collection_name = collection_name
        self.embeddings = embeddings or QDRANT_COLLECTIONS_PARAMS[collection_name]['embeddings_model']
        self.url = url or os.getenv("QDRANT_URL") or config[collection_name].get('url')
        self.api_key = api_key or os.getenv("QDRANT_API_KEY") or config[collection_name].get('api_key')
        self.vector_size = vector_size or QDRANT_COLLECTIONS_PARAMS[collection_name]['vector_size']
        self.distance = distance or QDRANT_COLLECTIONS_PARAMS[collection_name]['distance']

        self.client = QdrantClient(url=self.url, api_key=self.api_key)
        self.qdrant = Qdrant(
            client=self.client,
            collection_name=self.collection_name,
            embeddings=self.embeddings,
            content_payload_key="page_content",
            metadata_payload_key="metadata",
        )

    def create_collection(self) -> None:
        """Create a new collection if it doesn't exist."""
        collections = self.client.get_collections().collections
        if self.collection_name not in [c.name for c in collections]:
            self.client.create_collection(
                collection_name=self.collection_name,
                vectors_config=VectorParams(size=self.vector_size, distance=self.distance),
            )
            print(f"Collection '{self.collection_name}' created.")
        else:
            print(f"Collection '{self.collection_name}' already exists.")

    def add_documents(self, documents: List[Document], batch_size: int = 1000) -> None:
        """Add documents to the collection in batches."""
        for i in range(0, len(documents), batch_size):
            batch = documents[i:i+batch_size]
            self.qdrant.add_documents(batch)
            print(f"Added batch {i//batch_size + 1} ({len(batch)} documents)")
        print(f"Total documents added: {len(documents)}")

    def get_vectorstore(self):
        """Get the Qdrant vectorstore."""
        return self.qdrant

    def delete_collection(self) -> None:
        """Delete the collection."""
        self.client.delete_collection(self.collection_name)
        print(f"Collection '{self.collection_name}' deleted.")