from transformers import pipeline import gradio as gr import numpy as np import time import json import os from langchain_openai import ChatOpenAI from langchain_core.output_parsers import StrOutputParser from dotenv import load_dotenv import logging import librosa # 设置日志记录 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') load_dotenv() zhipuai_api_key = os.getenv("ZHIPUAI_API_KEY") if not zhipuai_api_key: logging.error("ZHIPUAI_API_KEY is not set in the environment variables") raise ValueError("ZHIPUAI_API_KEY is not set") # 使用中文优化的模型 try: transcriber = pipeline("automatic-speech-recognition", model="ydshieh/wav2vec2-large-xlsr-53-chinese-zh-cn-gpt") logging.info("Chinese ASR model loaded successfully") except Exception as e: logging.error(f"Error loading Chinese ASR model: {e}") transcriber = None # 初始化对话记录 conversation = [] current_speaker = "患者" def transcribe(audio): global current_speaker if audio is None: logging.info("No audio input received") return json.dumps({"error": "No audio input received"}) try: logging.info(f"Audio input received: {type(audio)}") if isinstance(audio, tuple): sr, y = audio else: y = audio sr = 16000 # 默认采样率,如果实际不是这个值,可能需要重采样 logging.info(f"Sample rate: {sr}, Audio data shape: {y.shape}, Audio data type: {y.dtype}") # 确保音频数据是浮点型 if not np.issubdtype(y.dtype, np.floating): y = y.astype(np.float32) # 如果采样率不是16000Hz,需要进行重采样 if sr != 16000: logging.info("Resampling audio to 16000Hz") y = librosa.resample(y, orig_sr=sr, target_sr=16000) sr = 16000 # 转换为单声道 if y.ndim > 1: y = y.mean(axis=1) logging.info(f"Audio data shape after conversion: {y.shape}") # 归一化音频数据 y /= np.max(np.abs(y)) # 使用中文进行转录 if transcriber is not None: logging.info("Starting transcription") try: result = transcriber({"sampling_rate": sr, "raw": y}, generate_kwargs={"language": "chinese"}) text = result["text"].strip() logging.info(f"Transcription result: {text}") except Exception as e: logging.error(f"Error during transcription: {e}", exc_info=True) return json.dumps({"error": f"Transcription error: {str(e)}"}) else: logging.error("Transcriber not initialized") return json.dumps({"error": "Transcriber not initialized"}) # 创建结构化数据 if text: current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) conversation.append({ "时间": current_time, "角色": current_speaker, "内容": text }) # 切换说话者 current_speaker = "医生" if current_speaker == "患者" else "患者" # 将对话记录转换为格式化的字符串 formatted_conversation = json.dumps(conversation, ensure_ascii=False, indent=2) logging.info(f"Formatted conversation: {formatted_conversation}") return formatted_conversation except Exception as e: logging.error(f"Error in transcribe function: {e}", exc_info=True) return json.dumps({"error": f"Error: {str(e)}"}) def switch_speaker(): global current_speaker current_speaker = "医生" if current_speaker == "患者" else "患者" return f"当前说话者:{current_speaker}" def generate_memo(conversation): try: logging.info(f"Generating memo for conversation: {conversation}") # 确保 conversation 是一个列表 if not isinstance(conversation, list): raise ValueError("Conversation must be a list") # 提取对话内容 dialogue = "\n".join([f"{item['角色']}: {item['内容']}" for item in conversation if '角色' in item and '内容' in item]) llm = ChatOpenAI( model="glm-3-turbo", temperature=0.7, openai_api_key=zhipuai_api_key, openai_api_base="https://open.bigmodel.cn/api/paas/v4/" ) prompt = f""" 请根据以下医生和患者的对话,生成一份结构化的备忘录。备忘录应包含以下字段:主诉、检查、诊断、治疗和备注。 如果某个字段在对话中没有明确提及,请填写"未提及"。 对话内容: {dialogue} 请以JSON格式输出备忘录,格式如下: {{ "主诉": "患者的主要症状和不适", "检查": "医生建议或已进行的检查", "诊断": "医生对患者的诊断", "治疗": "医生对患者的治疗建议", "备注": "医生对患者的备注" }} 请确保输出是有效的JSON格式。 """ logging.info(f"Sending prompt to LLM: {prompt}") output = llm.invoke(prompt) output_parser = StrOutputParser() output = output_parser.invoke(output) logging.info(f"Generated memo: {output}") # 尝试解析输出为 JSON try: json_output = json.loads(output) except json.JSONDecodeError: # 如果无法解析为 JSON,尝试提取 JSON 部分 import re json_match = re.search(r'\{.*\}', output, re.DOTALL) if json_match: json_output = json.loads(json_match.group()) else: raise ValueError("Unable to extract JSON from LLM output") return json.dumps({"result": json_output}, ensure_ascii=False) except Exception as e: error_msg = f"Error in generate_memo function: {str(e)}" logging.error(error_msg, exc_info=True) return json.dumps({"error": error_msg}) def safe_generate_memo(conversation_data): try: # 如果 conversation_data 是字符串,尝试解析为 JSON if isinstance(conversation_data, str): try: conversation = json.loads(conversation_data) except json.JSONDecodeError: # 如果无法解析为 JSON,将其视为纯文本 conversation = [{"内容": conversation_data}] elif isinstance(conversation_data, list): conversation = conversation_data else: conversation = [{"内容": str(conversation_data)}] return generate_memo(conversation) except Exception as e: logging.error(f"Error in safe_generate_memo: {e}", exc_info=True) return json.dumps({"error": f"生成备忘录时出错: {str(e)}"}) def display_memo(memo): try: # 如果 memo 是字符串,尝试解析为 JSON if isinstance(memo, str): memo_json = json.loads(memo) else: memo_json = memo if "error" in memo_json: return f"生成备忘录时出错: {memo_json['error']}" elif "result" in memo_json: if isinstance(memo_json["result"], str): try: result_json = json.loads(memo_json["result"]) return json.dumps(result_json, ensure_ascii=False, indent=2) except json.JSONDecodeError: return memo_json["result"] else: return json.dumps(memo_json["result"], ensure_ascii=False, indent=2) else: return json.dumps(memo_json, ensure_ascii=False, indent=2) except json.JSONDecodeError: return f"无效的 JSON 格式: {memo}" except Exception as e: return f"处理备忘录时出错: {str(e)}" def prepare_conversation(conversation_data): try: if isinstance(conversation_data, str): return json.loads(conversation_data) return conversation_data except json.JSONDecodeError: return [{"内容": conversation_data}] # 创建Gradio界面 with gr.Blocks() as demo: gr.Markdown("# 实时中文对话转录与备忘录生成") gr.Markdown("点击麦克风图标开始录音,说话后会自动进行语音识别。支持中文识别。") with gr.Row(): audio_input = gr.Audio(sources=["microphone"], type="numpy") speaker_button = gr.Button("切换说话者") speaker_label = gr.Label("当前说话者:患者") debug_output = gr.Textbox(label="调试信息") conversation_output = gr.JSON(label="对话记录") memo_output = gr.JSON(label="备忘录") generate_memo_button = gr.Button("生成备忘录") def debug_audio(audio): if audio is None: return "No audio input received" return f"Audio type: {type(audio)}, Shape: {audio[1].shape if isinstance(audio, tuple) else audio.shape}, Dtype: {audio[1].dtype if isinstance(audio, tuple) else audio.dtype}" audio_input.change(debug_audio, inputs=[audio_input], outputs=[debug_output]) audio_input.change(transcribe, inputs=[audio_input], outputs=[conversation_output]) speaker_button.click(switch_speaker, outputs=[speaker_label]) generate_memo_button.click( prepare_conversation, inputs=[conversation_output], outputs=[conversation_output] ).then( safe_generate_memo, inputs=[conversation_output], outputs=[memo_output] ).then( display_memo, inputs=[memo_output], outputs=[memo_output] ) if __name__ == "__main__": demo.launch()