Forensic-Type2 / app.py
DrBardiaGh's picture
Update app.py
03c17d1 verified
import gradio as gr
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
import openai
import os
# --------------- STEP 1: Load the free STT model from Hugging Face ---------------
# We use a Persian Wav2Vec2 model
# Example: 'm3hrdadfi/wav2vec2-large-xlsr-persian-v2'
stt_model_name = "m3hrdadfi/wav2vec2-large-xlsr-persian-v2"
tokenizer = Wav2Vec2Tokenizer.from_pretrained(stt_model_name)
stt_model = Wav2Vec2ForCTC.from_pretrained(stt_model_name)
# --------------- STEP 2: Configure OpenAI API for GPT-4o mini (fine-tuned) ---------------
# You need your OpenAI API key. Let's read from environment variable for safety.
openai.api_key = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY_HERE")
# If you'd rather hardcode, replace "YOUR_API_KEY_HERE" with your actual key.
# But do not commit your real key publicly.
OPENAI_MODEL_NAME = "gpt-4o-mini" # The exact model name you said you'd use (example)
# --------------- STT Function ---------------
def speech_to_text(audio):
"""
audio is a tuple (sample_rate, numpy_data) from Gradio's microphone or an audio file.
Convert Persian speech to text using wav2vec2 model offline.
"""
if audio is None:
return ""
sample_rate, data = audio
# Convert to torch tensor
input_values = tokenizer(data, return_tensors="pt", sampling_rate=sample_rate).input_values
with torch.no_grad():
logits = stt_model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = tokenizer.decode(predicted_ids[0])
# transcription might be uppercase and missing punctuation. We just return it raw.
return transcription.lower()
# --------------- Correction Function with GPT-4o mini ---------------
def correct_text_with_gpt(text):
"""
Send the text to OpenAI GPT-4o mini for correction/improvement.
The user wants a formal tone consistent with medical/legal usage.
Fine-tuned model is presumably specialized for that domain.
"""
if not text.strip():
return ""
system_message = (
"You are a specialized model that receives Persian text and corrects it with a formal tone, "
"particularly for medical/legal context. Make the text coherent, unify spacing, and punctuation, "
"and ensure a formal writing style. Do not add commentary. Just provide the corrected text."
)
user_message = f"متن خام: {text}\n\nلطفاً متن فوق را در یک پاراگراف رسمی و اداری اصلاح کن."
# We'll call the ChatCompletion API if GPT-4o mini is chat-based.
# If it's a completion-based model, code differs. But let's assume it's chat-based for simplicity.
response = openai.ChatCompletion.create(
model=OPENAI_MODEL_NAME,
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": user_message}
],
temperature=0.2,
max_tokens=1024
)
corrected = response.choices[0].message.content
return corrected.strip()
# --------------- Gradio Interface ---------------
# We want:
# - A widget to record or upload audio
# - A box that shows transcribed text
# - A button "اصلاح متن" that calls GPT-4o mini
# - A final output box
def process_audio_and_correct(audio):
# Step A: convert speech to text
raw_text = speech_to_text(audio)
# Step B: correct with GPT
final_text = correct_text_with_gpt(raw_text)
return raw_text, final_text
with gr.Blocks() as demo:
gr.Markdown("# وب‌سایت تبدیل گفتار به متن (فارسی) + اصلاح متن با GPT-4o mini")
gr.Markdown("در اینجا می‌توانید یک وویس فارسی ضبط یا فایل صوتی آپلود کنید. سپس متن استخراج‌شده و اصلاح‌شده را دریافت کنید.")
with gr.Row():
audio_in = gr.Audio(source="microphone", type="numpy", label="ضبط صدا یا آپلود فایل (رایگان)")
transcribed_text = gr.Textbox(label="متن خام (خروجی مرحله اول - رایگان)", lines=3)
corrected_text = gr.Textbox(label="متن اصلاح‌شده (با GPT-4o mini - پولی)", lines=3)
btn_process = gr.Button("تبدیل و اصلاح")
def audio_workflow(audio):
# step 1: stt
raw_text = speech_to_text(audio)
transcribed_text.update(raw_text)
# step 2: correction
corrected = correct_text_with_gpt(raw_text)
corrected_text.update(corrected)
return raw_text, corrected
btn_process.click(fn=process_audio_and_correct, inputs=audio_in, outputs=[transcribed_text, corrected_text])
demo.launch()