File size: 5,517 Bytes
3fee10b
9bf5d77
 
 
 
 
3fee10b
 
9bf5d77
99d5161
dc9e46f
3fee10b
 
 
9bf5d77
 
 
64fc971
3fee10b
64fc971
 
3fee10b
64fc971
3fee10b
9bf5d77
 
 
 
 
 
 
 
99d5161
e572754
9bf5d77
dc9e46f
61642c8
3fee10b
 
 
 
 
 
 
9bf5d77
3fee10b
 
 
99d5161
3fee10b
61642c8
3fee10b
 
 
 
 
 
64fc971
 
 
 
 
 
e572754
3fee10b
 
e572754
3fee10b
 
 
 
 
 
 
 
 
 
 
 
9bf5d77
dc9e46f
 
99d5161
dc9e46f
 
99d5161
e572754
9bf5d77
 
 
 
 
 
3fee10b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bf5d77
 
3fee10b
 
9bf5d77
 
e572754
9bf5d77
 
 
 
3fee10b
 
 
9bf5d77
61642c8
9bf5d77
3fee10b
9bf5d77
 
3fee10b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from transformers import pipeline
import gradio as gr
import numpy as np
import time
import json
import os
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from dotenv import load_dotenv
import logging

# 设置日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

load_dotenv()
zhipuai_api_key = os.getenv("ZHIPUAI_API_KEY")

# 使用中文优化的模型
try:
    transcriber = pipeline("automatic-speech-recognition", model="ydshieh/wav2vec2-large-xlsr-53-chinese-zh-cn-gpt")
    logging.info("Chinese ASR model loaded successfully")
except Exception as e:
    logging.error(f"Error loading Chinese ASR model: {e}")
    transcriber = None

# 初始化对话记录
conversation = []
current_speaker = "患者"

def transcribe(audio):
    global current_speaker
    if audio is None:
        logging.info("No audio input received")
        return json.dumps({"error": "No audio input received"})
    
    try:
        logging.info(f"Audio input received: {type(audio)}")
        if isinstance(audio, tuple):
            sr, y = audio
        else:
            y = audio
            sr = 16000  # 假设采样率为16kHz,如果不是,请相应调整

        logging.info(f"Sample rate: {sr}, Audio data shape: {y.shape}")
        
        # 转换为单声道
        if y.ndim > 1:
            y = y.mean(axis=1)
        
        logging.info(f"Audio data shape after conversion: {y.shape}")
        
        y = y.astype(np.float32)
        y /= np.max(np.abs(y))

        # 使用中文进行转录
        if transcriber is not None:
            logging.info("Starting transcription")
            try:
                result = transcriber({"sampling_rate": sr, "raw": y}, generate_kwargs={"language": "chinese"})
                text = result["text"].strip()
                logging.info(f"Transcription result: {text}")
            except Exception as e:
                logging.error(f"Error during transcription: {e}", exc_info=True)
                return json.dumps({"error": f"Transcription error: {str(e)}"})
        else:
            logging.error("Transcriber not initialized")
            return json.dumps({"error": "Transcriber not initialized"})

        # 创建结构化数据
        if text:
            current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
            conversation.append({
                "时间": current_time,
                "角色": current_speaker,
                "内容": text
            })
            
            # 切换说话者
            current_speaker = "医生" if current_speaker == "患者" else "患者"

        # 将对话记录转换为格式化的字符串
        formatted_conversation = json.dumps(conversation, ensure_ascii=False, indent=2)
        logging.info(f"Formatted conversation: {formatted_conversation}")
        return formatted_conversation
    except Exception as e:
        logging.error(f"Error in transcribe function: {e}", exc_info=True)
        return json.dumps({"error": f"Error: {str(e)}"})

def switch_speaker():
    global current_speaker
    current_speaker = "医生" if current_speaker == "患者" else "患者"
    return f"当前说话者:{current_speaker}"

def generate_memo(conversation_json):
    try:
        llm = ChatOpenAI(
            model="glm-3-turbo",
            temperature=0.7, 
            openai_api_key=zhipuai_api_key,
            openai_api_base="https://open.bigmodel.cn/api/paas/v4/"
        )
        
        prompt = f"""
        请根据以下医生和患者的对话,生成一份结构化的备忘录。备忘录应包含以下字段:主诉、检查、诊断、治疗和备注。
        如果某个字段在对话中没有明确提及,请填写"未提及"。

        对话内容:
        {conversation_json}

        请以JSON格式输出备忘录,格式如下:
        {{
            "主诉": "患者的主要症状和不适",
            "检查": "医生建议或已进行的检查",
            "诊断": "医生对患者的诊断",
            "治疗": "医生对患者的治疗建议",
            "备注": "医生对患者的备注"
        }}
        """
        
        output = llm.invoke(prompt)
        output_parser = StrOutputParser()
        output = output_parser.invoke(output)
        logging.info(f"Generated memo: {output}")
        return output
    except Exception as e:
        logging.error(f"Error in generate_memo function: {e}", exc_info=True)
        return f"Error generating memo: {str(e)}"

# 创建Gradio界面
with gr.Blocks() as demo:
    gr.Markdown("# 实时中文对话转录与备忘录生成")
    gr.Markdown("点击麦克风图标开始录音,说话后会自动进行语音识别。支持中文识别。")
    
    with gr.Row():
        audio_input = gr.Audio(sources=["microphone"], type="numpy", sample_rate=16000)
        speaker_button = gr.Button("切换说话者")
    
    speaker_label = gr.Label("当前说话者:患者")
    conversation_output = gr.JSON(label="对话记录")
    memo_output = gr.JSON(label="备忘录")
    
    generate_memo_button = gr.Button("生成备忘录")
    
    audio_input.change(transcribe, inputs=[audio_input], outputs=[conversation_output])
    speaker_button.click(switch_speaker, outputs=[speaker_label])
    generate_memo_button.click(generate_memo, inputs=[conversation_output], outputs=[memo_output])

if __name__ == "__main__":
    demo.launch()