Spaces:
Running
Running
File size: 9,916 Bytes
3fee10b 9bf5d77 3fee10b 9bf5d77 99d5161 f523886 dc9e46f 3fee10b 9bf5d77 1226387 64fc971 3fee10b 64fc971 3fee10b 64fc971 3fee10b 9bf5d77 99d5161 e572754 9bf5d77 dc9e46f 61642c8 3fee10b f523886 3fee10b 92462f8 9bf5d77 f523886 3fee10b 99d5161 3fee10b 61642c8 92462f8 3fee10b 64fc971 e572754 3fee10b e572754 3fee10b 9bf5d77 dc9e46f 99d5161 dc9e46f 99d5161 e572754 9bf5d77 8acea2f 9bf5d77 8acea2f 3fee10b 8acea2f 1226387 808b728 3fee10b 808b728 3fee10b 808b728 3fee10b 1226387 3fee10b 1320e9a 808b728 3fee10b 1226387 c5a1fa6 8acea2f c5a1fa6 1226387 f1be757 1226387 f1be757 808b728 f1be757 1226387 f1be757 1320e9a 3fee10b 8acea2f 9bf5d77 3fee10b 9bf5d77 f523886 9bf5d77 92462f8 9bf5d77 3fee10b 9bf5d77 92462f8 61642c8 9bf5d77 f1be757 8acea2f f1be757 1226387 9bf5d77 3fee10b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 |
from transformers import pipeline
import gradio as gr
import numpy as np
import time
import json
import os
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from dotenv import load_dotenv
import logging
import librosa
# 设置日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
load_dotenv()
zhipuai_api_key = os.getenv("ZHIPUAI_API_KEY")
if not zhipuai_api_key:
logging.error("ZHIPUAI_API_KEY is not set in the environment variables")
raise ValueError("ZHIPUAI_API_KEY is not set")
# 使用中文优化的模型
try:
transcriber = pipeline("automatic-speech-recognition", model="ydshieh/wav2vec2-large-xlsr-53-chinese-zh-cn-gpt")
logging.info("Chinese ASR model loaded successfully")
except Exception as e:
logging.error(f"Error loading Chinese ASR model: {e}")
transcriber = None
# 初始化对话记录
conversation = []
current_speaker = "患者"
def transcribe(audio):
global current_speaker
if audio is None:
logging.info("No audio input received")
return json.dumps({"error": "No audio input received"})
try:
logging.info(f"Audio input received: {type(audio)}")
if isinstance(audio, tuple):
sr, y = audio
else:
y = audio
sr = 16000 # 默认采样率,如果实际不是这个值,可能需要重采样
logging.info(f"Sample rate: {sr}, Audio data shape: {y.shape}, Audio data type: {y.dtype}")
# 确保音频数据是浮点型
if not np.issubdtype(y.dtype, np.floating):
y = y.astype(np.float32)
# 如果采样率不是16000Hz,需要进行重采样
if sr != 16000:
logging.info("Resampling audio to 16000Hz")
y = librosa.resample(y, orig_sr=sr, target_sr=16000)
sr = 16000
# 转换为单声道
if y.ndim > 1:
y = y.mean(axis=1)
logging.info(f"Audio data shape after conversion: {y.shape}")
# 归一化音频数据
y /= np.max(np.abs(y))
# 使用中文进行转录
if transcriber is not None:
logging.info("Starting transcription")
try:
result = transcriber({"sampling_rate": sr, "raw": y}, generate_kwargs={"language": "chinese"})
text = result["text"].strip()
logging.info(f"Transcription result: {text}")
except Exception as e:
logging.error(f"Error during transcription: {e}", exc_info=True)
return json.dumps({"error": f"Transcription error: {str(e)}"})
else:
logging.error("Transcriber not initialized")
return json.dumps({"error": "Transcriber not initialized"})
# 创建结构化数据
if text:
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
conversation.append({
"时间": current_time,
"角色": current_speaker,
"内容": text
})
# 切换说话者
current_speaker = "医生" if current_speaker == "患者" else "患者"
# 将对话记录转换为格式化的字符串
formatted_conversation = json.dumps(conversation, ensure_ascii=False, indent=2)
logging.info(f"Formatted conversation: {formatted_conversation}")
return formatted_conversation
except Exception as e:
logging.error(f"Error in transcribe function: {e}", exc_info=True)
return json.dumps({"error": f"Error: {str(e)}"})
def switch_speaker():
global current_speaker
current_speaker = "医生" if current_speaker == "患者" else "患者"
return f"当前说话者:{current_speaker}"
def generate_memo(conversation):
try:
logging.info(f"Generating memo for conversation: {conversation}")
# 确保 conversation 是一个列表
if not isinstance(conversation, list):
raise ValueError("Conversation must be a list")
# 提取对话内容
dialogue = "\n".join([f"{item['角色']}: {item['内容']}" for item in conversation if '角色' in item and '内容' in item])
llm = ChatOpenAI(
model="glm-3-turbo",
temperature=0.7,
openai_api_key=zhipuai_api_key,
openai_api_base="https://open.bigmodel.cn/api/paas/v4/"
)
prompt = f"""
请根据以下医生和患者的对话,生成一份结构化的备忘录。备忘录应包含以下字段:主诉、检查、诊断、治疗和备注。
如果某个字段在对话中没有明确提及,请填写"未提及"。
对话内容:
{dialogue}
请以JSON格式输出备忘录,格式如下:
{{
"主诉": "患者的主要症状和不适",
"检查": "医生建议或已进行的检查",
"诊断": "医生对患者的诊断",
"治疗": "医生对患者的治疗建议",
"备注": "医生对患者的备注"
}}
请确保输出是有效的JSON格式。
"""
logging.info(f"Sending prompt to LLM: {prompt}")
output = llm.invoke(prompt)
output_parser = StrOutputParser()
output = output_parser.invoke(output)
logging.info(f"Generated memo: {output}")
# 尝试解析输出为 JSON
try:
json_output = json.loads(output)
except json.JSONDecodeError:
# 如果无法解析为 JSON,尝试提取 JSON 部分
import re
json_match = re.search(r'\{.*\}', output, re.DOTALL)
if json_match:
json_output = json.loads(json_match.group())
else:
raise ValueError("Unable to extract JSON from LLM output")
return json.dumps({"result": json_output}, ensure_ascii=False)
except Exception as e:
error_msg = f"Error in generate_memo function: {str(e)}"
logging.error(error_msg, exc_info=True)
return json.dumps({"error": error_msg})
def safe_generate_memo(conversation_data):
try:
# 如果 conversation_data 是字符串,尝试解析为 JSON
if isinstance(conversation_data, str):
try:
conversation = json.loads(conversation_data)
except json.JSONDecodeError:
# 如果无法解析为 JSON,将其视为纯文本
conversation = [{"内容": conversation_data}]
elif isinstance(conversation_data, list):
conversation = conversation_data
else:
conversation = [{"内容": str(conversation_data)}]
return generate_memo(conversation)
except Exception as e:
logging.error(f"Error in safe_generate_memo: {e}", exc_info=True)
return json.dumps({"error": f"生成备忘录时出错: {str(e)}"})
def display_memo(memo):
try:
# 如果 memo 是字符串,尝试解析为 JSON
if isinstance(memo, str):
memo_json = json.loads(memo)
else:
memo_json = memo
if "error" in memo_json:
return f"生成备忘录时出错: {memo_json['error']}"
elif "result" in memo_json:
if isinstance(memo_json["result"], str):
try:
result_json = json.loads(memo_json["result"])
return json.dumps(result_json, ensure_ascii=False, indent=2)
except json.JSONDecodeError:
return memo_json["result"]
else:
return json.dumps(memo_json["result"], ensure_ascii=False, indent=2)
else:
return json.dumps(memo_json, ensure_ascii=False, indent=2)
except json.JSONDecodeError:
return f"无效的 JSON 格式: {memo}"
except Exception as e:
return f"处理备忘录时出错: {str(e)}"
def prepare_conversation(conversation_data):
try:
if isinstance(conversation_data, str):
return json.loads(conversation_data)
return conversation_data
except json.JSONDecodeError:
return [{"内容": conversation_data}]
# 创建Gradio界面
with gr.Blocks() as demo:
gr.Markdown("# 实时中文对话转录与备忘录生成")
gr.Markdown("点击麦克风图标开始录音,说话后会自动进行语音识别。支持中文识别。")
with gr.Row():
audio_input = gr.Audio(sources=["microphone"], type="numpy")
speaker_button = gr.Button("切换说话者")
speaker_label = gr.Label("当前说话者:患者")
debug_output = gr.Textbox(label="调试信息")
conversation_output = gr.JSON(label="对话记录")
memo_output = gr.JSON(label="备忘录")
generate_memo_button = gr.Button("生成备忘录")
def debug_audio(audio):
if audio is None:
return "No audio input received"
return f"Audio type: {type(audio)}, Shape: {audio[1].shape if isinstance(audio, tuple) else audio.shape}, Dtype: {audio[1].dtype if isinstance(audio, tuple) else audio.dtype}"
audio_input.change(debug_audio, inputs=[audio_input], outputs=[debug_output])
audio_input.change(transcribe, inputs=[audio_input], outputs=[conversation_output])
speaker_button.click(switch_speaker, outputs=[speaker_label])
generate_memo_button.click(
prepare_conversation,
inputs=[conversation_output],
outputs=[conversation_output]
).then(
safe_generate_memo,
inputs=[conversation_output],
outputs=[memo_output]
).then(
display_memo,
inputs=[memo_output],
outputs=[memo_output]
)
if __name__ == "__main__":
demo.launch()
|