File size: 5,610 Bytes
9178cf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import matplotlib.pyplot as plt
import librosa
import librosa.display
import numpy as np
import os,sys
import ruptures as rpt  
from glob import glob
import soundfile
import csv
import gradio as gr

def fig_ax(figsize=(15, 5), dpi=150):
    """Return a (matplotlib) figure and ax objects with given size."""
    return plt.subplots(figsize=figsize, dpi=dpi)

def get_sum_of_cost(algo, n_bkps) -> float:
        """Return the sum of costs for the change points `bkps`"""
        bkps = algo.predict(n_bkps=n_bkps)
        return algo.cost.sum_of_costs(bkps)
def variable_outputs(k):
    k = int(k)
    return [gr.Audio(visible=True)]*k + [gr.Audio(visible=False)]*(10-k)
def generate(wavfile,target_sampling_rate,hop_length_tempo,n_bkps_max):
    
    if target_sampling_rate is not None:
        signal2, sampling_rate = librosa.load(wavfile,sr=target_sampling_rate,mono=False)
    else:
        signal2, sampling_rate = librosa.load(wavfile,mono=False)
    signal = signal2.sum(axis=0) / 2
    # Compute the onset strength
    hop_length_tempo = 512
    oenv = librosa.onset.onset_strength(
        y=signal, sr=sampling_rate, hop_length=hop_length_tempo
    )
    # Compute the tempogram
    tempogram = librosa.feature.tempogram(
        onset_envelope=oenv,
        sr=sampling_rate,
        hop_length=hop_length_tempo,
    )
    algo = rpt.KernelCPD(kernel="linear").fit(tempogram.T)

    # Choose the number of changes (elbow heuristic)
    n_bkps_max = 10  # K_max
    # Start by computing the segmentation with most changes.
    # After start, all segmentations with 1, 2,..., K_max-1 changes are also available for free.
    _ = algo.predict(n_bkps_max)
    array_of_n_bkps = np.arange(1, n_bkps_max + 1)
    ex = [get_sum_of_cost(algo=algo, n_bkps=n_bkps) for n_bkps in array_of_n_bkps]
    # print(ex[0])
    biggiest=0
    for i in range(1,len(ex)):
        if abs(ex[i]- ex[i-1])>biggiest:
            biggiest=abs(ex[i]- ex[i-1])
            n_bkps=i+2

    bkps = algo.predict(n_bkps=n_bkps)
    # Convert the estimated change points (frame counts) to actual timestamps
    bkps_times = librosa.frames_to_time(bkps, sr=sampling_rate, hop_length=hop_length_tempo)

    # Compute change points corresponding indexes in original signal
    bkps_time_indexes = (sampling_rate * bkps_times).astype(int).tolist()
    bkps = [i//sampling_rate for i in bkps_time_indexes]
    # print(bkps_time_indexes)
    new_bkps_time_indexes =[]
    if len(bkps_time_indexes)>2:
        for i in range(len(bkps_time_indexes)):
            if i==0:
                if bkps_time_indexes[i]>=10*sampling_rate:
                    new_bkps_time_indexes.append(bkps_time_indexes[i])
            elif i==len(bkps_time_indexes)-1:
                if bkps_time_indexes[i]-bkps_time_indexes[i-1]<5*sampling_rate:
                    new_bkps_time_indexes.remove(new_bkps_time_indexes[-1])
                new_bkps_time_indexes.append(bkps_time_indexes[i])
            else:
                if bkps_time_indexes[i]-bkps_time_indexes[i-1]>=10*sampling_rate:
                    new_bkps_time_indexes.append(bkps_time_indexes[i])
    bkps_time_indexes = new_bkps_time_indexes
    fig, ax = fig_ax()
    _ = librosa.display.specshow(
        tempogram,
        ax=ax,
        x_axis="s",
        y_axis="tempo",
        hop_length=hop_length_tempo,
        sr=sampling_rate,
    )
    new_bkps_times = [ x/sampling_rate for x in bkps_time_indexes]
    for b in new_bkps_times:
        ax.axvline(b, ls="--", color="white", lw=4)
    seg_list = []
    for segment_number, (start, end) in enumerate(
        rpt.utils.pairwise([0] + bkps_time_indexes), start=1
    ):  
        save_name= f"output_{segment_number}.mp3"
        segment = signal2[:,start:end]
        seg_list.append(save_name)
        soundfile.write(save_name, 
            segment.T, 
            int(sampling_rate),
            format='MP3'
            )
    seg_len = len(seg_list)
    for i in range(10-seg_len):
        seg_list.append("None")
    return fig,seg_len,*seg_list
def list_map(lists):
    print(len(lists), len(RESULTS))
    for i in range(len(lists)):
        RESULTS[i]= str(lists[i])
    return RESULTS
with gr.Blocks() as demo:
    gr.Markdown(
        '''
        # Demo of Music Segmentation(Intro, Verse, Outro..) using Change Detection Algoritm 
        '''
    )
    result_list = gr.State()
    with gr.Column():
        with gr.Row():
            with gr.Column():
                wavfile = gr.Audio(sources="upload", type="filepath")
                btn_submit = gr.Button()
        result_image = gr.Plot(label="result")
        with gr.Accordion(label="Settings", open=False):
            target_sampling_rate = gr.Number(label="target_sampling_rate", value=44100, interactive=True)
            hop_length_tempo = gr.Number(label="hop_length_tempo", value=512, interactive=True)
            n_bkps_max = gr.Number(label="n_bkps_max", value=10, interactive=True)
            result_len = gr.Number(label="result_len",value=10,interactive=False)
        RESULTS = []
        with gr.Column():
            for i in range(1,11):
                w = gr.Audio(label=f"result part {i}",visible=False,type="filepath")
                RESULTS.append(w)
        result_len.change(variable_outputs,result_len,RESULTS)
        # result_len.change(list_map,result_list,RESULTS)
    btn_submit.click(
        fn=generate,
        inputs=[
            wavfile,target_sampling_rate,hop_length_tempo,n_bkps_max
        ],
        outputs=[
            result_image,result_len,*RESULTS
        ],
    )

demo.queue().launch(server_name="0.0.0.0")