File size: 9,916 Bytes
3fee10b
9bf5d77
 
 
 
 
3fee10b
 
9bf5d77
99d5161
f523886
dc9e46f
3fee10b
 
 
9bf5d77
 
 
1226387
 
 
 
64fc971
3fee10b
64fc971
 
3fee10b
64fc971
3fee10b
9bf5d77
 
 
 
 
 
 
 
99d5161
e572754
9bf5d77
dc9e46f
61642c8
3fee10b
 
 
 
f523886
3fee10b
92462f8
 
 
 
 
9bf5d77
f523886
 
 
 
 
 
3fee10b
 
 
99d5161
3fee10b
61642c8
92462f8
3fee10b
 
 
 
 
64fc971
 
 
 
 
 
e572754
3fee10b
 
e572754
3fee10b
 
 
 
 
 
 
 
 
 
 
 
9bf5d77
dc9e46f
 
99d5161
dc9e46f
 
99d5161
e572754
9bf5d77
 
 
 
8acea2f
9bf5d77
8acea2f
3fee10b
8acea2f
1226387
808b728
 
 
 
 
 
 
3fee10b
 
 
 
 
 
 
 
 
 
 
 
808b728
3fee10b
 
 
 
 
 
 
 
 
808b728
3fee10b
 
1226387
3fee10b
 
 
 
1320e9a
808b728
 
 
 
 
 
 
 
 
 
 
 
 
3fee10b
1226387
 
 
 
c5a1fa6
 
8acea2f
 
 
 
 
 
 
 
 
 
 
 
 
c5a1fa6
 
 
 
1226387
 
f1be757
 
 
 
 
 
1226387
 
f1be757
808b728
 
 
 
 
 
 
 
f1be757
 
1226387
f1be757
1320e9a
 
3fee10b
8acea2f
 
 
 
 
 
 
 
9bf5d77
 
3fee10b
 
9bf5d77
 
f523886
9bf5d77
 
 
92462f8
9bf5d77
3fee10b
 
 
9bf5d77
92462f8
 
 
 
 
 
61642c8
9bf5d77
f1be757
8acea2f
 
 
 
f1be757
 
 
 
 
 
 
1226387
9bf5d77
 
3fee10b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
from transformers import pipeline
import gradio as gr
import numpy as np
import time
import json
import os
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from dotenv import load_dotenv
import logging
import librosa

# 设置日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

load_dotenv()
zhipuai_api_key = os.getenv("ZHIPUAI_API_KEY")

if not zhipuai_api_key:
    logging.error("ZHIPUAI_API_KEY is not set in the environment variables")
    raise ValueError("ZHIPUAI_API_KEY is not set")

# 使用中文优化的模型
try:
    transcriber = pipeline("automatic-speech-recognition", model="ydshieh/wav2vec2-large-xlsr-53-chinese-zh-cn-gpt")
    logging.info("Chinese ASR model loaded successfully")
except Exception as e:
    logging.error(f"Error loading Chinese ASR model: {e}")
    transcriber = None

# 初始化对话记录
conversation = []
current_speaker = "患者"

def transcribe(audio):
    global current_speaker
    if audio is None:
        logging.info("No audio input received")
        return json.dumps({"error": "No audio input received"})
    
    try:
        logging.info(f"Audio input received: {type(audio)}")
        if isinstance(audio, tuple):
            sr, y = audio
        else:
            y = audio
            sr = 16000  # 默认采样率,如果实际不是这个值,可能需要重采样

        logging.info(f"Sample rate: {sr}, Audio data shape: {y.shape}, Audio data type: {y.dtype}")
        
        # 确保音频数据是浮点型
        if not np.issubdtype(y.dtype, np.floating):
            y = y.astype(np.float32)
        
        # 如果采样率不是16000Hz,需要进行重采样
        if sr != 16000:
            logging.info("Resampling audio to 16000Hz")
            y = librosa.resample(y, orig_sr=sr, target_sr=16000)
            sr = 16000

        # 转换为单声道
        if y.ndim > 1:
            y = y.mean(axis=1)
        
        logging.info(f"Audio data shape after conversion: {y.shape}")
        
        # 归一化音频数据
        y /= np.max(np.abs(y))

        # 使用中文进行转录
        if transcriber is not None:
            logging.info("Starting transcription")
            try:
                result = transcriber({"sampling_rate": sr, "raw": y}, generate_kwargs={"language": "chinese"})
                text = result["text"].strip()
                logging.info(f"Transcription result: {text}")
            except Exception as e:
                logging.error(f"Error during transcription: {e}", exc_info=True)
                return json.dumps({"error": f"Transcription error: {str(e)}"})
        else:
            logging.error("Transcriber not initialized")
            return json.dumps({"error": "Transcriber not initialized"})

        # 创建结构化数据
        if text:
            current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
            conversation.append({
                "时间": current_time,
                "角色": current_speaker,
                "内容": text
            })
            
            # 切换说话者
            current_speaker = "医生" if current_speaker == "患者" else "患者"

        # 将对话记录转换为格式化的字符串
        formatted_conversation = json.dumps(conversation, ensure_ascii=False, indent=2)
        logging.info(f"Formatted conversation: {formatted_conversation}")
        return formatted_conversation
    except Exception as e:
        logging.error(f"Error in transcribe function: {e}", exc_info=True)
        return json.dumps({"error": f"Error: {str(e)}"})

def switch_speaker():
    global current_speaker
    current_speaker = "医生" if current_speaker == "患者" else "患者"
    return f"当前说话者:{current_speaker}"

