image_captioner / vector_search.py
Sverd's picture
upload from local pc
1352a28 verified
import cohere
from annoy import AnnoyIndex
import numpy as np
import dotenv
import os
import pandas as pd
dotenv.load_dotenv()
model_name = "embed-english-v3.0"
api_key = os.environ['COHERE_API_KEY']
input_type_embed = "search_document"
# Set up the cohere client
co = cohere.Client(api_key)
# Get the dataset of topics
topics = pd.read_csv("aicovers_topics.csv")
# Get the embeddings
list_embeds = co.embed(texts=list(topics['topic_cleaned']), model=model_name, input_type=input_type_embed).embeddings
# Create the search index, pass the size of embedding
search_index = AnnoyIndex(np.array(list_embeds).shape[1], metric='angular')
# Add vectors to the search index
for i in range(len(list_embeds)):
search_index.add_item(i, list_embeds[i])
search_index.build(10) # 10 trees
search_index.save('test.ann')
def topic_from_caption(caption):
"""
Returns a topic from an uploaded list that is semantically similar to the input caption.
Args:
- caption (str): The image caption generated by MS Azure.
Returns:
- str: The extracted topic based on the provided caption.
"""
input_type_query = "search_query"
caption_embed = co.embed(texts=[caption], model=model_name, input_type=input_type_query).embeddings # embeds a caption
topic_ids = search_index.get_nns_by_vector(caption_embed[0], n=1, include_distances=True) # retrieves the nearest category
topic = topics.iloc[topic_ids[0]]['topic_cleaned'].to_string(index=False, header=False)
return topic