Abigail99216's picture
Update app.py
e572754 verified
raw
history blame
5.52 kB
from transformers import pipeline
import gradio as gr
import numpy as np
import time
import json
import os
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from dotenv import load_dotenv
import logging
# 设置日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
load_dotenv()
zhipuai_api_key = os.getenv("ZHIPUAI_API_KEY")
# 使用中文优化的模型
try:
transcriber = pipeline("automatic-speech-recognition", model="ydshieh/wav2vec2-large-xlsr-53-chinese-zh-cn-gpt")
logging.info("Chinese ASR model loaded successfully")
except Exception as e:
logging.error(f"Error loading Chinese ASR model: {e}")
transcriber = None
# 初始化对话记录
conversation = []
current_speaker = "患者"
def transcribe(audio):
global current_speaker
if audio is None:
logging.info("No audio input received")
return json.dumps({"error": "No audio input received"})
try:
logging.info(f"Audio input received: {type(audio)}")
if isinstance(audio, tuple):
sr, y = audio
else:
y = audio
sr = 16000 # 假设采样率为16kHz,如果不是,请相应调整
logging.info(f"Sample rate: {sr}, Audio data shape: {y.shape}")
# 转换为单声道
if y.ndim > 1:
y = y.mean(axis=1)
logging.info(f"Audio data shape after conversion: {y.shape}")
y = y.astype(np.float32)
y /= np.max(np.abs(y))
# 使用中文进行转录
if transcriber is not None:
logging.info("Starting transcription")
try:
result = transcriber({"sampling_rate": sr, "raw": y}, generate_kwargs={"language": "chinese"})
text = result["text"].strip()
logging.info(f"Transcription result: {text}")
except Exception as e:
logging.error(f"Error during transcription: {e}", exc_info=True)
return json.dumps({"error": f"Transcription error: {str(e)}"})
else:
logging.error("Transcriber not initialized")
return json.dumps({"error": "Transcriber not initialized"})
# 创建结构化数据
if text:
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
conversation.append({
"时间": current_time,
"角色": current_speaker,
"内容": text
})
# 切换说话者
current_speaker = "医生" if current_speaker == "患者" else "患者"
# 将对话记录转换为格式化的字符串
formatted_conversation = json.dumps(conversation, ensure_ascii=False, indent=2)
logging.info(f"Formatted conversation: {formatted_conversation}")
return formatted_conversation
except Exception as e:
logging.error(f"Error in transcribe function: {e}", exc_info=True)
return json.dumps({"error": f"Error: {str(e)}"})
def switch_speaker():
global current_speaker
current_speaker = "医生" if current_speaker == "患者" else "患者"
return f"当前说话者:{current_speaker}"
def generate_memo(conversation_json):
try:
llm = ChatOpenAI(
model="glm-3-turbo",
temperature=0.7,
openai_api_key=zhipuai_api_key,
openai_api_base="https://open.bigmodel.cn/api/paas/v4/"
)
prompt = f"""
请根据以下医生和患者的对话,生成一份结构化的备忘录。备忘录应包含以下字段:主诉、检查、诊断、治疗和备注。
如果某个字段在对话中没有明确提及,请填写"未提及"。
对话内容:
{conversation_json}
请以JSON格式输出备忘录,格式如下:
{{
"主诉": "患者的主要症状和不适",
"检查": "医生建议或已进行的检查",
"诊断": "医生对患者的诊断",
"治疗": "医生对患者的治疗建议",
"备注": "医生对患者的备注"
}}
"""
output = llm.invoke(prompt)
output_parser = StrOutputParser()
output = output_parser.invoke(output)
logging.info(f"Generated memo: {output}")
return output
except Exception as e:
logging.error(f"Error in generate_memo function: {e}", exc_info=True)
return f"Error generating memo: {str(e)}"
# 创建Gradio界面
with gr.Blocks() as demo:
gr.Markdown("# 实时中文对话转录与备忘录生成")
gr.Markdown("点击麦克风图标开始录音,说话后会自动进行语音识别。支持中文识别。")
with gr.Row():
audio_input = gr.Audio(sources=["microphone"], type="numpy", sample_rate=16000)
speaker_button = gr.Button("切换说话者")
speaker_label = gr.Label("当前说话者:患者")
conversation_output = gr.JSON(label="对话记录")
memo_output = gr.JSON(label="备忘录")
generate_memo_button = gr.Button("生成备忘录")
audio_input.change(transcribe, inputs=[audio_input], outputs=[conversation_output])
speaker_button.click(switch_speaker, outputs=[speaker_label])
generate_memo_button.click(generate_memo, inputs=[conversation_output], outputs=[memo_output])
if __name__ == "__main__":
demo.launch()