diff --git "a/modeling_minicpmo.py" "b/modeling_minicpmo.py"
new file mode 100644--- /dev/null
+++ "b/modeling_minicpmo.py"
@@ -0,0 +1,3259 @@
+# coding=utf-8
+# Copyright 2025 The OpenBMB Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import logging
+import math
+import os
+import types
+from collections.abc import Iterator
+from copy import deepcopy
+from dataclasses import dataclass
+from threading import Thread
+from typing import List
+from typing import Literal
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import numpy as np
+import soundfile as sf
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.nn.utils.parametrize as P
+from huggingface_hub import hf_hub_download
+from PIL import Image
+from torch.nn.utils.parametrizations import weight_norm
+from tqdm import tqdm
+from transformers import AutoProcessor
+from transformers import BertTokenizerFast
+from transformers import LlamaConfig
+from transformers import LlamaModel
+from transformers import LogitsWarper
+from transformers import PreTrainedModel
+from transformers import Qwen2ForCausalLM
+from transformers import Qwen2PreTrainedModel
+from transformers import TextIteratorStreamer
+from transformers import TopKLogitsWarper
+from transformers import TopPLogitsWarper
+from transformers.cache_utils import Cache
+from transformers.cache_utils import DynamicCache
+from transformers.cache_utils import EncoderDecoderCache
+from transformers.cache_utils import StaticCache
+from transformers.modeling_outputs import BaseModelOutputWithPast
+from transformers.modeling_outputs import ModelOutput
+from transformers.models.whisper.modeling_whisper import ACT2FN
+from transformers.models.whisper.modeling_whisper import WHISPER_ATTENTION_CLASSES
+from transformers.models.whisper.modeling_whisper import WhisperConfig
+from transformers.models.whisper.modeling_whisper import WhisperEncoder
+
+try:
+    from vector_quantize_pytorch import GroupedResidualFSQ
+    from vocos import Vocos
+    from vocos.pretrained import instantiate_class
+
+    _tts_deps = True
+except:
+    _tts_deps = False
+
+from .configuration_minicpm import ConditionalChatTTSConfig
+from .configuration_minicpm import MiniCPMOConfig
+from .modeling_navit_siglip import SiglipVisionTransformer
+from .resampler import Resampler
+from .utils import NumberToTextConverter
+from .utils import sentence_end
+from .utils import VoiceChecker
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class OmniOutput(ModelOutput):
+    text: Optional[Union[str, List[str], Iterator]] = None
+    spk_embeds: Optional[torch.FloatTensor] = None
+    audio_wav: Optional[np.ndarray] = None
+    sampling_rate: Optional[int] = None
+
+
+class MiniCPMOPreTrainedModel(Qwen2PreTrainedModel):
+    config_class = MiniCPMOConfig
+
+
+class MiniCPMO(MiniCPMOPreTrainedModel):
+    def __init__(self, config):
+        super().__init__(config)
+        self.llm = Qwen2ForCausalLM(config)
+        self.llm.prepare_inputs_for_generation = types.MethodType(prepare_inputs_for_generation, self.llm)  # patch llm
+
+        self.embed_dim = self.llm.config.hidden_size
+
+        # init vision module
+        if self.config.init_vision:
+            self.vpm = self.init_vision_module()
+            self.vision_dim = self.vpm.embed_dim
+            self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
+
+        # init audio module
+        if self.config.init_audio:
+            self.apm = self.init_audio_module()
+            audio_output_dim = int(self.apm.config.encoder_ffn_dim // 4)
+            self.audio_avg_pooler = nn.AvgPool1d(self.config.audio_pool_step, stride=self.config.audio_pool_step)
+            self.audio_projection_layer = MultiModalProjector(in_dim=audio_output_dim, out_dim=self.embed_dim)
+            self.audio_encoder_layer = -1
+
+        # init tts module
+        if self.config.init_tts:
+            assert _tts_deps, "please make sure vector_quantize_pytorch and vocos are installed."
+            self.tts = self.init_tts_module()
+
+        self.processor = AutoProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
+
+        self.terminators = ["<|im_end|>", "<|endoftext|>"]
+
+        self.default_tts_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}"
+        self.force_no_stop = False
+
+        # for stream api
+        self.reset_session()
+
+    def reset_session(self):
+        self.session_id = None
+        self.new_user_msg = True
+        self.llm_generated = False
+        self.llm_generate_completed = False
+        self.llm_past_key_values = None
+        self.audio_past_key_values = None  # apm kv cache
+
+    def init_tts(
+        self,
+        tts_text_tokenizer_path=None,
+        vocos_ckpt_path=None,
+    ):
+        """
+        load tts tokenizer and vocos
+        1. try load form local 2. try load from huggingface
+        """
+        from .processing_minicpmo import ChatTTSProcessor
+
+        if tts_text_tokenizer_path is None:
+            tts_text_tokenizer_path = os.path.join(self.config._name_or_path, "assets/chattts_tokenizer")
+        if not os.path.exists(tts_text_tokenizer_path):
+            # try from hf model_id
+            tts_text_tokenizer_path = "openbmb/chattts_tokenizer"
+
+        tts_text_tokenizer = BertTokenizerFast.from_pretrained(tts_text_tokenizer_path)
+        self.tts_processor = ChatTTSProcessor(text_tokenizer=tts_text_tokenizer)
+
+        if vocos_ckpt_path is None:
+            vocos_ckpt_path = os.path.join(self.config._name_or_path, "assets/Vocos.pt")
+        if not os.path.exists(vocos_ckpt_path):
+            vocos_ckpt_path = hf_hub_download(repo_id="openbmb/MiniCPM-o-2_6", subfolder="assets", filename="Vocos.pt")
+
+        assert os.path.exists(vocos_ckpt_path)
+        self.vocos = self.initialize_vocos(vocos_ckpt_path)
+
+    def initialize_vocos(self, ckpt_path):
+        feature_extractor = instantiate_class(
+            args=(),
+            init={
+                "class_path": "vocos.feature_extractors.MelSpectrogramFeatures",
+                "init_args": {"sample_rate": 24000, "n_fft": 1024, "hop_length": 256, "n_mels": 100},
+            },
+        )
+        backbone = instantiate_class(
+            args=(),
+            init={
+                "class_path": "vocos.models.VocosBackbone",
+                "init_args": {"input_channels": 100, "dim": 512, "intermediate_dim": 1536, "num_layers": 8},
+            },
+        )
+        head = instantiate_class(
+            args=(),
+            init={"class_path": "vocos.heads.ISTFTHead", "init_args": {"dim": 512, "n_fft": 1024, "hop_length": 256}},
+        )
+        vocos = Vocos(feature_extractor, backbone, head).to("cuda").eval().to(torch.float32)
+        vocos.load_state_dict(torch.load(ckpt_path, weights_only=True, mmap=True))
+        return vocos
+
+    def init_vision_module(self):
+        if self.config._attn_implementation == "flash_attention_2":
+            self.config.vision_config._attn_implementation = "flash_attention_2"
+        else:
+            self.config.vision_config._attn_implementation = "eager"
+        model = SiglipVisionTransformer(self.config.vision_config)
+        if self.config.drop_vision_last_layer:
+            model.encoder.layers = model.encoder.layers[:-1]
+
+        setattr(model, "embed_dim", model.embeddings.embed_dim)
+        setattr(model, "patch_size", model.embeddings.patch_size)
+
+        return model
+
+    def init_resampler(self, embed_dim, vision_dim):
+        return Resampler(
+            num_queries=self.config.query_num,
+            embed_dim=embed_dim,
+            num_heads=embed_dim // 128,
+            kv_dim=vision_dim,
+            adaptive=True,
+        )
+
+    def init_audio_module(self):
+        model = MiniCPMWhisperEncoder(self.config.audio_config)
+        return model
+
+    def init_tts_module(self):
+        model = ConditionalChatTTS(self.config.tts_config)
+        return model
+
+    def get_input_embeddings(self):
+        return self.llm.get_input_embeddings()
+
+    def set_input_embeddings(self, value):
+        self.llm.embed_tokens = value
+
+    def get_output_embeddings(self):
+        return self.llm.lm_head
+
+    def set_output_embeddings(self, new_embeddings):
+        self.llm.lm_head = new_embeddings
+
+    def set_decoder(self, decoder):
+        self.llm = decoder
+
+    def get_decoder(self):
+        return self.llm
+
+    def subsequent_chunk_mask(
+        self,
+        size: int,
+        chunk_size: int,
+        num_left_chunks: int = -1,
+        device: torch.device = torch.device("cpu"),
+        num_lookhead: int = 0,
+    ) -> torch.Tensor:
+        """Create mask for subsequent steps (size, size) with chunk size,
+        this is for streaming encoder
+
+        Args:
+            size (int): size of mask
+            chunk_size (int): size of chunk
+            num_left_chunks (int): number of left chunks
+                <0: use full chunk
+                >=0: use num_left_chunks
+            device (torch.device): "cpu" or "cuda" or torch.Tensor.device
+
+        Returns:
+            torch.Tensor: mask
+
+        Examples:
+            >>> subsequent_chunk_mask(4, 2)
+            [[1, 1, 0, 0],
+            [1, 1, 0, 0],
+            [1, 1, 1, 1],
+            [1, 1, 1, 1]]
+        """
+        ret = torch.zeros(size, size, device=device, dtype=torch.bool)
+        for i in range(size):
+            if num_left_chunks < 0:
+                start = 0
+            else:
+                start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
+            ending = min((i // chunk_size + 1) * chunk_size + num_lookhead, size)
+            ret[i, start:ending] = True
+        return ret
+
+    def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
+        """
+        Computes the output length of the convolutional layers and the output length of the audio encoder
+        """
+        input_lengths_after_cnn = (input_lengths - 1) // 2 + 1
+        input_lengths_after_pooling = (
+            input_lengths_after_cnn - self.config.audio_pool_step
+        ) // self.config.audio_pool_step + 1
+        input_lengths_after_pooling = input_lengths_after_pooling.to(dtype=torch.int32)
+
+        return input_lengths_after_cnn, input_lengths_after_pooling
+
+    def get_vllm_embedding(self, data):
+        """
+        Compute all visual embeddings, and set into llm embeddings.
+        Args:
+            data: Dict
+                tgt_sizes: image size after patch embedding
+                pixel_values: image features
+                image_bound: position of each picture corresponding to input_ids
+                input_ids: full input_ids, include placeholder
+        Returns:
+                embedding with vision, vision_hidden_states
+        """
+        if "vision_hidden_states" not in data:
+            dtype = self.llm.model.embed_tokens.weight.dtype
+            device = self.llm.model.embed_tokens.weight.device
+            tgt_sizes = data["tgt_sizes"]
+            pixel_values_list = data["pixel_values"]
+            vision_hidden_states = []
+            all_pixel_values = []
+            img_cnt = []
+            for pixel_values in pixel_values_list:
+                img_cnt.append(len(pixel_values))
+                all_pixel_values.extend([i.flatten(end_dim=1).permute(1, 0) for i in pixel_values])
+
+            # exist image
+            if all_pixel_values:
+                tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)]
+                tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
+
+                max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
+
+                all_pixel_values = torch.nn.utils.rnn.pad_sequence(
+                    all_pixel_values, batch_first=True, padding_value=0.0
+                )
+                B, L, _ = all_pixel_values.shape
+                all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
+
+                patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device)
+                for i in range(B):
+                    patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True
+
+                vision_batch_size = self.config.vision_batch_size
+                all_pixel_values = all_pixel_values.type(dtype)
+                if B > vision_batch_size:
+                    hs = []
+                    for i in range(0, B, vision_batch_size):
+                        start_idx = i
+                        end_idx = i + vision_batch_size
+                        tmp_hs = self.vpm(
+                            all_pixel_values[start_idx:end_idx],
+                            patch_attention_mask=patch_attn_mask[start_idx:end_idx],
+                            tgt_sizes=tgt_sizes[start_idx:end_idx],
+                        ).last_hidden_state
+                        hs.append(tmp_hs)
+                    vision_embedding = torch.cat(hs, dim=0)
+                else:
+                    vision_embedding = self.vpm(
+                        all_pixel_values, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes
+                    ).last_hidden_state
+                vision_embedding = self.resampler(vision_embedding, tgt_sizes)
+
+                start = 0
+                for pixel_values in pixel_values_list:
+                    img_cnt = len(pixel_values)
+                    if img_cnt > 0:
+                        vision_hidden_states.append(vision_embedding[start : start + img_cnt])
+                        start += img_cnt
+                    else:
+                        vision_hidden_states.append([])
+            else:  # no image
+                if self.training:
+                    dummy_image = torch.zeros((1, 3, 224, 224), device=device, dtype=dtype)
+                    tgt_sizes = torch.Tensor(
+                        [[(224 // self.config.patch_size), math.ceil(224 / self.config.patch_size)]]
+                    ).type(torch.int32)
+                    dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes)
+                else:
+                    dummy_feature = []
+                for _ in range(len(pixel_values_list)):
+                    vision_hidden_states.append(dummy_feature)
+
+        else:
+            vision_hidden_states = data["vision_hidden_states"]
+
+        if hasattr(self.llm.config, "scale_emb"):
+            vllm_embedding = self.llm.model.embed_tokens(data["input_ids"]) * self.llm.config.scale_emb
+        else:
+            vllm_embedding = self.llm.model.embed_tokens(data["input_ids"])
+
+        vision_hidden_states = [
+            i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
+        ]
+
+        bs = len(data["input_ids"])
+        for i in range(bs):
+            cur_vs_hs = vision_hidden_states[i]
+            if len(cur_vs_hs) > 0:
+                cur_vllm_emb = vllm_embedding[i]
+                cur_image_bound = data["image_bound"][i]
+                if len(cur_image_bound) > 0:
+                    image_indices = torch.stack(
+                        [torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
+                    ).to(vllm_embedding.device)
+
+                    cur_vllm_emb.scatter_(
+                        0,
+                        image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
+                        cur_vs_hs.view(-1, cur_vs_hs.shape[-1]),
+                    )
+                elif self.training:
+                    cur_vllm_emb += cur_vs_hs[0].mean() * 0
+
+        return vllm_embedding, vision_hidden_states
+
+    def get_audio_embedding_streaming(self, data):
+        r"""
+        Extract audio embeddings in a streaming manner using cached key-value pairs.
+
+        This method processes incoming audio features incrementally and stores/updates `past_key_values`
+        for faster inference on subsequent audio frames. It only supports batch_size=1 and is intended
+        for streaming scenarios.
+
+        Args:
+            data (dict):
+                - **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`.
+                - **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch.
+
+        Returns:
+            List[List[torch.Tensor]]: audio embeddings
+        """
+        wavforms = data.get("audio_features", [])  # (bs, 80, frames) or [], multi audios need filled in advance
+        audio_feature_lens_raw = data.get("audio_feature_lens", [])  # list, [[x1, x2], [y1], [z1]]
+
+        # exist audio
+        if len(wavforms) > 0:
+            audio_feature_lens = torch.hstack(audio_feature_lens_raw)
+            batch_size, _, max_mel_seq_len = wavforms.shape
+            assert batch_size == 1
+            max_seq_len = (max_mel_seq_len - 1) // 2 + 1
+
+            if self.audio_past_key_values is not None:
+                cache_length = self.audio_past_key_values[0][0].shape[2]
+                apm_max_len = self.apm.embed_positions.weight.shape[0]
+                if cache_length + max_seq_len >= apm_max_len:
+                    logger.warning(
+                        f"audio_past_key_values length {cache_length + max_seq_len} exceed {apm_max_len}, reset."
+                    )
+                    self.audio_past_key_values = None
+
+            audio_outputs = self.apm(wavforms, past_key_values=self.audio_past_key_values, use_cache=True)
+            audio_states = audio_outputs.last_hidden_state  # [:, :audio_feat_lengths, :]
+            self.audio_past_key_values = audio_outputs.past_key_values
+
+            audio_embeds = self.audio_projection_layer(audio_states)
+
+            audio_embeds = audio_embeds.transpose(1, 2)
+            audio_embeds = self.audio_avg_pooler(audio_embeds)
+            audio_embeds = audio_embeds.transpose(1, 2)
+
+            _, feature_lens_after_pooling = self._get_feat_extract_output_lengths(audio_feature_lens)
+
+            num_audio_tokens = feature_lens_after_pooling
+
+            final_audio_embeds = []
+            idx = 0
+            for i in range(len(audio_feature_lens_raw)):
+                target_audio_embeds = []
+                for _ in range(len(audio_feature_lens_raw[i])):
+                    target_audio_embeds.append(audio_embeds[idx, : num_audio_tokens[idx], :])
+                    idx += 1
+                final_audio_embeds.append(target_audio_embeds)
+            return final_audio_embeds
+        else:
+            return []
+
+    def get_audio_embedding(self, data, chunk_length=-1):
+        r"""
+        Extract full audio embeddings with optional chunk-based attention.
+
+        This method computes embeddings for all audio frames at once, either using full attention (when
+        `chunk_length` is -1) or chunk-based attention (when `chunk_length` is a positive number). It does
+        not use key-value caching and is suitable for non-streaming inference.
+
+        Args:
+            data (dict):
+                - **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`.
+                - **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch.
+            chunk_length (int, optional): Determines whether to use full attention (-1) or chunk-based
+                attention (>0) during embedding computation.
+
+        Returns:
+            List[List[torch.Tensor]]: audio embeddings
+        """
+
+        wavforms = data.get("audio_features", [])  # (bs, 80, frames) or [], multi audios need filled in advance
+        audio_feature_lens_raw = data.get("audio_feature_lens", [])  # list, [[x1, x2], [y1], [z1]]
+
+        # exist audio
+        if len(wavforms) > 0:
+            audio_feature_lens = torch.hstack(audio_feature_lens_raw)
+            batch_size, _, max_mel_seq_len = wavforms.shape
+            max_seq_len = (max_mel_seq_len - 1) // 2 + 1
+
+            # Create a sequence tensor of shape (batch_size, max_seq_len)
+            seq_range = (
+                torch.arange(0, max_seq_len, dtype=audio_feature_lens.dtype, device=audio_feature_lens.device)
+                .unsqueeze(0)
+                .expand(batch_size, max_seq_len)
+            )
+            lengths_expand = audio_feature_lens.unsqueeze(1).expand(batch_size, max_seq_len)
+            # Create mask
+            padding_mask = seq_range >= lengths_expand  # 1 for padded values
+
+            audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
+                batch_size, 1, max_seq_len, max_seq_len
+            )
+            audio_attention_mask = audio_attention_mask_.to(
+                dtype=self.apm.conv1.weight.dtype, device=self.apm.conv1.weight.device
+            )
+
+            if chunk_length > 0:
+                chunk_num_frame = int(chunk_length * 50)
+                chunk_mask = self.subsequent_chunk_mask(
+                    size=max_seq_len,
+                    chunk_size=chunk_num_frame,
+                    num_left_chunks=-1,
+                    device=audio_attention_mask_.device,
+                )
+                audio_attention_mask_ = torch.logical_or(audio_attention_mask_, torch.logical_not(chunk_mask))
+
+            audio_attention_mask[audio_attention_mask_] = float("-inf")
+            audio_states = self.apm(
+                wavforms, output_hidden_states=True, attention_mask=audio_attention_mask
+            ).hidden_states[self.audio_encoder_layer]
+            audio_embeds = self.audio_projection_layer(audio_states)
+
+            audio_embeds = audio_embeds.transpose(1, 2)
+            audio_embeds = self.audio_avg_pooler(audio_embeds)
+            audio_embeds = audio_embeds.transpose(1, 2)
+
+            _, feature_lens_after_pooling = self._get_feat_extract_output_lengths(audio_feature_lens)
+
+            num_audio_tokens = feature_lens_after_pooling
+
+            final_audio_embeds = []
+            idx = 0
+            for i in range(len(audio_feature_lens_raw)):
+                target_audio_embeds = []
+                for _ in range(len(audio_feature_lens_raw[i])):
+                    target_audio_embeds.append(audio_embeds[idx, : num_audio_tokens[idx], :])
+                    idx += 1
+                final_audio_embeds.append(target_audio_embeds)
+            return final_audio_embeds
+        else:
+            return []
+
+    def get_omni_embedding(self, data, input_embeddings, chunk_length=-1, stream_input=False):
+        """
+        Args:
+            data:
+            input_embeddings:
+            chunk_length: whisper use full attention or chunk attention
+            stream_input: use streaming audio embedding
+        Returns:
+            final embeddings with audio feature
+        """
+        if stream_input:
+            audio_embeddings = self.get_audio_embedding_streaming(data)
+        else:
+            audio_embeddings = self.get_audio_embedding(data, chunk_length)
+
+        bs = len(input_embeddings)
+        if len(data.get("audio_features", [])) > 0:
+            assert len(audio_embeddings) == len(input_embeddings)
+            if len(audio_embeddings) > 0:
+                audio_bounds = data["audio_bounds"]
+
+                if self.config.chunk_input:
+                    for i in range(bs):
+                        audio_embs = torch.cat(audio_embeddings[i], dim=0).to(
+                            device=input_embeddings.device, dtype=input_embeddings.dtype
+                        )
+                        audio_start_pos = 0
+                        for bound in audio_bounds[i]:
+                            audio_len = bound[1] - bound[0]
+                            input_embeddings[0, bound[0] : bound[1]] = audio_embs[
+                                audio_start_pos : audio_start_pos + audio_len, :
+                            ]
+                            audio_start_pos += audio_len
+                else:
+                    for i in range(bs):
+                        audio_embs = audio_embeddings[i]
+                        bounds = audio_bounds[i]
+                        for embs, bound in zip(audio_embs, bounds):
+                            audio_indices = torch.arange(bound[0], bound[1], dtype=torch.long).to(
+                                input_embeddings.device
+                            )
+
+                            if embs.shape[0] != len(audio_indices):
+                                raise ValueError(
+                                    f"Shape mismatch: Trying to assign embeddings of shape {embs.shape} "
+                                    f"to input indices of length {len(audio_indices)}"
+                                )
+                            input_embeddings[i, audio_indices] = embs.to(input_embeddings.dtype)
+        elif self.training:
+            for i in range(bs):
+                # dummy audio_embeddings
+                input_embeddings += audio_embeddings[0].mean() * 0
+
+        return input_embeddings
+
+    def forward(self, data, **kwargs):
+        vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
+
+        if self.config.init_audio:
+            vllm_embedding = self.get_omni_embedding(
+                data, input_embeddings=vllm_embedding, chunk_length=self.config.audio_chunk_length
+            )
+
+        position_ids = data["position_ids"]
+        if position_ids.dtype != torch.int64:
+            position_ids = position_ids.long()
+
+        # compatible with llama factory
+        for key in ["input_ids", "inputs_embeds", "position_ids"]:
+            if key in kwargs:
+                del kwargs[key]
+
+        return self.llm(input_ids=None, position_ids=position_ids, inputs_embeds=vllm_embedding, **kwargs)
+
+    def _decode(self, inputs_embeds, tokenizer, attention_mask, **kwargs):
+        terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
+        outputs = self.llm.generate(
+            inputs_embeds=inputs_embeds,
+            pad_token_id=0,
+            eos_token_id=terminators,
+            attention_mask=attention_mask,
+            output_hidden_states=True,
+            return_dict_in_generate=True,
+            **kwargs,
+        )
+        return outputs
+
+    def _decode_stream(self, inputs_embeds, tokenizer, **kwargs):
+        terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
+        streamer = TextIteratorStreamer(tokenizer=tokenizer)
+        generation_kwargs = {
+            "inputs_embeds": inputs_embeds,
+            "pad_token_id": 0,
+            "eos_token_id": terminators,
+            "streamer": streamer,
+        }
+        generation_kwargs.update(kwargs)
+
+        thread = Thread(target=self.llm.generate, kwargs=generation_kwargs)
+        thread.start()
+
+        return streamer
+
+    def _decode_text(self, result_ids, tokenizer):
+        terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
+        result_text = []
+        for result in result_ids:
+            result = result[result != 0]
+            if result[0] == tokenizer.bos_id:
+                result = result[1:]
+            if result[-1] in terminators:
+                result = result[:-1]
+            result_text.append(tokenizer.decode(result))
+        return result_text
+
+    def get_sys_prompt(self, ref_audio=None, mode="default", language="zh"):
+        """
+        Choose different system prompts according to different tasks
+        Args:
+            ref_audio: if ref_audio is not None, will use the voice cloning prompts, and the voice
+                       generated by the model will refer to the timbre of ref audio
+            mode:
+                "default": default system prompt and not refer to any task
+                "omni": input video and audio simultaneously
+                "audio_assistant": Default voice-only mode, the model will use the ref_audio's voice to reply user's question as a helpful assistant.
+                "audio_roleplay": Roleplay voice-only mode, the model will use the ref_audio's voice to reply, and also role-play the character based on the audio prompt.
+                "voice_cloning": TTS mode, the model will clone the voice of ref_audio.
+            language: prompts language, the model has the ability to automatically select the response language
+                    based on the question language
+        Returns:
+
+        """
+        if ref_audio is not None:
+            assert isinstance(ref_audio, np.ndarray), "ref_audio error"
+        if mode == "omni":
+            if language == "zh":
+                sys_prompt = "你是一个AI助手。你能接受视频,音频和文本输入并输出语音和文本。"
+                vc_prompt_prefix = sys_prompt + "模仿输入音频中的声音特征。"
+                vc_prompt_suffix = "作为助手,你将使用这种声音风格说话。"
+            else:
+                sys_prompt = "You are a helpful assistant. You can accept video, audio and text input and output voice and text. "
+                vc_prompt_prefix = sys_prompt + "Clone the voice in the provided audio prompt."
+                vc_prompt_suffix = "As an assistant, you will speak using this voice style."
+
+            if ref_audio is not None:
+                sys_msgs = {"role": "user", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]}
+
+            else:
+                sys_msgs = {"role": "user", "content": [sys_prompt]}
+
+            return sys_msgs
+        elif mode == "audio_assistant":
+            if language == "zh":
+                vc_prompt_prefix = "模仿输入音频中的声音特征。"
+                vc_prompt_suffix = "作为助手,你将使用这种声音风格说话。"
+            else:
+                vc_prompt_prefix = "Clone the voice in the provided audio prompt."
+                vc_prompt_suffix = "As an assistant, you will speak using this voice style."
+
+            if ref_audio is not None:
+                sys_msgs = {"role": "user", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]}
+
+            else:
+                logger.warning(
+                    "Warning: ref_audio is None, speech generation will be performed based on the default voice."
+                )
+                sys_msgs = {"role": "user", "content": ["Use the <reserved_53> voice.", vc_prompt_suffix]}
+
+            return sys_msgs
+        elif mode == "audio_roleplay":
+            if language == "zh":
+                vc_prompt_prefix = "模仿输入音频中的声音特征。"
+                vc_prompt_suffix = "假装你是上述音频中的人物,与我进行对话。"
+            else:
+                vc_prompt_prefix = "Clone the voice in the provided audio prompt."
+                vc_prompt_suffix = "Try to role-play the character based on the audio prompt above."
+
+            if ref_audio is not None:
+                sys_msgs = {"role": "user", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]}
+            else:
+                print("Warning: ref_audio is None, speech generation will be performed based on the default voice.")
+                sys_msgs = {"role": "user", "content": ["Use the <reserved_53> voice.", vc_prompt_suffix]}
+
+            return sys_msgs
+        elif mode == "voice_cloning":
+            if language == "zh":
+                vc_prompt_prefix = "模仿输入音频中的声音特征。"
+            else:
+                vc_prompt_prefix = "Clone the voice in the provided audio prompt."
+
+            if ref_audio is not None:
+                sys_msgs = {"role": "user", "content": [vc_prompt_prefix, ref_audio]}
+            else:
+                raise ValueError("ref_audio con't be None in voice_cloning mode.")
+
+            return sys_msgs
+        else:
+            sys_prompt = "You are a helpful assistant. You can accept audio and text input and output voice and text."
+            sys_msgs = {"role": "user", "content": [sys_prompt]}
+
+            return sys_msgs
+
+    def generate(
+        self,
+        input_ids=None,
+        pixel_values=None,
+        tgt_sizes=None,
+        audio_features=None,
+        audio_feature_lens=None,
+        image_bound=None,
+        audio_bounds=None,
+        spk_bounds=None,
+        attention_mask=None,
+        tokenizer=None,
+        vision_hidden_states=None,
+        stream=False,
+        **kwargs,
+    ):
+        assert input_ids is not None
+        assert len(input_ids) == len(pixel_values)
+
+        model_inputs = {
+            "input_ids": input_ids,
+            "audio_features": audio_features,
+            "audio_feature_lens": audio_feature_lens,
+            "image_bound": image_bound,
+            "audio_bounds": audio_bounds,
+            "spk_bounds": spk_bounds,
+        }
+
+        if vision_hidden_states is None:
+            model_inputs["pixel_values"] = pixel_values
+            model_inputs["tgt_sizes"] = tgt_sizes
+        else:
+            model_inputs["vision_hidden_states"] = vision_hidden_states
+
+        model_output = {}
+        with torch.inference_mode():
+            model_inputs["inputs_embeds"], vision_hidden_states = self.get_vllm_embedding(model_inputs)
+            model_inputs["inputs_embeds"] = self.get_omni_embedding(
+                model_inputs,
+                input_embeddings=model_inputs["inputs_embeds"],
+                chunk_length=self.config.audio_chunk_length,
+            )
+
+            if stream:
+                result = self._decode_stream(model_inputs["inputs_embeds"], tokenizer, **kwargs)
+                # if stream return TextIteratorStreamer and output is empty
+                outputs = {}
+            else:
+                outputs = self._decode(model_inputs["inputs_embeds"], tokenizer, attention_mask, **kwargs)
+
+                result = self._decode_text(outputs.sequences, tokenizer)
+
+        return result, outputs
+
+    def chat(
+        self,
+        image=None,
+        msgs=None,
+        tokenizer=None,
+        processor=None,
+        vision_hidden_states=None,
+        max_new_tokens=2048,
+        min_new_tokens=0,
+        sampling=True,
+        max_inp_length=32768,
+        stream=False,
+        chunk_input=True,
+        omni_input=False,
+        max_slice_nums=None,
+        use_image_id=None,
+        use_tts_template=False,
+        generate_audio=False,
+        return_spk_embed=False,
+        return_dict=False,
+        output_audio_path=None,
+        **kwargs,
+    ):
+        """
+        Unified chat function
+
+        Args:
+            image: use for batch_size=1 vqa, It is not recommended to continue to use this parameter
+            msgs: the input chat msgs, support text: (string)  / image: (PIL.Image) / audio (numpy.ndarray)
+            tokenizer: tokenizer for llm
+            processor: if None, use the default processor
+            max_new_tokens: the maximum length of the generation
+            min_new_tokens: the minimum length of the generation
+            sampling: whether to use sampling decoding or beam search decoding
+            max_inp_length: the maximum length of input
+            stream: whether to return generator, only used when tts is not required
+            chunk_input: whether to split audio into 1s chunks
+            omni_input: determine whether it is omni mode
+            max_slice_nums: control the maximum number of image slices
+            use_image_id: for video understanding or omni understanding, use_image_id should be False
+            use_tts_template: if the msgs contain audio, use_tts_template should be True
+            generate_audio: whether to generate audio output, only used when return_dict=True
+            return_spk_embed: whether to return spk embedding, only used when return_dict=True
+            return_dict: whether to return dict
+            output_audio_path: audio save path when generate_audio
+            **kwargs:
+        """
+        if isinstance(msgs[0], list):
+            batched = True
+        else:
+            batched = False
+
+        if generate_audio or return_spk_embed:
+            return_dict = True
+
+        msgs_list = msgs
+        images_list = image
+
+        if batched is False:
+            images_list, msgs_list = [images_list], [msgs_list]
+        else:
+            assert images_list is None, "Please integrate image to msgs when using batch inference."
+            images_list = [None] * len(msgs_list)
+        assert len(images_list) == len(msgs_list), "The batch dim of images_list and msgs_list should be the same."
+
+        if processor is None:
+            if self.processor is None:
+                self.processor = AutoProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
+            processor = self.processor
+
+        assert (
+            self.config.query_num == processor.image_processor.image_feature_size
+        ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
+        assert (
+            self.config.patch_size == processor.image_processor.patch_size
+        ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
+        assert (
+            self.config.use_image_id == processor.image_processor.use_image_id
+        ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
+        assert (
+            self.config.slice_config.max_slice_nums == processor.image_processor.max_slice_nums
+        ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
+        assert (
+            self.config.slice_mode == processor.image_processor.slice_mode
+        ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
+
+        prompts_lists = []
+        input_images_list = []
+        input_audios_list = []
+        audio_parts_list = []
+
+        for image, msgs in zip(images_list, msgs_list):
+            if isinstance(msgs, str):
+                msgs = json.loads(msgs)
+            copy_msgs = deepcopy(msgs)
+
+            assert len(msgs) > 0, "msgs is empty"
+            assert sampling or not stream, "if use stream mode, make sure sampling=True"
+
+            if image is not None and isinstance(copy_msgs[0]["content"], str):
+                copy_msgs[0]["content"] = [image, copy_msgs[0]["content"]]
+
+            images = []
+            audios = []
+            audio_parts = []
+            for i, msg in enumerate(copy_msgs):
+                role = msg["role"]
+                content = msg["content"]
+                assert role in ["system", "user", "assistant"]
+                if i == 0:
+                    assert role in ["user", "system"], "The role of first msg should be user"
+                if isinstance(content, str):
+                    content = [content]
+                cur_msgs = []
+                for c in content:
+                    if isinstance(c, Image.Image):
+                        images.append(c)
+                        cur_msgs.append("(<image>./</image>)")
+                    elif isinstance(c, np.ndarray):  # audio
+                        audios.append(c)
+                        audio_parts.append(i)
+                        cur_msgs.append("(<audio>./</audio>)")
+                        use_tts_template = True
+                    elif isinstance(c, str):
+                        cur_msgs.append(c)
+                if omni_input:
+                    msg["content"] = "".join(cur_msgs)
+                else:
+                    msg["content"] = "\n".join(cur_msgs)
+
+            prompts_lists.append(
+                processor.tokenizer.apply_chat_template(
+                    copy_msgs,
+                    tokenize=False,
+                    add_generation_prompt=True,
+                    chat_template=self.default_tts_chat_template if use_tts_template else None,
+                )
+            )
+            input_images_list.append(images)
+            input_audios_list.append(audios)
+            audio_parts_list.append(audio_parts)
+
+        inputs = processor(
+            prompts_lists,
+            input_images_list,
+            input_audios_list,
+            audio_parts_list,
+            max_slice_nums=max_slice_nums,
+            use_image_id=use_image_id,
+            chunk_input=chunk_input,
+            return_tensors="pt",
+            max_length=max_inp_length,
+        ).to(self.device)
+
+        if sampling:
+            generation_config = {
+                "top_p": 0.8,
+                "top_k": 100,
+                "temperature": 0.7,
+                "do_sample": True,
+                "repetition_penalty": 1.01,
+            }
+        else:
+            generation_config = {
+                "num_beams": 3,
+                "repetition_penalty": 1.2,
+            }
+
+        if min_new_tokens > 0:
+            generation_config["min_new_tokens"] = min_new_tokens
+
+        generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys())
+
+        inputs.pop("image_sizes")
+        with torch.inference_mode():
+            res, outputs = self.generate(
+                **inputs,
+                tokenizer=tokenizer,
+                max_new_tokens=max_new_tokens,
+                vision_hidden_states=vision_hidden_states,
+                stream=stream,
+                **generation_config,
+            )
+
+        if stream:
+
+            def stream_gen():
+                for text in res:
+                    for term in self.terminators:
+                        text = text.replace(term, "")
+                    yield text
+
+            if return_dict:
+                return OmniOutput(text=stream_gen())
+            else:
+                return stream_gen()
+
+        else:
+            spk_embeds = wav_numpy = sr = None
+
+            if batched:
+                answer = res
+            else:
+                answer = res[0]
+
+                if use_tts_template and generate_audio:
+                    mel_spec = self._generate_mel_spec(inputs, outputs, answer)
+                    wav_numpy, sr = self.decode_mel_to_audio(mel_spec, output_audio_path)
+
+            if return_spk_embed:
+                spk_embeds = self._get_last_spk_embeds(inputs, outputs)
+
+            if isinstance(answer, list):
+                answer = [i.replace(tokenizer.tts_end, "") for i in answer]
+            else:
+                answer = answer.replace(tokenizer.tts_end, "")
+
+            if return_dict:
+                return OmniOutput(text=answer, spk_embeds=spk_embeds, audio_wav=wav_numpy, sampling_rate=sr)
+            else:
+                return answer
+
+    @torch.inference_mode()
+    def streaming_prefill(
+        self,
+        session_id,
+        msgs,
+        tokenizer,
+        omni_input=True,
+        max_slice_nums=None,
+        ls_temperature=1.0,
+        **kwargs,
+    ):
+        """
+        Streaming video/audio input and output audio stream, Only support batch_size=1
+        Args:
+            session_id: Note: new connection should use a new session_id
+        """
+        assert session_id is not None
+        if self.session_id is None or session_id != self.session_id:  # new session
+            self.is_first = True
+        else:
+            self.is_first = False
+
+        images = []
+        audios = []
+
+        assert len(msgs) == 1
+        copy_msgs = deepcopy(msgs)
+        msg = copy_msgs[0]
+
+        assert msg["role"] in ["system", "user", "assistant"]
+
+        content = msg["content"]
+        cur_msgs = []
+        for j, c in enumerate(content):
+            if isinstance(c, Image.Image):
+                images.append(c)
+                cur_msgs.append("(<image>./</image>)")
+            elif isinstance(c, np.ndarray):  # audio
+                audios.append(c)
+                cur_msgs.append("(<audio>./</audio>)")
+            elif isinstance(c, str):
+                cur_msgs.append(c)
+            else:
+                logger.error("Invalid content type:", c)
+
+        cur_contents = "".join(cur_msgs) if omni_input else "\n".join(omni_input)
+        if not self.is_first and self.new_user_msg and msg["role"] == "user":  # new user add im_start
+            if self.llm_generated:
+                if self.llm_generate_completed:
+                    msg["content"] = "<|im_end|>\n<|im_start|>user\n" + cur_contents
+                else:  # break llm gen, add tts_eos
+                    msg["content"] = "<|tts_eos|><|im_end|>\n<|im_start|>user\n" + cur_contents
+            else:
+                msg["content"] = "<|im_start|>user\n" + cur_contents
+            self.new_user_msg = False
+        else:
+            msg["content"] = cur_contents
+
+        if msg["role"] in ["system", "assistant"]:
+            self.new_user_msg = True
+            self.audio_past_key_values = None  # apm kv cache
+
+        if self.is_first:
+            # init pask_key_values
+            logger.info(f"new session_id: {session_id}, reset kv cache")
+            self.reset_session()
+            self.session_id = session_id
+
+            prompt = tokenizer.apply_chat_template(
+                copy_msgs, tokenize=False, add_generation_prompt=False, chat_template=self.default_tts_chat_template
+            )
+            add_special_tokens = True  # add bos
+        else:
+            prompt = copy_msgs[0]["content"]
+            add_special_tokens = False
+
+        model_inputs = self.processor(
+            [prompt],
+            [images],
+            [audios],
+            max_slice_nums=1 if max_slice_nums is None else max_slice_nums,
+            use_image_id=False,
+            chunk_input=True,
+            return_tensors="pt",
+            max_length=None,
+            sampling_rate=16000,
+            add_special_tokens=add_special_tokens,
+        ).to(self.device)
+
+        # 1. prepare input embeddings
+        model_inputs["inputs_embeds"], _ = self.get_vllm_embedding(model_inputs)
+        # get audio embedding with audio_past_key_values
+        inputs_embeds = self.get_omni_embedding(
+            model_inputs, input_embeddings=model_inputs["inputs_embeds"], stream_input=True
+        )
+
+        if self.is_first:
+            # clean audio_past_key_values after first prefill
+            self.audio_past_key_values = None
+
+        if self.llm_past_key_values is not None:
+            cache_length = self.llm_past_key_values[0][0].shape[2]
+        else:
+            cache_length = 0
+
+        attention_mask = torch.ones((1, cache_length + inputs_embeds.shape[1]), dtype=torch.bool, device=self.device)
+
+        # 2. do prefill and predict listen/speak label
+        outputs = self.llm(
+            past_key_values=self.llm_past_key_values,
+            inputs_embeds=inputs_embeds,
+            attention_mask=attention_mask,
+            position_ids=None,  # position_ids,
+            use_cache=True,
+            return_dict=True,
+        )
+        self.llm_past_key_values = outputs["past_key_values"]
+        return
+
+    @torch.inference_mode()
+    def streaming_generate(
+        self,
+        session_id,
+        tokenizer,
+        max_new_tokens=512,
+        min_new_tokens=0,
+        sampling=True,
+        generate_audio=True,
+        enable_regenerate=False,
+        **kwargs,
+    ):
+        """
+        Streaming video/audio input and output audio stream
+        Args:
+        """
+        if sampling:
+            generation_config = {
+                "top_p": 0.8,
+                "top_k": 100,
+                "temperature": 0.7,
+                "do_sample": True,
+                "repetition_penalty": 1.01,
+            }
+        else:
+            generation_config = {
+                "num_beams": 3,
+                "repetition_penalty": 1.2,
+            }
+        generation_config["min_new_tokens"] = min_new_tokens
+        generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys())
+
+        # do generate
+        # reset buffer
+        self.new_user_msg = True
+        self.llm_generated = True
+        self.llm_generate_completed = False
+        self.audio_past_key_values = None  # apm kv cache
+
+        terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
+        generate_prompt = "<|im_end|>\n<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>"
+        input_ids = tokenizer(generate_prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].cuda()
+
+        spk_start_idx = torch.where(input_ids[0] == tokenizer.spk_start_id)[0]
+        spk_end_idx = torch.where(input_ids[0] == tokenizer.spk_end_id)[0]
+        spk_bounds = [
+            torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)])
+        ]  # List[Tensor], (1,2)
+
+        cache_length = past_length = self.llm_past_key_values[0][0].shape[2]
+        attention_mask = torch.ones((1, cache_length + input_ids.shape[1]), dtype=torch.bool, device=self.device)
+
+        generation_config["max_new_tokens"] = max_new_tokens
+        streamer = self.llm_generate_chunk(input_ids, attention_mask, tokenizer, terminators, generation_config)
+
+        if generate_audio:
+            result = self._generate_mel_spec_audio_streaming(
+                spk_bounds, streamer, output_chunk_size=25, enable_regenerate=enable_regenerate
+            )
+            return result
+        else:
+            return streamer
+
+    def llm_generate_chunk(self, input_ids, attention_mask, tokenizer, terminators, generation_config):
+        def check_uncompleted_token(ids):
+            cur_text = tokenizer.decode(ids)
+            end = len(ids)
+            while cur_text[-1] == "�":
+                end -= 1
+                if end == 0:
+                    break
+                cur_text = tokenizer.decode(ids[:end])
+            return end
+
+        max_new_tokens = int(generation_config.pop("max_new_tokens", 2048))
+        new_len = 0
+        first_chunk = True
+        eos = False
+        left_ids = None
+
+        while True:
+            outputs = self.llm.generate(
+                input_ids=input_ids,
+                past_key_values=self.llm_past_key_values,
+                attention_mask=attention_mask,
+                use_cache=True,
+                max_new_tokens=3,  # reduce first token delay
+                pad_token_id=0,
+                output_hidden_states=True if first_chunk else False,
+                return_dict_in_generate=True,
+                eos_token_id=terminators,
+                **generation_config,
+            )
+            if outputs.sequences[0, -1] in terminators:
+                eos = True
+            input_len = input_ids.shape[1]
+            cur_ids = outputs.sequences[:, input_len:]
+            new_len += cur_ids.shape[1]
+
+            if left_ids is not None and left_ids.shape[1] > 0:
+                cur_ids = torch.cat([left_ids, cur_ids], dim=1)
+            end = check_uncompleted_token(cur_ids[0])
+            left_ids = cur_ids[:, end:]
+            cur_ids = cur_ids[:, :end]
+            text = self._decode_text(cur_ids, tokenizer)[0] if end > 0 else ""
+
+            self.llm_past_key_values = outputs.past_key_values
+            input_ids = outputs.sequences[:, -1:]
+            cache_length = past_length = self.llm_past_key_values[0][0].shape[2]
+            attention_mask = torch.ones((1, cache_length + input_ids.shape[1]), dtype=torch.bool, device=self.device)
+
+            res = {"text": text}
+            if first_chunk:
+                res["hidden_states"] = outputs.hidden_states
+                first_chunk = False
+            yield res
+
+            if eos:
+                self.llm_generate_completed = True
+                break
+            if new_len >= max_new_tokens:
+                logger.debug(f"LLM generation {new_len} exceeds max_new_tokens({max_new_tokens}), break.")
+                break
+
+    def prepare_tts_text(self, text):
+        tts_tokens = self.tts_processor.text_tokenizer.encode(text, add_special_tokens=False)
+        tts_tokens_len = len(tts_tokens)
+        if tts_tokens_len < self.tts.streaming_text_reserved_len:
+            num_pad_tokens = self.tts.streaming_text_reserved_len - tts_tokens_len
+
+            pad_str = "[Etts]" + "[PAD]" * (num_pad_tokens - 1)
+        else:
+            tts_tokens = tts_tokens[0 : self.tts.streaming_text_reserved_len]
+            tts_tokens_len = len(tts_tokens)
+            text = self.tts_processor.text_tokenizer.decode(tts_tokens, add_special_tokens=False)
+            pad_str = ""
+        spk_emb_placeholder_tts = "[spk_emb]" * self.tts.num_spk_embs
+
+        new_text_tts = f"[Stts]{spk_emb_placeholder_tts}{text}{pad_str}[Ptts]"
+        return new_text_tts, tts_tokens_len
+
+    def get_tts_text_start_token_ids(self):
+        text = "[Stts]" + "[spk_emb]" * self.tts.num_spk_embs
+        tts_input_ids = self.tts_processor.text_tokenizer(text, return_tensors="pt", add_special_tokens=False)[
+            "input_ids"
+        ].cuda()
+        return tts_input_ids
+
+    def _build_streaming_mask(self, tts_tokens_len):
+        tts_sequence_full_length = (
+            1 + self.tts.num_spk_embs * self.tts.use_speaker_embedding + self.tts.streaming_text_reserved_len + 1
+        )
+        streaming_attention_mask = torch.zeros(tts_sequence_full_length, dtype=torch.int8)
+        streaming_attention_mask[0 : 1 + 1 + tts_tokens_len + 1] = 1
+        streaming_attention_mask[-1] = 1
+        return streaming_attention_mask
+
+    def _get_last_spk_embeds(self, inputs, outputs):
+        last_hidden_states = [hs[-1] for hs in outputs.hidden_states]
+
+        # batch = 1
+        last_hidden_states = torch.vstack([i[0] for i in last_hidden_states])
+
+        # last spk
+        spk_bound = inputs["spk_bounds"][0][-1]
+
+        spk_embeds = last_hidden_states[spk_bound[0] : spk_bound[1]]
+        return spk_embeds
+
+    def _generate_mel_spec(self, inputs, outputs, text, output_chunk_size=25, tts_max_new_tokens=2048):
+        spk_embeds = self._get_last_spk_embeds(inputs, outputs)
+
+        text = text.split("<|tts_bos|>")[-1]
+        gen_text = text.split("<|tts_eos|>")[0]
+        tts_text, tts_token_lens = self.prepare_tts_text(gen_text)
+        tts_inputs = self.tts_processor.text_tokenizer.encode(tts_text, add_special_tokens=False)
+        tts_input_ids = torch.Tensor(tts_inputs).unsqueeze(0).to("cuda", dtype=torch.long)
+        streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
+
+        logits_warpers, logits_processors = gen_logits(
+            num_code=626, top_P=self.tts.top_p, top_K=self.tts.top_k, repetition_penalty=self.tts.repetition_penalty
+        )
+
+        condition_length = (
+            1 + self.tts.use_speaker_embedding * self.tts.num_spk_embs + self.tts.streaming_text_reserved_len + 1
+        )
+
+        dtype = self.tts.emb_text.weight.dtype
+        emb = torch.zeros(1, condition_length, self.tts.num_vq, dtype=dtype, device=self.tts.device)
+        past_key_values = [
+            (
+                torch.zeros(
+                    1,
+                    self.tts.config.num_attention_heads,
+                    condition_length - 1,
+                    self.tts.config.hidden_size // self.tts.config.num_attention_heads,
+                    dtype=emb.dtype,
+                    device=self.tts.device,
+                ),
+                torch.zeros(
+                    1,
+                    self.tts.config.num_attention_heads,
+                    condition_length - 1,
+                    self.tts.config.hidden_size // self.tts.config.num_attention_heads,
+                    dtype=emb.dtype,
+                    device=self.tts.device,
+                ),
+            )
+            for _ in range(self.tts.config.num_hidden_layers)
+        ]
+
+        audio_input_ids = torch.zeros(1, condition_length, self.tts.num_vq, dtype=torch.long, device=self.tts.device)
+
+        eos_lab = False
+        for chunk_idx in range(math.ceil(emb.shape[1] / self.tts.streaming_text_chunk_size)):
+            if chunk_idx == 0:
+                begin = chunk_idx * self.tts.streaming_text_chunk_size + 0
+                end = (
+                    (chunk_idx + 1) * self.tts.streaming_text_chunk_size
+                    + 1
+                    + self.tts.use_speaker_embedding * self.tts.num_spk_embs
+                )
+            else:
+                begin = (
+                    chunk_idx * self.tts.streaming_text_chunk_size
+                    + 1
+                    + self.tts.use_speaker_embedding * self.tts.num_spk_embs
+                )
+                end = min(
+                    (chunk_idx + 1) * self.tts.streaming_text_chunk_size
+                    + 1
+                    + self.tts.use_speaker_embedding * self.tts.num_spk_embs,
+                    condition_length - 1,
+                )
+
+            if end - begin > 0:
+                text_input_ids = tts_input_ids[:, begin:end]
+                position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0)
+
+                if begin == 0:
+                    past_key_values = self.tts.prefill_text(
+                        input_ids=text_input_ids,
+                        position_ids=position_ids,
+                        past_key_values=past_key_values,
+                        lm_spk_emb_last_hidden_states=spk_embeds,
+                    )
+                else:
+                    past_key_values = self.tts.prefill_text(
+                        input_ids=text_input_ids, position_ids=position_ids, past_key_values=past_key_values
+                    )
+
+            outputs = self.tts.generate(
+                input_ids=audio_input_ids,
+                past_key_values=past_key_values,
+                streaming_tts_text_mask=streaming_tts_text_mask,
+                max_new_token=output_chunk_size,
+                force_no_stop=self.force_no_stop,
+                temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device),
+                eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device),
+                logits_warpers=logits_warpers,
+                logits_processors=logits_processors,
+            )
+            audio_input_ids = outputs.audio_input_ids
+            past_key_values = outputs.past_key_values
+
+            if outputs.finished:
+                logger.debug("Generation finished.")
+                eos_lab = True
+                break
+
+        if not eos_lab:
+            logger.debug("eos_lab False, Generation continue.")
+            while True:
+                outputs = self.tts.generate(
+                    input_ids=audio_input_ids,
+                    past_key_values=past_key_values,
+                    streaming_tts_text_mask=streaming_tts_text_mask,
+                    max_new_token=output_chunk_size,
+                    force_no_stop=self.force_no_stop,
+                    temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device),
+                    eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device),
+                    logits_warpers=logits_warpers,
+                    logits_processors=logits_processors,
+                )
+
+                audio_input_ids = outputs.audio_input_ids
+                past_key_values = outputs.past_key_values
+
+                if outputs.finished:
+                    logger.debug("Generation finished.")
+                    break
+                if outputs.new_ids.shape[1] > tts_max_new_tokens:
+                    logger.debug(f"Generation length > {tts_max_new_tokens}, stopped.")
+                    break
+
+        mel_spec = self.tts.decode_to_mel_specs(outputs.new_ids)
+        return mel_spec
+
+    def _linear_overlap_add2_wav(self, frames: List[torch.Tensor], overlap: int):
+        """
+        Merge two audio waveforms with smooth in streaming audio generation.
+        Borrowed some codes from `https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py`
+        """
+        assert len(frames) == 2
+        device = frames[0].device
+        dtype = frames[0].dtype
+        # shape = frames[0].shape[:-1]
+
+        frame0_length = frames[0].shape[-1]
+        frame1_length = frames[1].shape[-1]
+        total_size = frame0_length + frame1_length - overlap
+        weight_len = max(frame0_length, frame1_length) + overlap
+        t = torch.linspace(0, 1, weight_len + 2, device=device, dtype=dtype)[1:-1]
+        weight = 0.5 - (t - 0.5).abs()
+
+        sum_weight = torch.zeros(total_size, device=device, dtype=dtype)
+        out = torch.zeros(total_size, device=device, dtype=dtype)
+        offset: int = 0
+
+        out[offset : offset + frame0_length] += weight[-frame0_length:] * frames[0]
+        sum_weight[offset : offset + frame0_length] += weight[-frame0_length:]
+        offset += frame0_length - overlap
+        out[offset : offset + frame1_length] += weight[:frame1_length] * frames[1]
+        sum_weight[offset : offset + frame1_length] += weight[:frame1_length]
+
+        assert sum_weight.min() > 0
+        out = out / sum_weight
+        return out[:frame0_length], out[frame0_length:]
+
+    def _generate_mel_spec_audio_streaming(
+        self,
+        spk_bounds,
+        streamer,
+        output_chunk_size=25,
+        spk_embeds=None,
+        prev_seg_text_ids=None,
+        prev_seg_text_left="",
+        prev_seg_audio_ids=None,
+        enable_regenerate=False,
+    ):
+        # get spk_embedding
+        gen_text = ""
+        tts_text = ""
+        new_segment_gen = False
+        if spk_embeds is None:
+            spk_bound = spk_bounds[0][-1]
+            r = next(streamer)
+            txt = r["text"]
+            gen_text += txt.split("<|tts_eos|>")[0]
+            tts_text, tts_token_lens = self.prepare_tts_text(gen_text)
+            last_hidden_states = r["hidden_states"][0][-1][0]  # output: (input_seq_len, dim)
+            spk_embeds = last_hidden_states[spk_bound[0] : spk_bound[1]]
+
+        # init past_key_values
+        logits_warpers, logits_processors = gen_logits(
+            num_code=626, top_P=self.tts.top_p, top_K=self.tts.top_k, repetition_penalty=self.tts.repetition_penalty
+        )
+        condition_length = (
+            1 + self.tts.use_speaker_embedding * self.tts.num_spk_embs + self.tts.streaming_text_reserved_len + 1
+        )
+        tts_start_token_len = 1 + self.tts.use_speaker_embedding * self.tts.num_spk_embs
+        dtype = self.tts.emb_text.weight.dtype
+        past_key_values = [
+            (
+                torch.zeros(
+                    1,
+                    self.tts.config.num_attention_heads,
+                    condition_length - 1,
+                    self.tts.config.hidden_size // self.tts.config.num_attention_heads,
+                    dtype=dtype,
+                    device=self.tts.device,
+                ),
+                torch.zeros(
+                    1,
+                    self.tts.config.num_attention_heads,
+                    condition_length - 1,
+                    self.tts.config.hidden_size // self.tts.config.num_attention_heads,
+                    dtype=dtype,
+                    device=self.tts.device,
+                ),
+            )
+            for _ in range(self.tts.config.num_hidden_layers)
+        ]
+        audio_input_ids = torch.zeros(1, condition_length, self.tts.num_vq, dtype=torch.long, device=self.tts.device)
+
+        # prefill prev segment for smooth
+        chunk_idx = 0
+        new_ids_len = 0
+        prev_text_len = 0
+        if prev_seg_text_ids is not None and prev_seg_audio_ids is not None:
+            tts_token_lens = prev_seg_text_ids.shape[1]
+            # assert tts_token_lens % self.tts.streaming_text_chunk_size == 0
+            streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
+            position_ids = torch.arange(
+                0, tts_token_lens + tts_start_token_len, dtype=torch.long, device=self.tts.device
+            ).unsqueeze(0)
+
+            text_input_ids = self.get_tts_text_start_token_ids()
+            text_input_ids = torch.cat([text_input_ids, prev_seg_text_ids], dim=1)
+            past_key_values = self.tts.prefill_text(
+                input_ids=text_input_ids,
+                position_ids=position_ids,
+                past_key_values=past_key_values,
+                lm_spk_emb_last_hidden_states=spk_embeds,
+            )
+            past_key_values = self.tts.prefill_audio_ids(
+                input_ids=prev_seg_audio_ids[:, :-1, :],
+                # not prefill last id, which will be input_id of next generation
+                past_key_values=past_key_values,
+                streaming_tts_text_mask=streaming_tts_text_mask,
+            )
+
+            # update init
+            chunk_idx += int(tts_token_lens / self.tts.streaming_text_chunk_size)
+            audio_input_ids = torch.cat([audio_input_ids, prev_seg_audio_ids], dim=1)
+            text = self.tts_processor.text_tokenizer.decode(prev_seg_text_ids[0].tolist(), add_special_tokens=False)
+
+            gen_text += text
+            gen_text += prev_seg_text_left
+            prev_text_len = len(gen_text)  # takecare the position
+            new_ids_len += prev_seg_audio_ids.shape[1]
+
+        prev_wav = None
+        eos_lab = False
+        stop = False
+        shift_len = 180
+        voice_checker = VoiceChecker()
+        number_converter = NumberToTextConverter()
+        lang = None
+        gen_text_raw = gen_text
+        for t, r in enumerate(streamer):
+            t += 1
+            txt = r["text"]
+            txt = txt.split("<|tts_eos|>")[0]
+            gen_text_raw += txt
+            if t == 1 and txt == "" and prev_seg_text_ids is not None:
+                logger.warning("New segment is empty, generation finished.")
+                return
+            if t <= 2:  # do just one time, more token greater certainty
+                lang = number_converter.detect_language(gen_text_raw)
+            gen_text += number_converter.replace_numbers_with_text(txt, lang).replace("*", "")  # markdown **
+
+            # TODO speed up
+            tts_text, tts_token_lens = self.prepare_tts_text(gen_text)
+
+            if tts_token_lens >= self.tts.streaming_text_reserved_len - shift_len:
+                end_c = sentence_end(txt)
+                if end_c:
+                    end_c_idx = gen_text.rfind(end_c)
+                    assert end_c_idx != -1
+                    text_left = gen_text[end_c_idx + 1 :]
+                    gen_text = gen_text[: end_c_idx + 1]
+                    tts_text, tts_token_lens = self.prepare_tts_text(gen_text)
+                    new_segment_gen = True
+                    logger.debug(
+                        f"tts_text tokens {tts_token_lens} exceed {self.tts.streaming_text_reserved_len - shift_len}, starting a new segment generation"
+                    )
+                    break
+
+            if tts_token_lens >= (chunk_idx + 1) * self.tts.streaming_text_chunk_size:
+
+                # do prefill and generate
+                if chunk_idx == 0:
+                    begin = 0
+                    end = (chunk_idx + 1) * self.tts.streaming_text_chunk_size + tts_start_token_len
+                else:
+                    begin = chunk_idx * self.tts.streaming_text_chunk_size + tts_start_token_len
+                    end = min(
+                        (chunk_idx + 1) * self.tts.streaming_text_chunk_size + tts_start_token_len, condition_length - 1
+                    )
+
+                tts_input_ids = self.tts_processor.text_tokenizer(
+                    tts_text, return_tensors="pt", add_special_tokens=False
+                )["input_ids"].cuda()
+                text_input_ids = tts_input_ids[:, begin:end]
+                streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
+                position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0)
+
+                past_key_values = self.tts.prefill_text(
+                    input_ids=text_input_ids,
+                    position_ids=position_ids,
+                    past_key_values=past_key_values,
+                    lm_spk_emb_last_hidden_states=spk_embeds if chunk_idx == 0 else None,
+                )
+                outputs = self.tts.generate(
+                    input_ids=audio_input_ids,
+                    past_key_values=past_key_values,
+                    streaming_tts_text_mask=streaming_tts_text_mask,
+                    max_new_token=output_chunk_size,
+                    force_no_stop=self.force_no_stop,
+                    temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device),
+                    eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device),
+                    logits_warpers=logits_warpers,
+                    logits_processors=logits_processors,
+                )
+                audio_input_ids = (
+                    outputs.audio_input_ids
+                )  # [1,seq_len,4] seq_len=tts.streaming_text_reserved_len + 3 + len(new_ids)
+                past_key_values = outputs.past_key_values
+                chunk_idx += 1
+
+                mel_spec = self.tts.decode_to_mel_specs(outputs.new_ids[:, max(new_ids_len - 4, 0) :, :])
+                new_ids_len = outputs.new_ids.shape[1]  # [1, seq_len, 4]
+
+                wav_np, sr = self.decode_mel_to_audio(mel_spec)  # [1,100,50] -> [50*256]
+
+                if enable_regenerate:
+                    if prev_wav is not None:
+                        check_wav_np = wav_np[2048:].cpu().numpy()  # 2*4*256(hop)
+                        check_mel = mel_spec[0, :, 8:].cpu().numpy()  # 2*4
+                    else:
+                        check_wav_np = wav_np.cpu().numpy()
+                        check_mel = mel_spec[0].cpu().numpy()
+                if enable_regenerate and voice_checker.is_bad(check_wav_np, check_mel, chunk_size=2560):
+                    voice_checker.reset()
+                    # regenerate
+                    N = output_chunk_size if prev_wav is None else output_chunk_size * 2
+                    past_kv = []
+                    for i in range(len(past_key_values)):
+                        past_kv.append(
+                            (
+                                past_key_values[i][0][:, :, :-N, :],  # .clone(),
+                                past_key_values[i][1][:, :, :-N, :],  # .clone(),
+                            )
+                        )
+                    outputs = self.tts.generate(
+                        input_ids=audio_input_ids[:, :-N, :],
+                        past_key_values=past_kv,
+                        streaming_tts_text_mask=streaming_tts_text_mask,
+                        max_new_token=N,
+                        force_no_stop=self.force_no_stop,
+                        temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device),
+                        eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device),
+                        logits_warpers=logits_warpers,
+                        logits_processors=logits_processors,
+                    )
+                    audio_input_ids = outputs.audio_input_ids
+                    past_key_values = outputs.past_key_values
+
+                    new_ids_len -= N
+                    mel_spec = self.tts.decode_to_mel_specs(outputs.new_ids[:, new_ids_len:, :])
+                    new_ids_len = outputs.new_ids.shape[1]  # [1, seq_len, 4]
+                    wav_np, sr = self.decode_mel_to_audio(mel_spec)
+
+                    if prev_wav is not None:
+                        wav_y = wav_np[: len(prev_wav)]
+                        prev_wav = wav_np[len(prev_wav) :]
+                        cur_text = gen_text_raw[prev_text_len:]
+                        prev_text_len = len(gen_text_raw)
+                        yield OmniOutput(text=cur_text, audio_wav=wav_y, sampling_rate=sr)
+
+                    else:
+                        prev_wav = wav_np
+                else:
+                    # smooth wav
+                    if prev_wav is not None:
+                        wav_np, prev_wav = self._linear_overlap_add2_wav(
+                            [prev_wav, wav_np], overlap=512 * 4
+                        )  # tts_hop256*2
+                        cur_text = gen_text_raw[prev_text_len:]
+                        prev_text_len = len(gen_text_raw)
+                        yield OmniOutput(text=cur_text, audio_wav=wav_np, sampling_rate=sr)
+
+                    else:
+                        prev_wav = wav_np
+
+                if outputs.finished:
+                    logger.debug("Generation finished.")
+                    eos_lab = True
+                    break
+
+        if not eos_lab and tts_text:
+            logger.debug("eos_lab False, Generation continue.")
+
+            if chunk_idx == 0:
+                begin = 0
+            else:
+                begin = chunk_idx * self.tts.streaming_text_chunk_size + tts_start_token_len
+            end = tts_token_lens + tts_start_token_len + 1  # 1 for [Etts]
+            if end > begin:
+                tts_input_ids = self.tts_processor.text_tokenizer(
+                    tts_text, return_tensors="pt", add_special_tokens=False
+                )["input_ids"].cuda()
+                text_input_ids = tts_input_ids[:, begin:end]
+                streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
+                position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0)
+
+                past_key_values = self.tts.prefill_text(
+                    input_ids=text_input_ids,
+                    position_ids=position_ids,
+                    past_key_values=past_key_values,
+                    lm_spk_emb_last_hidden_states=spk_embeds if chunk_idx == 0 else None,
+                )
+
+            while True:
+                # temp = [0.1, 0.3, 0.1, 0.3] if chunk_idx < 21 else [0.1] * self.tts.num_vq
+                outputs = self.tts.generate(
+                    input_ids=audio_input_ids,
+                    past_key_values=past_key_values,
+                    streaming_tts_text_mask=streaming_tts_text_mask,
+                    max_new_token=output_chunk_size,
+                    force_no_stop=self.force_no_stop,
+                    # temperature=torch.tensor([0.1] * self.tts.num_vq, dtype=torch.float, device=self.tts.device),
+                    temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device),
+                    eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device),
+                    logits_warpers=logits_warpers,
+                    logits_processors=logits_processors,
+                )
+                audio_input_ids = outputs.audio_input_ids
+                past_key_values = outputs.past_key_values
+                chunk_idx += 1
+
+                mel_spec = self.tts.decode_to_mel_specs(outputs.new_ids[:, max(new_ids_len - 4, 0) :, :])
+                new_ids_len = outputs.new_ids.shape[1]  # [1, seq_len, 4]
+
+                wav_np, sr = self.decode_mel_to_audio(mel_spec)
+
+                if enable_regenerate:
+                    if prev_wav is not None:
+                        check_wav_np = wav_np[2048:].cpu().numpy()  # 2*4*256(hop)
+                        check_mel = mel_spec[0, :, 8:].cpu().numpy()  # 2*4
+                    else:
+                        check_wav_np = wav_np.cpu().numpy()
+                        check_mel = mel_spec[0].cpu().numpy()
+                if enable_regenerate and voice_checker.is_bad(check_wav_np, check_mel, chunk_size=2560):
+                    voice_checker.reset()
+                    # regenerate
+                    N = output_chunk_size if prev_wav is None else output_chunk_size * 2
+                    past_kv = []
+                    for i in range(len(past_key_values)):
+                        past_kv.append(
+                            (
+                                past_key_values[i][0][:, :, :-N, :],  # .clone(),
+                                past_key_values[i][1][:, :, :-N, :],  # .clone(),
+                            )
+                        )
+                    outputs = self.tts.generate(
+                        input_ids=audio_input_ids[:, :-N, :],
+                        past_key_values=past_kv,
+                        streaming_tts_text_mask=streaming_tts_text_mask,
+                        max_new_token=N,
+                        force_no_stop=self.force_no_stop,
+                        temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device),
+                        eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device),
+                        logits_warpers=logits_warpers,
+                        logits_processors=logits_processors,
+                    )
+                    audio_input_ids = outputs.audio_input_ids
+                    past_key_values = outputs.past_key_values
+
+                    new_ids_len -= N
+                    mel_spec = self.tts.decode_to_mel_specs(outputs.new_ids[:, new_ids_len:, :])
+                    new_ids_len = outputs.new_ids.shape[1]  # [1, seq_len, 4]
+                    wav_np, sr = self.decode_mel_to_audio(mel_spec)
+
+                    if prev_wav is not None:
+                        wav_y = wav_np[: len(prev_wav)]
+                        prev_wav = wav_np[len(prev_wav) :]
+                        cur_text = gen_text_raw[prev_text_len:]
+                        prev_text_len = len(gen_text_raw)
+                        yield OmniOutput(text=cur_text, audio_wav=wav_y, sampling_rate=sr)
+                    else:
+                        prev_wav = wav_np
+                else:
+                    # smooth wav
+                    if prev_wav is not None:
+                        wav_np, prev_wav = self._linear_overlap_add2_wav(
+                            [prev_wav, wav_np], overlap=512 * 4
+                        )  # tts_hop256*2
+                        cur_text = gen_text_raw[prev_text_len:]
+                        prev_text_len = len(gen_text_raw)
+                        yield OmniOutput(text=cur_text, audio_wav=wav_np, sampling_rate=sr)
+                    else:
+                        prev_wav = wav_np
+
+                if outputs.finished:
+                    logger.debug("Generation finished.")
+                    break
+                if outputs.new_ids.shape[1] > 2048:
+                    stop = True
+                    logger.debug("Generation length > 2048, stopped.")
+                    break
+
+        if prev_wav is not None:
+            cur_text = gen_text_raw[prev_text_len:]
+            yield OmniOutput(text=cur_text, audio_wav=prev_wav, sampling_rate=sr)  # yield last chunk wav without smooth
+
+        if new_segment_gen and not stop:
+            logger.debug(
+                f"tts_text tokens {tts_token_lens} exceed {self.tts.streaming_text_reserved_len - shift_len}, start a new segment generation"
+            )
+            tid_len = 5  # self.tts.streaming_text_chunk_size
+            prev_seg_text_ids = tts_input_ids[:, end - 1 - tid_len : end - 1]  # exclude last Etts
+            aid_len = 50  # int(tid_len * new_ids_len / tts_token_lens)
+            prev_seg_audio_ids = outputs.new_ids[:, -aid_len:, :]
+
+            result = self._generate_mel_spec_audio_streaming(
+                spk_bounds,
+                streamer,
+                output_chunk_size,
+                spk_embeds,
+                prev_seg_text_ids,
+                text_left,
+                prev_seg_audio_ids,
+                enable_regenerate=enable_regenerate,
+            )
+            for res in result:
+                yield res
+
+    def decode_mel_to_audio(self, mel_spec, output_path=""):
+        with torch.inference_mode():
+            wav_numpy = self.vocos.decode(mel_spec.float()).cpu().squeeze()
+            sr = 24000
+        if output_path:
+            sf.write(output_path, wav_numpy.numpy(), samplerate=sr)
+            logger.info(f"Audio saved to {output_path}")
+        return wav_numpy, sr
+
+
+# Copied from transformers.models.whisper.modeling_whisper.WhisperEncoderLayer and add use_cache for streaming inference
+class MiniCPMWhisperEncoderLayer(nn.Module):
+    def __init__(self, config: WhisperConfig, layer_idx: int = None):
+        super().__init__()
+        self.embed_dim = config.d_model
+        self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
+            embed_dim=self.embed_dim,
+            num_heads=config.encoder_attention_heads,
+            dropout=config.attention_dropout,
+            config=config,
+            layer_idx=layer_idx,
+        )
+        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+        self.dropout = config.dropout
+        self.activation_fn = ACT2FN[config.activation_function]
+        self.activation_dropout = config.activation_dropout
+        self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
+        self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
+        self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: torch.Tensor,
+        layer_head_mask: torch.Tensor,
+        output_attentions: bool = False,
+        past_key_values: Optional[EncoderDecoderCache] = None,
+        use_cache: Optional[bool] = False,
+    ) -> torch.Tensor:
+        r"""
+        Args:
+            hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, embed_dim)`):
+                Hidden states to be fed into the encoder layer.
+            attention_mask (`torch.FloatTensor` of shape `(batch_size, 1, tgt_len, src_len)`):
+                Attention mask where padding elements are indicated by large negative values.
+            layer_head_mask (`torch.FloatTensor` of shape `(encoder_attention_heads,)`):
+                Mask to nullify selected heads of the attention modules.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attention weights.
+            past_key_values (`EncoderDecoderCache`, *optional*):
+                Past key-value pairs used for incremental decoding.
+            use_cache (`bool`, *optional*):
+                Whether or not to return updated `past_key_values` for caching.
+
+        Returns:
+            A tuple of shape `(hidden_states, optional(attn_weights), optional(past_key_values))`.
+        """
+        residual = hidden_states
+        hidden_states = self.self_attn_layer_norm(hidden_states)
+        hidden_states, attn_weights, past_key_values = self.self_attn(
+            hidden_states=hidden_states,
+            attention_mask=attention_mask,
+            layer_head_mask=layer_head_mask,
+            output_attentions=output_attentions,
+            past_key_value=past_key_values,
+        )
+        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+        hidden_states = residual + hidden_states
+
+        residual = hidden_states
+        hidden_states = self.final_layer_norm(hidden_states)
+        hidden_states = self.activation_fn(self.fc1(hidden_states))
+        hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+        hidden_states = self.fc2(hidden_states)
+        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+        hidden_states = residual + hidden_states
+
+        if hidden_states.dtype == torch.float16 and (
+            torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
+        ):
+            clamp_value = torch.finfo(hidden_states.dtype).max - 1000
+            hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (attn_weights,)
+
+        if use_cache:
+            outputs += (past_key_values,)
+
+        return outputs
+
+
+# Copied from from transformers.models.whisper.modeling_whisper.WhisperEncoder and add use_cache for streaming inference
+class MiniCPMWhisperEncoder(WhisperEncoder):
+
+    def __init__(self, config: WhisperConfig):
+        super().__init__(config)
+        self.layers = nn.ModuleList(
+            [MiniCPMWhisperEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)]
+        )
+
+    def forward(
+        self,
+        input_features,
+        attention_mask=None,
+        head_mask=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+        past_key_values: Optional[EncoderDecoderCache] = None,
+        use_cache: Optional[bool] = None,
+    ):
+        r"""
+        Forward pass of the Whisper encoder.
+
+        Args:
+            input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`):
+                Float values of log-mel features extracted from the raw audio waveform. Typically generated
+                by a feature extractor (e.g., `WhisperFeatureExtractor`) that processes `.flac` or `.wav`
+                files into padded 2D mel spectrogram frames. These features are projected via convolution layers
+                (`conv1` and `conv2`) and then transformed into embeddings for the encoder.
+
+            attention_mask (`torch.Tensor`, *optional*):
+                Not used by Whisper for masking `input_features`, but included for API compatibility with
+                other models. If provided, it is simply ignored within the model. By default, Whisper
+                effectively ignores silence in the input log-mel spectrogram.
+
+            head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+                Mask to nullify selected attention heads. The elements should be either 1 or 0, where:
+                - 1 indicates the head is **not masked**,
+                - 0 indicates the head is **masked** (i.e., the attention head is dropped).
+
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attention tensors of all encoder layers. If set to `True`, the
+                returned tuple (or `BaseModelOutputWithPast`) will contain an additional element with
+                attention weights for each encoder layer.
+
+            output_hidden_states (`bool`, *optional*):
+                Whether or not to return the hidden states of all layers. If set to `True`, the returned
+                tuple (or `BaseModelOutputWithPast`) will contain a tuple of hidden states, including the
+                initial embedding output as well as the outputs of each layer.
+
+            return_dict (`bool`, *optional*):
+                Whether or not to return a `BaseModelOutputWithPast` (a subclass of `ModelOutput`) instead
+                of a plain tuple. If set to `True`, the output will be a `BaseModelOutputWithPast` object,
+                otherwise it will be a tuple.
+
+            past_key_values (`EncoderDecoderCache`, *optional*):
+                When using caching for faster inference, this is an object that stores the key-value pairs
+                for attention states. If provided, the model will append new states to the existing cache
+                and return the updated cache. This speeds up sequential decoding or chunked inference.
+
+                - If `past_key_values` is `None`, no past states are used or returned.
+                - If `past_key_values` is not `None` and `use_cache=True`, the model will use the provided
+                cache and return the updated cache (as `next_encoder_cache`).
+
+            use_cache (`bool`, *optional*):
+                Whether or not the model should use caching (`past_key_values`) to speed up processing
+                during inference. When set to `True`, the model will:
+                - Inspect and use `past_key_values` if provided.
+                - Return updated `past_key_values` (under the name `next_encoder_cache` in
+                    `BaseModelOutputWithPast`).
+
+        Returns:
+            `BaseModelOutputWithPast` or `tuple` (depending on `return_dict`):
+                If `return_dict=True`, a `BaseModelOutputWithPast` is returned, which contains:
+                - **last_hidden_state** (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+                The output of the final encoder layer.
+                - **hidden_states** (`tuple(torch.FloatTensor)`, *optional*, returned if `output_hidden_states=True`):
+                Hidden states of the model at each layer (including the initial projection).
+                - **attentions** (`tuple(torch.FloatTensor)`, *optional*, returned if `output_attentions=True`):
+                Attention weights from each encoder layer.
+                - **past_key_values** (an object of type `EncoderDecoderCache` or `None`, *optional*):
+                Updated cache of key-value pairs if `use_cache=True`.
+
+                If `return_dict=False`, a tuple is returned, where the format is:
+                `(last_hidden_state, hidden_states, attentions)`, with `hidden_states` and `attentions`
+                only present if their respective `output_*` arguments are set to `True`.
+
+        Example:
+            >>> from transformers import AutoFeatureExtractor, WhisperConfig, WhisperForConditionalGeneration
+            >>> import torch
+
+            >>> # Load a feature extractor and a Whisper model
+            >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny.en")
+            >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
+
+            >>> # Assume you have audio (list of floats or numpy array) loaded from a file
+            >>> # Then extract the mel features:
+            >>> input_features = feature_extractor(audio, sampling_rate=16000, return_tensors="pt").input_features
+
+            >>> # Forward pass
+            >>> outputs = model.encoder(
+            ...     input_features=input_features,
+            ...     output_hidden_states=True,
+            ...     output_attentions=True,
+            ...     use_cache=True
+            ... )
+
+            >>> # Retrieve the last hidden state
+            >>> last_hidden_state = outputs.last_hidden_state
+            >>> print(last_hidden_state.shape)
+            torch.Size([batch_size, seq_length, hidden_size])
+
+            >>> # Retrieve the intermediate hidden states if output_hidden_states=True
+            >>> all_encoder_hidden_states = outputs.hidden_states
+
+            >>> # Retrieve attention weights if output_attentions=True
+            >>> all_encoder_attentions = outputs.attentions
+
+            >>> # Retrieve updated past key values if use_cache=True
+            >>> encoder_cache = outputs.past_key_values
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        # Ignore copy
+        input_features = input_features.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device)
+
+        inputs_embeds = nn.functional.gelu(self.conv1(input_features))
+        inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
+
+        inputs_embeds = inputs_embeds.permute(0, 2, 1)
+
+        embed_pos = self.embed_positions.weight
+        past_key_values_length = 0
+        if use_cache:
+            if past_key_values is None:
+                past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
+            elif isinstance(past_key_values, list):
+                past_key_values = EncoderDecoderCache(DynamicCache.from_legacy_cache(past_key_values), DynamicCache())
+            elif isinstance(past_key_values, DynamicCache):
+                past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
+            else:
+                pass
+            past_key_values_length = past_key_values.self_attention_cache.get_usable_length(inputs_embeds.shape[1])
+            if inputs_embeds.shape[1] + past_key_values_length > embed_pos.shape[0]:
+                logger.warning("seems the audio is longer than 30s. repeating the last part of the audio")
+                embed_pos_front = embed_pos[past_key_values_length:, :]
+                embed_pos = torch.cat(
+                    (
+                        embed_pos_front,
+                        torch.repeat_interleave(
+                            embed_pos[-1, :].unsqueeze(0),
+                            inputs_embeds.shape[1] - embed_pos.shape[0] + past_key_values_length,
+                            dim=0,
+                        ),
+                    )
+                )
+            else:
+                embed_pos = embed_pos[past_key_values_length : inputs_embeds.shape[1] + past_key_values_length, :]
+        else:
+            embed_pos = embed_pos[: inputs_embeds.shape[1], :]
+
+        hidden_states = inputs_embeds + embed_pos
+        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+        encoder_states = () if output_hidden_states else None
+        all_attentions = () if output_attentions else None
+
+        # check if head_mask has a correct number of layers specified if desired
+        if head_mask is not None:
+            assert head_mask.size()[0] == (
+                len(self.layers)
+            ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+
+        for idx, encoder_layer in enumerate(self.layers):
+            if output_hidden_states:
+                encoder_states = encoder_states + (hidden_states,)
+            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+            to_drop = False
+            if self.training:
+                dropout_probability = torch.rand([])
+                if dropout_probability < self.layerdrop:  # skip the layer
+                    to_drop = True
+
+            # Ignore copy
+            if to_drop:
+                layer_outputs = (None, None)
+            else:
+                if self.gradient_checkpointing and self.training:
+                    layer_outputs = self._gradient_checkpointing_func(
+                        encoder_layer.__call__,
+                        hidden_states,
+                        attention_mask,
+                        (head_mask[idx] if head_mask is not None else None),
+                        output_attentions,
+                        past_key_values,
+                        use_cache,
+                    )
+                else:
+                    layer_outputs = encoder_layer(
+                        hidden_states,
+                        attention_mask,
+                        layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+                        output_attentions=output_attentions,
+                        past_key_values=past_key_values,
+                        use_cache=use_cache,
+                    )
+
+                hidden_states = layer_outputs[0]
+
+            if use_cache:
+                next_encoder_cache = layer_outputs[2 if output_attentions else 1]
+            else:
+                next_encoder_cache = None
+
+            if output_attentions:
+                all_attentions = all_attentions + (layer_outputs[1],)
+
+        hidden_states = self.layer_norm(hidden_states)
+        if output_hidden_states:
+            encoder_states = encoder_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+        return BaseModelOutputWithPast(
+            last_hidden_state=hidden_states,
+            hidden_states=encoder_states,
+            attentions=all_attentions,
+            past_key_values=next_encoder_cache,
+        )
+
+
+# Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/dvae.py`
+class ConvNeXtBlock(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        intermediate_dim: int,
+        kernel: int,
+        dilation: int,
+        layer_scale_init_value: float = 1e-6,
+    ):
+        # ConvNeXt Block copied from Vocos.
+        super().__init__()
+        self.dwconv = nn.Conv1d(
+            dim,
+            dim,
+            kernel_size=kernel,
+            padding=dilation * (kernel // 2),
+            dilation=dilation,
+            groups=dim,
+        )
+
+        self.norm = nn.LayerNorm(dim, eps=1e-6)
+        self.pwconv1 = nn.Linear(dim, intermediate_dim)
+        self.act = nn.GELU()
+        self.pwconv2 = nn.Linear(intermediate_dim, dim)
+        self.coef = (
+            nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
+            if layer_scale_init_value > 0
+            else None
+        )
+
+    def forward(self, x: torch.Tensor, cond=None) -> torch.Tensor:
+        residual = x
+
+        y = self.dwconv(x)
+        y.transpose_(1, 2)  # (B, C, T) -> (B, T, C)
+        x = self.norm(y)
+        del y
+        y = self.pwconv1(x)
+        del x
+        x = self.act(y)
+        del y
+        y = self.pwconv2(x)
+        del x
+        if self.coef is not None:
+            y *= self.coef
+        y.transpose_(1, 2)  # (B, T, C) -> (B, C, T)
+
+        x = y + residual
+        del y
+
+        return x
+
+
+# Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/dvae.py`
+class GFSQ(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        levels: List[int],
+        G: int,
+        R: int,
+        eps=1e-5,
+        transpose=True,
+    ):
+        super(GFSQ, self).__init__()
+        self.quantizer = GroupedResidualFSQ(
+            dim=dim,
+            levels=list(levels),
+            num_quantizers=R,
+            groups=G,
+        )
+        self.n_ind = math.prod(levels)
+        self.eps = eps
+        self.transpose = transpose
+        self.G = G
+        self.R = R
+
+    def _embed(self, x: torch.Tensor):
+        if self.transpose:
+            x = x.transpose(1, 2)
+        x = x.view(x.size(0), x.size(1), self.G, self.R).permute(2, 0, 1, 3)
+        feat = self.quantizer.get_output_from_indices(x)
+        return feat.transpose_(1, 2) if self.transpose else feat
+
+    def __call__(self, x: torch.Tensor) -> torch.Tensor:
+        return super().__call__(x)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        if self.transpose:
+            x.transpose_(1, 2)
+        _, ind = self.quantizer(x)
+        ind = ind.permute(1, 2, 0, 3).contiguous()
+        ind = ind.view(ind.size(0), ind.size(1), -1)
+        return ind.transpose_(1, 2) if self.transpose else ind
+
+
+# Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/dvae.py`
+class DVAEDecoder(nn.Module):
+    def __init__(
+        self,
+        idim: int,
+        odim: int,
+        n_layer=12,
+        bn_dim=64,
+        hidden=256,
+        kernel=7,
+        dilation=2,
+        up=False,
+    ):
+        super().__init__()
+        self.up = up
+        self.conv_in = nn.Sequential(
+            nn.Conv1d(idim, bn_dim, 3, 1, 1),
+            nn.GELU(),
+            nn.Conv1d(bn_dim, hidden, 3, 1, 1),
+        )
+        self.decoder_block = nn.ModuleList(
+            [
+                ConvNeXtBlock(
+                    hidden,
+                    hidden * 4,
+                    kernel,
+                    dilation,
+                )
+                for _ in range(n_layer)
+            ]
+        )
+        self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False)
+
+    def forward(self, x: torch.Tensor, conditioning=None) -> torch.Tensor:
+        # B, C, T
+        y = self.conv_in(x)
+        del x
+        for f in self.decoder_block:
+            y = f(y, conditioning)
+
+        x = self.conv_out(y)
+        del y
+        return x
+
+
+# Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/dvae.py`
+class DVAE(nn.Module):
+    def __init__(
+        self,
+    ):
+        super().__init__()
+
+        coef = torch.rand(100)
+        self.coef = nn.Parameter(coef.unsqueeze(0).unsqueeze_(2))
+
+        self.downsample_conv = nn.Sequential(
+            nn.Conv1d(100, 512, 3, 1, 1),
+            nn.GELU(),
+            nn.Conv1d(512, 512, 4, 2, 1),
+            nn.GELU(),
+        )
+
+        self.encoder = DVAEDecoder(
+            idim=512,
+            odim=1024,
+            hidden=256,
+            n_layer=12,
+            bn_dim=128,
+        )
+
+        self.decoder = DVAEDecoder(
+            idim=512,
+            odim=512,
+            hidden=256,
+            n_layer=12,
+            bn_dim=128,
+        )
+
+        self.out_conv = nn.Conv1d(512, 100, 3, 1, 1, bias=False)
+
+        self.vq_layer = GFSQ(
+            dim=1024,
+            levels=(5, 5, 5, 5),
+            G=2,
+            R=2,
+        )
+
+    @torch.inference_mode()
+    def forward(self, inp: torch.Tensor, mode: Literal["encode", "decode"] = "decode") -> torch.Tensor:
+        if mode == "encode" and hasattr(self, "encoder") and self.vq_layer is not None:
+            mel = inp.clone()
+            x: torch.Tensor = self.downsample_conv(
+                torch.div(mel, self.coef.view(100, 1).expand(mel.shape), out=mel),
+            ).unsqueeze_(0)
+            del mel
+            x = self.encoder(x)
+            ind = self.vq_layer(x)
+            del x
+            return ind
+
+        if self.vq_layer is not None:
+            vq_feats = self.vq_layer._embed(inp)
+        else:
+            vq_feats = inp
+
+        vq_feats = (
+            vq_feats.view(
+                (vq_feats.size(0), 2, vq_feats.size(1) // 2, vq_feats.size(2)),
+            )
+            .permute(0, 2, 3, 1)
+            .flatten(2)
+        )
+
+        dec_out = self.out_conv(
+            self.decoder(
+                x=vq_feats,
+            ),
+        )
+
+        del vq_feats
+
+        return torch.mul(dec_out, self.coef, out=dec_out)
+
+
+def apply_spk_emb(
+    input_ids: torch.Tensor = None,
+    spk_emb: torch.Tensor = None,
+    input_embeds: torch.Tensor = None,
+    spk_emb_token_id: int = 0,
+    num_spk_embs: int = 1,
+):
+    """
+    Replace consecutive `num_spk_embs` speaker embedding placeholders in input_embeds with pre-prepared speaker embeddings. This is an in-place replacement, no new tensor is created, so no value is returned.
+
+    Args:
+        input_ids (torch.Tensor): Input ID tensor, shape [batch_size, seq_len_max]
+        spk_emb (torch.Tensor): Speaker embedding tensor, shape [batch_size, num_spk_emb, hidden_dim]
+        input_embeds (torch.Tensor): Input embedding tensor, shape [batch_size, seq_len_max, hidden_dim]
+        spk_emb_token_id (int): ID of the speaker embedding token
+        num_spk_embs (int): Number of speaker embeddings
+
+    Returns:
+        None
+    """
+
+    batch_size = input_ids.shape[0]
+
+    for idx in range(batch_size):
+        input_ids_ = input_ids[idx]  # [seq_len_max]
+        spk_emb_ = spk_emb[idx]  # [num_spk_emb]
+        mask_ = input_ids_ == spk_emb_token_id  # [batch_size, seq_len_max]
+        nonzero_position_idx = mask_.nonzero(as_tuple=False)  # [num_spk_emb, 1]
+        assert nonzero_position_idx.shape[0] == num_spk_embs
+        begin_idx = nonzero_position_idx.min()
+        end_idx = nonzero_position_idx.max()
+        input_embeds[idx, begin_idx : end_idx + 1, :] = spk_emb_
+
+    return
+
+
+def make_streaming_chunk_mask_generation(
+    inputs_embeds: torch.Tensor,
+    past_seen_tokens: int,
+    streaming_tts_text_mask: torch.Tensor,
+    streaming_reserved_length: int = 300,
+    streaming_audio_chunk_size: int = 50,
+    streaming_text_chunk_size: int = 10,
+    num_spk_emb: int = 1,
+    use_spk_emb: bool = True,
+) -> torch.Tensor:
+    """
+    In streaming audio generation, determine which `text` positions the TTS model can attend to when generating each chunk of `audio` tokens.
+
+    This function creates a mask that allows the model to attend to a specific chunk of text
+    tokens when generating each chunk of audio tokens, enabling streaming TTS generation.
+
+    Args:
+        inputs_embeds (torch.Tensor): Input embeddings tensor.
+        past_seen_tokens (int): Number of tokens already seen by the model.
+        streaming_tts_text_mask (torch.Tensor): Mask for the text tokens.
+        streaming_reserved_length (int, optional): Number of reserved tokens for streaming. Defaults to 300.
+        streaming_chunk_length (int, optional): Length of each streaming chunk. Defaults to 50.
+        streaming_text_chunk_size (int, optional): Size of each text chunk. Defaults to 7.
+
+    Returns:
+        torch.Tensor: Causal mask for streaming TTS generation, shape is [batch_size=1, 1, seq_len=1, past_seen_tokens+1]
+
+    Raises:
+        AssertionError: If the batch size is not 1 (only supports batch size of 1 for inference).
+    """
+    assert inputs_embeds.shape[0] == 1
+
+    dtype = inputs_embeds.dtype
+    device = inputs_embeds.device
+    min_dtype = torch.finfo(dtype).min
+
+    # Add `1` to the past seen tokens to account for new `tokens` during `generate`
+    causal_mask = torch.full((1, past_seen_tokens + inputs_embeds.shape[1]), fill_value=0, dtype=dtype, device=device)
+
+    # Calculate the start of invisible text tokens
+    invisible_text_tokens_start = (
+        min(
+            math.ceil((past_seen_tokens - streaming_reserved_length) / streaming_audio_chunk_size)
+            * streaming_text_chunk_size,
+            streaming_reserved_length,
+        )
+        + 1
+        + num_spk_emb * use_spk_emb
+    )  # Add 1 for [Stts] and N for [spk_emb] tokens if `use_spk_emb` is True
+
+    invisible_text_tokens_end = (
+        streaming_reserved_length + 1 + num_spk_emb * use_spk_emb + 1
+    )  # Add 1 for [Ptts] (aka `audio_bos_token_id`)
+
+    # Set invisible text tokens to min_dtype (effectively -inf)
+    causal_mask[0, invisible_text_tokens_start:invisible_text_tokens_end] = min_dtype
+
+    # Mask padding positions in the text mask
+    causal_mask[0, 0 : 1 + num_spk_emb * use_spk_emb + streaming_reserved_length + 1].masked_fill_(
+        streaming_tts_text_mask == 0, min_dtype
+    )
+
+    # Add extra dimensions for batch and heads
+    causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
+
+    return causal_mask
+
+
+# Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/processors.py`
+class CustomRepetitionPenaltyLogitsProcessorRepeat:
+    def __init__(self, penalty: float, max_input_ids: int, past_window: int):
+        if not isinstance(penalty, float) or not (penalty > 0):
+            raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
+
+        self.penalty = penalty
+        self.max_input_ids = max_input_ids
+        self.past_window = past_window
+
+    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+        if input_ids.size(1) > self.past_window:
+            input_ids = input_ids.narrow(1, -self.past_window, self.past_window)
+        freq = F.one_hot(input_ids, scores.size(1)).sum(1)
+        if freq.size(0) > self.max_input_ids:
+            freq.narrow(0, self.max_input_ids, freq.size(0) - self.max_input_ids).zero_()
+        alpha = torch.pow(self.penalty, freq)
+        scores = scores.contiguous()
+        inp = scores.multiply(alpha)
+        oth = scores.divide(alpha)
+        con = scores < 0
+        out = torch.where(con, inp, oth)
+        del inp, oth, scores, con, alpha
+        return out
+
+
+@dataclass
+class ConditionalChatTTSGenerationOutput(ModelOutput):
+    """
+    Output class for ConditionalChatTTS generation.
+
+    Args:
+        new_ids (torch.LongTensor): Newly generated audio code sequence, shape (batch_size, sequence_length, num_vq).
+        audio_input_ids (torch.LongTensor): Updated input IDs including condition and generated audio codes, shape (batch_size, full_sequence_length, num_vq).
+        past_key_values (Tuple[Tuple[torch.FloatTensor]]): Tuple containing pre-computed keys and values used for attention mechanism. Each element has shape (batch_size, num_heads, sequence_length, embed_size_per_head).
+        finished (bool): Boolean indicating whether generation is complete.
+
+    """
+
+    new_ids: torch.LongTensor = None
+    audio_input_ids: torch.LongTensor = None
+    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+    finished: bool = None
+
+
+class MultiModalProjector(nn.Module):
+    def __init__(self, in_dim, out_dim):
+        super().__init__()
+        self.linear1 = nn.Linear(in_features=in_dim, out_features=out_dim, bias=True)
+        self.relu = nn.ReLU()
+        self.linear2 = nn.Linear(in_features=out_dim, out_features=out_dim, bias=True)
+
+    def forward(self, audio_features):
+        hidden_states = self.relu(self.linear1(audio_features))
+        hidden_states = self.linear2(hidden_states)
+        return hidden_states
+
+
+class ConditionalChatTTS(PreTrainedModel):
+    """A conditional text-to-speech model that can generate speech from text with speaker conditioning.
+
+    This model extends PreTrainedModel to provide text-to-speech capabilities with:
+    - LLM hidden state conditioning
+    - Streaming generation
+
+    The model uses a transformer architecture with LLM hidden states and can operate in both
+    streaming and non-streaming modes for flexible deployment.
+
+    The model process sequence in the following format:
+    | text bos token | LLM embedding projected to tts embedding space | text tokens (fixed length, reserved for future tokens) | audio bos token | audio tokens (audio token length is not fixed)| audio eos token |
+
+    The format is designed to support LLM-conditioned streaming audio generation.
+
+    Usage:
+    To support streaming generation, two global variables should be maintained outside of the model.
+        1. `audio_input_ids`: stores *discrete* audio codes. It is a tensor with shape [1, sequence length+1, num_vq].
+        2. `past_key_values`: stores the KV cache for both text tokens and audio codes. It is a list of tuples, each tuple contains two tensors with shape [1, num_attention_heads, sequence length, hidden_size // num_attention_heads]
+
+    where `num_vq` is the number of audio codebooks, in default setting, it is `4`.
+
+    1. Create an empty `past_key_values` with
+    ```python
+    initial_kv_cache_length = 1 + model.num_spk_embs + model.streaming_text_reserved_len # where `1` denotes the `bos` token
+    dtype = model.emb_text.weight.dtype
+    device = model.emb_text.weight.device
+    past_key_values = [
+        (
+            torch.zeros(1, model.config.num_attention_heads, initial_kv_cache_length, model.config.hidden_size // model.config.num_attention_heads, dtype=dtype, device=device),
+            torch.zeros(1, model.config.num_attention_heads, initial_kv_cache_length, model.config.hidden_size // model.config.num_attention_heads, dtype=dtype, device=device)
+        )
+        for _ in range(model.config.num_hidden_layers)
+    ]
+
+    2. At the same time, create an empty `audio_input_ids` with shape [1, sequence length, num_vq], `num_vq` denotes multiple layer audio codebooks. But here we also include text tokens in the sequence, but they will be zeros, and will not be used, just a placeholder.
+
+    ```python
+    initial_audio_input_ids_length = 1 + model.num_spk_embs + model.streaming_text_reserved_len + 1
+    # [bos token, speaker embeddings, text tokens, audio bos token]
+    audio_input_ids = torch.zeros(batch_size=1, initial_audio_input_ids_length, model.num_vq)
+    ```
+
+    2. Prefill some text tokens to TTS model (for example, 10 tokens) using `prefill_text` method.
+
+    ```python
+    outputs = llm.generate(**kwargs)
+    llm_tokens = some_function_to_extract_llm_tokens(outputs)
+    lm_spk_emb_last_hidden_states = some_function_to_extract_lm_spk_emb_last_hidden_states(outputs)
+    tts_text_input_ids = tts_tokenizer.encode(llm_tokenizer.decode(llm_tokens))
+    # here assume we are prefilling text token 0 to text token 9 (included), totally 10 tokens.
+    begin = 0
+    end = 9+1
+    position_ids = torch.arange(begin, end, dtype=torch.long, device=device)
+
+    past_key_values = model.prefill_text(
+        input_ids=tts_text_input_ids,
+        position_ids=position_ids,
+        past_key_values=past_key_values,
+        lm_spk_emb_last_hidden_states=lm_spk_emb_last_hidden_states,
+    )
+    ```
+
+    3. Make a `streaming_tts_text_mask` to denote which position contains valid text tokens, similar to `attention_mask` in standard causal attention.
+
+    ```python
+    streaming_tts_text_mask = torch.zeros(model.streaming_reserved_length)
+    streaming_tts_text_mask[0:end] = 1 # denotes these post
+    ```
+
+    3. Generate audio codes using `generate` method.
+
+    ```python
+    outputs = model.generate(
+        input_ids=audio_input_ids,
+        past_key_values=past_key_values,
+        streaming_tts_text_mask=streaming_tts_text_mask,
+        max_new_token=50,
+    )
+
+    # update past_key_values and input_ids
+    past_key_values = outputs.past_key_values
+    audio_input_ids = outputs.input_ids
+    ```
+
+    The `past_key_values` is extended by `max_new_token=50`, and `audio_input_ids` is also extended by `max_new_token=50` after `generate` calling.
+
+    4. Notice that after prefilling `10` text tokens, the model can generate up to `50` audio tokens, if you want to generate more audio tokens, you need to prefill next `10` text tokens. And it is okay to only generate `25` audio tokens for faster initial response.
+
+    5. Repeat steps `2,3,4` as needed in your streaming audio generation cases, but ensure usage complies with the following guidelines discussed above.
+    """
+
+    config_class = ConditionalChatTTSConfig
+
+    def __init__(self, config: ConditionalChatTTSConfig):
+        super().__init__(config)
+
+        self.use_speaker_embedding = config.use_speaker_embedding
+        self.use_llm_hidden_state = config.use_llm_hidden_state
+        self.num_spk_embs = config.num_spk_embs
+        self.spk_emb_token_id = config.spk_emb_token_id
+
+        self.use_text = config.use_text
+        self.streaming = config.streaming
+        self.streaming_text_chunk_size = config.streaming_text_chunk_size
+        self.streaming_audio_chunk_size = config.streaming_audio_chunk_size
+        self.streaming_text_reserved_len = config.streaming_text_reserved_len
+        self.audio_bos_token_id = config.audio_bos_token_id
+        self.num_mel_bins = config.num_mel_bins
+        self.num_vq = config.num_vq
+        self.num_audio_tokens = config.num_audio_tokens
+
+        self.top_p = config.top_p
+        self.top_k = config.top_k
+        self.repetition_penalty = config.repetition_penalty
+
+        if self.config.use_mlp:
+            self.projector = MultiModalProjector(config.llm_dim, config.hidden_size)
+        else:
+            self.projector = nn.Linear(config.llm_dim, config.hidden_size, bias=False)
+        self.emb_code = nn.ModuleList(
+            [nn.Embedding(config.num_audio_tokens, config.hidden_size) for _ in range(config.num_vq)]
+        )
+        self.emb_text = nn.Embedding(config.num_text_tokens, config.hidden_size)
+        self.head_code = nn.ModuleList(
+            [
+                weight_norm(
+                    nn.Linear(config.hidden_size, config.num_audio_tokens, bias=False),
+                    name="weight",
+                )
+                for _ in range(config.num_vq)
+            ]
+        )
+        dvae = DVAE()
+        self.dvae = dvae
+
+        model_config = LlamaConfig(
+            hidden_size=config.hidden_size,
+            intermediate_size=config.intermediate_size,
+            num_attention_heads=config.num_attention_heads,
+            num_hidden_layers=config.num_hidden_layers,
+            max_position_embeddings=config.max_position_embeddings,
+            attn_implementation=config.attn_implementation,
+        )
+
+        model = LlamaModel(model_config)
+        self.model = model
+
+    @torch.inference_mode()
+    def merge_inputs_embeds(
+        self,
+        input_ids: torch.Tensor,
+        lm_spk_emb_last_hidden_states: Optional[torch.Tensor] = None,
+    ):
+        """Merge `input_ids` and `lm_spk_emb_last_hidden_states` to `inputs_embeds`.
+
+        Args:
+            input_ids (torch.Tensor): Input token IDs.
+            lm_spk_emb_last_hidden_states (Optional[torch.Tensor], optional): Last hidden states of speaker embeddings from the language model. Defaults to None.
+
+        Raises:
+            NotImplementedError: If speaker embedding is not used and language model hidden states are not implemented.
+
+        Returns:
+            torch.Tensor: Prepared input embeddings for the model.
+        """
+        assert input_ids.shape[0] == 1
+
+        # Embed input_ids to input_embeds
+        inputs_embeds = self.emb_text(input_ids)
+
+        # Inject speaker embedding to input_embeds if it exists
+        if self.use_speaker_embedding:
+            spk_emb_mask = input_ids == self.spk_emb_token_id
+            if spk_emb_mask.any():
+                assert lm_spk_emb_last_hidden_states is not None
+                # Project spk emb to tts hidden size first, [batch_size, num_spk_emb, llm_dim] -> [batch_size, num_spk_emb, self.hidden_size]
+                lm_spk_emb_last_hidden_states = lm_spk_emb_last_hidden_states.to(self.projector.linear1.weight.dtype)
+                projected_spk_emb = self.projector(lm_spk_emb_last_hidden_states)
+                projected_spk_emb = F.normalize(projected_spk_emb, p=2, dim=-1)
+                apply_spk_emb(
+                    input_ids=input_ids,
+                    spk_emb=projected_spk_emb,
+                    input_embeds=inputs_embeds,
+                    spk_emb_token_id=self.spk_emb_token_id,
+                    num_spk_embs=self.num_spk_embs,
+                )
+        else:
+            raise NotImplementedError
+
+        return inputs_embeds
+
+    @torch.inference_mode()
+    def prefill_text(
+        self,
+        input_ids: torch.Tensor,
+        position_ids: torch.LongTensor,
+        past_key_values: List[Tuple[torch.Tensor, torch.Tensor]],
+        lm_spk_emb_last_hidden_states: Optional[torch.Tensor] = None,
+    ):
+        """Prefill a chunk of new text tokens in streaming setting.
+        Specifically speaking, update `past_key_values` using new text tokens, then the model will read the new text tokens.
+
+        Args:
+            input_ids (Tensor): Tensor of shape [batch_size, seq_len]
+            position_ids (LongTensor): Tensor of shape [batch_size, seq_len]
+            past_key_values (List[Tuple[Tensor]]): KV Cache of all layers, each layer is a tuple (Tensor, Tensor) denoting keys and values. Each tensor is of seq_len = `self.streaming_text_reserved_len`. `past_key_values` will be updated.
+            lm_spk_emb_last_hidden_states (Tensor, optional): Tensor of shape [batch_size, num_spk_emb, llm_dim]. Defaults to None.
+            lm_last_hidden_states (Tensor, optional): _description_. Defaults to None.
+
+        Note that all `batch_size` should be `1`.
+        """
+        assert input_ids.shape[0] == 1
+        assert past_key_values is not None
+
+        # Merge text and LLM embeddings
+        inputs_embeds = self.merge_inputs_embeds(
+            input_ids=input_ids,
+            lm_spk_emb_last_hidden_states=lm_spk_emb_last_hidden_states,
+        )
+
+        # Clone KV Cache
+        past_key_values_for_prefill = []
+        for i in range(len(past_key_values)):
+            past_key_values_for_prefill.append(
+                (
+                    past_key_values[i][0][:, :, : position_ids[:, 0], :].clone(),
+                    past_key_values[i][1][:, :, : position_ids[:, 0], :].clone(),
+                )
+            )
+
+        # Model forward
+        outputs_prefill: BaseModelOutputWithPast = self.model(
+            attention_mask=None,  # because for text, it is standard causal attention mask, do nothing
+            position_ids=position_ids,  # position_ids denotes the position of new text tokens in the sequence
+            past_key_values=past_key_values_for_prefill,  # `past_key_values` will be updated by the model
+            inputs_embeds=inputs_embeds,  # contains text and language model embedding
+            use_cache=True,
+            output_attentions=False,
+            cache_position=position_ids,  # which new positions will use this cache, basically the same as position_ids
+        )
+
+        # Get model updated KV Cache
+        past_key_values_for_prefill_updated = outputs_prefill.past_key_values
+
+        # Update generated KV Cache to input `past_key_values`
+        for layer_idx in range(len(past_key_values)):
+            # Update keys
+            past_key_values[layer_idx][0][:, :, position_ids[:, 0] : position_ids[:, -1] + 1, :] = (
+                past_key_values_for_prefill_updated[layer_idx][0][
+                    :, :, position_ids[:, 0] : position_ids[:, -1] + 1
+                ].clone()
+            )
+            # Update values
+            past_key_values[layer_idx][1][:, :, position_ids[:, 0] : position_ids[:, -1] + 1, :] = (
+                past_key_values_for_prefill_updated[layer_idx][1][
+                    :, :, position_ids[:, 0] : position_ids[:, -1] + 1
+                ].clone()
+            )
+
+        # TODO: del past_key_values_for_prefill_updated recursively
+        # TODO: del outputs_prefill recursively
+
+        return past_key_values
+
+    @torch.inference_mode()
+    def prefill_audio_ids(
+        self,
+        input_ids: torch.Tensor,
+        past_key_values: List[Tuple[torch.Tensor, torch.Tensor]],
+        streaming_tts_text_mask=None,
+        add_audio_bos: bool = True,
+    ):
+        """Prefill a chunk of audio ids to the model. Used in sliding-window long audio generation.
+        Specifically, prefill many audio ids (typically from last window) to the model in the new window.
+
+        Args:
+            input_ids (torch.Tensor): (1, seq_len, num_vq) Audio input token ids.
+            past_key_values (List[Tuple[torch.Tensor, torch.Tensor]]): Past key values for attention mechanism.
+        """
+        assert input_ids.shape[0] == 1
+        assert past_key_values is not None
+
+        code_emb = [self.emb_code[i](input_ids[:, :, i]) for i in range(self.num_vq)]
+        inputs_embeds = torch.stack(code_emb, 3).sum(3)  # [1,seq_len,768]
+        input_len = input_ids.shape[1]
+
+        if add_audio_bos:
+            narrowed_input_ids = torch.tensor([[self.audio_bos_token_id]], dtype=torch.long, device=self.device)
+            bos_inputs_embeds = self.emb_text(narrowed_input_ids)
+            inputs_embeds = torch.cat([bos_inputs_embeds, inputs_embeds], dim=1)
+            input_len += 1
+
+        past_key_values_length = past_key_values[0][0].shape[2]
+        position_ids = torch.arange(
+            past_key_values_length, past_key_values_length + input_len, dtype=torch.long, device=self.device
+        ).unsqueeze(0)
+
+        cache_position = position_ids.clone()
+        causal_mask = make_streaming_chunk_mask_generation(
+            inputs_embeds=inputs_embeds,
+            past_seen_tokens=past_key_values[0][0].shape[2],
+            streaming_tts_text_mask=streaming_tts_text_mask,
+            streaming_reserved_length=self.streaming_text_reserved_len,
+            streaming_text_chunk_size=self.streaming_text_chunk_size,
+        )  # [1, 1, 1, past_key_values_length + input_len]
+
+        # Model forward
+        outputs: BaseModelOutputWithPast = self.model(
+            attention_mask=causal_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            use_cache=True,
+            output_attentions=False,
+            cache_position=cache_position,
+        )
+        past_key_values = outputs.past_key_values
+        return past_key_values
+
+    @torch.inference_mode()
+    def generate(
+        self,
+        input_ids: torch.Tensor,
+        past_key_values: List[Tuple[torch.Tensor, torch.Tensor]],
+        temperature: torch.Tensor,
+        eos_token: Union[int, torch.Tensor],
+        streaming_tts_text_mask=None,
+        force_no_stop=False,
+        min_new_token=10,
+        max_new_token=50,
+        logits_warpers: List[LogitsWarper] = [],
+        logits_processors: List[CustomRepetitionPenaltyLogitsProcessorRepeat] = [],
+        show_tqdm=False,
+    ):
+        """Generate audio codes in streaming setting or non-streaming setting.
+        Specifically speaking, generate audio codes when not all text tokens are prefilled.
+
+        Always pass a valid `past_key_values` to the method. The method does not do `prefill` by itself. It relies on `prefill_text` method to provide valid `past_key_values`. Please refer to docstring of this class for more details.
+
+        In this method, we borrowed a lot of codes from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/gpt.py`.
+
+        Args:
+            input_ids (torch.Tensor): Input token ids.
+            past_key_values (List[Tuple[torch.Tensor, torch.Tensor]]): Past key values for attention mechanism.
+            temperature (torch.Tensor): Temperature for sampling.
+            eos_token (Union[int, torch.Tensor]): End of sequence token.
+            streaming_tts_text_mask (Optional[torch.Tensor], optional): Mask for streaming TTS text. Defaults to None.
+            max_new_token (int, optional): Maximum number of new tokens to generate. Defaults to 50.
+            logits_warpers (List[LogitsWarper], optional): List of logits warpers. Defaults to [].
+            logits_processors (List[CustomRepetitionPenaltyLogitsProcessorRepeat], optional): List of logits processors. Defaults to [].
+            show_tqdm (bool, optional): Whether to show progress bar. Defaults to True.
+
+        Returns:
+            GenerationOutputs: Generation outputs.
+        """
+
+        # We only support batch size `1` for now
+        assert input_ids.shape[0] == 1
+        assert past_key_values is not None
+
+        # fix: this should not be `input_ids.shape[1]`
+        # start_idx = input_ids.shape[1]
+        start_idx = 1 + self.num_spk_embs * self.use_speaker_embedding + self.streaming_text_reserved_len + 1
+
+        finish = torch.zeros(input_ids.shape[0], device=input_ids.device).bool()
+
+        temperature = temperature.unsqueeze(0).expand(input_ids.shape[0], -1).contiguous().view(-1, 1)
+
+        progress = input_ids.shape[1]
+
+        # Pre-allocate input_ids, shape is [batch_size=1, max_possible_seq_len, self.num_vqs]
+        input_ids_buf = torch.zeros(
+            input_ids.shape[0],  # batch_size
+            progress + max_new_token,  # max_possible_seq_len = input_ids.shape[1] + max_new_token
+            input_ids.shape[2],  # self.num_vqs
+            dtype=input_ids.dtype,
+            device=input_ids.device,
+        )
+
+        # Copy existing `input_ids` to `input_ids_buf`
+        input_ids_buf.narrow(1, 0, progress).copy_(input_ids)
+
+        del input_ids
+        input_ids = input_ids_buf.narrow(1, 0, progress)
+
+        pbar: Optional[tqdm] = None
+        if show_tqdm:
+            pbar = tqdm(
+                total=max_new_token,
+                desc="code",
+                bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]",
+            )
+
+        condition_length = 1 + self.num_spk_embs * self.use_speaker_embedding + self.streaming_text_reserved_len + 1
+
+        for i in range(max_new_token):
+            # Prepare generation inputs
+            audio_bos = False
+
+            # If this is the first audio token, the case is SPECIAL
+            if progress == condition_length:
+                audio_bos = True
+
+            assert progress == (
+                past_key_values[0][0].shape[2] + 1
+            )  # If you are using according to the guidelines, this should be passed.
+
+            if audio_bos:
+                # Generate the first token, activate the model with `self.audio_bos_token_id`, the model will predict a new audio token. This is a special case because without the `audio bos token`, it is impossible to generate the first audio token in our streaming setting.
+                narrowed_input_ids = torch.tensor([[self.audio_bos_token_id]], dtype=torch.long, device=self.device)
+                inputs_embeds = self.emb_text(narrowed_input_ids)
+                del narrowed_input_ids
+            else:
+                # Generate the following audio tokens, it is applicable to all other cases, including second and the following calling of `generate`.
+                narrowed_input_ids = input_ids.narrow(dim=1, start=input_ids.shape[1] - 1, length=1)
+                code_emb = [self.emb_code[i](narrowed_input_ids[:, :, i]) for i in range(self.num_vq)]
+                inputs_embeds = torch.stack(code_emb, 3).sum(3)
+
+            position_ids = torch.tensor(
+                [past_key_values[0][0].shape[2] + 1], dtype=torch.long, device=self.device
+            ).unsqueeze(0)
+
+            cache_position = position_ids.clone()
+
+            # Make causal mask
+            causal_mask = make_streaming_chunk_mask_generation(
+                inputs_embeds=inputs_embeds,
+                past_seen_tokens=past_key_values[0][0].shape[2],
+                streaming_tts_text_mask=streaming_tts_text_mask,
+                streaming_reserved_length=self.streaming_text_reserved_len,
+                streaming_text_chunk_size=self.streaming_text_chunk_size,
+            )
+
+            # Model forward
+            outputs: BaseModelOutputWithPast = self.model(
+                attention_mask=causal_mask,
+                position_ids=position_ids,
+                past_key_values=past_key_values,
+                inputs_embeds=inputs_embeds,
+                use_cache=True,
+                output_attentions=False,
+                cache_position=cache_position,
+            )
+
+            del position_ids
+            del inputs_embeds
+            del cache_position
+            del causal_mask
+
+            hidden_states = outputs.last_hidden_state
+            past_key_values = outputs.past_key_values
+
+            with P.cached():
+                logits = torch.empty(
+                    hidden_states.size(0),
+                    hidden_states.size(1),
+                    self.num_audio_tokens,
+                    self.num_vq,
+                    dtype=torch.float,
+                    device=self.device,
+                )
+                for num_vq_iter in range(self.num_vq):
+                    x: torch.Tensor = self.head_code[num_vq_iter](hidden_states)
+                    logits[..., num_vq_iter] = x
+                    del x
+
+            del hidden_states
+
+            # logits = logits[:, -1].float()
+            logits = logits.narrow(1, -1, 1).squeeze_(1).float()
+
+            # logits = rearrange(logits, "b c n -> (b n) c")
+            logits = logits.permute(0, 2, 1)
+            logits = logits.reshape(-1, logits.size(2))
+            # logits_token = rearrange(input_ids[:, start_idx:], "b c n -> (b n) c")
+            input_ids_sliced = input_ids.narrow(
+                1,
+                start_idx,
+                input_ids.size(1) - start_idx,
+            ).permute(0, 2, 1)
+            logits_token = input_ids_sliced.reshape(
+                input_ids_sliced.size(0) * input_ids_sliced.size(1),
+                -1,
+            ).to(self.device)
+            del input_ids_sliced
+
+            logits /= temperature
+
+            if not audio_bos:
+                for logitsProcessors in logits_processors:
+                    logits = logitsProcessors(logits_token, logits)
+            if not audio_bos:
+                for logitsWarpers in logits_warpers:
+                    logits = logitsWarpers(logits_token, logits)
+
+            del logits_token
+
+            if i < min_new_token:
+                logits[:, eos_token] = -torch.inf
+
+            if force_no_stop:
+                logits[:, eos_token] = -torch.inf
+
+            scores = F.softmax(logits, dim=-1)
+
+            del logits
+            idx_next = torch.multinomial(scores, num_samples=1)  # .to(finish.device)
+
+            del scores
+
+            # idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
+            idx_next = idx_next.view(-1, self.num_vq)
+            finish_or = idx_next.eq(eos_token).any(1)
+            finish.logical_or_(finish_or)
+
+            del finish_or
+            # Store new `token` into `input_ids_buf`
+            input_ids_buf.narrow(1, progress, 1).copy_(idx_next.unsqueeze_(1))
+
+            if i == 0 and finish.any():
+                # raise Exception
+                break
+
+            del idx_next
+            progress += 1
+            input_ids = input_ids_buf.narrow(1, 0, progress)
+
+            if finish.all():
+                break
+
+            if pbar is not None:
+                pbar.update(1)
+
+        if pbar is not None:
+            pbar.close()
+
+        if not finish.all():
+            if show_tqdm:
+                logger.info(f"incomplete result. hit max_new_token: {max_new_token}")
+
+        del input_ids_buf
+
+        if finish.all():
+            # the last may contains eos token
+            genrated_input_ids = input_ids[:, condition_length:-1, :]
+        else:
+            # there is no eos token
+            genrated_input_ids = input_ids[:, condition_length:, :]
+
+        return ConditionalChatTTSGenerationOutput(
+            new_ids=genrated_input_ids,
+            audio_input_ids=input_ids,  # for update purpose
+            past_key_values=past_key_values,  # for update purpose
+            finished=finish.all(),
+        )
+
+    @torch.inference_mode()
+    def decode_to_mel_specs(
+        self,
+        result_list: List[torch.Tensor],
+    ):
+        """Decode discrete audio codes to mel spectrograms.
+
+        Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/core.py`
+
+        Args:
+            result_list (List[torch.Tensor]): Audio codes output from `generate`.
+
+        Returns:
+            torch.Tensor: Mel spectrograms.
+        """
+
+        decoder = self.dvae
+        max_x_len = -1
+        if len(result_list) == 0:
+            return np.array([], dtype=np.float32)
+        for result in result_list:
+            if result.size(0) > max_x_len:
+                max_x_len = result.size(0)
+        batch_result = torch.zeros(
+            (len(result_list), result_list[0].size(1), max_x_len),
+            dtype=result_list[0].dtype,
+            device=result_list[0].device,
+        )
+        for i in range(len(result_list)):
+            src = result_list[i]
+            batch_result[i].narrow(1, 0, src.size(0)).copy_(src.permute(1, 0))
+            del src
+
+        mel_specs = decoder(batch_result)
+        del batch_result
+        return mel_specs
+
+
+# Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/processors.py`
+def gen_logits(
+    num_code: int,
+    top_P=0.7,
+    top_K=20,
+    repetition_penalty=1.0,
+):
+    logits_warpers = []
+    if top_P is not None:
+        logits_warpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
+    if top_K is not None:
+        logits_warpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))
+
+    logits_processors = []
+    if repetition_penalty is not None and repetition_penalty != 1:
+        logits_processors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(repetition_penalty, num_code, 16))
+
+    return logits_warpers, logits_processors
+
+
+# Copy and modified from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
+def prepare_inputs_for_generation(
+    self,
+    input_ids,
+    past_key_values=None,
+    attention_mask=None,
+    inputs_embeds=None,
+    cache_position=None,
+    position_ids=None,
+    use_cache=True,
+    **kwargs,
+):
+    if past_key_values is not None:
+        if isinstance(past_key_values, Cache):
+            cache_length = past_key_values.get_seq_length()
+            past_length = past_key_values.seen_tokens
+        else:
+            cache_length = past_length = past_key_values[0][0].shape[2]
+
+        # Keep only the unprocessed tokens:
+        # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
+        # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
+        # input)
+        if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
+            input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
+        # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
+        # input_ids based on the past_length.
+        elif past_length < input_ids.shape[1]:
+            input_ids = input_ids[:, past_length:]
+        # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
+
+    if attention_mask is not None and position_ids is None:
+        # create position_ids on the fly for batch generation
+        position_ids = attention_mask.long().cumsum(-1) - 1
+        position_ids.masked_fill_(attention_mask == 0, 1)
+        if past_key_values:
+            position_ids = position_ids[:, -input_ids.shape[1] :]
+
+            # This clo≠clo≠clone call is needed to avoid recapturing cuda graphs with →rch.comπ≤→rch.comπ≤torch.compile's  mode=reduce−overheadmode=reduce-overheadmode="reduce-overhead, as otherwise the input positionidspositionidsposition_ids would have various stride during the decoding. Here, simply using .contiguous().contiguous().contiguous() is not sufficient as in the batch size = 1 case, positionidspositionidsposition_ids is already contiguous but with varying stride which retriggers a capture.
+            position_ids = position_ids.clone(memory_format=torch.contiguous_format)
+
+    # if ∈putsembeds∈putsembedsinputs_embeds are passed, we only want to use them in the 1st generation step
+    if inputs_embeds is not None and cache_position[0] == 0:
+        model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
+    else:
+        # The clone here is for the same reason as for positionidspositionidsposition_ids.
+        model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
+
+    if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
+        if model_inputs["inputs_embeds"] is not None:
+            batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
+            device = model_inputs["inputs_embeds"].device
+        else:
+            batch_size, sequence_length = model_inputs["input_ids"].shape
+            device = model_inputs["input_ids"].device
+
+        dtype = self.lm_head.weight.dtype
+        min_dtype = torch.finfo(dtype).min
+
+        attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
+            attention_mask,
+            sequence_length=sequence_length,
+            target_length=past_key_values.get_max_length(),
+            dtype=dtype,
+            device=device,
+            min_dtype=min_dtype,
+            cache_position=cache_position,
+            batch_size=batch_size,
+        )
+
+    model_inputs.update(
+        {
+            "position_ids": position_ids,
+            # "cache_position": cache_position,
+            "past_key_values": past_key_values,
+            "use_cache": use_cache,
+            "attention_mask": attention_mask,
+        }
+    )
+    return model_inputs