Spaces:
Paused
Paused
File size: 8,571 Bytes
0fdb130 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
{
"cells": [
{
"cell_type": "code",
"execution_count": 16,
"id": "e3000a69",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of PegasusForConditionalGeneration were not initialized from the model checkpoint at human-centered-summarization/financial-summarization-pegasus and are newly initialized: ['model.decoder.embed_positions.weight', 'model.encoder.embed_positions.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"from transformers import PegasusTokenizer, PegasusForConditionalGeneration, TFPegasusForConditionalGeneration\n",
"from rouge import Rouge\n",
"\n",
"# Let's load the model and the tokenizer \n",
"model_name = \"human-centered-summarization/financial-summarization-pegasus\"\n",
"tokenizer = PegasusTokenizer.from_pretrained(model_name, local_files_only=True)\n",
"model = PegasusForConditionalGeneration.from_pretrained(model_name, local_files_only=True)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "6832cc0c",
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2 8\n",
"0.09230769142721895\n",
"0.02312138672190853\n",
"0.09230769142721895\n",
"----------------------------------------------------------------------\n",
"2 32\n",
"0.28767123031713265\n",
"0.11578947163656512\n",
"0.2465753399061738\n",
"----------------------------------------------------------------------\n",
"2 64\n",
"0.28767123031713265\n",
"0.11578947163656512\n",
"0.2465753399061738\n",
"----------------------------------------------------------------------\n",
"2 128\n",
"0.28767123031713265\n",
"0.11578947163656512\n",
"0.2465753399061738\n",
"----------------------------------------------------------------------\n",
"2 256\n",
"0.28767123031713265\n",
"0.11578947163656512\n",
"0.2465753399061738\n",
"----------------------------------------------------------------------\n",
"\n",
"5 8\n",
"0.09230769142721895\n",
"0.02312138672190853\n",
"0.09230769142721895\n",
"----------------------------------------------------------------------\n",
"5 32\n",
"0.28767123031713265\n",
"0.11578947163656512\n",
"0.2465753399061738\n",
"----------------------------------------------------------------------\n",
"5 64\n",
"0.28767123031713265\n",
"0.11578947163656512\n",
"0.2465753399061738\n",
"----------------------------------------------------------------------\n",
"5 128\n",
"0.28767123031713265\n",
"0.11578947163656512\n",
"0.2465753399061738\n",
"----------------------------------------------------------------------\n",
"5 256\n",
"0.28767123031713265\n",
"0.11578947163656512\n",
"0.2465753399061738\n",
"----------------------------------------------------------------------\n",
"\n",
"8 8\n",
"0.09230769142721895\n",
"0.02312138672190853\n",
"0.09230769142721895\n",
"----------------------------------------------------------------------\n",
"8 32\n",
"0.28767123031713265\n",
"0.11578947163656512\n",
"0.2465753399061738\n",
"----------------------------------------------------------------------\n",
"8 64\n",
"0.28767123031713265\n",
"0.11578947163656512\n",
"0.2465753399061738\n",
"----------------------------------------------------------------------\n",
"8 128\n",
"0.28767123031713265\n",
"0.11578947163656512\n",
"0.2465753399061738\n",
"----------------------------------------------------------------------\n",
"8 256\n",
"0.28767123031713265\n",
"0.11578947163656512\n",
"0.2465753399061738\n",
"----------------------------------------------------------------------\n",
"\n",
"12 8\n",
"0.09230769142721895\n",
"0.02312138672190853\n",
"0.09230769142721895\n",
"----------------------------------------------------------------------\n",
"12 32\n",
"0.28767123031713265\n",
"0.11578947163656512\n",
"0.2465753399061738\n",
"----------------------------------------------------------------------\n",
"12 64\n",
"0.28767123031713265\n",
"0.11578947163656512\n",
"0.2465753399061738\n",
"----------------------------------------------------------------------\n",
"12 128\n",
"0.28767123031713265\n",
"0.11578947163656512\n",
"0.2465753399061738\n",
"----------------------------------------------------------------------\n",
"12 256\n",
"0.28767123031713265\n",
"0.11578947163656512\n",
"0.2465753399061738\n",
"----------------------------------------------------------------------\n",
"\n",
"20 8\n",
"0.09230769142721895\n",
"0.02312138672190853\n",
"0.09230769142721895\n",
"----------------------------------------------------------------------\n",
"20 32\n",
"0.28767123031713265\n",
"0.11578947163656512\n",
"0.2465753399061738\n",
"----------------------------------------------------------------------\n",
"20 64\n",
"0.28767123031713265\n",
"0.11578947163656512\n",
"0.2465753399061738\n",
"----------------------------------------------------------------------\n",
"20 128\n",
"0.28767123031713265\n",
"0.11578947163656512\n",
"0.2465753399061738\n",
"----------------------------------------------------------------------\n",
"20 256\n",
"0.28767123031713265\n",
"0.11578947163656512\n",
"0.2465753399061738\n",
"----------------------------------------------------------------------\n",
"\n"
]
}
],
"source": [
"reference = \"National Commercial Bank (NCB), Saudi Arabia’s largest lender by assets, agreed to buy rival Samba Financial Group for $15 billion in the biggest banking takeover this year.NCB will pay 28.45 riyals ($7.58) for each Samba share, according to a statement on Sunday, valuing it at about 55.7 billion riyals. NCB will offer 0.739 new shares for each Samba share, at the lower end of the 0.736-0.787 ratio the banks set when they signed an initial framework agreement in June.The offer is a 3.5% premium to Samba’s Oct. 8 closing price of 27.50 riyals and about 24% higher than the level the shares traded at before the talks were made public. Bloomberg News first reported the merger discussions.The new bank will have total assets of more than $220 billion, creating the Gulf region’s third-largest lender. The entity’s $46 billion market capitalization nearly matches that of Qatar National Bank QPSC, which is still the Middle East’s biggest lender with about $268 billion of assets.\"\n",
"for num_beams in [2, 5, 8, 12, 20]:\n",
" for max_length in [8, 32, 64, 128, 256]:\n",
" print(num_beams, max_length)\n",
" input_ids = tokenizer(reference, return_tensors=\"pt\").input_ids\n",
"\n",
" # Generate the output (Here, we use beam search but you can also use any other strategy you like)\n",
" output = model.generate(\n",
" input_ids, \n",
" max_length=max_length, \n",
" num_beams=5, \n",
" early_stopping=True\n",
" )\n",
"\n",
" summary = tokenizer.decode(output[0], skip_special_tokens=True)\n",
" ROUGE = Rouge()\n",
" scores = ROUGE.get_scores(summary, reference)\n",
" for rouge, score in scores[-1].items():\n",
" print(score['f'])\n",
" print('-' * 70)\n",
" print()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|