article_writer / app.py
minko186's picture
Update app.py
31e5c35 verified
raw
history blame
3.16 kB
import gradio as gr
# from humanize import paraphrase_text
from gradio_client import Client
from ai_generate import generate
client = Client("polygraf-ai/Humanizer")
def humanize(
text,
model,
temperature=1.2,
repetition_penalty=1,
top_k=50,
length_penalty=1,
):
ai_text = generate(f"Write an article about the topic: {text}")
print(f"AI Generated: {ai_text}")
ai_text = ai_text["choices"][0]["message"]["content"]
result = client.predict(
text=text,
model_name=model,
temperature=temperature,
repetition_penalty=repetition_penalty,
top_k=top_k,
length_penalty=length_penalty,
api_name="/paraphrase_text"
)
return result
with gr.Blocks() as demo:
gr.Markdown("# Polygraf Writer")
with gr.Row():
with gr.Column(scale=0.7):
gr.Markdown("## Enter a topic to write an article about:")
input_topic = gr.Textbox(label="Topic")
model_dropdown = gr.Radio(
choices=[
"Base Model",
"Large Model",
"XL Model",
"XL Law Model",
"XL Marketing Model",
"XL Child Style Model",
],
value="Large Model",
label="Select Model Version",
)
process_button = gr.Button("Humanize Text")
gr.Markdown("### Humanized article:")
output_label = gr.HTML(label="Output")
with gr.Column(scale=0.3):
temperature_slider = gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.2, label="Temperature")
gr.Markdown("Controls the randomness of the paraphrase. Higher values generate more varied text.")
top_k_slider = gr.Slider(
minimum=0,
maximum=300,
step=25,
value=50,
label="Top k",
)
gr.Markdown("Limits the number of top tokens considered during generation.")
repetition_penalty_slider = gr.Slider(
minimum=1.0,
maximum=2.0,
step=0.1,
value=1,
label="Repetition Penalty",
)
gr.Markdown("Penalizes repeated words to encourage diverse language use")
length_penalty_slider = gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.0,
label="Length Penalty",
)
gr.Markdown("Penalizes shorter outputs.")
process_button.click(
fn=humanize,
inputs=[
input_topic,
model_dropdown,
temperature_slider,
repetition_penalty_slider,
top_k_slider,
length_penalty_slider,
],
outputs=output_label,
)
if __name__ == "__main__":
demo.launch(demo.launch(server_name="0.0.0.0"))