Abigail99216's picture
Update app.py
808b728 verified
from transformers import pipeline
import gradio as gr
import numpy as np
import time
import json
import os
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from dotenv import load_dotenv
import logging
import librosa
# 设置日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
load_dotenv()
zhipuai_api_key = os.getenv("ZHIPUAI_API_KEY")
if not zhipuai_api_key:
logging.error("ZHIPUAI_API_KEY is not set in the environment variables")
raise ValueError("ZHIPUAI_API_KEY is not set")
# 使用中文优化的模型
try:
transcriber = pipeline("automatic-speech-recognition", model="ydshieh/wav2vec2-large-xlsr-53-chinese-zh-cn-gpt")
logging.info("Chinese ASR model loaded successfully")
except Exception as e:
logging.error(f"Error loading Chinese ASR model: {e}")
transcriber = None
# 初始化对话记录
conversation = []
current_speaker = "患者"
def transcribe(audio):
global current_speaker
if audio is None:
logging.info("No audio input received")
return json.dumps({"error": "No audio input received"})
try:
logging.info(f"Audio input received: {type(audio)}")
if isinstance(audio, tuple):
sr, y = audio
else:
y = audio
sr = 16000 # 默认采样率,如果实际不是这个值,可能需要重采样
logging.info(f"Sample rate: {sr}, Audio data shape: {y.shape}, Audio data type: {y.dtype}")
# 确保音频数据是浮点型
if not np.issubdtype(y.dtype, np.floating):
y = y.astype(np.float32)
# 如果采样率不是16000Hz,需要进行重采样
if sr != 16000:
logging.info("Resampling audio to 16000Hz")
y = librosa.resample(y, orig_sr=sr, target_sr=16000)
sr = 16000
# 转换为单声道
if y.ndim > 1:
y = y.mean(axis=1)
logging.info(f"Audio data shape after conversion: {y.shape}")
# 归一化音频数据
y /= np.max(np.abs(y))
# 使用中文进行转录
if transcriber is not None:
logging.info("Starting transcription")
try:
result = transcriber({"sampling_rate": sr, "raw": y}, generate_kwargs={"language": "chinese"})
text = result["text"].strip()
logging.info(f"Transcription result: {text}")
except Exception as e:
logging.error(f"Error during transcription: {e}", exc_info=True)
return json.dumps({"error": f"Transcription error: {str(e)}"})
else:
logging.error("Transcriber not initialized")
return json.dumps({"error": "Transcriber not initialized"})
# 创建结构化数据
if text:
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
conversation.append({
"时间": current_time,
"角色": current_speaker,
"内容": text
})
# 切换说话者
current_speaker = "医生" if current_speaker == "患者" else "患者"
# 将对话记录转换为格式化的字符串
formatted_conversation = json.dumps(conversation, ensure_ascii=False, indent=2)
logging.info(f"Formatted conversation: {formatted_conversation}")
return formatted_conversation
except Exception as e:
logging.error(f"Error in transcribe function: {e}", exc_info=True)
return json.dumps({"error": f"Error: {str(e)}"})
def switch_speaker():
global current_speaker
current_speaker = "医生" if current_speaker == "患者" else "患者"
return f"当前说话者:{current_speaker}"
def generate_memo(conversation):
try:
logging.info(f"Generating memo for conversation: {conversation}")
# 确保 conversation 是一个列表
if not isinstance(conversation, list):
raise ValueError("Conversation must be a list")
# 提取对话内容
dialogue = "\n".join([f"{item['角色']}: {item['内容']}" for item in conversation if '角色' in item and '内容' in item])
llm = ChatOpenAI(
model="glm-3-turbo",
temperature=0.7,
openai_api_key=zhipuai_api_key,
openai_api_base="https://open.bigmodel.cn/api/paas/v4/"
)
prompt = f"""
请根据以下医生和患者的对话,生成一份结构化的备忘录。备忘录应包含以下字段:主诉、检查、诊断、治疗和备注。
如果某个字段在对话中没有明确提及,请填写"未提及"。
对话内容:
{dialogue}
请以JSON格式输出备忘录,格式如下:
{{
"主诉": "患者的主要症状和不适",
"检查": "医生建议或已进行的检查",
"诊断": "医生对患者的诊断",
"治疗": "医生对患者的治疗建议",
"备注": "医生对患者的备注"
}}
请确保输出是有效的JSON格式。
"""
logging.info(f"Sending prompt to LLM: {prompt}")
output = llm.invoke(prompt)
output_parser = StrOutputParser()
output = output_parser.invoke(output)
logging.info(f"Generated memo: {output}")
# 尝试解析输出为 JSON
try:
json_output = json.loads(output)
except json.JSONDecodeError:
# 如果无法解析为 JSON,尝试提取 JSON 部分
import re
json_match = re.search(r'\{.*\}', output, re.DOTALL)
if json_match:
json_output = json.loads(json_match.group())
else:
raise ValueError("Unable to extract JSON from LLM output")
return json.dumps({"result": json_output}, ensure_ascii=False)
except Exception as e:
error_msg = f"Error in generate_memo function: {str(e)}"
logging.error(error_msg, exc_info=True)
return json.dumps({"error": error_msg})
def safe_generate_memo(conversation_data):
try:
# 如果 conversation_data 是字符串,尝试解析为 JSON
if isinstance(conversation_data, str):
try:
conversation = json.loads(conversation_data)
except json.JSONDecodeError:
# 如果无法解析为 JSON,将其视为纯文本
conversation = [{"内容": conversation_data}]
elif isinstance(conversation_data, list):
conversation = conversation_data
else:
conversation = [{"内容": str(conversation_data)}]
return generate_memo(conversation)
except Exception as e:
logging.error(f"Error in safe_generate_memo: {e}", exc_info=True)
return json.dumps({"error": f"生成备忘录时出错: {str(e)}"})
def display_memo(memo):
try:
# 如果 memo 是字符串,尝试解析为 JSON
if isinstance(memo, str):
memo_json = json.loads(memo)
else:
memo_json = memo
if "error" in memo_json:
return f"生成备忘录时出错: {memo_json['error']}"
elif "result" in memo_json:
if isinstance(memo_json["result"], str):
try:
result_json = json.loads(memo_json["result"])
return json.dumps(result_json, ensure_ascii=False, indent=2)
except json.JSONDecodeError:
return memo_json["result"]
else:
return json.dumps(memo_json["result"], ensure_ascii=False, indent=2)
else:
return json.dumps(memo_json, ensure_ascii=False, indent=2)
except json.JSONDecodeError:
return f"无效的 JSON 格式: {memo}"
except Exception as e:
return f"处理备忘录时出错: {str(e)}"
def prepare_conversation(conversation_data):
try:
if isinstance(conversation_data, str):
return json.loads(conversation_data)
return conversation_data
except json.JSONDecodeError:
return [{"内容": conversation_data}]
# 创建Gradio界面
with gr.Blocks() as demo:
gr.Markdown("# 实时中文对话转录与备忘录生成")
gr.Markdown("点击麦克风图标开始录音,说话后会自动进行语音识别。支持中文识别。")
with gr.Row():
audio_input = gr.Audio(sources=["microphone"], type="numpy")
speaker_button = gr.Button("切换说话者")
speaker_label = gr.Label("当前说话者:患者")
debug_output = gr.Textbox(label="调试信息")
conversation_output = gr.JSON(label="对话记录")
memo_output = gr.JSON(label="备忘录")
generate_memo_button = gr.Button("生成备忘录")
def debug_audio(audio):
if audio is None:
return "No audio input received"
return f"Audio type: {type(audio)}, Shape: {audio[1].shape if isinstance(audio, tuple) else audio.shape}, Dtype: {audio[1].dtype if isinstance(audio, tuple) else audio.dtype}"
audio_input.change(debug_audio, inputs=[audio_input], outputs=[debug_output])
audio_input.change(transcribe, inputs=[audio_input], outputs=[conversation_output])
speaker_button.click(switch_speaker, outputs=[speaker_label])
generate_memo_button.click(
prepare_conversation,
inputs=[conversation_output],
outputs=[conversation_output]
).then(
safe_generate_memo,
inputs=[conversation_output],
outputs=[memo_output]
).then(
display_memo,
inputs=[memo_output],
outputs=[memo_output]
)
if __name__ == "__main__":
demo.launch()