srf_chatbot_v2 / qdrant_manager.py
nadaaaita's picture
set up folder structure
0436f2c
raw
history blame
4.16 kB
import os
import sys
from typing import List, Optional
from langchain.embeddings.base import Embeddings
from langchain_qdrant import Qdrant
from langchain.schema import Document
from langchain_openai import OpenAIEmbeddings
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams
import config as config
# Qdrant Collections Params
openai_embeddings = OpenAIEmbeddings(model=config.EMBEDDINGS_MODEL_NAME)
QDRANT_COLLECTIONS_PARAMS = {'openai_large_chunks_1500char': {'collection_name': 'openai_large_chunks_1500char',
'embeddings_model_name': 'text-embedding-3-large',
'vector_size': 3072,
'distance': Distance.COSINE,
'embeddings_model': openai_embeddings},
'openai_large_chunks_500char': {'collection_name': 'openai_large_chunks_500char',
'embeddings_model_name': 'text-embedding-3-large',
'vector_size': 3072,
'distance': Distance.COSINE,
'embeddings_model': openai_embeddings}}
class QdrantManager:
'''Qdrant Manager to create a collection, add documents, and get a retriever. To see available collections, run client.get_collections()'''
def __init__(
self,
collection_name: str,
embeddings: Optional[Embeddings] = None,
url: Optional[str] = None,
api_key: Optional[str] = None,
vector_size: Optional[int] = None,
distance: Optional[Distance] = None
):
self.collection_name = collection_name
self.embeddings = embeddings or QDRANT_COLLECTIONS_PARAMS[collection_name]['embeddings_model']
self.url = url or os.getenv("QDRANT_URL") or config[collection_name].get('url')
self.api_key = api_key or os.getenv("QDRANT_API_KEY") or config[collection_name].get('api_key')
self.vector_size = vector_size or QDRANT_COLLECTIONS_PARAMS[collection_name]['vector_size']
self.distance = distance or QDRANT_COLLECTIONS_PARAMS[collection_name]['distance']
self.client = QdrantClient(url=self.url, api_key=self.api_key)
self.qdrant = Qdrant(
client=self.client,
collection_name=self.collection_name,
embeddings=self.embeddings,
content_payload_key="page_content",
metadata_payload_key="metadata",
)
def create_collection(self) -> None:
"""Create a new collection if it doesn't exist."""
collections = self.client.get_collections().collections
if self.collection_name not in [c.name for c in collections]:
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=self.vector_size, distance=self.distance),
)
print(f"Collection '{self.collection_name}' created.")
else:
print(f"Collection '{self.collection_name}' already exists.")
def add_documents(self, documents: List[Document], batch_size: int = 1000) -> None:
"""Add documents to the collection in batches."""
for i in range(0, len(documents), batch_size):
batch = documents[i:i+batch_size]
self.qdrant.add_documents(batch)
print(f"Added batch {i//batch_size + 1} ({len(batch)} documents)")
print(f"Total documents added: {len(documents)}")
def get_vectorstore(self):
"""Get the Qdrant vectorstore."""
return self.qdrant
def delete_collection(self) -> None:
"""Delete the collection."""
self.client.delete_collection(self.collection_name)
print(f"Collection '{self.collection_name}' deleted.")