cydxg commited on
Commit
e5c2a63
·
verified ·
1 Parent(s): 55d04e1

Upload 5 files

Browse files
Files changed (5) hide show
  1. flow_inference.py +142 -0
  2. model_server.py +116 -0
  3. quantification.py +27 -0
  4. requirements.txt +36 -0
  5. web_demo.py +258 -0
flow_inference.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import numpy as np
4
+ import re
5
+ from hyperpyyaml import load_hyperpyyaml
6
+ import uuid
7
+ from collections import defaultdict
8
+
9
+
10
+ def fade_in_out(fade_in_mel, fade_out_mel, window):
11
+ device = fade_in_mel.device
12
+ fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
13
+ mel_overlap_len = int(window.shape[0] / 2)
14
+ fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
15
+ fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
16
+ return fade_in_mel.to(device)
17
+
18
+
19
+ class AudioDecoder:
20
+ def __init__(self, config_path, flow_ckpt_path, hift_ckpt_path, device="cuda"):
21
+ self.device = device
22
+
23
+ with open(config_path, 'r') as f:
24
+ self.scratch_configs = load_hyperpyyaml(f)
25
+
26
+ # Load models
27
+ self.flow = self.scratch_configs['flow']
28
+ self.flow.load_state_dict(torch.load(flow_ckpt_path, map_location=self.device))
29
+ self.hift = self.scratch_configs['hift']
30
+ self.hift.load_state_dict(torch.load(hift_ckpt_path, map_location=self.device))
31
+
32
+ # Move models to the appropriate device
33
+ self.flow.to(self.device)
34
+ self.hift.to(self.device)
35
+ self.mel_overlap_dict = defaultdict(lambda: None)
36
+ self.hift_cache_dict = defaultdict(lambda: None)
37
+ self.token_min_hop_len = 2 * self.flow.input_frame_rate
38
+ self.token_max_hop_len = 4 * self.flow.input_frame_rate
39
+ self.token_overlap_len = 5
40
+ self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
41
+ self.mel_window = np.hamming(2 * self.mel_overlap_len)
42
+ # hift cache
43
+ self.mel_cache_len = 1
44
+ self.source_cache_len = int(self.mel_cache_len * 256)
45
+ # speech fade in out
46
+ self.speech_window = np.hamming(2 * self.source_cache_len)
47
+
48
+ def token2wav(self, token, uuid, prompt_token=torch.zeros(1, 0, dtype=torch.int32),
49
+ prompt_feat=torch.zeros(1, 0, 80), embedding=torch.zeros(1, 192), finalize=False):
50
+ tts_mel = self.flow.inference(token=token.to(self.device),
51
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
52
+ prompt_token=prompt_token.to(self.device),
53
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(
54
+ self.device),
55
+ prompt_feat=prompt_feat.to(self.device),
56
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(
57
+ self.device),
58
+ embedding=embedding.to(self.device))
59
+
60
+ # mel overlap fade in out
61
+ if self.mel_overlap_dict[uuid] is not None:
62
+ tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
63
+ # append hift cache
64
+ if self.hift_cache_dict[uuid] is not None:
65
+ hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
66
+ tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
67
+
68
+ else:
69
+ hift_cache_source = torch.zeros(1, 1, 0)
70
+ # _tts_mel=tts_mel.contiguous()
71
+ # keep overlap mel and hift cache
72
+ if finalize is False:
73
+ self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
74
+ tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
75
+ tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
76
+
77
+ self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
78
+ 'source': tts_source[:, :, -self.source_cache_len:],
79
+ 'speech': tts_speech[:, -self.source_cache_len:]}
80
+ # if self.hift_cache_dict[uuid] is not None:
81
+ # tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
82
+ tts_speech = tts_speech[:, :-self.source_cache_len]
83
+
84
+ else:
85
+ tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
86
+ del self.hift_cache_dict[uuid]
87
+ del self.mel_overlap_dict[uuid]
88
+ # if uuid in self.hift_cache_dict.keys() and self.hift_cache_dict[uuid] is not None:
89
+ # tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
90
+ return tts_speech, tts_mel
91
+
92
+ def offline_inference(self, token):
93
+ this_uuid = str(uuid.uuid1())
94
+ tts_speech, tts_mel = self.token2wav(token, uuid=this_uuid, finalize=True)
95
+ return tts_speech.cpu()
96
+
97
+ def stream_inference(self, token):
98
+ token.to(self.device)
99
+ this_uuid = str(uuid.uuid1())
100
+
101
+ # Prepare other necessary input tensors
102
+ llm_embedding = torch.zeros(1, 192).to(self.device)
103
+ prompt_speech_feat = torch.zeros(1, 0, 80).to(self.device)
104
+ flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32).to(self.device)
105
+
106
+ tts_speechs = []
107
+ tts_mels = []
108
+
109
+ block_size = self.flow.encoder.block_size
110
+ prev_mel = None
111
+
112
+ for idx in range(0, token.size(1), block_size):
113
+ # if idx>block_size: break
114
+ tts_token = token[:, idx:idx + block_size]
115
+
116
+ print(tts_token.size())
117
+
118
+ if prev_mel is not None:
119
+ prompt_speech_feat = torch.cat(tts_mels, dim=-1).transpose(1, 2)
120
+ flow_prompt_speech_token = token[:, :idx]
121
+
122
+ if idx + block_size >= token.size(-1):
123
+ is_finalize = True
124
+ else:
125
+ is_finalize = False
126
+
127
+ tts_speech, tts_mel = self.token2wav(tts_token, uuid=this_uuid,
128
+ prompt_token=flow_prompt_speech_token.to(self.device),
129
+ prompt_feat=prompt_speech_feat.to(self.device), finalize=is_finalize)
130
+
131
+ prev_mel = tts_mel
132
+ prev_speech = tts_speech
133
+ print(tts_mel.size())
134
+
135
+ tts_speechs.append(tts_speech)
136
+ tts_mels.append(tts_mel)
137
+
138
+ # Convert Mel spectrogram to audio using HiFi-GAN
139
+ tts_speech = torch.cat(tts_speechs, dim=-1).cpu()
140
+
141
+ return tts_speech.cpu()
142
+
model_server.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A model worker executes the model.
3
+ """
4
+ import argparse
5
+ import json
6
+ import uuid
7
+
8
+ from fastapi import FastAPI, Request
9
+ from fastapi.responses import StreamingResponse
10
+ from transformers import AutoModel, AutoTokenizer
11
+ import torch
12
+ import uvicorn
13
+
14
+ from transformers.generation.streamers import BaseStreamer
15
+ from threading import Thread
16
+ from queue import Queue
17
+
18
+
19
+ class TokenStreamer(BaseStreamer):
20
+ def __init__(self, skip_prompt: bool = False, timeout=None):
21
+ self.skip_prompt = skip_prompt
22
+
23
+ # variables used in the streaming process
24
+ self.token_queue = Queue()
25
+ self.stop_signal = None
26
+ self.next_tokens_are_prompt = True
27
+ self.timeout = timeout
28
+
29
+ def put(self, value):
30
+ if len(value.shape) > 1 and value.shape[0] > 1:
31
+ raise ValueError("TextStreamer only supports batch size 1")
32
+ elif len(value.shape) > 1:
33
+ value = value[0]
34
+
35
+ if self.skip_prompt and self.next_tokens_are_prompt:
36
+ self.next_tokens_are_prompt = False
37
+ return
38
+
39
+ for token in value.tolist():
40
+ self.token_queue.put(token)
41
+
42
+ def end(self):
43
+ self.token_queue.put(self.stop_signal)
44
+
45
+ def __iter__(self):
46
+ return self
47
+
48
+ def __next__(self):
49
+ value = self.token_queue.get(timeout=self.timeout)
50
+ if value == self.stop_signal:
51
+ raise StopIteration()
52
+ else:
53
+ return value
54
+
55
+
56
+ class ModelWorker:
57
+ def __init__(self, model_path, device='cuda'):
58
+ self.device = device
59
+ self.glm_model = AutoModel.from_pretrained(model_path, trust_remote_code=True,
60
+ device_map=device,low_cpu_mem_usage=True,load_in_4bit=True).eval()
61
+ self.glm_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
62
+
63
+ @torch.inference_mode()
64
+ def generate_stream(self, params):
65
+ tokenizer, model = self.glm_tokenizer, self.glm_model
66
+
67
+ prompt = params["prompt"]
68
+
69
+ temperature = float(params.get("temperature", 1.0))
70
+ top_p = float(params.get("top_p", 1.0))
71
+ max_new_tokens = int(params.get("max_new_tokens", 256))
72
+
73
+ inputs = tokenizer([prompt], return_tensors="pt")
74
+ inputs = inputs.to(self.device)
75
+ streamer = TokenStreamer(skip_prompt=True)
76
+ thread = Thread(target=model.generate,
77
+ kwargs=dict(**inputs, max_new_tokens=int(max_new_tokens),
78
+ temperature=float(temperature), top_p=float(top_p),
79
+ streamer=streamer))
80
+ thread.start()
81
+ for token_id in streamer:
82
+ yield (json.dumps({"token_id": token_id, "error_code": 0}) + "\n").encode()
83
+
84
+ def generate_stream_gate(self, params):
85
+ try:
86
+ for x in self.generate_stream(params):
87
+ yield x
88
+ except Exception as e:
89
+ print("Caught Unknown Error", e)
90
+ ret = {
91
+ "text": "Server Error",
92
+ "error_code": 1,
93
+ }
94
+ yield (json.dumps(ret)+ "\n").encode()
95
+
96
+
97
+ app = FastAPI()
98
+
99
+
100
+ @app.post("/generate_stream")
101
+ async def generate_stream(request: Request):
102
+ params = await request.json()
103
+
104
+ generator = worker.generate_stream_gate(params)
105
+ return StreamingResponse(generator)
106
+
107
+
108
+ if __name__ == "__main__":
109
+ parser = argparse.ArgumentParser()
110
+ parser.add_argument("--host", type=str, default="localhost")
111
+ parser.add_argument("--port", type=int, default=10000)
112
+ parser.add_argument("--model-path", type=str, default="glm-4-voice-9b-int4")
113
+ args = parser.parse_args()
114
+
115
+ worker = ModelWorker(args.model_path)
116
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
quantification.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+
4
+ device = "cuda:0"
5
+
6
+ tokenizer = AutoTokenizer.from_pretrained("glm-4-voice-9b", trust_remote_code=True)
7
+
8
+ tokenizer.chat_template = "{{role}}: {{content}}"
9
+
10
+ query = "你好"
11
+
12
+ inputs = tokenizer.apply_chat_template([{"role": "user", "content": query}],
13
+ add_generation_prompt=True,
14
+ tokenize=True,
15
+ return_tensors="pt",
16
+ return_dict=True
17
+ )
18
+
19
+ inputs = inputs.to(device)
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ "glm-4-voice-9b",
22
+ low_cpu_mem_usage=True,
23
+ trust_remote_code=True,
24
+ load_in_4bit=True
25
+ ).eval()
26
+ model.save_pretrained("glm-4-voice-9b-int4")
27
+ tokenizer.save_pretrained("glm-4-voice-9b-int4")
requirements.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ conformer==0.3.2
2
+ deepspeed==0.14.2; sys_platform == 'linux'
3
+ diffusers==0.27.2
4
+ fastapi==0.115.3
5
+ fastapi-cli==0.0.4
6
+ gdown==5.1.0
7
+ gradio==5.3.0
8
+ grpcio==1.57.0
9
+ grpcio-tools==1.57.0
10
+ huggingface_hub==0.25.2
11
+ hydra-core==1.3.2
12
+ HyperPyYAML==1.2.2
13
+ inflect==7.3.1
14
+ librosa==0.10.2
15
+ lightning==2.2.4
16
+ matplotlib==3.7.5
17
+ modelscope==1.15.0
18
+ networkx==3.1
19
+ numpy==1.24.4
20
+ omegaconf==2.3.0
21
+ onnxruntime-gpu==1.16.0; sys_platform == 'linux'
22
+ onnxruntime==1.16.0; sys_platform == 'darwin' or sys_platform == 'windows'
23
+ openai-whisper==20231117
24
+ protobuf==4.25
25
+ pydantic==2.7.0
26
+ rich==13.7.1
27
+ Requests==2.32.3
28
+ safetensors==0.4.5
29
+ soundfile==0.12.1
30
+ tensorboard==2.14.0
31
+ transformers==4.44.1
32
+ uvicorn==0.32.0
33
+ wget==3.2
34
+ WeTextProcessing==1.0.3
35
+ torch==2.3.0
36
+ torchaudio==2.3.0
web_demo.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os.path
3
+ import tempfile
4
+ import sys
5
+ import re
6
+ import uuid
7
+ import requests
8
+ from argparse import ArgumentParser
9
+
10
+ import torchaudio
11
+ from transformers import WhisperFeatureExtractor, AutoTokenizer, AutoModel
12
+ from speech_tokenizer.modeling_whisper import WhisperVQEncoder
13
+
14
+
15
+ sys.path.insert(0, "./cosyvoice")
16
+ sys.path.insert(0, "./third_party/Matcha-TTS")
17
+
18
+ from speech_tokenizer.utils import extract_speech_token
19
+
20
+ import gradio as gr
21
+ import torch
22
+
23
+ audio_token_pattern = re.compile(r"<\|audio_(\d+)\|>")
24
+
25
+ from flow_inference import AudioDecoder
26
+
27
+ if __name__ == "__main__":
28
+ parser = ArgumentParser()
29
+ parser.add_argument("--host", type=str, default="localhost")
30
+ parser.add_argument("--port", type=int, default="8888")
31
+ parser.add_argument("--flow-path", type=str, default="./glm-4-voice-decoder")
32
+ parser.add_argument("--model-path", type=str, default="./glm-4-voice-9b-int4")
33
+ parser.add_argument("--tokenizer-path", type=str, default="./glm-4-voice-tokenizer")
34
+ args = parser.parse_args()
35
+
36
+ flow_config = os.path.join(args.flow_path, "config.yaml")
37
+ flow_checkpoint = os.path.join(args.flow_path, 'flow.pt')
38
+ hift_checkpoint = os.path.join(args.flow_path, 'hift.pt')
39
+ glm_tokenizer = None
40
+ device = "cuda"
41
+ audio_decoder: AudioDecoder = None
42
+ whisper_model, feature_extractor = None, None
43
+
44
+
45
+ def initialize_fn():
46
+ global audio_decoder, feature_extractor, whisper_model, glm_model, glm_tokenizer
47
+ if audio_decoder is not None:
48
+ return
49
+
50
+ # GLM
51
+ glm_tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
52
+
53
+ # Flow & Hift
54
+ audio_decoder = AudioDecoder(config_path=flow_config, flow_ckpt_path=flow_checkpoint,
55
+ hift_ckpt_path=hift_checkpoint,
56
+ device=device)
57
+
58
+ # Speech tokenizer
59
+ whisper_model = WhisperVQEncoder.from_pretrained(args.tokenizer_path).eval().to(device)
60
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(args.tokenizer_path)
61
+
62
+
63
+ def clear_fn():
64
+ return [], [], '', '', '', None, None
65
+
66
+
67
+ def inference_fn(
68
+ temperature: float,
69
+ top_p: float,
70
+ max_new_token: int,
71
+ input_mode,
72
+ audio_path: str | None,
73
+ input_text: str | None,
74
+ history: list[dict],
75
+ previous_input_tokens: str,
76
+ previous_completion_tokens: str,
77
+ ):
78
+
79
+ if input_mode == "audio":
80
+ assert audio_path is not None
81
+ history.append({"role": "user", "content": {"path": audio_path}})
82
+ audio_tokens = extract_speech_token(
83
+ whisper_model, feature_extractor, [audio_path]
84
+ )[0]
85
+ if len(audio_tokens) == 0:
86
+ raise gr.Error("No audio tokens extracted")
87
+ audio_tokens = "".join([f"<|audio_{x}|>" for x in audio_tokens])
88
+ audio_tokens = "<|begin_of_audio|>" + audio_tokens + "<|end_of_audio|>"
89
+ user_input = audio_tokens
90
+ system_prompt = "User will provide you with a speech instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens. "
91
+
92
+ else:
93
+ assert input_text is not None
94
+ history.append({"role": "user", "content": input_text})
95
+ user_input = input_text
96
+ system_prompt = "User will provide you with a text instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens."
97
+
98
+
99
+ # Gather history
100
+ inputs = previous_input_tokens + previous_completion_tokens
101
+ inputs = inputs.strip()
102
+ if "<|system|>" not in inputs:
103
+ inputs += f"<|system|>\n{system_prompt}"
104
+ inputs += f"<|user|>\n{user_input}<|assistant|>streaming_transcription\n"
105
+
106
+ with torch.no_grad():
107
+ response = requests.post(
108
+ "http://localhost:10000/generate_stream",
109
+ data=json.dumps({
110
+ "prompt": inputs,
111
+ "temperature": temperature,
112
+ "top_p": top_p,
113
+ "max_new_tokens": max_new_token,
114
+ }),
115
+ stream=True
116
+ )
117
+ text_tokens, audio_tokens = [], []
118
+ audio_offset = glm_tokenizer.convert_tokens_to_ids('<|audio_0|>')
119
+ end_token_id = glm_tokenizer.convert_tokens_to_ids('<|user|>')
120
+ complete_tokens = []
121
+ prompt_speech_feat = torch.zeros(1, 0, 80).to(device)
122
+ flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int64).to(device)
123
+ this_uuid = str(uuid.uuid4())
124
+ tts_speechs = []
125
+ tts_mels = []
126
+ prev_mel = None
127
+ is_finalize = False
128
+ block_size = 10
129
+ for chunk in response.iter_lines():
130
+ token_id = json.loads(chunk)["token_id"]
131
+ if token_id == end_token_id:
132
+ is_finalize = True
133
+ if len(audio_tokens) >= block_size or (is_finalize and audio_tokens):
134
+ block_size = 20
135
+ tts_token = torch.tensor(audio_tokens, device=device).unsqueeze(0)
136
+
137
+ if prev_mel is not None:
138
+ prompt_speech_feat = torch.cat(tts_mels, dim=-1).transpose(1, 2)
139
+
140
+ tts_speech, tts_mel = audio_decoder.token2wav(tts_token, uuid=this_uuid,
141
+ prompt_token=flow_prompt_speech_token.to(device),
142
+ prompt_feat=prompt_speech_feat.to(device),
143
+ finalize=is_finalize)
144
+ prev_mel = tts_mel
145
+
146
+ tts_speechs.append(tts_speech.squeeze())
147
+ tts_mels.append(tts_mel)
148
+ yield history, inputs, '', '', (22050, tts_speech.squeeze().cpu().numpy()), None
149
+ flow_prompt_speech_token = torch.cat((flow_prompt_speech_token, tts_token), dim=-1)
150
+ audio_tokens = []
151
+ if not is_finalize:
152
+ complete_tokens.append(token_id)
153
+ if token_id >= audio_offset:
154
+ audio_tokens.append(token_id - audio_offset)
155
+ else:
156
+ text_tokens.append(token_id)
157
+ tts_speech = torch.cat(tts_speechs, dim=-1).cpu()
158
+ complete_text = glm_tokenizer.decode(complete_tokens, spaces_between_special_tokens=False)
159
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
160
+ torchaudio.save(f, tts_speech.unsqueeze(0), 22050, format="wav")
161
+ history.append({"role": "assistant", "content": {"path": f.name, "type": "audio/wav"}})
162
+ history.append({"role": "assistant", "content": glm_tokenizer.decode(text_tokens, ignore_special_tokens=False)})
163
+ yield history, inputs, complete_text, '', None, (22050, tts_speech.numpy())
164
+
165
+
166
+ def update_input_interface(input_mode):
167
+ if input_mode == "audio":
168
+ return [gr.update(visible=True), gr.update(visible=False)]
169
+ else:
170
+ return [gr.update(visible=False), gr.update(visible=True)]
171
+
172
+
173
+ # Create the Gradio interface
174
+ with gr.Blocks(title="GLM-4-Voice Demo", fill_height=True) as demo:
175
+ with gr.Row():
176
+ temperature = gr.Number(
177
+ label="Temperature",
178
+ value=0.2
179
+ )
180
+
181
+ top_p = gr.Number(
182
+ label="Top p",
183
+ value=0.8
184
+ )
185
+
186
+ max_new_token = gr.Number(
187
+ label="Max new tokens",
188
+ value=2000,
189
+ )
190
+
191
+ chatbot = gr.Chatbot(
192
+ elem_id="chatbot",
193
+ bubble_full_width=False,
194
+ type="messages",
195
+ scale=1,
196
+ )
197
+
198
+ with gr.Row():
199
+ with gr.Column():
200
+ input_mode = gr.Radio(["audio", "text"], label="Input Mode", value="audio")
201
+ audio = gr.Audio(label="Input audio", type='filepath', show_download_button=True, visible=True)
202
+ text_input = gr.Textbox(label="Input text", placeholder="Enter your text here...", lines=2, visible=False)
203
+
204
+ with gr.Column():
205
+ submit_btn = gr.Button("Submit")
206
+ reset_btn = gr.Button("Clear")
207
+ output_audio = gr.Audio(label="Play", streaming=True,
208
+ autoplay=True, show_download_button=False)
209
+ complete_audio = gr.Audio(label="Last Output Audio (If Any)", show_download_button=True)
210
+
211
+
212
+
213
+ gr.Markdown("""## Debug Info""")
214
+ with gr.Row():
215
+ input_tokens = gr.Textbox(
216
+ label=f"Input Tokens",
217
+ interactive=False,
218
+ )
219
+
220
+ completion_tokens = gr.Textbox(
221
+ label=f"Completion Tokens",
222
+ interactive=False,
223
+ )
224
+
225
+ detailed_error = gr.Textbox(
226
+ label=f"Detailed Error",
227
+ interactive=False,
228
+ )
229
+
230
+ history_state = gr.State([])
231
+
232
+ respond = submit_btn.click(
233
+ inference_fn,
234
+ inputs=[
235
+ temperature,
236
+ top_p,
237
+ max_new_token,
238
+ input_mode,
239
+ audio,
240
+ text_input,
241
+ history_state,
242
+ input_tokens,
243
+ completion_tokens,
244
+ ],
245
+ outputs=[history_state, input_tokens, completion_tokens, detailed_error, output_audio, complete_audio]
246
+ )
247
+
248
+ respond.then(lambda s: s, [history_state], chatbot)
249
+
250
+ reset_btn.click(clear_fn, outputs=[chatbot, history_state, input_tokens, completion_tokens, detailed_error, output_audio, complete_audio])
251
+ input_mode.input(clear_fn, outputs=[chatbot, history_state, input_tokens, completion_tokens, detailed_error, output_audio, complete_audio]).then(update_input_interface, inputs=[input_mode], outputs=[audio, text_input])
252
+
253
+ initialize_fn()
254
+ # Launch the interface
255
+ demo.launch(
256
+ server_port=args.port,
257
+ server_name=args.host
258
+ )