shoukaku's picture
make the app more 'user friendly'
802b789
from typing import Callable
import gradio as gr
from src.scraper.generic_scraper import GenericScraper
if gr.NO_RELOAD:
import numpy as np
from src.model import BaseTransferLearningModel
DEVICE = 'cpu'
MODELS = [
# (
# 'bert-model_1950',
# lambda: BaseTransferLearningModel(
# 'bert-base-uncased',
# [('linear', ['in', 'out']), ('softmax')],
# 2,
# device=DEVICE,
# state_dict='src/ckpt/bert-model_1950.pt',
# ),
# ),
# (
# 'bert-model_2000',
# lambda: BaseTransferLearningModel(
# 'bert-base-uncased',
# [('linear', ['in', 'out']), ('softmax')],
# 2,
# device=DEVICE,
# state_dict='src/ckpt/bert-model_2000.pt',
# ),
# ),
# (
# 'deberta-base-model_1100',
# lambda: BaseTransferLearningModel(
# 'microsoft/deberta-base',
# [('linear', ['in', 'out']), ('softmax')],
# 2,
# device=DEVICE,
# state_dict='src/ckpt/deberta-base-model_4400.pt',
# ),
# ),
# (
# 'deberta-base-model_2000',
# lambda: BaseTransferLearningModel(
# 'microsoft/deberta-base',
# [('linear', ['in', 'out']), ('softmax')],
# 2,
# device=DEVICE,
# state_dict='src/ckpt/deberta-base-model_8000.pt',
# ),
# ),
# (
# 'deberta-v3-base-model_1700',
# lambda: BaseTransferLearningModel(
# 'microsoft/deberta-v3-base',
# [('linear', ['in', 'out']), ('softmax')],
# 2,
# device=DEVICE,
# state_dict='src/ckpt/deberta-v3-base-model_3400.pt',
# ),
# ),
(
'deberta-v3-base-model_2000',
lambda: BaseTransferLearningModel(
'microsoft/deberta-v3-base',
[('linear', ['in', 'out']), ('softmax')],
2,
device=DEVICE,
state_dict='src/ckpt/deberta-v3-base-model_4000.pt',
),
),
# (
# 'distilbert-model_1850',
# lambda: BaseTransferLearningModel(
# 'distilbert-base-uncased',
# [('linear', ['in', 'out']), ('softmax')],
# 2,
# device=DEVICE,
# state_dict='src/ckpt/distilbert-model_1850.pt',
# ),
# ),
# (
# 'distilbert-model_2000',
# lambda: BaseTransferLearningModel(
# 'distilbert-base-uncased',
# [('linear', ['in', 'out']), ('softmax')],
# 2,
# device=DEVICE,
# state_dict='src/ckpt/distilbert-model_2000.pt',
# ),
# ),
# (
# 'roberta-base-model_1250',
# lambda: BaseTransferLearningModel(
# 'FacebookAI/roberta-base',
# [('linear', ['in', 'out']), ('softmax')],
# 2,
# device=DEVICE,
# state_dict='src/ckpt/roberta-base-model_1250.pt',
# ),
# ),
# (
# 'roberta-base-model_2000',
# lambda: BaseTransferLearningModel(
# 'FacebookAI/roberta-base',
# [('linear', ['in', 'out']), ('softmax')],
# 2,
# device=DEVICE,
# state_dict='src/ckpt/roberta-base-model_2000.pt',
# ),
# ),
]
class WebUI:
def __init__(
self,
models: list[(str, Callable)] = [],
device: str = 'cpu',
debug: bool = False,
) -> None:
self.models = models
self.device = device
self.is_ready = False
self.model = self.models[0][1]()
self.is_ready = True
self.scraper = GenericScraper()
self.debug = debug
def _change_model(self, idx: int) -> str:
if gr.NO_RELOAD:
try:
print(self.models[idx])
self.is_ready = False
del self.model
self.model = self.models[idx][1]()
self.is_ready = True
print('done loading')
except Exception as e:
print(e)
gr.Error(e)
return self.models[idx][0]
def _predict(self, text: str) -> str:
print(text)
if self.is_ready == False:
return 'Model is not yet ready!'
output = self.model.predict(text, self.device).detach().cpu().numpy()[0]
if self.debug:
return f'Fake: {output[0]:.10f}, Real: {output[1]:.10f}'
return f'We think that this is a {"fake" if output[0] > output[1] else "real"} news article with {max(output[0], output[1]) * 100:.2f}% certainty.'
def _scrape(self, url: str) -> str:
try:
return self.scraper.scrape(url)
except Exception as e:
return str(e)
def get_ui(self) -> None:
with gr.Blocks() as ui:
with gr.Row():
with gr.Column():
t_url = gr.Textbox(label='URL')
with gr.Row():
btn_scrape_reset = gr.ClearButton(
value='Reset',
components=[
t_url,
],
)
btn_scrape = gr.Button(value='Get From URL', variant='primary')
t_inp = gr.Textbox(label='Input')
with gr.Row():
btn_reset = gr.ClearButton(
value='Reset',
components=[
t_inp,
],
)
btn_submit = gr.Button(value='Submit', variant='primary')
with gr.Column():
if self.debug:
ddl_model = gr.Dropdown(
label='Model',
choices=[model[0] for model in self.models],
value=self._change_model(0),
type='index',
interactive=True,
filterable=True,
)
t_out = gr.Textbox(label='Output')
if self.debug:
ddl_model.change(fn=self._change_model, inputs=ddl_model)
btn_scrape.click(fn=self._scrape, inputs=t_url, outputs=t_inp)
btn_submit.click(fn=self._predict, inputs=t_inp, outputs=t_out)
return ui
webui = WebUI(models=MODELS, device=DEVICE).get_ui()
if __name__ == '__main__':
webui.launch()