shubhambhawsar commited on
Commit
2e6c3ed
·
verified ·
1 Parent(s): 360b342
Files changed (1) hide show
  1. app.py +111 -0
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import json
3
+ import torch
4
+ import torch
5
+ import torch.nn as nn
6
+ import gradio as gr
7
+ from torch.utils.data import Dataset
8
+ from transformers import BertTokenizer, BertModel
9
+
10
+ tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
11
+ model = BertModel.from_pretrained('bert-base-multilingual-cased')
12
+ device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ with open("word_to_index.pkl", "rb") as f:
15
+ word_to_index = pickle.load(f)
16
+
17
+ # Loading index to word mapping
18
+ with open("index_to_word.pkl", "rb") as f:
19
+ index_to_word = pickle.load(f)
20
+ numclass=len(index_to_word)
21
+
22
+
23
+ class DefinitionClassifier(nn.Module):
24
+ def __init__(self, model, tokenizer, num_classes):
25
+ super(DefinitionClassifier, self).__init__()
26
+
27
+ # Load pre-trained Indic BERT model
28
+ self.bert = model
29
+ self.tokenizer = tokenizer
30
+ self.num_classes = num_classes
31
+
32
+ # Define classification layer
33
+ self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)
34
+
35
+ def forward(self, input_ids, attention_mask):
36
+ # Forward pass through the BERT model
37
+ outputs = self.bert(input_ids, attention_mask)
38
+
39
+ # Extract the CLS token embeddings
40
+ cls_embeddings = outputs.pooler_output
41
+
42
+ # Pass the CLS embeddings through the classification layer
43
+ logits = self.classifier(cls_embeddings)
44
+
45
+ return logits
46
+ model_final = DefinitionClassifier(model=model,tokenizer=tokenizer,num_classes=numclass).to(device)
47
+ state_dict = torch.load("/data5/home/shubhambhaws2/project_drona1/project/modelmbert.pth")
48
+
49
+ # Load the state dictionary into the model
50
+ model_final.load_state_dict(state_dict)
51
+
52
+ def formate_story(story):
53
+ sen= story.split("\n")
54
+ c=0
55
+ sen[0]=sen[0]+"\n"
56
+ fin=[]
57
+ for i in sen:
58
+ if(c%2==0):
59
+ fin.append(i+"\n")
60
+ else:
61
+ fin.append(i)
62
+ c=c+1
63
+ return "".join(fin)
64
+
65
+ print("Model Loaded")
66
+
67
+ def generate_word(defination, k):
68
+ text = "the place where everyone can go"
69
+ inputs = tokenizer.encode_plus(
70
+ defination,
71
+ max_length=128,
72
+ padding='max_length',
73
+ truncation=True,
74
+ return_tensors='pt'
75
+ )
76
+ input_ids = inputs['input_ids'].to(device)
77
+ attention_mask = inputs['attention_mask'].to(device)
78
+ logits = model_final(input_ids, attention_mask)
79
+
80
+ probabilities = torch.softmax(logits, dim=1)
81
+ topk_probabilities, topk_indices = torch.topk(probabilities, k, dim=1)
82
+ pred = topk_indices.squeeze().cpu().numpy().tolist()
83
+ output = []
84
+
85
+ for i in pred:
86
+ output.append(index_to_word[i])
87
+
88
+ # Return a string with each word on a new line
89
+ return "\n".join(output)
90
+
91
+ def gradio_reset():
92
+ return None, None, 10
93
+
94
+ # New Title
95
+ title = """<h1 align='center'>Reverse Dictionary</a></h1>"""
96
+
97
+ with gr.Blocks() as demo:
98
+ gr.Markdown(title)
99
+ story = gr.Textbox(label="Input Description", lines=2)
100
+ k = gr.Slider(label="Total Output", minimum=1, maximum=100, step=1, value=10)
101
+
102
+ with gr.Row():
103
+ upload_button = gr.Button(value="Generate Word", interactive=True, variant="primary")
104
+ clear = gr.Button("Clear")
105
+ output = gr.Textbox(label="Output Hindi word", lines=20)
106
+
107
+ upload_button.click(generate_word, [story, k], [output])
108
+ clear.click(gradio_reset, [], [story, output,k])
109
+
110
+ demo.queue()
111
+ demo.launch(share=True)