File size: 5,587 Bytes
c7262b6
5295899
 
5174e76
5295899
 
 
c7262b6
 
 
 
5295899
c7262b6
5295899
 
c7262b6
 
 
5295899
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cb4787
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5295899
 
 
1cb4787
 
 
5295899
1cb4787
 
5295899
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cb4787
5295899
 
 
 
1cb4787
 
 
5295899
1cb4787
 
5295899
1cb4787
 
 
5295899
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
---
language:
- en
license: cc-by-4.0
tags:
- music
- art
---
# Model Card for Model ID
## Model Details
### Model Description
The model consists of a music encoder ```MERT-v1-300M```, a natural language decoder ```vicuna-7b-delta-v0```, and a linear projection laer between the two.

This checkpoint of MusiLingo is developed on the MusicInstruct (MI)-long and can answer long instructions with music raw audio, such as querying about the subjective feelings etc.
You can use the [MI](https://huggingface.co/datasets/m-a-p/Music-Instruct) dataset for the following demo 


### Model Sources [optional]
- **Repository:** [GitHub repo](https://github.com/zihaod/MusiLingo)
- **Paper [optional]:** __[MusiLingo: Bridging Music and Text with Pre-trained Language Models for Music Captioning and Query Response](https://arxiv.org/abs/2309.08730)__
<!-- - **Demo [optional]:** [More Information Needed] -->



## Getting Start
```
from tqdm.auto import tqdm

import torch
from torch.utils.data import DataLoader
from transformers import Wav2Vec2FeatureExtractor
from transformers import StoppingCriteria, StoppingCriteriaList



class StoppingCriteriaSub(StoppingCriteria):
    def __init__(self, stops=[], encounters=1):
        super().__init__()
        self.stops = stops
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        for stop in self.stops:
            if torch.all((stop == input_ids[0][-len(stop):])).item():
                return True
        return False


class StoppingCriteriaSub(StoppingCriteria):
    def __init__(self, stops=[], encounters=1):
        super().__init__()
        self.stops = stops
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        for stop in self.stops:
            if torch.all((stop == input_ids[0][-len(stop):])).item():
                return True
        return False

def get_musilingo_pred(model, text, audio_path, stopping, length_penalty=1, temperature=0.1,
    max_new_tokens=300, num_beams=1, min_length=1, top_p=0.5, repetition_penalty=1.0):

    # see https://huggingface.co/m-a-p/MusiLingo-musicqa-v1 for load_audio function definition
    audio = load_audio(audio_path, target_sr=24000,
                        is_mono=True,
                        is_normalize=False,
                        crop_to_length_in_sample_points=int(30*16000)+1,
                        crop_randomly=True, 
                        pad=False).cuda()    
    processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-330M",trust_remote_code=True) 
    audio = processor(audio, 
                    sampling_rate=24000, 
                    return_tensors="pt")['input_values'][0].cuda() 
        
    audio_embeds, atts_audio = model.encode_audio(audio)
        
    prompt = '<Audio><AudioHere></Audio> ' + text
    instruction_prompt = [model.prompt_template.format(prompt)]
    audio_embeds, atts_audio = model.instruction_prompt_wrap(audio_embeds, atts_audio, instruction_prompt)
    
    model.llama_tokenizer.padding_side = "right"
    batch_size = audio_embeds.shape[0]
    bos = torch.ones([batch_size, 1],
                    dtype=torch.long,
                    device=torch.device('cuda')) * model.llama_tokenizer.bos_token_id
    bos_embeds = model.llama_model.model.embed_tokens(bos)
    # atts_bos = atts_audio[:, :1]
    inputs_embeds = torch.cat([bos_embeds, audio_embeds], dim=1)
    # attention_mask = torch.cat([atts_bos, atts_audio], dim=1)
    outputs = model.llama_model.generate(
        inputs_embeds=inputs_embeds,
        max_new_tokens=max_new_tokens,
        stopping_criteria=stopping,
        num_beams=num_beams,
        do_sample=True,
        min_length=min_length,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        length_penalty=length_penalty,
        temperature=temperature,
    )
    output_token = outputs[0]
    if output_token[0] == 0:  # the model might output a unknow token <unk> at the beginning. remove it
        output_token = output_token[1:]
    if output_token[0] == 1:  # if there is a start token <s> at the beginning. remove it
        output_token = output_token[1:]
    output_text = model.llama_tokenizer.decode(output_token, add_special_tokens=False)
    output_text = output_text.split('###')[0]  # remove the stop sign '###'
    output_text = output_text.split('Assistant:')[-1].strip()
    return output_text

musilingo = AutoModel.from_pretrained("m-a-p/MusiLingo-long-v1", trust_remote_code=True)
musilingo.to("cuda")
musilingo.eval()

prompt = "this is the task instruction and input question for MusiLingo model"
audio = "/path/to/the/audio"
stopping = StoppingCriteriaList([StoppingCriteriaSub([torch.tensor([835]).cuda(),
                                  torch.tensor([2277, 29937]).cuda()])])
response = get_musilingo_pred(musilingo.model, prompt, audio_path, stopping, length_penalty=100, temperature=0.1)
        
```

# Citing This Work

If you find the work useful for your research, please consider citing it using the following BibTeX entry:
```
@inproceedings{deng2024musilingo,
  title={MusiLingo: Bridging Music and Text with Pre-trained Language Models for Music Captioning and Query Response},
  author={Deng, Zihao and Ma, Yinghao and Liu, Yudong and Guo, Rongchen and Zhang, Ge and Chen, Wenhu and Huang, Wenhao and Benetos, Emmanouil},
  booktitle={Proceedings of the 2024 Annual Conference of the North American Chapter of the Association for Computational Linguistics (NAACL 2024)},
  year={2024},
  organization={Association for Computational Linguistics}
}
```