nicholasKluge commited on
Commit
8d1df0b
·
1 Parent(s): 4237769

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -27
app.py CHANGED
@@ -1,22 +1,31 @@
1
  import time
2
  import torch
3
  import gradio as gr
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
  model_id = "nicholasKluge/Aira-Instruct-PT-560M"
 
 
7
  token = "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"
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
 
 
 
11
 
12
- if device == "cuda":
13
- model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=token, load_in_8bit=True)
14
-
15
- else:
16
- model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=token)
17
 
18
- tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
19
  model.to(device)
 
 
 
 
 
 
 
20
 
21
  intro = """
22
  ## O que é `Aira`?
@@ -25,11 +34,19 @@ intro = """
25
 
26
  Desenvolvemos os nossos chatbots de conversação de domínio aberto através da geração de texto condicional/ajuste fino por instruções. Esta abordagem tem muitas limitações. Apesar de podermos criar um chatbot capaz de responder a perguntas sobre qualquer assunto, é difícil forçar o modelo a produzir respostas de boa qualidade. E por boa, queremos dizer texto **factual** e **não tóxico**. Isto leva-nos a dois dos problemas mais comuns quando lidando com modelos generativos utilizados em aplicações de conversação:
27
 
 
 
28
  🤥 Modelos generativos podem perpetuar a geração de conteúdo pseudo-informativo, ou seja, informações falsas que podem parecer verdadeiras.
29
 
30
  🤬 Em certos tipos de tarefas, modelos generativos podem produzir conteúdo prejudicial e discriminatório inspirado em estereótipos históricos.
