Arcana / nylon.py
Ocillus's picture
Upload 5 files
c20f7c1 verified
raw
history blame
8.74 kB
import nltk
from nltk.collocations import BigramAssocMeasures, BigramCollocationFinder
from nltk.corpus import stopwords
import spacy
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from datetime import datetime, timedelta
import numpy as np
import heapq
from concurrent.futures import ThreadPoolExecutor
from annoy import AnnoyIndex
from transformers import pipeline
from rank_bm25 import BM25Okapi
from functools import partial
import ssl
try:
_create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
pass
else:
ssl._create_default_https_context = _create_unverified_https_context
nltk.download('stopwords')
nltk.download('punkt')
nlp = spacy.load('en_core_web_sm')
q = 0
def get_keywords(text, cache):
global q
if q % 1000 == 0:
print(q)
q += 1
if text in cache:
return cache[text]
doc = nlp(text)
keywords = []
for token in doc:
if token.pos_ in ['NOUN', 'PROPN', 'VERB']:
keywords.append(token.text.lower())
stop_words = set(stopwords.words('english'))
keywords = [word for word in keywords if word not in stop_words]
bigram_measures = BigramAssocMeasures()
finder = BigramCollocationFinder.from_words([token.text for token in doc])
bigrams = finder.nbest(bigram_measures.pmi, 10)
keywords.extend([' '.join(bigram) for bigram in bigrams])
cache[text] = keywords
return keywords
def calculate_weight(message, sender_messages, cache):
message_time = datetime.strptime(message[1], '%Y-%m-%d %H:%M:%S')
recent_messages = sender_messages[np.abs((np.array([datetime.strptime(m[1], '%Y-%m-%d %H:%M:%S') for m in sender_messages]) - message_time).astype('timedelta64[s]').astype(int) <= 5 * 3600)]
recent_keywords = [get_keywords(m[2], cache) for m in recent_messages]
keyword_counts = [sum([k.count(keyword) for k in recent_keywords]) for keyword in get_keywords(message[2], cache)]
weight = sum(keyword_counts)
return weight
class ChatDatabase:
def __init__(self, filename):
self.filename = filename
self.messages = []
self.messages_array = None
self.sender_array = None
self.load_messages()
self.index = None
self.tfidf = None
def load_messages(self):
with open(self.filename, 'a') as file:
pass
with open(self.filename, 'r') as f:
for line in f:
parts = line.strip().split('\t')
if len(parts) == 4:
sender, time, text, tag = parts
else:
sender, time, text = parts
tag = None
message = (sender, time, text, tag)
self.messages.append(message)
self.messages_array = np.array(self.messages, dtype=object)
self.sender_array = self.messages_array[:, 0]
print(f'Database loaded. Number of messages: {len(self.messages_array)}')
def add_message(self, sender, time, text, tag=None):
message = (sender, time, text, tag)
self.messages.append(message)
self.messages_array = np.append(self.messages_array, [message], axis=0)
self.sender_array = np.append(self.sender_array, sender)
with open(self.filename, 'a') as f:
f.write(f'{sender}\t{time}\t{text}\t{tag}\n')
def predict_response_separate(self, query, sender, cache):
if self.messages_array is None:
print("Error: messages_array is None")
return None
sender_messages = self.messages_array[self.sender_array == sender]
if len(sender_messages) == 0:
print(f"No messages found for sender: {sender}")
return None
query_keywords = ' '.join(get_keywords(query, cache))
query_vector = self.tfidf.transform([query_keywords]).toarray()[0]
relevant_indices = self.index.get_nns_by_vector(query_vector, 1)
relevant_message = sender_messages[relevant_indices[0]]
next_message_index = np.where(self.sender_array != sender)[0][0]
if next_message_index < len(self.messages_array):
predicted_response = self.messages_array[next_message_index]
return tuple(predicted_response)
else:
return None
def get_relevant_messages(self, sender, query, N, cache, query_tag=None, n_threads=30, tag_boost=1.5):
if self.messages_array is None:
print("Error: messages_array is None")
return []
query_keywords = query.lower().split()
# Filter messages by sender, tag, and keywords in a single line
if query_tag:
sender_messages = self.messages_array[
(self.sender_array == sender) &
np.array([any(keyword in message.lower() for keyword in query_keywords) for message in self.messages_array[:, 2]])
]
else:
sender_messages = self.messages_array[
(self.sender_array == sender) &
np.array([any(keyword in message.lower() for keyword in query_keywords) for message in self.messages_array[:, 2]])
]
if len(sender_messages) == 0:
print(f"No messages found for sender: {sender} with the given keywords")
return []
else:
print(len(sender_messages))
def process_batch(batch, query_keywords, current_time, query_tag):
batch_keywords = [get_keywords(message[2], cache) for message in batch]
bm25 = BM25Okapi(batch_keywords)
bm25_scores = bm25.get_scores(query_keywords)
time_scores = 1 / (1 + (current_time - np.array([datetime.strptime(m[1], '%Y-%m-%d %H:%M:%S') for m in batch])).astype('timedelta64[D]').astype(int))
tag_scores = np.where(np.array([m[3] for m in batch]) == query_tag, tag_boost, 1)
combined_scores = 0.6 * np.array(bm25_scores) + 0.2 * time_scores + 0.2 * tag_scores
return combined_scores, batch
current_time = datetime.now()
batch_size = max(1, len(sender_messages) // n_threads)
batches = [sender_messages[i:i+batch_size] for i in range(0, len(sender_messages), batch_size)]
with ThreadPoolExecutor(max_workers=n_threads) as executor:
process_func = partial(process_batch, query_keywords=query_keywords, current_time=current_time, query_tag=query_tag)
results = list(executor.map(process_func, batches))
all_scores = np.concatenate([r[0] for r in results])
all_messages = np.concatenate([r[1] for r in results])
top_indices = np.argsort(all_scores)[-N:][::-1]
relevant_messages = all_messages[top_indices]
return relevant_messages.tolist()
def generate_response(self, query, sender, cache, query_tag=None):
relevant_messages = self.get_relevant_messages(sender, query, 5, cache, query_tag)
context = ' '.join([message[2] for message in relevant_messages])
generator = pipeline('text-generation', model='EleutherAI/gpt-neo-2.7B')
response = generator(f'{context} {query}', max_length=100, do_sample=True)[0]['generated_text']
response = response.split(query)[-1].strip()
return response
# Usage example remains the same
'''
# Usage example
db = ChatDatabase('messages.txt')
# Example 1: Get relevant messages
query = 'Goodnight.'
sender = 'Alice'
N = 10
cache = {}
query_tag = 'evening'
relevant_messages = db.get_relevant_messages(sender, query, N, cache, query_tag)
print("Relevant messages:")
for message in relevant_messages:
print(f"Sender: {message[0]}, Time: {message[1]}, Tag: {message[3]}")
print(f"Message: {message[2][:100]}...")
print()
# Example 2: Predict response (using the original method)
query = "what was that?"
sender = 'David'
db.build_index_separate(cache)
predicted_response = db.predict_response_separate(query, sender, cache)
print("\nPredicted response:")
if predicted_response is not None:
print(f"Sender: {predicted_response[0]}, Time: {predicted_response[1]}, Tag: {predicted_response[3]}")
print(f"Message: {predicted_response[2][:100]}...")
else:
print('No predicted response found')
# Example 3: Generate response
query = "Let's plan a trip"
sender = 'Alice'
query_tag = 'travel'
generated_response = db.generate_response(query, sender, cache, query_tag)
print("\nGenerated response:")
print(generated_response)
'''