Spaces:
Runtime error
Runtime error
File size: 9,676 Bytes
6996634 d85e9cd 6996634 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 |
# -- Utils .py file
# -- Libraries
from typing import Any, Dict, List, Mapping, Optional
from pydantic import Extra, Field, root_validator
from langchain.llms.base import LLM
from langchain.utils import get_from_dict_or_env
from langchain.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
from langchain.document_loaders import TextLoader
from langchain.embeddings import HuggingFaceEmbeddings
from googletrans import Translator
import streamlit as st
import together
import textwrap
import spacy
import os
import re
os.environ["TOGETHER_API_KEY"] = "6101599d6e33e3bda336b8d007ca22e35a64c72cfd52c2d8197f663389fc50c5"
# -- LLM class
class TogetherLLM(LLM):
"""Together large language models."""
model: str = "togethercomputer/llama-2-70b-chat"
"""model endpoint to use"""
together_api_key: str = os.environ["TOGETHER_API_KEY"]
"""Together API key"""
temperature: float = 0.7
"""What sampling temperature to use."""
max_tokens: int = 512
"""The maximum number of tokens to generate in the completion."""
original_transcription: str = ""
"""Original transcription"""
class Config:
extra = Extra.forbid
#@root_validator(skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the API key is set."""
api_key = get_from_dict_or_env(
values, "together_api_key", "TOGETHER_API_KEY"
)
values["together_api_key"] = api_key
return values
@property
def _llm_type(self) -> str:
"""Return type of LLM."""
return "together"
def clean_duplicates(self, transcription: str) -> str:
transcription = transcription.strip().replace('/n/n ', """
""")
new_transcription_aux = []
for text in transcription.split('\n\n'):
if text not in new_transcription_aux:
new_transcription_aux.append(text)
return '\n\n'.join(new_transcription_aux)
def _call(
self,
prompt: str,
**kwargs: Any,
) -> str:
"""Call to Together endpoint."""
regex_transcription = r'CONTEXTO:(\n.*)+PREGUNTA'
regex_init_transcription = r"Desde el instante [0-9]+:[0-9]+:[0-9]+(?:\.[0-9]+)? hasta [0-9]+:[0-9]+:[0-9]+(?:\.[0-9]+)? [a-zA-Z ]+ dice: ?"
# -- Extract transcription
together.api_key = self.together_api_key
cleaned_prompt = self.clean_duplicates(prompt)
print(cleaned_prompt)
resultado = re.search(regex_transcription, cleaned_prompt, re.DOTALL)
resultado = re.sub(regex_init_transcription, "", resultado.group(1).strip()).replace('\"', '')
resultado_alpha_num = [re.sub(r'\W+', ' ', resultado_aux).strip().lower() for resultado_aux in resultado.split('\n\n')]
# -- Setup new transcription format, without duplicates and with its correspondent speaker
new_transcription = []
for transcription in self.original_transcription.split('\n\n'):
transcription_cleaned = re.sub(regex_init_transcription, "", transcription.strip()).replace('\"', '')
transcription_cleaned = re.sub(r'\W+', ' ', transcription_cleaned).strip().lower()
for resultado_aux in resultado_alpha_num:
if resultado_aux in transcription_cleaned or transcription_cleaned in resultado_aux:
init_transcription = re.search(regex_init_transcription, transcription).group(0)
new_transcription.append(init_transcription + '\"' + resultado_aux + '\"')
# -- Merge with original transcription
new_transcription = '\n\n'.join(list(set(new_transcription)))
new_cleaned_prompt = re.sub(regex_transcription, f"""CONTEXTO:
{new_transcription}
PREGUNTA:""", cleaned_prompt, re.DOTALL)
print(new_cleaned_prompt)
output = together.Complete.create(new_cleaned_prompt,
model=self.model,
max_tokens=self.max_tokens,
temperature=self.temperature,
)
text = output['output']['choices'][0]['text']
return text
# -- Python function to setup basic features: translator, SpaCy pipeline and LLM model
@st.cache_resource
def setup_app(transcription_path, emb_model, model, _logger):
# -- Setup enviroment and features
translator = Translator(service_urls=['translate.googleapis.com'])
nlp = spacy.load('es_core_news_lg')
_logger.info('Setup environment and features...')
# -- Setup LLM
together.api_key = os.environ["TOGETHER_API_KEY"]
# List available models and descriptons
models = together.Models.list()
# Set llama2 7b LLM
together.Models.start(model)
_logger.info('Setup environment and features - FINISHED!')
# -- Read translated transcription
_logger.info('Loading transcription...')
loader = TextLoader(transcription_path)
documents = loader.load()
# Splitting the text into chunks
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=100)
texts = text_splitter.split_documents(documents)
_logger.info('Loading transcription - FINISHED!')
# -- Load embedding
_logger.info('Loading embedding...')
encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
model_norm = HuggingFaceEmbeddings(
model_name=emb_model,
model_kwargs={'device': 'cpu'},
encode_kwargs=encode_kwargs
)
_logger.info('Loading embedding - FINISHED!')
# -- Create document database
_logger.info('Creating document database...')
# Embed and store the texts
# Supplying a persist_directory will store the embeddings on disk
persist_directory = 'db'
## Here is the nmew embeddings being used
embedding = model_norm
vectordb = Chroma.from_documents(documents=texts,
embedding=embedding,
persist_directory=persist_directory)
# -- Make a retreiver
retriever = vectordb.as_retriever(search_kwargs={"k": 5})
_logger.info('Creating document database - FINISHED!')
_logger.info('Setup finished!')
return translator, nlp, retriever
# -- Function to get prompt template
def get_prompt(instruction, system_prompt, b_sys, e_sys, b_inst, e_inst, _logger):
new_system_prompt = b_sys + system_prompt + e_sys
prompt_template = b_inst + new_system_prompt + instruction + e_inst
_logger.info('Prompt template created: {}'.format(instruction))
return prompt_template
# -- Function to create the chain to answer questions
@st.cache_resource
def create_llm_chain(model, _retriever, _chain_type_kwargs, _logger, transcription_path):
_logger.info('Creating LLM chain...')
# -- Keep original transcription
with open(transcription_path, 'r') as f:
formatted_transcription = f.read()
llm = TogetherLLM(
model= model,
temperature = 0.0,
max_tokens = 1024,
original_transcription = formatted_transcription
)
qa_chain = RetrievalQA.from_chain_type(llm=llm,
chain_type="stuff",
retriever=_retriever,
chain_type_kwargs=_chain_type_kwargs,
return_source_documents=True)
_logger.info('Creating LLM chain - FINISHED!')
return qa_chain
# -------------------------------------------
# -- Auxiliar functions
def wrap_text_preserve_newlines(text, width=110):
# Split the input text into lines based on newline characters
lines = text.split('\n')
# Wrap each line individually
wrapped_lines = [textwrap.fill(line, width=width) for line in lines]
# Join the wrapped lines back together using newline characters
wrapped_text = '\n'.join(wrapped_lines)
return wrapped_text
def process_llm_response(llm_response, nlp):
response = llm_response['result']
return wrap_text_preserve_newlines(response)
def time_to_seconds(time_str):
parts = time_str.split(':')
hours, minutes, seconds = map(float, parts)
return int((hours * 3600) + (minutes * 60) + seconds)
# -- Extract seconds from transcription
def add_hyperlink_and_convert_to_seconds(text):
time_pattern = r'(\d{2}:\d{2}:\d{2}(?:.\d{6})?)'
def get_seconds(match):
start_time_str, end_time_str = match[0], match[1]
start_time_seconds = time_to_seconds(start_time_str)
end_time_seconds = time_to_seconds(end_time_str)
return start_time_str, start_time_seconds, end_time_str, end_time_seconds
start_time_str, start_time_seconds, end_time_str, end_time_seconds = get_seconds(re.findall(time_pattern, text))
return start_time_str, start_time_seconds, end_time_str, end_time_seconds
# -- Streamlit HTML template
def typewrite(youtube_video_url, i=0):
youtube_video_url = youtube_video_url.replace("?enablejsapi=1", "")
margin = "{margin: 0;}"
html = f"""
<html>
<style>
p {margin}
</style>
<body>
<script src="https://www.youtube.com/player_api"></script>
<p align="center">
<iframe id="player_{i}" src="{youtube_video_url}" width="600" height="450"></iframe>
</p>
</body>
</html>
"""
return html |