import gradio as gr import json from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments, AutoTokenizer from datasets import Dataset import shutil import os # Load tokenized data def load_data(tokenized_file): with open(tokenized_file.name, 'r') as f: tokenized_data = json.load(f) return tokenized_data # Fine-tune the model def fine_tune_model(tokenized_file, progress=gr.Progress()): tokenized_data = load_data(tokenized_file) # Convert tokenized data to Dataset dataset = Dataset.from_dict(tokenized_data) # Split the dataset into train and validation sets tokenized_datasets = dataset.train_test_split(test_size=0.2) model = AutoModelForSequenceClassification.from_pretrained('anferico/bert-for-patents', num_labels=2) tokenizer = AutoTokenizer.from_pretrained('anferico/bert-for-patents') training_args = TrainingArguments( output_dir='./results', num_train_epochs=3, per_device_train_batch_size=8, # Reduce batch size per_device_eval_batch_size=16, gradient_accumulation_steps=2, # Use gradient accumulation fp16=True, # Enable mixed precision training warmup_steps=500, weight_decay=0.01, logging_dir='./logs', logging_steps=10, evaluation_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, ) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_datasets['train'], eval_dataset=tokenized_datasets['test'], ) progress(0.5, "Fine-tuning the model...") trainer.train() progress(1.0, "Fine-tuning complete.") model.save_pretrained('./fine_tuned_patentbert') tokenizer.save_pretrained('./fine_tuned_patentbert') # Create a zip file of the fine-tuned model shutil.make_archive('fine_tuned_patentbert', 'zip', './fine_tuned_patentbert') return "Model fine-tuned and saved successfully. Download the model using the link below.", "fine_tuned_patentbert.zip" # Create Gradio interface iface = gr.Interface( fn=fine_tune_model, inputs=[ gr.File(label="Upload Tokenized Data JSON") ], outputs=[ gr.Textbox(label="Processing Information"), gr.File(label="Download Fine-Tuned Model") ], title="Fine-Tune Patent BERT Model", description="Upload tokenized JSON file to fine-tune the BERT model.", live=True # Enable live updates for progress ) # Launch the interface iface.launch()