Spaces:
Sleeping
Sleeping
File size: 5,517 Bytes
3fee10b 9bf5d77 3fee10b 9bf5d77 99d5161 dc9e46f 3fee10b 9bf5d77 64fc971 3fee10b 64fc971 3fee10b 64fc971 3fee10b 9bf5d77 99d5161 e572754 9bf5d77 dc9e46f 61642c8 3fee10b 9bf5d77 3fee10b 99d5161 3fee10b 61642c8 3fee10b 64fc971 e572754 3fee10b e572754 3fee10b 9bf5d77 dc9e46f 99d5161 dc9e46f 99d5161 e572754 9bf5d77 3fee10b 9bf5d77 3fee10b 9bf5d77 e572754 9bf5d77 3fee10b 9bf5d77 61642c8 9bf5d77 3fee10b 9bf5d77 3fee10b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
from transformers import pipeline
import gradio as gr
import numpy as np
import time
import json
import os
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from dotenv import load_dotenv
import logging
# 设置日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
load_dotenv()
zhipuai_api_key = os.getenv("ZHIPUAI_API_KEY")
# 使用中文优化的模型
try:
transcriber = pipeline("automatic-speech-recognition", model="ydshieh/wav2vec2-large-xlsr-53-chinese-zh-cn-gpt")
logging.info("Chinese ASR model loaded successfully")
except Exception as e:
logging.error(f"Error loading Chinese ASR model: {e}")
transcriber = None
# 初始化对话记录
conversation = []
current_speaker = "患者"
def transcribe(audio):
global current_speaker
if audio is None:
logging.info("No audio input received")
return json.dumps({"error": "No audio input received"})
try:
logging.info(f"Audio input received: {type(audio)}")
if isinstance(audio, tuple):
sr, y = audio
else:
y = audio
sr = 16000 # 假设采样率为16kHz,如果不是,请相应调整
logging.info(f"Sample rate: {sr}, Audio data shape: {y.shape}")
# 转换为单声道
if y.ndim > 1:
y = y.mean(axis=1)
logging.info(f"Audio data shape after conversion: {y.shape}")
y = y.astype(np.float32)
y /= np.max(np.abs(y))
# 使用中文进行转录
if transcriber is not None:
logging.info("Starting transcription")
try:
result = transcriber({"sampling_rate": sr, "raw": y}, generate_kwargs={"language": "chinese"})
text = result["text"].strip()
logging.info(f"Transcription result: {text}")
except Exception as e:
logging.error(f"Error during transcription: {e}", exc_info=True)
return json.dumps({"error": f"Transcription error: {str(e)}"})
else:
logging.error("Transcriber not initialized")
return json.dumps({"error": "Transcriber not initialized"})
# 创建结构化数据
if text:
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
conversation.append({
"时间": current_time,
"角色": current_speaker,
"内容": text
})
# 切换说话者
current_speaker = "医生" if current_speaker == "患者" else "患者"
# 将对话记录转换为格式化的字符串
formatted_conversation = json.dumps(conversation, ensure_ascii=False, indent=2)
logging.info(f"Formatted conversation: {formatted_conversation}")
return formatted_conversation
except Exception as e:
logging.error(f"Error in transcribe function: {e}", exc_info=True)
return json.dumps({"error": f"Error: {str(e)}"})
def switch_speaker():
global current_speaker
current_speaker = "医生" if current_speaker == "患者" else "患者"
return f"当前说话者:{current_speaker}"
def generate_memo(conversation_json):
try:
llm = ChatOpenAI(
model="glm-3-turbo",
temperature=0.7,
openai_api_key=zhipuai_api_key,
openai_api_base="https://open.bigmodel.cn/api/paas/v4/"
)
prompt = f"""
请根据以下医生和患者的对话,生成一份结构化的备忘录。备忘录应包含以下字段:主诉、检查、诊断、治疗和备注。
如果某个字段在对话中没有明确提及,请填写"未提及"。
对话内容:
{conversation_json}
请以JSON格式输出备忘录,格式如下:
{{
"主诉": "患者的主要症状和不适",
"检查": "医生建议或已进行的检查",
"诊断": "医生对患者的诊断",
"治疗": "医生对患者的治疗建议",
"备注": "医生对患者的备注"
}}
"""
output = llm.invoke(prompt)
output_parser = StrOutputParser()
output = output_parser.invoke(output)
logging.info(f"Generated memo: {output}")
return output
except Exception as e:
logging.error(f"Error in generate_memo function: {e}", exc_info=True)
return f"Error generating memo: {str(e)}"
# 创建Gradio界面
with gr.Blocks() as demo:
gr.Markdown("# 实时中文对话转录与备忘录生成")
gr.Markdown("点击麦克风图标开始录音,说话后会自动进行语音识别。支持中文识别。")
with gr.Row():
audio_input = gr.Audio(sources=["microphone"], type="numpy", sample_rate=16000)
speaker_button = gr.Button("切换说话者")
speaker_label = gr.Label("当前说话者:患者")
conversation_output = gr.JSON(label="对话记录")
memo_output = gr.JSON(label="备忘录")
generate_memo_button = gr.Button("生成备忘录")
audio_input.change(transcribe, inputs=[audio_input], outputs=[conversation_output])
speaker_button.click(switch_speaker, outputs=[speaker_label])
generate_memo_button.click(generate_memo, inputs=[conversation_output], outputs=[memo_output])
if __name__ == "__main__":
demo.launch()
|