File size: 4,849 Bytes
99e744f
 
 
 
 
 
 
 
 
 
 
 
 
c293aab
 
 
 
 
99e744f
c293aab
173b5f1
 
 
 
 
 
 
99e744f
c293aab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173b5f1
 
 
76446bf
173b5f1
 
 
 
 
 
 
 
793459b
173b5f1
6238067
173b5f1
 
09fa1cc
173b5f1
 
 
 
 
 
76446bf
173b5f1
 
c293aab
 
 
 
 
 
 
 
 
 
 
 
 
99e744f
 
c293aab
 
99e744f
c293aab
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from model import get_model
from mapReduceSummarizer import get_map_reduce_chain
from refineSummarizer import get_refine_chain
from preprocess import prepare_for_summarize
from transformers import AutoTokenizer
from langchain.prompts import PromptTemplate
from logging import getLogger
import time

logger = getLogger(__name__)



def summarizer_init(model_name,model_type,api_key=None) -> None:
        # model_type = model_type
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        base_summarizer = get_model(model_type,model_name,api_key) 
        return tokenizer,base_summarizer

def summarizer_summarize(model_type,tokenizer, base_summarizer, text:str,summarizer_type = "map_reduce")->str:
    # prompt = "SUmmarize this by focusing numerical importance sentences dont omit numerical sentences.Include all numerical details input text:"
    text = text
    
    #!!!!!!!!!!!!!!!!!!!Removed because map reduce is not suitable or take long time
    # text_to_summarize,length_type = prepare_for_summarize(text,tokenizer)
    length_type = "short"
    text_to_summarize = text

    if length_type =="short":

        logger.info("Processing Input Text less than 12000 Tokens")
        if model_type=="openai":
            llm = base_summarizer
            prompt = PromptTemplate.from_template(
                template="""Write a concise and complete summary in bullet points of the given annual report.
                    Important:
                    * Note that the summary should contain all important information and it should not contain any unwanted information.
                    * Make sure to keep the summary as short as possible. And Summary should be in bullet points. Seperate each point with a new line.
                    TEXT: {text}
                    SUMMARY:"""
            )
            llm_chain = prompt|llm
            start = time.time()
            summary =  llm_chain.invoke({"text": text_to_summarize})
            end = time.time()
            print(f"Summary generation took {round((end-start),2)}s.")
            return summary,round((end-start),2)
        
        elif model_type == "local":
            pipe = base_summarizer
            start = time.time()

            #!!!!!!!!!!!!!!!!!!!!Changes to llama model 
            input_text = text_to_summarize
            logger.info("Input", input_text)
            chat = [
            { "role": "user", 
            "content": f"""
            SUmmarize this by focusing numerical importance sentences in the perspective of financial executive. input text: {input_text}
            """ },
            ]
            prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
            inputs = tokenizer(prompt, 
                      return_tensors="pt", truncation=True).to('cpu')
            attention_mask = inputs["attention_mask"]
            approximate_tokens = int(len(text)//8)
            output = base_summarizer.generate(inputs['input_ids'], 
                                    attention_mask = attention_mask,
                                    top_k=10, max_new_tokens=approximate_tokens+25,
                                    pad_token_id = tokenizer.eos_token_id)

            base_summary = tokenizer.batch_decode(output[:, inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
            summary = base_summary[0]
            # summary = pipe(text_to_summarize)[0]['generated_text']
            #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!1Changes finished for llama model
            logger.info("Output: "+summary)


            end = time.time()
            print(f"Summary generation took {round((end-start),2)}s.")
            return summary,round((end-start),2)
    else:
        if summarizer_type == "refine":
            print("The text is too long, Running Refine Summarizer")
            llm_chain = get_refine_chain(base_summarizer,model_type)
            logger.info("Running Refine Chain for Summarization")
            start = time.time()
            summary = llm_chain.invoke({"input_documents": text_to_summarize}, return_only_outputs=True)['output_text']
            end = time.time()
            print(f"Summary generation took {round((end-start),2)}s.")
            return summary,round((end-start),2)


        else: 
            print("The text is too long, Running Map Reduce Summarizer")
            
            llm_chain = get_map_reduce_chain(base_summarizer,model_type=model_type)
            logger.info("Running Map Reduce Chain for Summarization")
            start = time.time()
            summary = llm_chain.invoke({"input_documents": text_to_summarize}, return_only_outputs=True)['output_text']
            end = time.time()
            print(f"Summary generation took {round((end-start),2)}s.")
            return summary,round((end-start),2)