import nltk from nltk.collocations import BigramAssocMeasures, BigramCollocationFinder from nltk.corpus import stopwords import spacy from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity from datetime import datetime, timedelta import numpy as np import heapq from concurrent.futures import ThreadPoolExecutor from annoy import AnnoyIndex from transformers import pipeline from rank_bm25 import BM25Okapi from functools import partial import ssl try: _create_unverified_https_context = ssl._create_unverified_context except AttributeError: pass else: ssl._create_default_https_context = _create_unverified_https_context nltk.download('stopwords') nltk.download('punkt') nlp = spacy.load('en_core_web_sm') q = 0 def get_keywords(text, cache): global q if q % 1000 == 0: print(q) q += 1 if text in cache: return cache[text] doc = nlp(text) keywords = [] for token in doc: if token.pos_ in ['NOUN', 'PROPN', 'VERB']: keywords.append(token.text.lower()) stop_words = set(stopwords.words('english')) keywords = [word for word in keywords if word not in stop_words] bigram_measures = BigramAssocMeasures() finder = BigramCollocationFinder.from_words([token.text for token in doc]) bigrams = finder.nbest(bigram_measures.pmi, 10) keywords.extend([' '.join(bigram) for bigram in bigrams]) cache[text] = keywords return keywords def calculate_weight(message, sender_messages, cache): message_time = datetime.strptime(message[1], '%Y-%m-%d %H:%M:%S') recent_messages = sender_messages[np.abs((np.array([datetime.strptime(m[1], '%Y-%m-%d %H:%M:%S') for m in sender_messages]) - message_time).astype('timedelta64[s]').astype(int) <= 5 * 3600)] recent_keywords = [get_keywords(m[2], cache) for m in recent_messages] keyword_counts = [sum([k.count(keyword) for k in recent_keywords]) for keyword in get_keywords(message[2], cache)] weight = sum(keyword_counts) return weight class ChatDatabase: def __init__(self, filename): self.filename = filename self.messages = [] self.messages_array = None self.sender_array = None self.load_messages() self.index = None self.tfidf = None def load_messages(self): with open(self.filename, 'a') as file: pass with open(self.filename, 'r') as f: for line in f: parts = line.strip().split('\t') if len(parts) == 4: sender, time, text, tag = parts else: sender, time, text = parts tag = None message = (sender, time, text, tag) self.messages.append(message) self.messages_array = np.array(self.messages, dtype=object) self.sender_array = self.messages_array[:, 0] print(f'Database loaded. Number of messages: {len(self.messages_array)}') def add_message(self, sender, time, text, tag=None): message = (sender, time, text, tag) self.messages.append(message) self.messages_array = np.append(self.messages_array, [message], axis=0) self.sender_array = np.append(self.sender_array, sender) with open(self.filename, 'a') as f: f.write(f'{sender}\t{time}\t{text}\t{tag}\n') def predict_response_separate(self, query, sender, cache): if self.messages_array is None: print("Error: messages_array is None") return None sender_messages = self.messages_array[self.sender_array == sender] if len(sender_messages) == 0: print(f"No messages found for sender: {sender}") return None query_keywords = ' '.join(get_keywords(query, cache)) query_vector = self.tfidf.transform([query_keywords]).toarray()[0] relevant_indices = self.index.get_nns_by_vector(query_vector, 1) relevant_message = sender_messages[relevant_indices[0]] next_message_index = np.where(self.sender_array != sender)[0][0] if next_message_index < len(self.messages_array): predicted_response = self.messages_array[next_message_index] return tuple(predicted_response) else: return None def get_relevant_messages(self, sender, query, N, cache, query_tag=None, n_threads=30, tag_boost=1.5): if self.messages_array is None: print("Error: messages_array is None") return [] query_keywords = query.lower().split() # Filter messages by sender, tag, and keywords in a single line if query_tag: sender_messages = self.messages_array[ (self.sender_array == sender) & np.array([any(keyword in message.lower() for keyword in query_keywords) for message in self.messages_array[:, 2]]) ] else: sender_messages = self.messages_array[ (self.sender_array == sender) & np.array([any(keyword in message.lower() for keyword in query_keywords) for message in self.messages_array[:, 2]]) ] if len(sender_messages) == 0: print(f"No messages found for sender: {sender} with the given keywords") return [] else: print(len(sender_messages)) def process_batch(batch, query_keywords, current_time, query_tag): batch_keywords = [get_keywords(message[2], cache) for message in batch] bm25 = BM25Okapi(batch_keywords) bm25_scores = bm25.get_scores(query_keywords) time_scores = 1 / (1 + (current_time - np.array([datetime.strptime(m[1], '%Y-%m-%d %H:%M:%S') for m in batch])).astype('timedelta64[D]').astype(int)) tag_scores = np.where(np.array([m[3] for m in batch]) == query_tag, tag_boost, 1) combined_scores = 0.6 * np.array(bm25_scores) + 0.2 * time_scores + 0.2 * tag_scores return combined_scores, batch current_time = datetime.now() batch_size = max(1, len(sender_messages) // n_threads) batches = [sender_messages[i:i+batch_size] for i in range(0, len(sender_messages), batch_size)] with ThreadPoolExecutor(max_workers=n_threads) as executor: process_func = partial(process_batch, query_keywords=query_keywords, current_time=current_time, query_tag=query_tag) results = list(executor.map(process_func, batches)) all_scores = np.concatenate([r[0] for r in results]) all_messages = np.concatenate([r[1] for r in results]) top_indices = np.argsort(all_scores)[-N:][::-1] relevant_messages = all_messages[top_indices] return relevant_messages.tolist() def generate_response(self, query, sender, cache, query_tag=None): relevant_messages = self.get_relevant_messages(sender, query, 5, cache, query_tag) context = ' '.join([message[2] for message in relevant_messages]) generator = pipeline('text-generation', model='EleutherAI/gpt-neo-2.7B') response = generator(f'{context} {query}', max_length=100, do_sample=True)[0]['generated_text'] response = response.split(query)[-1].strip() return response # Usage example remains the same ''' # Usage example db = ChatDatabase('messages.txt') # Example 1: Get relevant messages query = 'Goodnight.' sender = 'Alice' N = 10 cache = {} query_tag = 'evening' relevant_messages = db.get_relevant_messages(sender, query, N, cache, query_tag) print("Relevant messages:") for message in relevant_messages: print(f"Sender: {message[0]}, Time: {message[1]}, Tag: {message[3]}") print(f"Message: {message[2][:100]}...") print() # Example 2: Predict response (using the original method) query = "what was that?" sender = 'David' db.build_index_separate(cache) predicted_response = db.predict_response_separate(query, sender, cache) print("\nPredicted response:") if predicted_response is not None: print(f"Sender: {predicted_response[0]}, Time: {predicted_response[1]}, Tag: {predicted_response[3]}") print(f"Message: {predicted_response[2][:100]}...") else: print('No predicted response found') # Example 3: Generate response query = "Let's plan a trip" sender = 'Alice' query_tag = 'travel' generated_response = db.generate_response(query, sender, cache, query_tag) print("\nGenerated response:") print(generated_response) '''