Spaces:
Sleeping
Sleeping
import matplotlib.pyplot as plt | |
import librosa | |
import librosa.display | |
import numpy as np | |
import os,sys | |
import ruptures as rpt | |
from glob import glob | |
from tqdm import tqdm | |
import soundfile | |
import pandas as pd | |
import csv | |
import gradio as gr | |
def fig_ax(figsize=(15, 5), dpi=150): | |
"""Return a (matplotlib) figure and ax objects with given size.""" | |
return plt.subplots(figsize=figsize, dpi=dpi) | |
def get_sum_of_cost(algo, n_bkps) -> float: | |
"""Return the sum of costs for the change points `bkps`""" | |
bkps = algo.predict(n_bkps=n_bkps) | |
return algo.cost.sum_of_costs(bkps) | |
def variable_outputs(k): | |
k = int(k) | |
return [gr.Audio(visible=True)]*k + [gr.Audio(visible=False)]*(10-k) | |
def generate(wavfile,target_sampling_rate,hop_length_tempo,n_bkps_max): | |
if target_sampling_rate is not None: | |
signal2, sampling_rate = librosa.load(wavfile,sr=target_sampling_rate,mono=False) | |
else: | |
signal2, sampling_rate = librosa.load(wavfile,mono=False) | |
signal = signal2.sum(axis=0) / 2 | |
# Compute the onset strength | |
hop_length_tempo = 512 | |
oenv = librosa.onset.onset_strength( | |
y=signal, sr=sampling_rate, hop_length=hop_length_tempo | |
) | |
# Compute the tempogram | |
tempogram = librosa.feature.tempogram( | |
onset_envelope=oenv, | |
sr=sampling_rate, | |
hop_length=hop_length_tempo, | |
) | |
algo = rpt.KernelCPD(kernel="linear").fit(tempogram.T) | |
# Choose the number of changes (elbow heuristic) | |
n_bkps_max = 10 # K_max | |
# Start by computing the segmentation with most changes. | |
# After start, all segmentations with 1, 2,..., K_max-1 changes are also available for free. | |
_ = algo.predict(n_bkps_max) | |
array_of_n_bkps = np.arange(1, n_bkps_max + 1) | |
ex = [get_sum_of_cost(algo=algo, n_bkps=n_bkps) for n_bkps in array_of_n_bkps] | |
# print(ex[0]) | |
biggiest=0 | |
for i in range(1,len(ex)): | |
if abs(ex[i]- ex[i-1])>biggiest: | |
biggiest=abs(ex[i]- ex[i-1]) | |
n_bkps=i+2 | |
bkps = algo.predict(n_bkps=n_bkps) | |
# Convert the estimated change points (frame counts) to actual timestamps | |
bkps_times = librosa.frames_to_time(bkps, sr=sampling_rate, hop_length=hop_length_tempo) | |
# Compute change points corresponding indexes in original signal | |
bkps_time_indexes = (sampling_rate * bkps_times).astype(int).tolist() | |
bkps = [i//sampling_rate for i in bkps_time_indexes] | |
# print(bkps_time_indexes) | |
new_bkps_time_indexes =[] | |
if len(bkps_time_indexes)>2: | |
for i in range(len(bkps_time_indexes)): | |
if i==0: | |
if bkps_time_indexes[i]>=10*sampling_rate: | |
new_bkps_time_indexes.append(bkps_time_indexes[i]) | |
elif i==len(bkps_time_indexes)-1: | |
if bkps_time_indexes[i]-bkps_time_indexes[i-1]<5*sampling_rate: | |
new_bkps_time_indexes.remove(new_bkps_time_indexes[-1]) | |
new_bkps_time_indexes.append(bkps_time_indexes[i]) | |
else: | |
if bkps_time_indexes[i]-bkps_time_indexes[i-1]>=10*sampling_rate: | |
new_bkps_time_indexes.append(bkps_time_indexes[i]) | |
bkps_time_indexes = new_bkps_time_indexes | |
fig, ax = fig_ax() | |
_ = librosa.display.specshow( | |
tempogram, | |
ax=ax, | |
x_axis="s", | |
y_axis="tempo", | |
hop_length=hop_length_tempo, | |
sr=sampling_rate, | |
) | |
new_bkps_times = [ x/sampling_rate for x in bkps_time_indexes] | |
for b in new_bkps_times: | |
ax.axvline(b, ls="--", color="white", lw=4) | |
seg_list = [] | |
for segment_number, (start, end) in enumerate( | |
rpt.utils.pairwise([0] + bkps_time_indexes), start=1 | |
): | |
save_name= f"output_{segment_number}.mp3" | |
segment = signal2[:,start:end] | |
seg_list.append(save_name) | |
soundfile.write(save_name, | |
segment.T, | |
int(sampling_rate), | |
format='MP3' | |
) | |
seg_len = len(seg_list) | |
for i in range(10-seg_len): | |
seg_list.append("None") | |
return fig,seg_len,*seg_list | |
def list_map(lists): | |
print(len(lists), len(RESULTS)) | |
for i in range(len(lists)): | |
RESULTS[i]= str(lists[i]) | |
return RESULTS | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
''' | |
# Demo of Music Segmentation(Intro, Verse, Outro..) using Change Detection Algoritm | |
''' | |
) | |
result_list = gr.State() | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
wavfile = gr.Audio(sources="upload", type="filepath") | |
btn_submit = gr.Button() | |
result_image = gr.Plot(label="result") | |
with gr.Accordion(label="Settings", open=False): | |
target_sampling_rate = gr.Number(label="target_sampling_rate", value=44100, interactive=True) | |
hop_length_tempo = gr.Number(label="hop_length_tempo", value=512, interactive=True) | |
n_bkps_max = gr.Number(label="n_bkps_max", value=10, interactive=True) | |
result_len = gr.Number(label="result_len",value=10,interactive=False) | |
RESULTS = [] | |
with gr.Column(): | |
for i in range(1,11): | |
w = gr.Audio(label=f"result part {i}",visible=False,type="filepath") | |
RESULTS.append(w) | |
result_len.change(variable_outputs,result_len,RESULTS) | |
# result_len.change(list_map,result_list,RESULTS) | |
btn_submit.click( | |
fn=generate, | |
inputs=[ | |
wavfile,target_sampling_rate,hop_length_tempo,n_bkps_max | |
], | |
outputs=[ | |
result_image,result_len,*RESULTS | |
], | |
) | |
demo.queue().launch(server_name="0.0.0.0") | |