|
from model import get_model |
|
from mapReduceSummarizer import get_map_reduce_chain |
|
from refineSummarizer import get_refine_chain |
|
from preprocess import prepare_for_summarize |
|
from transformers import AutoTokenizer |
|
from langchain.prompts import PromptTemplate |
|
from logging import getLogger |
|
import time |
|
|
|
logger = getLogger(__name__) |
|
|
|
|
|
|
|
def summarizer_init(model_name,model_type,api_key=None) -> None: |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
base_summarizer = get_model(model_type,model_name,api_key) |
|
return tokenizer,base_summarizer |
|
|
|
def summarizer_summarize(model_type,tokenizer, base_summarizer, text:str,summarizer_type = "map_reduce")->str: |
|
|
|
text = text |
|
|
|
|
|
|
|
length_type = "short" |
|
text_to_summarize = text |
|
|
|
if length_type =="short": |
|
|
|
logger.info("Processing Input Text less than 12000 Tokens") |
|
if model_type=="openai": |
|
llm = base_summarizer |
|
prompt = PromptTemplate.from_template( |
|
template="""Write a concise and complete summary in bullet points of the given annual report. |
|
Important: |
|
* Note that the summary should contain all important information and it should not contain any unwanted information. |
|
* Make sure to keep the summary as short as possible. And Summary should be in bullet points. Seperate each point with a new line. |
|
TEXT: {text} |
|
SUMMARY:""" |
|
) |
|
llm_chain = prompt|llm |
|
start = time.time() |
|
summary = llm_chain.invoke({"text": text_to_summarize}) |
|
end = time.time() |
|
print(f"Summary generation took {round((end-start),2)}s.") |
|
return summary,round((end-start),2) |
|
|
|
elif model_type == "local": |
|
pipe = base_summarizer |
|
start = time.time() |
|
|
|
|
|
input_text = text_to_summarize |
|
logger.info("Input", input_text) |
|
chat = [ |
|
{ "role": "user", |
|
"content": f""" |
|
SUmmarize this by focusing numerical importance sentences in the perspective of financial executive. input text: {input_text} |
|
""" }, |
|
] |
|
prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) |
|
inputs = tokenizer(prompt, |
|
return_tensors="pt", truncation=True).to('cpu') |
|
attention_mask = inputs["attention_mask"] |
|
approximate_tokens = int(len(text)//8) |
|
output = base_summarizer.generate(inputs['input_ids'], |
|
attention_mask = attention_mask, |
|
top_k=10, max_new_tokens=approximate_tokens+25, |
|
pad_token_id = tokenizer.eos_token_id) |
|
|
|
base_summary = tokenizer.batch_decode(output[:, inputs['input_ids'].shape[-1]:], skip_special_tokens=True) |
|
summary = base_summary[0] |
|
|
|
|
|
logger.info("Output: "+summary) |
|
|
|
|
|
end = time.time() |
|
print(f"Summary generation took {round((end-start),2)}s.") |
|
return summary,round((end-start),2) |
|
else: |
|
if summarizer_type == "refine": |
|
print("The text is too long, Running Refine Summarizer") |
|
llm_chain = get_refine_chain(base_summarizer,model_type) |
|
logger.info("Running Refine Chain for Summarization") |
|
start = time.time() |
|
summary = llm_chain.invoke({"input_documents": text_to_summarize}, return_only_outputs=True)['output_text'] |
|
end = time.time() |
|
print(f"Summary generation took {round((end-start),2)}s.") |
|
return summary,round((end-start),2) |
|
|
|
|
|
else: |
|
print("The text is too long, Running Map Reduce Summarizer") |
|
|
|
llm_chain = get_map_reduce_chain(base_summarizer,model_type=model_type) |
|
logger.info("Running Map Reduce Chain for Summarization") |
|
start = time.time() |
|
summary = llm_chain.invoke({"input_documents": text_to_summarize}, return_only_outputs=True)['output_text'] |
|
end = time.time() |
|
print(f"Summary generation took {round((end-start),2)}s.") |
|
return summary,round((end-start),2) |
|
|
|
|