from typing import Callable import gradio as gr from src.scraper.generic_scraper import GenericScraper if gr.NO_RELOAD: import numpy as np from src.model import BaseTransferLearningModel DEVICE = 'cpu' MODELS = [ # ( # 'bert-model_1950', # lambda: BaseTransferLearningModel( # 'bert-base-uncased', # [('linear', ['in', 'out']), ('softmax')], # 2, # device=DEVICE, # state_dict='src/ckpt/bert-model_1950.pt', # ), # ), # ( # 'bert-model_2000', # lambda: BaseTransferLearningModel( # 'bert-base-uncased', # [('linear', ['in', 'out']), ('softmax')], # 2, # device=DEVICE, # state_dict='src/ckpt/bert-model_2000.pt', # ), # ), # ( # 'deberta-base-model_1100', # lambda: BaseTransferLearningModel( # 'microsoft/deberta-base', # [('linear', ['in', 'out']), ('softmax')], # 2, # device=DEVICE, # state_dict='src/ckpt/deberta-base-model_4400.pt', # ), # ), # ( # 'deberta-base-model_2000', # lambda: BaseTransferLearningModel( # 'microsoft/deberta-base', # [('linear', ['in', 'out']), ('softmax')], # 2, # device=DEVICE, # state_dict='src/ckpt/deberta-base-model_8000.pt', # ), # ), # ( # 'deberta-v3-base-model_1700', # lambda: BaseTransferLearningModel( # 'microsoft/deberta-v3-base', # [('linear', ['in', 'out']), ('softmax')], # 2, # device=DEVICE, # state_dict='src/ckpt/deberta-v3-base-model_3400.pt', # ), # ), ( 'deberta-v3-base-model_2000', lambda: BaseTransferLearningModel( 'microsoft/deberta-v3-base', [('linear', ['in', 'out']), ('softmax')], 2, device=DEVICE, state_dict='src/ckpt/deberta-v3-base-model_4000.pt', ), ), # ( # 'distilbert-model_1850', # lambda: BaseTransferLearningModel( # 'distilbert-base-uncased', # [('linear', ['in', 'out']), ('softmax')], # 2, # device=DEVICE, # state_dict='src/ckpt/distilbert-model_1850.pt', # ), # ), # ( # 'distilbert-model_2000', # lambda: BaseTransferLearningModel( # 'distilbert-base-uncased', # [('linear', ['in', 'out']), ('softmax')], # 2, # device=DEVICE, # state_dict='src/ckpt/distilbert-model_2000.pt', # ), # ), # ( # 'roberta-base-model_1250', # lambda: BaseTransferLearningModel( # 'FacebookAI/roberta-base', # [('linear', ['in', 'out']), ('softmax')], # 2, # device=DEVICE, # state_dict='src/ckpt/roberta-base-model_1250.pt', # ), # ), # ( # 'roberta-base-model_2000', # lambda: BaseTransferLearningModel( # 'FacebookAI/roberta-base', # [('linear', ['in', 'out']), ('softmax')], # 2, # device=DEVICE, # state_dict='src/ckpt/roberta-base-model_2000.pt', # ), # ), ] class WebUI: def __init__( self, models: list[(str, Callable)] = [], device: str = 'cpu', debug: bool = False, ) -> None: self.models = models self.device = device self.is_ready = False self.model = self.models[0][1]() self.is_ready = True self.scraper = GenericScraper() self.debug = debug def _change_model(self, idx: int) -> str: if gr.NO_RELOAD: try: print(self.models[idx]) self.is_ready = False del self.model self.model = self.models[idx][1]() self.is_ready = True print('done loading') except Exception as e: print(e) gr.Error(e) return self.models[idx][0] def _predict(self, text: str) -> str: print(text) if self.is_ready == False: return 'Model is not yet ready!' output = self.model.predict(text, self.device).detach().cpu().numpy()[0] if self.debug: return f'Fake: {output[0]:.10f}, Real: {output[1]:.10f}' return f'We think that this is a {"fake" if output[0] > output[1] else "real"} news article with {max(output[0], output[1]) * 100:.2f}% certainty.' def _scrape(self, url: str) -> str: try: return self.scraper.scrape(url) except Exception as e: return str(e) def get_ui(self) -> None: with gr.Blocks() as ui: with gr.Row(): with gr.Column(): t_url = gr.Textbox(label='URL') with gr.Row(): btn_scrape_reset = gr.ClearButton( value='Reset', components=[ t_url, ], ) btn_scrape = gr.Button(value='Get From URL', variant='primary') t_inp = gr.Textbox(label='Input') with gr.Row(): btn_reset = gr.ClearButton( value='Reset', components=[ t_inp, ], ) btn_submit = gr.Button(value='Submit', variant='primary') with gr.Column(): if self.debug: ddl_model = gr.Dropdown( label='Model', choices=[model[0] for model in self.models], value=self._change_model(0), type='index', interactive=True, filterable=True, ) t_out = gr.Textbox(label='Output') if self.debug: ddl_model.change(fn=self._change_model, inputs=ddl_model) btn_scrape.click(fn=self._scrape, inputs=t_url, outputs=t_inp) btn_submit.click(fn=self._predict, inputs=t_inp, outputs=t_out) return ui webui = WebUI(models=MODELS, device=DEVICE).get_ui() if __name__ == '__main__': webui.launch()