jeduardogruiz
commited on
Create encoded.py
Browse files- encoded.py +105 -0
encoded.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# [email protected]:facebookresearch/encodec.git
|
2 |
+
|
3 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This source code is licensed under the license found in the
|
7 |
+
# LICENSE file in the root directory of this source tree.
|
8 |
+
|
9 |
+
"""Various utilities."""
|
10 |
+
|
11 |
+
from hashlib import sha256
|
12 |
+
from pathlib import Path
|
13 |
+
import typing as tp
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torchaudio
|
17 |
+
|
18 |
+
|
19 |
+
def _linear_overlap_add(frames: tp.List[torch.Tensor], stride: int):
|
20 |
+
# Generic overlap add, with linear fade-in/fade-out, supporting complex scenario
|
21 |
+
# e.g., more than 2 frames per position.
|
22 |
+
# The core idea is to use a weight function that is a triangle,
|
23 |
+
# with a maximum value at the middle of the segment.
|
24 |
+
# We use this weighting when summing the frames, and divide by the sum of weights
|
25 |
+
# for each positions at the end. Thus:
|
26 |
+
# - if a frame is the only one to cover a position, the weighting is a no-op.
|
27 |
+
# - if 2 frames cover a position:
|
28 |
+
# ... ...
|
29 |
+
# / \/ \
|
30 |
+
# / /\ \
|
31 |
+
# S T , i.e. S offset of second frame starts, T end of first frame.
|
32 |
+
# Then the weight function for each one is: (t - S), (T - t), with `t` a given offset.
|
33 |
+
# After the final normalization, the weight of the second frame at position `t` is
|
34 |
+
# (t - S) / (t - S + (T - t)) = (t - S) / (T - S), which is exactly what we want.
|
35 |
+
#
|
36 |
+
# - if more than 2 frames overlap at a given point, we hope that by induction
|
37 |
+
# something sensible happens.
|
38 |
+
assert len(frames)
|
39 |
+
device = frames[0].device
|
40 |
+
dtype = frames[0].dtype
|
41 |
+
shape = frames[0].shape[:-1]
|
42 |
+
total_size = stride * (len(frames) - 1) + frames[-1].shape[-1]
|
43 |
+
|
44 |
+
frame_length = frames[0].shape[-1]
|
45 |
+
t = torch.linspace(0, 1, frame_length + 2, device=device, dtype=dtype)[1: -1]
|
46 |
+
weight = 0.5 - (t - 0.5).abs()
|
47 |
+
|
48 |
+
sum_weight = torch.zeros(total_size, device=device, dtype=dtype)
|
49 |
+
out = torch.zeros(*shape, total_size, device=device, dtype=dtype)
|
50 |
+
offset: int = 0
|
51 |
+
|
52 |
+
for frame in frames:
|
53 |
+
frame_length = frame.shape[-1]
|
54 |
+
out[..., offset:offset + frame_length] += weight[:frame_length] * frame
|
55 |
+
sum_weight[offset:offset + frame_length] += weight[:frame_length]
|
56 |
+
offset += stride
|
57 |
+
assert sum_weight.min() > 0
|
58 |
+
return out / sum_weight
|
59 |
+
|
60 |
+
|
61 |
+
def _get_checkpoint_url(root_url: str, checkpoint: str):
|
62 |
+
if not root_url.endswith('/'):
|
63 |
+
root_url += '/'
|
64 |
+
return root_url + checkpoint
|
65 |
+
|
66 |
+
|
67 |
+
def _check_checksum(path: Path, checksum: str):
|
68 |
+
sha = sha256()
|
69 |
+
with open(path, 'rb') as file:
|
70 |
+
while True:
|
71 |
+
buf = file.read(2**20)
|
72 |
+
if not buf:
|
73 |
+
break
|
74 |
+
sha.update(buf)
|
75 |
+
actual_checksum = sha.hexdigest()[:len(checksum)]
|
76 |
+
if actual_checksum != checksum:
|
77 |
+
raise RuntimeError(f'Invalid checksum for file {path}, '
|
78 |
+
f'expected {checksum} but got {actual_checksum}')
|
79 |
+
|
80 |
+
|
81 |
+
def convert_audio(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int):
|
82 |
+
assert wav.dim() >= 2, "Audio tensor must have at least 2 dimensions"
|
83 |
+
assert wav.shape[-2] in [1, 2], "Audio must be mono or stereo."
|
84 |
+
*shape, channels, length = wav.shape
|
85 |
+
if target_channels == 1:
|
86 |
+
wav = wav.mean(-2, keepdim=True)
|
87 |
+
elif target_channels == 2:
|
88 |
+
wav = wav.expand(*shape, target_channels, length)
|
89 |
+
elif channels == 1:
|
90 |
+
wav = wav.expand(target_channels, -1)
|
91 |
+
else:
|
92 |
+
raise RuntimeError(f"Impossible to convert from {channels} to {target_channels}")
|
93 |
+
wav = torchaudio.transforms.Resample(sr, target_sr)(wav)
|
94 |
+
return wav
|
95 |
+
|
96 |
+
|
97 |
+
def save_audio(wav: torch.Tensor, path: tp.Union[Path, str],
|
98 |
+
sample_rate: int, rescale: bool = False):
|
99 |
+
limit = 0.99
|
100 |
+
mx = wav.abs().max()
|
101 |
+
if rescale:
|
102 |
+
wav = wav * min(limit / mx, 1)
|
103 |
+
else:
|
104 |
+
wav = wav.clamp(-limit, limit)
|
105 |
+
torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
|