import time import os from typing import Literal, Tuple import gradio as gr import torch from transformers import AutoModel, AutoTokenizer import meilisearch tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-base-en-v1.5") model = AutoModel.from_pretrained("BAAI/bge-base-en-v1.5") model.eval() cuda_available = torch.cuda.is_available() print(f"CUDA available: {cuda_available}") meilisearch_client = meilisearch.Client( "https://edge.meilisearch.com", os.environ["MEILISEARCH_KEY"] ) meilisearch_index_name = "docs-embed" meilisearch_index = meilisearch_client.index(meilisearch_index_name) output_options = ["RAG-friendly", "human-friendly"] def search_embeddings( query_text: str, output_option: Literal["RAG-friendly", "human-friendly"] ) -> Tuple[str, str]: start_time_embedding = time.time() query_prefix = "Represent this sentence for searching code documentation: " query_tokens = tokenizer( query_prefix + query_text, padding=True, truncation=True, return_tensors="pt", max_length=512, ) # step1: tokenizer the query with torch.no_grad(): # Compute token embeddings model_output = model(**query_tokens) sentence_embeddings = model_output[0][:, 0] # normalize embeddings sentence_embeddings = torch.nn.functional.normalize( sentence_embeddings, p=2, dim=1 ) sentence_embeddings_list = sentence_embeddings[0].tolist() elapsed_time_embedding = time.time() - start_time_embedding # step2: search meilisearch start_time_meilisearch = time.time() response = meilisearch_index.search( "", opt_params={ "vector": sentence_embeddings_list, "hybrid": {"semanticRatio": 1.0}, "limit": 5, "attributesToRetrieve": [ "text", "source_page_url", "source_page_title", "library", ], }, ) elapsed_time_meilisearch = time.time() - start_time_meilisearch hits = response["hits"] sources_md = [ f"[\"{hit['source_page_title']}\"]({hit['source_page_url']})" for hit in hits ] sources_md = ", ".join(sources_md) # step3: present the results in markdown if output_option == "human-friendly": md = f"Stats:\n\nembedding time: {elapsed_time_embedding:.2f}s\n\nmeilisearch time: {elapsed_time_meilisearch:.2f}s\n\n---\n\n" for hit in hits: text, source_page_url, source_page_title = ( hit["text"], hit["source_page_url"], hit["source_page_title"], ) source = f'src: ["{source_page_title}"]({source_page_url})' md += text + f"\n\n{source}\n\n---\n\n" return md, sources_md elif output_option == "RAG-friendly": hit_texts = [hit["text"] for hit in hits] hit_text_str = "\n------------\n".join(hit_texts) return hit_text_str, sources_md demo = gr.Interface( fn=search_embeddings, inputs=[ gr.Textbox( label="enter your query", placeholder="Type Markdown here...", lines=10 ), gr.Radio( label="Select an output option", choices=output_options, value="RAG-friendly", ), ], outputs=[gr.Markdown(), gr.Markdown()], title="HF Docs Embeddings Explorer", allow_flagging="never", ) if __name__ == "__main__": demo.launch()