31
-
 
 
32
  `Aira` destina-se apenas à investigação académica. Para mais informações, visite o nosso [HuggingFace models](https://huggingface.co/nicholasKluge) para ver como desenvolvemos `Aira`.
 
 
 
 
33
  """
34
 
35
  disclaimer = """
@@ -39,19 +56,20 @@ Se desejar apresentar uma reclamação sobre qualquer mensagem produzida por `Ai
39
  """
40
 
41
  with gr.Blocks(theme='freddyaboulton/dracula_revamped') as demo:
42
-
43
- gr.Markdown("""<h1><center>Aira Demo (Portuguese) 🤓💬</h1></center>""")
44
  gr.Markdown(intro)
45
-
46
  chatbot = gr.Chatbot(label="Aira").style(height=500)
 
47
 
48
- with gr.Accordion(label="Parâmetros ⚙️", open=False):
49
- top_k = gr.Slider( minimum=10, maximum=100, value=15, step=5, interactive=True, label="Top-k",)
50
- top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.15, step=0.05, interactive=True, label="Top-p",)
51
- temperature = gr.Slider( minimum=0.001, maximum=2.0, value=0.1, step=0.1, interactive=True, label="Temperatura",)
52
- max_length = gr.Slider( minimum=10, maximum=500, value=100, step=10, interactive=True, label="Comprimento Máximo",)
53
-
54
- msg = gr.Textbox(label="Escreva uma pergunta para Aira ...", placeholder="Olá Aira, como você vai?")
55
 
56
  clear = gr.Button("Limpar Conversa 🧹")
57
  gr.Markdown(disclaimer)
@@ -59,23 +77,66 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped') as demo:
59
  def user(user_message, chat_history):
60
  return gr.update(value=user_message, interactive=True), chat_history + [["👤 " + user_message, None]]
61
 
62
- def generate_response(user_msg, top_p, temperature, top_k, max_length, chat_history):
63
 
64
- inputs = tokenizer(tokenizer.bos_token + user_msg + tokenizer.eos_token, return_tensors="pt").to(device)
65
 
66
  generated_response = model.generate(**inputs,
67
  bos_token_id=tokenizer.bos_token_id,
68
  pad_token_id=tokenizer.pad_token_id,
69
  eos_token_id=tokenizer.eos_token_id,
 
70
  do_sample=True,
71
- early_stopping=True,
72
- top_k=top_k,
73
  max_length=max_length,
74
  top_p=top_p,
75
- temperature=temperature,
76
- num_return_sequences=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- bot_message = tokenizer.decode(generated_response[0], skip_special_tokens=True).replace(user_msg, "")
 
79
 
80
  chat_history[-1][1] = "🤖 "
81
  for character in bot_message:
@@ -84,10 +145,10 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped') as demo:
84
  yield chat_history
85
 
86
  response = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
87
- generate_response, [msg, top_p, temperature, top_k, max_length, chatbot], chatbot
88
  )
89
  response.then(lambda: gr.update(interactive=True), None, [msg], queue=False)
90
- msg.submit(lambda x: gr.update(value=''), [],[msg])
91
  clear.click(lambda: None, None, chatbot, queue=False)
92
 
93
  demo.queue()
 
1
  import time
2
  import torch
3
  import gradio as gr
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
5
 
6
  model_id = "nicholasKluge/Aira-Instruct-PT-560M"
7
+ rewardmodel_id = "nicholasKluge/RewardModelPT"
8
+ toxicitymodel_id = "nicholasKluge/ToxicityModelPT"
9
  token = "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"
10
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
+ model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=token)
14
+ rewardModel = AutoModelForSequenceClassification.from_pretrained(rewardmodel_id, use_auth_token=token)
15
+ toxicityModel = AutoModelForSequenceClassification.from_pretrained(toxicitymodel_id, use_auth_token=token)
16
 
17
+ model.eval()
18
+ rewardModel.eval()
19
+ toxicityModel.eval()
 
 
20
 
 
21
  model.to(device)
22
+ rewardModel.to(device)
23
+ toxicityModel.to(device)
24
+
25
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
26
+ rewardTokenizer = AutoTokenizer.from_pretrained(rewardmodel_id, use_auth_token=token)
27
+ toxiciyTokenizer = AutoTokenizer.from_pretrained(toxicitymodel_id, use_auth_token=token)
28
+
29
 
30
  intro = """
31
  ## O que é `Aira`?
 
34
 
35
  Desenvolvemos os nossos chatbots de conversação de domínio aberto através da geração de texto condicional/ajuste fino por instruções. Esta abordagem tem muitas limitações. Apesar de podermos criar um chatbot capaz de responder a perguntas sobre qualquer assunto, é difícil forçar o modelo a produzir respostas de boa qualidade. E por boa, queremos dizer texto **factual** e **não tóxico**. Isto leva-nos a dois dos problemas mais comuns quando lidando com modelos generativos utilizados em aplicações de conversação:
36
 
37
+ ## Limitações
38
+
39
  🤥 Modelos generativos podem perpetuar a geração de conteúdo pseudo-informativo, ou seja, informações falsas que podem parecer verdadeiras.
40
 
41
  🤬 Em certos tipos de tarefas, modelos generativos podem produzir conteúdo prejudicial e discriminatório inspirado em estereótipos históricos.
42
+
43
+ ## Uso Intendido
44
+
45
  `Aira` destina-se apenas à investigação académica. Para mais informações, visite o nosso [HuggingFace models](https://huggingface.co/nicholasKluge) para ver como desenvolvemos `Aira`.
46
+
47
+ ## Como essa demo funciona?
48
+
49
+ Esta demonstração utiliza um [`modelo de recompensa`](https://huggingface.co/nicholasKluge/RewardModel) e um [`modelo de toxicidade`](https://huggingface.co/nicholasKluge/ToxicityModel) para avaliar a pontuação de cada resposta candidata, considerando o seu alinhamento com a mensagem do utilizador e o seu nível de toxicidade. A função de geração organiza as respostas candidatas por ordem da sua pontuação de recompensa e elimina as respostas consideradas tóxicas ou nocivas. Posteriormente, a função de geração devolve a resposta candidata com a pontuação mais elevada que ultrapassa o limiar de segurança, ou uma mensagem pré-estabelecida se não forem identificados candidatos seguros.
50
  """
51
 
52
  disclaimer = """
 
56
  """
57
 
58
  with gr.Blocks(theme='freddyaboulton/dracula_revamped') as demo:
59
+
60
+ gr.Markdown("""<h1><center>Aira Demo 🤓💬</h1></center>""")
61
  gr.Markdown(intro)
62
+
63
  chatbot = gr.Chatbot(label="Aira").style(height=500)
64
+ msg = gr.Textbox(label="Write a question or comment to Aira ...", placeholder="Hi Aira, how are you?")
65
 
66
+ with gr.Accordion(label="Parâmetros ⚙️", open=True):
67
+ safety = gr.Radio(["On", "Off"], label="Proteção 🛡️", value="On", info="Ajuda a prevenir o modelo de gerar conteúdo tóxico.")
68
+ top_k = gr.Slider(minimum=10, maximum=100, value=50, step=5, interactive=True, label="Top-k", info="Controla o número de tokens de maior probabilidade a considerar em cada passo.")
69
+ top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.50, step=0.05, interactive=True, label="Top-p", info="Controla a probabilidade cumulativa dos tokens gerados.")
70
+ temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.1, step=0.1, interactive=True, label="Temperatura", info="Controla a aleatoriedade dos tokens gerados.")
71
+ max_length = gr.Slider(minimum=10, maximum=500, value=100, step=10, interactive=True, label="Comprimento Máximo", info="Controla o comprimento máximo do texto gerado.")
72
+ smaple_from = gr.Slider(minimum=2, maximum=10, value=2, step=1, interactive=True, label="Amostragem por Rejeição", info="Controla o número de gerações a partir das quais o modelo de recompensa irá selecionar.")
73
 
74
  clear = gr.Button("Limpar Conversa 🧹")
75
  gr.Markdown(disclaimer)
 
77
  def user(user_message, chat_history):
78
  return gr.update(value=user_message, interactive=True), chat_history + [["👤 " + user_message, None]]
79
 
80
+ def generate_response(user_msg, top_p, temperature, top_k, max_length, smaple_from, safety, chat_history):
81
 
82
+ inputs = tokenizer(tokenizer.bos_token + user_msg + tokenizer.eos_token, return_tensors="pt").to(model.device)
83
 
84
  generated_response = model.generate(**inputs,
85
  bos_token_id=tokenizer.bos_token_id,
86
  pad_token_id=tokenizer.pad_token_id,
87
  eos_token_id=tokenizer.eos_token_id,
88
+ repetition_penalty=1.8,
89
  do_sample=True,
90
+ early_stopping=True,
91
+ top_k=top_k,
92
  max_length=max_length,
93
  top_p=top_p,
94
+ temperature=temperature,
95
+ num_return_sequences=smaple_from)
96
+
97
+ decoded_text = [tokenizer.decode(tokens, skip_special_tokens=True).replace(user_msg, "") for tokens in generated_response]
98
+
99
+ rewards = list()
100
+ toxicities = list()
101
+
102
+ for text in decoded_text:
103
+ reward_tokens = rewardTokenizer(user_msg, text,
104
+ truncation=True,
105
+ max_length=512,
106
+ return_token_type_ids=False,
107
+ return_tensors="pt",
108
+ return_attention_mask=True)
109
+
110
+ reward_tokens.to(rewardModel.device)
111
+
112
+ reward = rewardModel(**reward_tokens)[0].item()
113
+
114
+ toxicity_tokens = toxiciyTokenizer(user_msg + " " + text,
115
+ truncation=True,
116
+ max_length=512,
117
+ return_token_type_ids=False,
118
+ return_tensors="pt",
119
+ return_attention_mask=True)
120
+
121
+ toxicity_tokens.to(toxicityModel.device)
122
+
123
+ toxicity = toxicityModel(**toxicity_tokens)[0].item()
124
+
125
+ rewards.append(reward)
126
+ toxicities.append(toxicity)
127
+
128
+ toxicity_threshold = 5
129
+
130
+ ordered_generations = sorted(zip(decoded_text, rewards, toxicities), key=lambda x: x[1], reverse=True)
131
+
132
+ if safety == "On":
133
+ ordered_generations = [(x, y, z) for (x, y, z) in ordered_generations if z >= toxicity_threshold]
134
+
135
+ if len(ordered_generations) == 0:
136
+ bot_message = """Peço desculpa pelo incómodo, mas parece que não foi possível identificar respostas adequadas que cumpram as nossas normas de segurança. Infelizmente, isto indica que o conteúdo gerado pode conter elementos de toxicidade ou pode não ajudar a responder à sua mensagem. A sua opinião é valiosa para nós e esforçamo-nos por garantir uma conversa segura e construtiva. Não hesite em fornecer mais pormenores ou colocar quaisquer outras questões, e farei o meu melhor para o ajudar."""
137
 
138
+ else:
139
+ bot_message = ordered_generations[0][0]
140
 
141
  chat_history[-1][1] = "🤖 "
142
  for character in bot_message:
 
145
  yield chat_history
146
 
147
  response = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
148
+ generate_response, [msg, top_p, temperature, top_k, max_length, smaple_from, safety, chatbot], chatbot
149
  )
150
  response.then(lambda: gr.update(interactive=True), None, [msg], queue=False)
151
+ msg.submit(lambda x: gr.update(value=''), None,[msg])
152
  clear.click(lambda: None, None, chatbot, queue=False)
153
 
154
  demo.queue()