Spaces:
Sleeping
Sleeping
from typing import Callable | |
import gradio as gr | |
from src.scraper.generic_scraper import GenericScraper | |
if gr.NO_RELOAD: | |
import numpy as np | |
from src.model import BaseTransferLearningModel | |
DEVICE = 'cpu' | |
MODELS = [ | |
# ( | |
# 'bert-model_1950', | |
# lambda: BaseTransferLearningModel( | |
# 'bert-base-uncased', | |
# [('linear', ['in', 'out']), ('softmax')], | |
# 2, | |
# device=DEVICE, | |
# state_dict='src/ckpt/bert-model_1950.pt', | |
# ), | |
# ), | |
# ( | |
# 'bert-model_2000', | |
# lambda: BaseTransferLearningModel( | |
# 'bert-base-uncased', | |
# [('linear', ['in', 'out']), ('softmax')], | |
# 2, | |
# device=DEVICE, | |
# state_dict='src/ckpt/bert-model_2000.pt', | |
# ), | |
# ), | |
# ( | |
# 'deberta-base-model_1100', | |
# lambda: BaseTransferLearningModel( | |
# 'microsoft/deberta-base', | |
# [('linear', ['in', 'out']), ('softmax')], | |
# 2, | |
# device=DEVICE, | |
# state_dict='src/ckpt/deberta-base-model_4400.pt', | |
# ), | |
# ), | |
# ( | |
# 'deberta-base-model_2000', | |
# lambda: BaseTransferLearningModel( | |
# 'microsoft/deberta-base', | |
# [('linear', ['in', 'out']), ('softmax')], | |
# 2, | |
# device=DEVICE, | |
# state_dict='src/ckpt/deberta-base-model_8000.pt', | |
# ), | |
# ), | |
# ( | |
# 'deberta-v3-base-model_1700', | |
# lambda: BaseTransferLearningModel( | |
# 'microsoft/deberta-v3-base', | |
# [('linear', ['in', 'out']), ('softmax')], | |
# 2, | |
# device=DEVICE, | |
# state_dict='src/ckpt/deberta-v3-base-model_3400.pt', | |
# ), | |
# ), | |
( | |
'deberta-v3-base-model_2000', | |
lambda: BaseTransferLearningModel( | |
'microsoft/deberta-v3-base', | |
[('linear', ['in', 'out']), ('softmax')], | |
2, | |
device=DEVICE, | |
state_dict='src/ckpt/deberta-v3-base-model_4000.pt', | |
), | |
), | |
# ( | |
# 'distilbert-model_1850', | |
# lambda: BaseTransferLearningModel( | |
# 'distilbert-base-uncased', | |
# [('linear', ['in', 'out']), ('softmax')], | |
# 2, | |
# device=DEVICE, | |
# state_dict='src/ckpt/distilbert-model_1850.pt', | |
# ), | |
# ), | |
# ( | |
# 'distilbert-model_2000', | |
# lambda: BaseTransferLearningModel( | |
# 'distilbert-base-uncased', | |
# [('linear', ['in', 'out']), ('softmax')], | |
# 2, | |
# device=DEVICE, | |
# state_dict='src/ckpt/distilbert-model_2000.pt', | |
# ), | |
# ), | |
# ( | |
# 'roberta-base-model_1250', | |
# lambda: BaseTransferLearningModel( | |
# 'FacebookAI/roberta-base', | |
# [('linear', ['in', 'out']), ('softmax')], | |
# 2, | |
# device=DEVICE, | |
# state_dict='src/ckpt/roberta-base-model_1250.pt', | |
# ), | |
# ), | |
# ( | |
# 'roberta-base-model_2000', | |
# lambda: BaseTransferLearningModel( | |
# 'FacebookAI/roberta-base', | |
# [('linear', ['in', 'out']), ('softmax')], | |
# 2, | |
# device=DEVICE, | |
# state_dict='src/ckpt/roberta-base-model_2000.pt', | |
# ), | |
# ), | |
] | |
class WebUI: | |
def __init__( | |
self, | |
models: list[(str, Callable)] = [], | |
device: str = 'cpu', | |
debug: bool = False, | |
) -> None: | |
self.models = models | |
self.device = device | |
self.is_ready = False | |
self.model = self.models[0][1]() | |
self.is_ready = True | |
self.scraper = GenericScraper() | |
self.debug = debug | |
def _change_model(self, idx: int) -> str: | |
if gr.NO_RELOAD: | |
try: | |
print(self.models[idx]) | |
self.is_ready = False | |
del self.model | |
self.model = self.models[idx][1]() | |
self.is_ready = True | |
print('done loading') | |
except Exception as e: | |
print(e) | |
gr.Error(e) | |
return self.models[idx][0] | |
def _predict(self, text: str) -> str: | |
print(text) | |
if self.is_ready == False: | |
return 'Model is not yet ready!' | |
output = self.model.predict(text, self.device).detach().cpu().numpy()[0] | |
if self.debug: | |
return f'Fake: {output[0]:.10f}, Real: {output[1]:.10f}' | |
return f'We think that this is a {"fake" if output[0] > output[1] else "real"} news article with {max(output[0], output[1]) * 100:.2f}% certainty.' | |
def _scrape(self, url: str) -> str: | |
try: | |
return self.scraper.scrape(url) | |
except Exception as e: | |
return str(e) | |
def get_ui(self) -> None: | |
with gr.Blocks() as ui: | |
with gr.Row(): | |
with gr.Column(): | |
t_url = gr.Textbox(label='URL') | |
with gr.Row(): | |
btn_scrape_reset = gr.ClearButton( | |
value='Reset', | |
components=[ | |
t_url, | |
], | |
) | |
btn_scrape = gr.Button(value='Get From URL', variant='primary') | |
t_inp = gr.Textbox(label='Input') | |
with gr.Row(): | |
btn_reset = gr.ClearButton( | |
value='Reset', | |
components=[ | |
t_inp, | |
], | |
) | |
btn_submit = gr.Button(value='Submit', variant='primary') | |
with gr.Column(): | |
if self.debug: | |
ddl_model = gr.Dropdown( | |
label='Model', | |
choices=[model[0] for model in self.models], | |
value=self._change_model(0), | |
type='index', | |
interactive=True, | |
filterable=True, | |
) | |
t_out = gr.Textbox(label='Output') | |
if self.debug: | |
ddl_model.change(fn=self._change_model, inputs=ddl_model) | |
btn_scrape.click(fn=self._scrape, inputs=t_url, outputs=t_inp) | |
btn_submit.click(fn=self._predict, inputs=t_inp, outputs=t_out) | |
return ui | |
webui = WebUI(models=MODELS, device=DEVICE).get_ui() | |
if __name__ == '__main__': | |
webui.launch() | |