def generate_memo(conversation):
    try:
        logging.info(f"Generating memo for conversation: {conversation}")
        
        # 确保 conversation 是一个列表
        if not isinstance(conversation, list):
            raise ValueError("Conversation must be a list")
        
        # 提取对话内容
        dialogue = "\n".join([f"{item['角色']}: {item['内容']}" for item in conversation if '角色' in item and '内容' in item])
        
        llm = ChatOpenAI(
            model="glm-3-turbo",
            temperature=0.7, 
            openai_api_key=zhipuai_api_key,
            openai_api_base="https://open.bigmodel.cn/api/paas/v4/"
        )
        
        prompt = f"""
        请根据以下医生和患者的对话,生成一份结构化的备忘录。备忘录应包含以下字段:主诉、检查、诊断、治疗和备注。
        如果某个字段在对话中没有明确提及,请填写"未提及"。

        对话内容:
        {dialogue}

        请以JSON格式输出备忘录,格式如下:
        {{
            "主诉": "患者的主要症状和不适",
            "检查": "医生建议或已进行的检查",
            "诊断": "医生对患者的诊断",
            "治疗": "医生对患者的治疗建议",
            "备注": "医生对患者的备注"
        }}
        请确保输出是有效的JSON格式。
        """
        
        logging.info(f"Sending prompt to LLM: {prompt}")
        output = llm.invoke(prompt)
        output_parser = StrOutputParser()
        output = output_parser.invoke(output)
        logging.info(f"Generated memo: {output}")
        
        # 尝试解析输出为 JSON
        try:
            json_output = json.loads(output)
        except json.JSONDecodeError:
            # 如果无法解析为 JSON,尝试提取 JSON 部分
            import re
            json_match = re.search(r'\{.*\}', output, re.DOTALL)
            if json_match:
                json_output = json.loads(json_match.group())
            else:
                raise ValueError("Unable to extract JSON from LLM output")
        
        return json.dumps({"result": json_output}, ensure_ascii=False)
    except Exception as e:
        error_msg = f"Error in generate_memo function: {str(e)}"
        logging.error(error_msg, exc_info=True)
        return json.dumps({"error": error_msg})

def safe_generate_memo(conversation_data):
    try:
        # 如果 conversation_data 是字符串,尝试解析为 JSON
        if isinstance(conversation_data, str):
            try:
                conversation = json.loads(conversation_data)
            except json.JSONDecodeError:
                # 如果无法解析为 JSON,将其视为纯文本
                conversation = [{"内容": conversation_data}]
        elif isinstance(conversation_data, list):
            conversation = conversation_data
        else:
            conversation = [{"内容": str(conversation_data)}]

        return generate_memo(conversation)
    except Exception as e:
        logging.error(f"Error in safe_generate_memo: {e}", exc_info=True)
        return json.dumps({"error": f"生成备忘录时出错: {str(e)}"})

def display_memo(memo):
    try:
        # 如果 memo 是字符串,尝试解析为 JSON
        if isinstance(memo, str):
            memo_json = json.loads(memo)
        else:
            memo_json = memo

        if "error" in memo_json:
            return f"生成备忘录时出错: {memo_json['error']}"
        elif "result" in memo_json:
            if isinstance(memo_json["result"], str):
                try:
                    result_json = json.loads(memo_json["result"])
                    return json.dumps(result_json, ensure_ascii=False, indent=2)
                except json.JSONDecodeError:
                    return memo_json["result"]
            else:
                return json.dumps(memo_json["result"], ensure_ascii=False, indent=2)
        else:
            return json.dumps(memo_json, ensure_ascii=False, indent=2)
    except json.JSONDecodeError:
        return f"无效的 JSON 格式: {memo}"
    except Exception as e:
        return f"处理备忘录时出错: {str(e)}"

def prepare_conversation(conversation_data):
    try:
        if isinstance(conversation_data, str):
            return json.loads(conversation_data)
        return conversation_data
    except json.JSONDecodeError:
        return [{"内容": conversation_data}]

# 创建Gradio界面
with gr.Blocks() as demo:
    gr.Markdown("# 实时中文对话转录与备忘录生成")
    gr.Markdown("点击麦克风图标开始录音,说话后会自动进行语音识别。支持中文识别。")
    
    with gr.Row():
        audio_input = gr.Audio(sources=["microphone"], type="numpy")
        speaker_button = gr.Button("切换说话者")
    
    speaker_label = gr.Label("当前说话者:患者")
    debug_output = gr.Textbox(label="调试信息")
    conversation_output = gr.JSON(label="对话记录")
    memo_output = gr.JSON(label="备忘录")
    
    generate_memo_button = gr.Button("生成备忘录")
    
    def debug_audio(audio):
        if audio is None:
            return "No audio input received"
        return f"Audio type: {type(audio)}, Shape: {audio[1].shape if isinstance(audio, tuple) else audio.shape}, Dtype: {audio[1].dtype if isinstance(audio, tuple) else audio.dtype}"
    
    audio_input.change(debug_audio, inputs=[audio_input], outputs=[debug_output])
    audio_input.change(transcribe, inputs=[audio_input], outputs=[conversation_output])
    speaker_button.click(switch_speaker, outputs=[speaker_label])
    generate_memo_button.click(
        prepare_conversation,
        inputs=[conversation_output],
        outputs=[conversation_output]
    ).then(
        safe_generate_memo,
        inputs=[conversation_output],
        outputs=[memo_output]
    ).then(
        display_memo,
        inputs=[memo_output],
        outputs=[memo_output]
    )

if __name__ == "__main__":
    demo.launch()