Lahiru Menikdiwela
fix max_new_toke_issue
09fa1cc
raw
history blame
4.76 kB
from model import get_model
from mapReduceSummarizer import get_map_reduce_chain
from refineSummarizer import get_refine_chain
from preprocess import prepare_for_summarize
from transformers import AutoTokenizer
from langchain.prompts import PromptTemplate
from logging import getLogger
import time
logger = getLogger(__name__)
def summarizer_init(model_name,model_type,api_key=None) -> None:
# model_type = model_type
tokenizer = AutoTokenizer.from_pretrained(model_name)
base_summarizer = get_model(model_type,model_name,api_key)
return tokenizer,base_summarizer
def summarizer_summarize(model_type,tokenizer, base_summarizer, text:str,summarizer_type = "map_reduce")->str:
# prompt = "SUmmarize this by focusing numerical importance sentences dont omit numerical sentences.Include all numerical details input text:"
text = text
#!!!!!!!!!!!!!!!!!!!Removed because map reduce is not suitable or take long time
# text_to_summarize,length_type = prepare_for_summarize(text,tokenizer)
length_type = "short"
text_to_summarize = text
if length_type =="short":
logger.info("Processing Input Text less than 12000 Tokens")
if model_type=="openai":
llm = base_summarizer
prompt = PromptTemplate.from_template(
template="""Write a concise and complete summary in bullet points of the given annual report.
Important:
* Note that the summary should contain all important information and it should not contain any unwanted information.
* Make sure to keep the summary as short as possible. And Summary should be in bullet points. Seperate each point with a new line.
TEXT: {text}
SUMMARY:"""
)
llm_chain = prompt|llm
start = time.time()
summary = llm_chain.invoke({"text": text_to_summarize})
end = time.time()
print(f"Summary generation took {round((end-start),2)}s.")
return summary,round((end-start),2)
elif model_type == "local":
pipe = base_summarizer
start = time.time()
#!!!!!!!!!!!!!!!!!!!!Changes to llama model
input_text = text_to_summarize
chat = [
{ "role": "user",
"content": f"""
SUmmarize this by focusing numerical importance sentences in the perspective of financial executive. input text: {input_text}
""" },
]
prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt,
return_tensors="pt", truncation=True).to('cpu')
attention_mask = inputs["attention_mask"]
approximate_tokens = int(len(text)//10)
output = base_summarizer.generate(inputs['input_ids'],
attention_mask = attention_mask,
top_k=10, max_new_tokens=approximate_tokens+25,
pad_token_id = tokenizer.eos_token_id)
base_summary = tokenizer.batch_decode(output[:, inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
summary = base_summary[0]
# summary = pipe(text_to_summarize)[0]['generated_text']
#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!1Changes finished for llama model
end = time.time()
print(f"Summary generation took {round((end-start),2)}s.")
return summary,round((end-start),2)
else:
if summarizer_type == "refine":
print("The text is too long, Running Refine Summarizer")
llm_chain = get_refine_chain(base_summarizer,model_type)
logger.info("Running Refine Chain for Summarization")
start = time.time()
summary = llm_chain.invoke({"input_documents": text_to_summarize}, return_only_outputs=True)['output_text']
end = time.time()
print(f"Summary generation took {round((end-start),2)}s.")
return summary,round((end-start),2)
else:
print("The text is too long, Running Map Reduce Summarizer")
llm_chain = get_map_reduce_chain(base_summarizer,model_type=model_type)
logger.info("Running Map Reduce Chain for Summarization")
start = time.time()
summary = llm_chain.invoke({"input_documents": text_to_summarize}, return_only_outputs=True)['output_text']
end = time.time()
print(f"Summary generation took {round((end-start),2)}s.")
return summary,round((end-start),2)