Spaces:
Running
on
Zero
Running
on
Zero
added audio sr files, adapted them to zerogpu and optimization for memory
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +1 -1
- app.py +46 -6
- audiosr/__init__.py +2 -0
- audiosr/__main__.py +123 -0
- audiosr/__pycache__/__init__.cpython-310.pyc +0 -0
- audiosr/__pycache__/lowpass.cpython-310.pyc +0 -0
- audiosr/__pycache__/pipeline.cpython-310.pyc +0 -0
- audiosr/__pycache__/utils.cpython-310.pyc +0 -0
- audiosr/clap/__init__.py +0 -0
- audiosr/clap/__pycache__/__init__.cpython-310.pyc +0 -0
- audiosr/clap/open_clip/__init__.py +25 -0
- audiosr/clap/open_clip/__pycache__/__init__.cpython-310.pyc +0 -0
- audiosr/clap/open_clip/__pycache__/factory.cpython-310.pyc +0 -0
- audiosr/clap/open_clip/__pycache__/feature_fusion.cpython-310.pyc +0 -0
- audiosr/clap/open_clip/__pycache__/htsat.cpython-310.pyc +0 -0
- audiosr/clap/open_clip/__pycache__/loss.cpython-310.pyc +0 -0
- audiosr/clap/open_clip/__pycache__/model.cpython-310.pyc +0 -0
- audiosr/clap/open_clip/__pycache__/openai.cpython-310.pyc +0 -0
- audiosr/clap/open_clip/__pycache__/pann_model.cpython-310.pyc +0 -0
- audiosr/clap/open_clip/__pycache__/pretrained.cpython-310.pyc +0 -0
- audiosr/clap/open_clip/__pycache__/tokenizer.cpython-310.pyc +0 -0
- audiosr/clap/open_clip/__pycache__/transform.cpython-310.pyc +0 -0
- audiosr/clap/open_clip/__pycache__/utils.cpython-310.pyc +0 -0
- audiosr/clap/open_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
- audiosr/clap/open_clip/factory.py +276 -0
- audiosr/clap/open_clip/feature_fusion.py +192 -0
- audiosr/clap/open_clip/htsat.py +1304 -0
- audiosr/clap/open_clip/loss.py +397 -0
- audiosr/clap/open_clip/model.py +931 -0
- audiosr/clap/open_clip/model_configs/HTSAT-base.json +23 -0
- audiosr/clap/open_clip/model_configs/HTSAT-large.json +23 -0
- audiosr/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json +23 -0
- audiosr/clap/open_clip/model_configs/HTSAT-tiny.json +23 -0
- audiosr/clap/open_clip/model_configs/PANN-10.json +23 -0
- audiosr/clap/open_clip/model_configs/PANN-14-fmax-18k.json +23 -0
- audiosr/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json +23 -0
- audiosr/clap/open_clip/model_configs/PANN-14-tiny-transformer.json +23 -0
- audiosr/clap/open_clip/model_configs/PANN-14-win-1536.json +23 -0
- audiosr/clap/open_clip/model_configs/PANN-14.json +23 -0
- audiosr/clap/open_clip/model_configs/PANN-6.json +23 -0
- audiosr/clap/open_clip/model_configs/RN101-quickgelu.json +22 -0
- audiosr/clap/open_clip/model_configs/RN101.json +21 -0
- audiosr/clap/open_clip/model_configs/RN50-quickgelu.json +22 -0
- audiosr/clap/open_clip/model_configs/RN50.json +21 -0
- audiosr/clap/open_clip/model_configs/RN50x16.json +21 -0
- audiosr/clap/open_clip/model_configs/RN50x4.json +21 -0
- audiosr/clap/open_clip/model_configs/ViT-B-16.json +16 -0
- audiosr/clap/open_clip/model_configs/ViT-B-32-quickgelu.json +17 -0
- audiosr/clap/open_clip/model_configs/ViT-B-32.json +16 -0
- audiosr/clap/open_clip/model_configs/ViT-L-14.json +16 -0
README.md
CHANGED
@@ -6,7 +6,7 @@ colorTo: blue
|
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.31.0
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
license: mit
|
11 |
short_description: Fixed fork of the original audio sr!
|
12 |
---
|
|
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.31.0
|
8 |
app_file: app.py
|
9 |
+
pinned: true
|
10 |
license: mit
|
11 |
short_description: Fixed fork of the original audio sr!
|
12 |
---
|
app.py
CHANGED
@@ -1,10 +1,50 @@
|
|
1 |
-
import
|
2 |
-
import
|
|
|
|
|
|
|
3 |
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from audiosr import super_resolution, build_model
|
3 |
+
import torch
|
4 |
+
import gc # free up memory
|
5 |
+
import spaces
|
6 |
|
7 |
+
@spaces.GPU(duration=300)
|
8 |
+
def inference(audio_file, model_name, guidance_scale, ddim_steps, seed):
|
9 |
+
audiosr = build_model(model_name=model_name)
|
10 |
+
|
11 |
+
if torch.cuda.is_avaible():
|
12 |
+
torch.cuda.empty_cache() # empty cuda cache
|
13 |
|
14 |
+
gc.collect()
|
15 |
|
16 |
+
# set random seed when seed input value is 0
|
17 |
+
if seed == 0:
|
18 |
+
import random
|
19 |
+
seed = random.randint(1, 2**32-1)
|
20 |
+
|
21 |
+
waveform = super_resolution(
|
22 |
+
audiosr,
|
23 |
+
audio_file,
|
24 |
+
seed,
|
25 |
+
guidance_scale=guidance_scale,
|
26 |
+
ddim_steps=ddim_steps
|
27 |
+
)
|
28 |
|
29 |
+
if torch.cuda.is_avaible():
|
30 |
+
torch.cuda.empty_cache()
|
31 |
+
|
32 |
+
gc.collect()
|
33 |
+
|
34 |
+
return (48000, waveform)
|
35 |
+
|
36 |
+
iface = gr.Interface(
|
37 |
+
fn=inference,
|
38 |
+
inputs=[
|
39 |
+
gr.Audio(type="filepath", label="Input Audio"),
|
40 |
+
gr.Dropdown(["basic", "speech"], value="basic", label="Model"),
|
41 |
+
gr.Slider(1, 10, value=3.5, step=0.1, label="Guidance Scale", info="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)"),
|
42 |
+
gr.Slider(1, 100, value=50, step=1, label="DDIM Steps", info="The sampling step for DDIM"),
|
43 |
+
gr.Number(value=42, precision=0, label="Seed", info="Changing this value (any integer number) will lead to a different generation result, put 0 for a random one.")
|
44 |
+
],
|
45 |
+
outputs=gr.Audio(type="numpy", label="Output Audio"),
|
46 |
+
title="AudioSR",
|
47 |
+
description="Audio Super Resolution with AudioSR"
|
48 |
+
)
|
49 |
+
|
50 |
+
iface.launch(share=False)
|
audiosr/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .utils import seed_everything, save_wave, get_time, get_duration, read_list
|
2 |
+
from .pipeline import *
|
audiosr/__main__.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import logging
|
5 |
+
from audiosr import super_resolution, build_model, save_wave, get_time, read_list
|
6 |
+
import argparse
|
7 |
+
|
8 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
9 |
+
matplotlib_logger = logging.getLogger('matplotlib')
|
10 |
+
matplotlib_logger.setLevel(logging.WARNING)
|
11 |
+
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
|
14 |
+
parser.add_argument(
|
15 |
+
"-i",
|
16 |
+
"--input_audio_file",
|
17 |
+
type=str,
|
18 |
+
required=False,
|
19 |
+
help="Input audio file for audio super resolution",
|
20 |
+
)
|
21 |
+
|
22 |
+
parser.add_argument(
|
23 |
+
"-il",
|
24 |
+
"--input_file_list",
|
25 |
+
type=str,
|
26 |
+
required=False,
|
27 |
+
default="",
|
28 |
+
help="A file that contains all audio files that need to perform audio super resolution",
|
29 |
+
)
|
30 |
+
|
31 |
+
parser.add_argument(
|
32 |
+
"-s",
|
33 |
+
"--save_path",
|
34 |
+
type=str,
|
35 |
+
required=False,
|
36 |
+
help="The path to save model output",
|
37 |
+
default="./output",
|
38 |
+
)
|
39 |
+
|
40 |
+
parser.add_argument(
|
41 |
+
"--model_name",
|
42 |
+
type=str,
|
43 |
+
required=False,
|
44 |
+
help="The checkpoint you gonna use",
|
45 |
+
default="basic",
|
46 |
+
choices=["basic","speech"]
|
47 |
+
)
|
48 |
+
|
49 |
+
parser.add_argument(
|
50 |
+
"-d",
|
51 |
+
"--device",
|
52 |
+
type=str,
|
53 |
+
required=False,
|
54 |
+
help="The device for computation. If not specified, the script will automatically choose the device based on your environment.",
|
55 |
+
default="auto",
|
56 |
+
)
|
57 |
+
|
58 |
+
parser.add_argument(
|
59 |
+
"--ddim_steps",
|
60 |
+
type=int,
|
61 |
+
required=False,
|
62 |
+
default=50,
|
63 |
+
help="The sampling step for DDIM",
|
64 |
+
)
|
65 |
+
|
66 |
+
parser.add_argument(
|
67 |
+
"-gs",
|
68 |
+
"--guidance_scale",
|
69 |
+
type=float,
|
70 |
+
required=False,
|
71 |
+
default=3.5,
|
72 |
+
help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)",
|
73 |
+
)
|
74 |
+
|
75 |
+
parser.add_argument(
|
76 |
+
"--seed",
|
77 |
+
type=int,
|
78 |
+
required=False,
|
79 |
+
default=42,
|
80 |
+
help="Changing this value (any integer number) will lead to a different generation result.",
|
81 |
+
)
|
82 |
+
|
83 |
+
parser.add_argument(
|
84 |
+
"--suffix",
|
85 |
+
type=str,
|
86 |
+
required=False,
|
87 |
+
help="Suffix for the output file",
|
88 |
+
default="_AudioSR_Processed_48K",
|
89 |
+
)
|
90 |
+
|
91 |
+
args = parser.parse_args()
|
92 |
+
torch.set_float32_matmul_precision("high")
|
93 |
+
save_path = os.path.join(args.save_path, get_time())
|
94 |
+
|
95 |
+
assert args.input_file_list is not None or args.input_audio_file is not None,"Please provide either a list of audio files or a single audio file"
|
96 |
+
|
97 |
+
input_file = args.input_audio_file
|
98 |
+
random_seed = args.seed
|
99 |
+
sample_rate=48000
|
100 |
+
latent_t_per_second=12.8
|
101 |
+
guidance_scale = args.guidance_scale
|
102 |
+
|
103 |
+
os.makedirs(save_path, exist_ok=True)
|
104 |
+
audiosr = build_model(model_name=args.model_name, device=args.device)
|
105 |
+
|
106 |
+
if(args.input_file_list):
|
107 |
+
print("Generate audio based on the text prompts in %s" % args.input_file_list)
|
108 |
+
files_todo = read_list(args.input_file_list)
|
109 |
+
else:
|
110 |
+
files_todo = [input_file]
|
111 |
+
|
112 |
+
for input_file in files_todo:
|
113 |
+
name = os.path.splitext(os.path.basename(input_file))[0] + args.suffix
|
114 |
+
|
115 |
+
waveform = super_resolution(
|
116 |
+
audiosr,
|
117 |
+
input_file,
|
118 |
+
seed=random_seed,
|
119 |
+
guidance_scale=guidance_scale,
|
120 |
+
ddim_steps=args.ddim_steps,
|
121 |
+
latent_t_per_second=latent_t_per_second
|
122 |
+
)
|
123 |
+
save_wave(waveform, inputpath=input_file, savepath=save_path, name=name, samplerate=sample_rate)
|
audiosr/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (312 Bytes). View file
|
|
audiosr/__pycache__/lowpass.cpython-310.pyc
ADDED
Binary file (5.22 kB). View file
|
|
audiosr/__pycache__/pipeline.cpython-310.pyc
ADDED
Binary file (4.18 kB). View file
|
|
audiosr/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (13 kB). View file
|
|
audiosr/clap/__init__.py
ADDED
File without changes
|
audiosr/clap/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (165 Bytes). View file
|
|
audiosr/clap/open_clip/__init__.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .factory import (
|
2 |
+
list_models,
|
3 |
+
create_model,
|
4 |
+
create_model_and_transforms,
|
5 |
+
add_model_config,
|
6 |
+
)
|
7 |
+
from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics
|
8 |
+
from .model import (
|
9 |
+
CLAP,
|
10 |
+
CLAPTextCfg,
|
11 |
+
CLAPVisionCfg,
|
12 |
+
CLAPAudioCfp,
|
13 |
+
convert_weights_to_fp16,
|
14 |
+
trace_model,
|
15 |
+
)
|
16 |
+
from .openai import load_openai_model, list_openai_models
|
17 |
+
from .pretrained import (
|
18 |
+
list_pretrained,
|
19 |
+
list_pretrained_tag_models,
|
20 |
+
list_pretrained_model_tags,
|
21 |
+
get_pretrained_url,
|
22 |
+
download_pretrained,
|
23 |
+
)
|
24 |
+
from .tokenizer import SimpleTokenizer, tokenize
|
25 |
+
from .transform import image_transform
|
audiosr/clap/open_clip/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (971 Bytes). View file
|
|
audiosr/clap/open_clip/__pycache__/factory.cpython-310.pyc
ADDED
Binary file (6.65 kB). View file
|
|
audiosr/clap/open_clip/__pycache__/feature_fusion.cpython-310.pyc
ADDED
Binary file (4.13 kB). View file
|
|
audiosr/clap/open_clip/__pycache__/htsat.cpython-310.pyc
ADDED
Binary file (30.8 kB). View file
|
|
audiosr/clap/open_clip/__pycache__/loss.cpython-310.pyc
ADDED
Binary file (7.92 kB). View file
|
|
audiosr/clap/open_clip/__pycache__/model.cpython-310.pyc
ADDED
Binary file (23.7 kB). View file
|
|
audiosr/clap/open_clip/__pycache__/openai.cpython-310.pyc
ADDED
Binary file (4.53 kB). View file
|
|
audiosr/clap/open_clip/__pycache__/pann_model.cpython-310.pyc
ADDED
Binary file (13 kB). View file
|
|
audiosr/clap/open_clip/__pycache__/pretrained.cpython-310.pyc
ADDED
Binary file (5.04 kB). View file
|
|
audiosr/clap/open_clip/__pycache__/tokenizer.cpython-310.pyc
ADDED
Binary file (7.37 kB). View file
|
|
audiosr/clap/open_clip/__pycache__/transform.cpython-310.pyc
ADDED
Binary file (989 Bytes). View file
|
|
audiosr/clap/open_clip/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (9.73 kB). View file
|
|
audiosr/clap/open_clip/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
audiosr/clap/open_clip/factory.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
from copy import deepcopy
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from .model import CLAP, convert_weights_to_fp16
|
11 |
+
from .openai import load_openai_model
|
12 |
+
from .pretrained import get_pretrained_url, download_pretrained
|
13 |
+
from .transform import image_transform
|
14 |
+
|
15 |
+
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
|
16 |
+
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
|
17 |
+
|
18 |
+
|
19 |
+
def _natural_key(string_):
|
20 |
+
return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
|
21 |
+
|
22 |
+
|
23 |
+
def _rescan_model_configs():
|
24 |
+
global _MODEL_CONFIGS
|
25 |
+
|
26 |
+
config_ext = (".json",)
|
27 |
+
config_files = []
|
28 |
+
for config_path in _MODEL_CONFIG_PATHS:
|
29 |
+
if config_path.is_file() and config_path.suffix in config_ext:
|
30 |
+
config_files.append(config_path)
|
31 |
+
elif config_path.is_dir():
|
32 |
+
for ext in config_ext:
|
33 |
+
config_files.extend(config_path.glob(f"*{ext}"))
|
34 |
+
|
35 |
+
for cf in config_files:
|
36 |
+
if os.path.basename(cf)[0] == ".":
|
37 |
+
continue # Ignore hidden files
|
38 |
+
|
39 |
+
with open(cf, "r") as f:
|
40 |
+
model_cfg = json.load(f)
|
41 |
+
if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")):
|
42 |
+
_MODEL_CONFIGS[cf.stem] = model_cfg
|
43 |
+
|
44 |
+
_MODEL_CONFIGS = {
|
45 |
+
k: v
|
46 |
+
for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))
|
47 |
+
}
|
48 |
+
|
49 |
+
|
50 |
+
_rescan_model_configs() # initial populate of model config registry
|
51 |
+
|
52 |
+
|
53 |
+
def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True):
|
54 |
+
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
55 |
+
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
|
56 |
+
state_dict = checkpoint["state_dict"]
|
57 |
+
else:
|
58 |
+
state_dict = checkpoint
|
59 |
+
if skip_params:
|
60 |
+
if next(iter(state_dict.items()))[0].startswith("module"):
|
61 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
62 |
+
# for k in state_dict:
|
63 |
+
# if k.startswith('transformer'):
|
64 |
+
# v = state_dict.pop(k)
|
65 |
+
# state_dict['text_branch.' + k[12:]] = v
|
66 |
+
return state_dict
|
67 |
+
|
68 |
+
|
69 |
+
def create_model(
|
70 |
+
amodel_name: str,
|
71 |
+
tmodel_name: str,
|
72 |
+
pretrained: str = "",
|
73 |
+
precision: str = "fp32",
|
74 |
+
device: torch.device = torch.device("cpu"),
|
75 |
+
jit: bool = False,
|
76 |
+
force_quick_gelu: bool = False,
|
77 |
+
openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"),
|
78 |
+
skip_params=True,
|
79 |
+
pretrained_audio: str = "",
|
80 |
+
pretrained_text: str = "",
|
81 |
+
enable_fusion: bool = False,
|
82 |
+
fusion_type: str = "None"
|
83 |
+
# pretrained_image: bool = False,
|
84 |
+
):
|
85 |
+
amodel_name = amodel_name.replace(
|
86 |
+
"/", "-"
|
87 |
+
) # for callers using old naming with / in ViT names
|
88 |
+
pretrained_orig = pretrained
|
89 |
+
pretrained = pretrained.lower()
|
90 |
+
if pretrained == "openai":
|
91 |
+
if amodel_name in _MODEL_CONFIGS:
|
92 |
+
logging.info(f"Loading {amodel_name} model config.")
|
93 |
+
model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
|
94 |
+
else:
|
95 |
+
logging.error(
|
96 |
+
f"Model config for {amodel_name} not found; available models {list_models()}."
|
97 |
+
)
|
98 |
+
raise RuntimeError(f"Model config for {amodel_name} not found.")
|
99 |
+
|
100 |
+
logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.")
|
101 |
+
# Hard Code in model name
|
102 |
+
model_cfg["text_cfg"]["model_type"] = tmodel_name
|
103 |
+
model = load_openai_model(
|
104 |
+
"ViT-B-16",
|
105 |
+
model_cfg,
|
106 |
+
device=device,
|
107 |
+
jit=jit,
|
108 |
+
cache_dir=openai_model_cache_dir,
|
109 |
+
enable_fusion=enable_fusion,
|
110 |
+
fusion_type=fusion_type,
|
111 |
+
)
|
112 |
+
# See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
|
113 |
+
if precision == "amp" or precision == "fp32":
|
114 |
+
model = model.float()
|
115 |
+
else:
|
116 |
+
if amodel_name in _MODEL_CONFIGS:
|
117 |
+
logging.info(f"Loading {amodel_name} model config.")
|
118 |
+
model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
|
119 |
+
else:
|
120 |
+
logging.error(
|
121 |
+
f"Model config for {amodel_name} not found; available models {list_models()}."
|
122 |
+
)
|
123 |
+
raise RuntimeError(f"Model config for {amodel_name} not found.")
|
124 |
+
|
125 |
+
if force_quick_gelu:
|
126 |
+
# override for use of QuickGELU on non-OpenAI transformer models
|
127 |
+
model_cfg["quick_gelu"] = True
|
128 |
+
|
129 |
+
# if pretrained_image:
|
130 |
+
# if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}):
|
131 |
+
# # pretrained weight loading for timm models set via vision_cfg
|
132 |
+
# model_cfg['vision_cfg']['timm_model_pretrained'] = True
|
133 |
+
# else:
|
134 |
+
# assert False, 'pretrained image towers currently only supported for timm models'
|
135 |
+
model_cfg["text_cfg"]["model_type"] = tmodel_name
|
136 |
+
model_cfg["enable_fusion"] = enable_fusion
|
137 |
+
model_cfg["fusion_type"] = fusion_type
|
138 |
+
model = CLAP(**model_cfg)
|
139 |
+
|
140 |
+
if pretrained:
|
141 |
+
checkpoint_path = ""
|
142 |
+
url = get_pretrained_url(amodel_name, pretrained)
|
143 |
+
if url:
|
144 |
+
checkpoint_path = download_pretrained(url, root=openai_model_cache_dir)
|
145 |
+
elif os.path.exists(pretrained_orig):
|
146 |
+
checkpoint_path = pretrained_orig
|
147 |
+
if checkpoint_path:
|
148 |
+
logging.info(
|
149 |
+
f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})."
|
150 |
+
)
|
151 |
+
ckpt = load_state_dict(checkpoint_path, skip_params=True)
|
152 |
+
model.load_state_dict(ckpt)
|
153 |
+
param_names = [n for n, p in model.named_parameters()]
|
154 |
+
# for n in param_names:
|
155 |
+
# print(n, "\t", "Loaded" if n in ckpt else "Unloaded")
|
156 |
+
else:
|
157 |
+
logging.warning(
|
158 |
+
f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
|
159 |
+
)
|
160 |
+
raise RuntimeError(
|
161 |
+
f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
|
162 |
+
)
|
163 |
+
|
164 |
+
if pretrained_audio:
|
165 |
+
if amodel_name.startswith("PANN"):
|
166 |
+
if "Cnn14_mAP" in pretrained_audio: # official checkpoint
|
167 |
+
audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
|
168 |
+
audio_ckpt = audio_ckpt["model"]
|
169 |
+
keys = list(audio_ckpt.keys())
|
170 |
+
for key in keys:
|
171 |
+
if (
|
172 |
+
"spectrogram_extractor" not in key
|
173 |
+
and "logmel_extractor" not in key
|
174 |
+
):
|
175 |
+
v = audio_ckpt.pop(key)
|
176 |
+
audio_ckpt["audio_branch." + key] = v
|
177 |
+
elif os.path.basename(pretrained_audio).startswith(
|
178 |
+
"PANN"
|
179 |
+
): # checkpoint trained via HTSAT codebase
|
180 |
+
audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
|
181 |
+
audio_ckpt = audio_ckpt["state_dict"]
|
182 |
+
keys = list(audio_ckpt.keys())
|
183 |
+
for key in keys:
|
184 |
+
if key.startswith("sed_model"):
|
185 |
+
v = audio_ckpt.pop(key)
|
186 |
+
audio_ckpt["audio_branch." + key[10:]] = v
|
187 |
+
elif os.path.basename(pretrained_audio).startswith(
|
188 |
+
"finetuned"
|
189 |
+
): # checkpoint trained via linear probe codebase
|
190 |
+
audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
|
191 |
+
else:
|
192 |
+
raise ValueError("Unknown audio checkpoint")
|
193 |
+
elif amodel_name.startswith("HTSAT"):
|
194 |
+
if "HTSAT_AudioSet_Saved" in pretrained_audio: # official checkpoint
|
195 |
+
audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
|
196 |
+
audio_ckpt = audio_ckpt["state_dict"]
|
197 |
+
keys = list(audio_ckpt.keys())
|
198 |
+
for key in keys:
|
199 |
+
if key.startswith("sed_model") and (
|
200 |
+
"spectrogram_extractor" not in key
|
201 |
+
and "logmel_extractor" not in key
|
202 |
+
):
|
203 |
+
v = audio_ckpt.pop(key)
|
204 |
+
audio_ckpt["audio_branch." + key[10:]] = v
|
205 |
+
elif os.path.basename(pretrained_audio).startswith(
|
206 |
+
"HTSAT"
|
207 |
+
): # checkpoint trained via HTSAT codebase
|
208 |
+
audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
|
209 |
+
audio_ckpt = audio_ckpt["state_dict"]
|
210 |
+
keys = list(audio_ckpt.keys())
|
211 |
+
for key in keys:
|
212 |
+
if key.startswith("sed_model"):
|
213 |
+
v = audio_ckpt.pop(key)
|
214 |
+
audio_ckpt["audio_branch." + key[10:]] = v
|
215 |
+
elif os.path.basename(pretrained_audio).startswith(
|
216 |
+
"finetuned"
|
217 |
+
): # checkpoint trained via linear probe codebase
|
218 |
+
audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
|
219 |
+
else:
|
220 |
+
raise ValueError("Unknown audio checkpoint")
|
221 |
+
else:
|
222 |
+
raise f"this audio encoder pretrained checkpoint is not support"
|
223 |
+
|
224 |
+
model.load_state_dict(audio_ckpt, strict=False)
|
225 |
+
logging.info(
|
226 |
+
f"Loading pretrained {amodel_name} weights ({pretrained_audio})."
|
227 |
+
)
|
228 |
+
param_names = [n for n, p in model.named_parameters()]
|
229 |
+
for n in param_names:
|
230 |
+
print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded")
|
231 |
+
|
232 |
+
model.to(device=device)
|
233 |
+
if precision == "fp16":
|
234 |
+
assert device.type != "cpu"
|
235 |
+
convert_weights_to_fp16(model)
|
236 |
+
|
237 |
+
if jit:
|
238 |
+
model = torch.jit.script(model)
|
239 |
+
|
240 |
+
return model, model_cfg
|
241 |
+
|
242 |
+
|
243 |
+
def create_model_and_transforms(
|
244 |
+
model_name: str,
|
245 |
+
pretrained: str = "",
|
246 |
+
precision: str = "fp32",
|
247 |
+
device: torch.device = torch.device("cpu"),
|
248 |
+
jit: bool = False,
|
249 |
+
force_quick_gelu: bool = False,
|
250 |
+
# pretrained_image: bool = False,
|
251 |
+
):
|
252 |
+
model = create_model(
|
253 |
+
model_name,
|
254 |
+
pretrained,
|
255 |
+
precision,
|
256 |
+
device,
|
257 |
+
jit,
|
258 |
+
force_quick_gelu=force_quick_gelu,
|
259 |
+
# pretrained_image=pretrained_image
|
260 |
+
)
|
261 |
+
preprocess_train = image_transform(model.visual.image_size, is_train=True)
|
262 |
+
preprocess_val = image_transform(model.visual.image_size, is_train=False)
|
263 |
+
return model, preprocess_train, preprocess_val
|
264 |
+
|
265 |
+
|
266 |
+
def list_models():
|
267 |
+
"""enumerate available model architectures based on config files"""
|
268 |
+
return list(_MODEL_CONFIGS.keys())
|
269 |
+
|
270 |
+
|
271 |
+
def add_model_config(path):
|
272 |
+
"""add model config path or file and update registry"""
|
273 |
+
if not isinstance(path, Path):
|
274 |
+
path = Path(path)
|
275 |
+
_MODEL_CONFIG_PATHS.append(path)
|
276 |
+
_rescan_model_configs()
|
audiosr/clap/open_clip/feature_fusion.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Feature Fusion for Varible-Length Data Processing
|
3 |
+
AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py
|
4 |
+
According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
|
11 |
+
class DAF(nn.Module):
|
12 |
+
"""
|
13 |
+
直接相加 DirectAddFuse
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self):
|
17 |
+
super(DAF, self).__init__()
|
18 |
+
|
19 |
+
def forward(self, x, residual):
|
20 |
+
return x + residual
|
21 |
+
|
22 |
+
|
23 |
+
class iAFF(nn.Module):
|
24 |
+
"""
|
25 |
+
多特征融合 iAFF
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self, channels=64, r=4, type="2D"):
|
29 |
+
super(iAFF, self).__init__()
|
30 |
+
inter_channels = int(channels // r)
|
31 |
+
|
32 |
+
if type == "1D":
|
33 |
+
# 本地注意力
|
34 |
+
self.local_att = nn.Sequential(
|
35 |
+
nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
36 |
+
nn.BatchNorm1d(inter_channels),
|
37 |
+
nn.ReLU(inplace=True),
|
38 |
+
nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
39 |
+
nn.BatchNorm1d(channels),
|
40 |
+
)
|
41 |
+
|
42 |
+
# 全局注意力
|
43 |
+
self.global_att = nn.Sequential(
|
44 |
+
nn.AdaptiveAvgPool1d(1),
|
45 |
+
nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
46 |
+
nn.BatchNorm1d(inter_channels),
|
47 |
+
nn.ReLU(inplace=True),
|
48 |
+
nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
49 |
+
nn.BatchNorm1d(channels),
|
50 |
+
)
|
51 |
+
|
52 |
+
# 第二次本地注意力
|
53 |
+
self.local_att2 = nn.Sequential(
|
54 |
+
nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
55 |
+
nn.BatchNorm1d(inter_channels),
|
56 |
+
nn.ReLU(inplace=True),
|
57 |
+
nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
58 |
+
nn.BatchNorm1d(channels),
|
59 |
+
)
|
60 |
+
# 第二次全局注意力
|
61 |
+
self.global_att2 = nn.Sequential(
|
62 |
+
nn.AdaptiveAvgPool1d(1),
|
63 |
+
nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
64 |
+
nn.BatchNorm1d(inter_channels),
|
65 |
+
nn.ReLU(inplace=True),
|
66 |
+
nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
67 |
+
nn.BatchNorm1d(channels),
|
68 |
+
)
|
69 |
+
elif type == "2D":
|
70 |
+
# 本地注意力
|
71 |
+
self.local_att = nn.Sequential(
|
72 |
+
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
73 |
+
nn.BatchNorm2d(inter_channels),
|
74 |
+
nn.ReLU(inplace=True),
|
75 |
+
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
76 |
+
nn.BatchNorm2d(channels),
|
77 |
+
)
|
78 |
+
|
79 |
+
# 全局注意力
|
80 |
+
self.global_att = nn.Sequential(
|
81 |
+
nn.AdaptiveAvgPool2d(1),
|
82 |
+
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
83 |
+
nn.BatchNorm2d(inter_channels),
|
84 |
+
nn.ReLU(inplace=True),
|
85 |
+
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
86 |
+
nn.BatchNorm2d(channels),
|
87 |
+
)
|
88 |
+
|
89 |
+
# 第二次本地注意力
|
90 |
+
self.local_att2 = nn.Sequential(
|
91 |
+
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
92 |
+
nn.BatchNorm2d(inter_channels),
|
93 |
+
nn.ReLU(inplace=True),
|
94 |
+
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
95 |
+
nn.BatchNorm2d(channels),
|
96 |
+
)
|
97 |
+
# 第二次全局注意力
|
98 |
+
self.global_att2 = nn.Sequential(
|
99 |
+
nn.AdaptiveAvgPool2d(1),
|
100 |
+
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
101 |
+
nn.BatchNorm2d(inter_channels),
|
102 |
+
nn.ReLU(inplace=True),
|
103 |
+
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
104 |
+
nn.BatchNorm2d(channels),
|
105 |
+
)
|
106 |
+
else:
|
107 |
+
raise f"the type is not supported"
|
108 |
+
|
109 |
+
self.sigmoid = nn.Sigmoid()
|
110 |
+
|
111 |
+
def forward(self, x, residual):
|
112 |
+
flag = False
|
113 |
+
xa = x + residual
|
114 |
+
if xa.size(0) == 1:
|
115 |
+
xa = torch.cat([xa, xa], dim=0)
|
116 |
+
flag = True
|
117 |
+
xl = self.local_att(xa)
|
118 |
+
xg = self.global_att(xa)
|
119 |
+
xlg = xl + xg
|
120 |
+
wei = self.sigmoid(xlg)
|
121 |
+
xi = x * wei + residual * (1 - wei)
|
122 |
+
|
123 |
+
xl2 = self.local_att2(xi)
|
124 |
+
xg2 = self.global_att(xi)
|
125 |
+
xlg2 = xl2 + xg2
|
126 |
+
wei2 = self.sigmoid(xlg2)
|
127 |
+
xo = x * wei2 + residual * (1 - wei2)
|
128 |
+
if flag:
|
129 |
+
xo = xo[0].unsqueeze(0)
|
130 |
+
return xo
|
131 |
+
|
132 |
+
|
133 |
+
class AFF(nn.Module):
|
134 |
+
"""
|
135 |
+
多特征融合 AFF
|
136 |
+
"""
|
137 |
+
|
138 |
+
def __init__(self, channels=64, r=4, type="2D"):
|
139 |
+
super(AFF, self).__init__()
|
140 |
+
inter_channels = int(channels // r)
|
141 |
+
|
142 |
+
if type == "1D":
|
143 |
+
self.local_att = nn.Sequential(
|
144 |
+
nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
145 |
+
nn.BatchNorm1d(inter_channels),
|
146 |
+
nn.ReLU(inplace=True),
|
147 |
+
nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
148 |
+
nn.BatchNorm1d(channels),
|
149 |
+
)
|
150 |
+
self.global_att = nn.Sequential(
|
151 |
+
nn.AdaptiveAvgPool1d(1),
|
152 |
+
nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
153 |
+
nn.BatchNorm1d(inter_channels),
|
154 |
+
nn.ReLU(inplace=True),
|
155 |
+
nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
156 |
+
nn.BatchNorm1d(channels),
|
157 |
+
)
|
158 |
+
elif type == "2D":
|
159 |
+
self.local_att = nn.Sequential(
|
160 |
+
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
161 |
+
nn.BatchNorm2d(inter_channels),
|
162 |
+
nn.ReLU(inplace=True),
|
163 |
+
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
164 |
+
nn.BatchNorm2d(channels),
|
165 |
+
)
|
166 |
+
self.global_att = nn.Sequential(
|
167 |
+
nn.AdaptiveAvgPool2d(1),
|
168 |
+
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
169 |
+
nn.BatchNorm2d(inter_channels),
|
170 |
+
nn.ReLU(inplace=True),
|
171 |
+
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
172 |
+
nn.BatchNorm2d(channels),
|
173 |
+
)
|
174 |
+
else:
|
175 |
+
raise f"the type is not supported."
|
176 |
+
|
177 |
+
self.sigmoid = nn.Sigmoid()
|
178 |
+
|
179 |
+
def forward(self, x, residual):
|
180 |
+
flag = False
|
181 |
+
xa = x + residual
|
182 |
+
if xa.size(0) == 1:
|
183 |
+
xa = torch.cat([xa, xa], dim=0)
|
184 |
+
flag = True
|
185 |
+
xl = self.local_att(xa)
|
186 |
+
xg = self.global_att(xa)
|
187 |
+
xlg = xl + xg
|
188 |
+
wei = self.sigmoid(xlg)
|
189 |
+
xo = 2 * x * wei + 2 * residual * (1 - wei)
|
190 |
+
if flag:
|
191 |
+
xo = xo[0].unsqueeze(0)
|
192 |
+
return xo
|
audiosr/clap/open_clip/htsat.py
ADDED
@@ -0,0 +1,1304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ke Chen
|
2 | |
3 |
+
# HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
|
4 |
+
# Some layers designed on the model
|
5 |
+
# below codes are based and referred from https://github.com/microsoft/Swin-Transformer
|
6 |
+
# Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from itertools import repeat
|
11 |
+
import collections.abc
|
12 |
+
import math
|
13 |
+
import warnings
|
14 |
+
|
15 |
+
from torch.nn.init import _calculate_fan_in_and_fan_out
|
16 |
+
import torch.utils.checkpoint as checkpoint
|
17 |
+
|
18 |
+
import random
|
19 |
+
|
20 |
+
from torchlibrosa.stft import Spectrogram, LogmelFilterBank
|
21 |
+
from torchlibrosa.augmentation import SpecAugmentation
|
22 |
+
|
23 |
+
from itertools import repeat
|
24 |
+
from .utils import do_mixup, interpolate
|
25 |
+
|
26 |
+
from .feature_fusion import iAFF, AFF, DAF
|
27 |
+
|
28 |
+
|
29 |
+
# from PyTorch internals
|
30 |
+
def _ntuple(n):
|
31 |
+
def parse(x):
|
32 |
+
if isinstance(x, collections.abc.Iterable):
|
33 |
+
return x
|
34 |
+
return tuple(repeat(x, n))
|
35 |
+
|
36 |
+
return parse
|
37 |
+
|
38 |
+
|
39 |
+
to_1tuple = _ntuple(1)
|
40 |
+
to_2tuple = _ntuple(2)
|
41 |
+
to_3tuple = _ntuple(3)
|
42 |
+
to_4tuple = _ntuple(4)
|
43 |
+
to_ntuple = _ntuple
|
44 |
+
|
45 |
+
|
46 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
47 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
48 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
49 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
50 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
51 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
52 |
+
'survival rate' as the argument.
|
53 |
+
"""
|
54 |
+
if drop_prob == 0.0 or not training:
|
55 |
+
return x
|
56 |
+
keep_prob = 1 - drop_prob
|
57 |
+
shape = (x.shape[0],) + (1,) * (
|
58 |
+
x.ndim - 1
|
59 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
60 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
61 |
+
random_tensor.floor_() # binarize
|
62 |
+
output = x.div(keep_prob) * random_tensor
|
63 |
+
return output
|
64 |
+
|
65 |
+
|
66 |
+
class DropPath(nn.Module):
|
67 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
68 |
+
|
69 |
+
def __init__(self, drop_prob=None):
|
70 |
+
super(DropPath, self).__init__()
|
71 |
+
self.drop_prob = drop_prob
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
return drop_path(x, self.drop_prob, self.training)
|
75 |
+
|
76 |
+
|
77 |
+
class PatchEmbed(nn.Module):
|
78 |
+
"""2D Image to Patch Embedding"""
|
79 |
+
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
img_size=224,
|
83 |
+
patch_size=16,
|
84 |
+
in_chans=3,
|
85 |
+
embed_dim=768,
|
86 |
+
norm_layer=None,
|
87 |
+
flatten=True,
|
88 |
+
patch_stride=16,
|
89 |
+
enable_fusion=False,
|
90 |
+
fusion_type="None",
|
91 |
+
):
|
92 |
+
super().__init__()
|
93 |
+
img_size = to_2tuple(img_size)
|
94 |
+
patch_size = to_2tuple(patch_size)
|
95 |
+
patch_stride = to_2tuple(patch_stride)
|
96 |
+
self.img_size = img_size
|
97 |
+
self.patch_size = patch_size
|
98 |
+
self.patch_stride = patch_stride
|
99 |
+
self.grid_size = (
|
100 |
+
img_size[0] // patch_stride[0],
|
101 |
+
img_size[1] // patch_stride[1],
|
102 |
+
)
|
103 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
104 |
+
self.flatten = flatten
|
105 |
+
self.in_chans = in_chans
|
106 |
+
self.embed_dim = embed_dim
|
107 |
+
|
108 |
+
self.enable_fusion = enable_fusion
|
109 |
+
self.fusion_type = fusion_type
|
110 |
+
|
111 |
+
padding = (
|
112 |
+
(patch_size[0] - patch_stride[0]) // 2,
|
113 |
+
(patch_size[1] - patch_stride[1]) // 2,
|
114 |
+
)
|
115 |
+
|
116 |
+
if (self.enable_fusion) and (self.fusion_type == "channel_map"):
|
117 |
+
self.proj = nn.Conv2d(
|
118 |
+
in_chans * 4,
|
119 |
+
embed_dim,
|
120 |
+
kernel_size=patch_size,
|
121 |
+
stride=patch_stride,
|
122 |
+
padding=padding,
|
123 |
+
)
|
124 |
+
else:
|
125 |
+
self.proj = nn.Conv2d(
|
126 |
+
in_chans,
|
127 |
+
embed_dim,
|
128 |
+
kernel_size=patch_size,
|
129 |
+
stride=patch_stride,
|
130 |
+
padding=padding,
|
131 |
+
)
|
132 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
133 |
+
|
134 |
+
if (self.enable_fusion) and (
|
135 |
+
self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
|
136 |
+
):
|
137 |
+
self.mel_conv2d = nn.Conv2d(
|
138 |
+
in_chans,
|
139 |
+
embed_dim,
|
140 |
+
kernel_size=(patch_size[0], patch_size[1] * 3),
|
141 |
+
stride=(patch_stride[0], patch_stride[1] * 3),
|
142 |
+
padding=padding,
|
143 |
+
)
|
144 |
+
if self.fusion_type == "daf_2d":
|
145 |
+
self.fusion_model = DAF()
|
146 |
+
elif self.fusion_type == "aff_2d":
|
147 |
+
self.fusion_model = AFF(channels=embed_dim, type="2D")
|
148 |
+
elif self.fusion_type == "iaff_2d":
|
149 |
+
self.fusion_model = iAFF(channels=embed_dim, type="2D")
|
150 |
+
|
151 |
+
def forward(self, x, longer_idx=None):
|
152 |
+
if (self.enable_fusion) and (
|
153 |
+
self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
|
154 |
+
):
|
155 |
+
global_x = x[:, 0:1, :, :]
|
156 |
+
|
157 |
+
# global processing
|
158 |
+
B, C, H, W = global_x.shape
|
159 |
+
assert (
|
160 |
+
H == self.img_size[0] and W == self.img_size[1]
|
161 |
+
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
162 |
+
global_x = self.proj(global_x)
|
163 |
+
TW = global_x.size(-1)
|
164 |
+
if len(longer_idx) > 0:
|
165 |
+
# local processing
|
166 |
+
local_x = x[longer_idx, 1:, :, :].contiguous()
|
167 |
+
B, C, H, W = local_x.shape
|
168 |
+
local_x = local_x.view(B * C, 1, H, W)
|
169 |
+
local_x = self.mel_conv2d(local_x)
|
170 |
+
local_x = local_x.view(
|
171 |
+
B, C, local_x.size(1), local_x.size(2), local_x.size(3)
|
172 |
+
)
|
173 |
+
local_x = local_x.permute((0, 2, 3, 1, 4)).contiguous().flatten(3)
|
174 |
+
TB, TC, TH, _ = local_x.size()
|
175 |
+
if local_x.size(-1) < TW:
|
176 |
+
local_x = torch.cat(
|
177 |
+
[
|
178 |
+
local_x,
|
179 |
+
torch.zeros(
|
180 |
+
(TB, TC, TH, TW - local_x.size(-1)),
|
181 |
+
device=global_x.device,
|
182 |
+
),
|
183 |
+
],
|
184 |
+
dim=-1,
|
185 |
+
)
|
186 |
+
else:
|
187 |
+
local_x = local_x[:, :, :, :TW]
|
188 |
+
|
189 |
+
global_x[longer_idx] = self.fusion_model(global_x[longer_idx], local_x)
|
190 |
+
x = global_x
|
191 |
+
else:
|
192 |
+
B, C, H, W = x.shape
|
193 |
+
assert (
|
194 |
+
H == self.img_size[0] and W == self.img_size[1]
|
195 |
+
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
196 |
+
x = self.proj(x)
|
197 |
+
|
198 |
+
if self.flatten:
|
199 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
200 |
+
x = self.norm(x)
|
201 |
+
return x
|
202 |
+
|
203 |
+
|
204 |
+
class Mlp(nn.Module):
|
205 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
206 |
+
|
207 |
+
def __init__(
|
208 |
+
self,
|
209 |
+
in_features,
|
210 |
+
hidden_features=None,
|
211 |
+
out_features=None,
|
212 |
+
act_layer=nn.GELU,
|
213 |
+
drop=0.0,
|
214 |
+
):
|
215 |
+
super().__init__()
|
216 |
+
out_features = out_features or in_features
|
217 |
+
hidden_features = hidden_features or in_features
|
218 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
219 |
+
self.act = act_layer()
|
220 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
221 |
+
self.drop = nn.Dropout(drop)
|
222 |
+
|
223 |
+
def forward(self, x):
|
224 |
+
x = self.fc1(x)
|
225 |
+
x = self.act(x)
|
226 |
+
x = self.drop(x)
|
227 |
+
x = self.fc2(x)
|
228 |
+
x = self.drop(x)
|
229 |
+
return x
|
230 |
+
|
231 |
+
|
232 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
233 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
234 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
235 |
+
def norm_cdf(x):
|
236 |
+
# Computes standard normal cumulative distribution function
|
237 |
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
238 |
+
|
239 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
240 |
+
warnings.warn(
|
241 |
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
242 |
+
"The distribution of values may be incorrect.",
|
243 |
+
stacklevel=2,
|
244 |
+
)
|
245 |
+
|
246 |
+
with torch.no_grad():
|
247 |
+
# Values are generated by using a truncated uniform distribution and
|
248 |
+
# then using the inverse CDF for the normal distribution.
|
249 |
+
# Get upper and lower cdf values
|
250 |
+
l = norm_cdf((a - mean) / std)
|
251 |
+
u = norm_cdf((b - mean) / std)
|
252 |
+
|
253 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
254 |
+
# [2l-1, 2u-1].
|
255 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
256 |
+
|
257 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
258 |
+
# standard normal
|
259 |
+
tensor.erfinv_()
|
260 |
+
|
261 |
+
# Transform to proper mean, std
|
262 |
+
tensor.mul_(std * math.sqrt(2.0))
|
263 |
+
tensor.add_(mean)
|
264 |
+
|
265 |
+
# Clamp to ensure it's in the proper range
|
266 |
+
tensor.clamp_(min=a, max=b)
|
267 |
+
return tensor
|
268 |
+
|
269 |
+
|
270 |
+
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
271 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
272 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
273 |
+
normal distribution. The values are effectively drawn from the
|
274 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
275 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
276 |
+
the bounds. The method used for generating the random values works
|
277 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
278 |
+
Args:
|
279 |
+
tensor: an n-dimensional `torch.Tensor`
|
280 |
+
mean: the mean of the normal distribution
|
281 |
+
std: the standard deviation of the normal distribution
|
282 |
+
a: the minimum cutoff value
|
283 |
+
b: the maximum cutoff value
|
284 |
+
Examples:
|
285 |
+
>>> w = torch.empty(3, 5)
|
286 |
+
>>> nn.init.trunc_normal_(w)
|
287 |
+
"""
|
288 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
289 |
+
|
290 |
+
|
291 |
+
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
|
292 |
+
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
293 |
+
if mode == "fan_in":
|
294 |
+
denom = fan_in
|
295 |
+
elif mode == "fan_out":
|
296 |
+
denom = fan_out
|
297 |
+
elif mode == "fan_avg":
|
298 |
+
denom = (fan_in + fan_out) / 2
|
299 |
+
|
300 |
+
variance = scale / denom
|
301 |
+
|
302 |
+
if distribution == "truncated_normal":
|
303 |
+
# constant is stddev of standard normal truncated to (-2, 2)
|
304 |
+
trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
|
305 |
+
elif distribution == "normal":
|
306 |
+
tensor.normal_(std=math.sqrt(variance))
|
307 |
+
elif distribution == "uniform":
|
308 |
+
bound = math.sqrt(3 * variance)
|
309 |
+
tensor.uniform_(-bound, bound)
|
310 |
+
else:
|
311 |
+
raise ValueError(f"invalid distribution {distribution}")
|
312 |
+
|
313 |
+
|
314 |
+
def lecun_normal_(tensor):
|
315 |
+
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
|
316 |
+
|
317 |
+
|
318 |
+
def window_partition(x, window_size):
|
319 |
+
"""
|
320 |
+
Args:
|
321 |
+
x: (B, H, W, C)
|
322 |
+
window_size (int): window size
|
323 |
+
Returns:
|
324 |
+
windows: (num_windows*B, window_size, window_size, C)
|
325 |
+
"""
|
326 |
+
B, H, W, C = x.shape
|
327 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
328 |
+
windows = (
|
329 |
+
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
330 |
+
)
|
331 |
+
return windows
|
332 |
+
|
333 |
+
|
334 |
+
def window_reverse(windows, window_size, H, W):
|
335 |
+
"""
|
336 |
+
Args:
|
337 |
+
windows: (num_windows*B, window_size, window_size, C)
|
338 |
+
window_size (int): Window size
|
339 |
+
H (int): Height of image
|
340 |
+
W (int): Width of image
|
341 |
+
Returns:
|
342 |
+
x: (B, H, W, C)
|
343 |
+
"""
|
344 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
345 |
+
x = windows.view(
|
346 |
+
B, H // window_size, W // window_size, window_size, window_size, -1
|
347 |
+
)
|
348 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
349 |
+
return x
|
350 |
+
|
351 |
+
|
352 |
+
class WindowAttention(nn.Module):
|
353 |
+
r"""Window based multi-head self attention (W-MSA) module with relative position bias.
|
354 |
+
It supports both of shifted and non-shifted window.
|
355 |
+
Args:
|
356 |
+
dim (int): Number of input channels.
|
357 |
+
window_size (tuple[int]): The height and width of the window.
|
358 |
+
num_heads (int): Number of attention heads.
|
359 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
360 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
361 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
362 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
363 |
+
"""
|
364 |
+
|
365 |
+
def __init__(
|
366 |
+
self,
|
367 |
+
dim,
|
368 |
+
window_size,
|
369 |
+
num_heads,
|
370 |
+
qkv_bias=True,
|
371 |
+
qk_scale=None,
|
372 |
+
attn_drop=0.0,
|
373 |
+
proj_drop=0.0,
|
374 |
+
):
|
375 |
+
super().__init__()
|
376 |
+
self.dim = dim
|
377 |
+
self.window_size = window_size # Wh, Ww
|
378 |
+
self.num_heads = num_heads
|
379 |
+
head_dim = dim // num_heads
|
380 |
+
self.scale = qk_scale or head_dim**-0.5
|
381 |
+
|
382 |
+
# define a parameter table of relative position bias
|
383 |
+
self.relative_position_bias_table = nn.Parameter(
|
384 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
|
385 |
+
) # 2*Wh-1 * 2*Ww-1, nH
|
386 |
+
|
387 |
+
# get pair-wise relative position index for each token inside the window
|
388 |
+
coords_h = torch.arange(self.window_size[0])
|
389 |
+
coords_w = torch.arange(self.window_size[1])
|
390 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
391 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
392 |
+
relative_coords = (
|
393 |
+
coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
394 |
+
) # 2, Wh*Ww, Wh*Ww
|
395 |
+
relative_coords = relative_coords.permute(
|
396 |
+
1, 2, 0
|
397 |
+
).contiguous() # Wh*Ww, Wh*Ww, 2
|
398 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
399 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
400 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
401 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
402 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
403 |
+
|
404 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
405 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
406 |
+
self.proj = nn.Linear(dim, dim)
|
407 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
408 |
+
|
409 |
+
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
410 |
+
self.softmax = nn.Softmax(dim=-1)
|
411 |
+
|
412 |
+
def forward(self, x, mask=None):
|
413 |
+
"""
|
414 |
+
Args:
|
415 |
+
x: input features with shape of (num_windows*B, N, C)
|
416 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
417 |
+
"""
|
418 |
+
B_, N, C = x.shape
|
419 |
+
qkv = (
|
420 |
+
self.qkv(x)
|
421 |
+
.reshape(B_, N, 3, self.num_heads, C // self.num_heads)
|
422 |
+
.permute(2, 0, 3, 1, 4)
|
423 |
+
)
|
424 |
+
q, k, v = (
|
425 |
+
qkv[0],
|
426 |
+
qkv[1],
|
427 |
+
qkv[2],
|
428 |
+
) # make torchscript happy (cannot use tensor as tuple)
|
429 |
+
|
430 |
+
q = q * self.scale
|
431 |
+
attn = q @ k.transpose(-2, -1)
|
432 |
+
|
433 |
+
relative_position_bias = self.relative_position_bias_table[
|
434 |
+
self.relative_position_index.view(-1)
|
435 |
+
].view(
|
436 |
+
self.window_size[0] * self.window_size[1],
|
437 |
+
self.window_size[0] * self.window_size[1],
|
438 |
+
-1,
|
439 |
+
) # Wh*Ww,Wh*Ww,nH
|
440 |
+
relative_position_bias = relative_position_bias.permute(
|
441 |
+
2, 0, 1
|
442 |
+
).contiguous() # nH, Wh*Ww, Wh*Ww
|
443 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
444 |
+
|
445 |
+
if mask is not None:
|
446 |
+
nW = mask.shape[0]
|
447 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
|
448 |
+
1
|
449 |
+
).unsqueeze(0)
|
450 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
451 |
+
attn = self.softmax(attn)
|
452 |
+
else:
|
453 |
+
attn = self.softmax(attn)
|
454 |
+
|
455 |
+
attn = self.attn_drop(attn)
|
456 |
+
|
457 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
458 |
+
x = self.proj(x)
|
459 |
+
x = self.proj_drop(x)
|
460 |
+
return x, attn
|
461 |
+
|
462 |
+
def extra_repr(self):
|
463 |
+
return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}"
|
464 |
+
|
465 |
+
|
466 |
+
# We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model
|
467 |
+
class SwinTransformerBlock(nn.Module):
|
468 |
+
r"""Swin Transformer Block.
|
469 |
+
Args:
|
470 |
+
dim (int): Number of input channels.
|
471 |
+
input_resolution (tuple[int]): Input resulotion.
|
472 |
+
num_heads (int): Number of attention heads.
|
473 |
+
window_size (int): Window size.
|
474 |
+
shift_size (int): Shift size for SW-MSA.
|
475 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
476 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
477 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
478 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
479 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
480 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
481 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
482 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
483 |
+
"""
|
484 |
+
|
485 |
+
def __init__(
|
486 |
+
self,
|
487 |
+
dim,
|
488 |
+
input_resolution,
|
489 |
+
num_heads,
|
490 |
+
window_size=7,
|
491 |
+
shift_size=0,
|
492 |
+
mlp_ratio=4.0,
|
493 |
+
qkv_bias=True,
|
494 |
+
qk_scale=None,
|
495 |
+
drop=0.0,
|
496 |
+
attn_drop=0.0,
|
497 |
+
drop_path=0.0,
|
498 |
+
act_layer=nn.GELU,
|
499 |
+
norm_layer=nn.LayerNorm,
|
500 |
+
norm_before_mlp="ln",
|
501 |
+
):
|
502 |
+
super().__init__()
|
503 |
+
self.dim = dim
|
504 |
+
self.input_resolution = input_resolution
|
505 |
+
self.num_heads = num_heads
|
506 |
+
self.window_size = window_size
|
507 |
+
self.shift_size = shift_size
|
508 |
+
self.mlp_ratio = mlp_ratio
|
509 |
+
self.norm_before_mlp = norm_before_mlp
|
510 |
+
if min(self.input_resolution) <= self.window_size:
|
511 |
+
# if window size is larger than input resolution, we don't partition windows
|
512 |
+
self.shift_size = 0
|
513 |
+
self.window_size = min(self.input_resolution)
|
514 |
+
assert (
|
515 |
+
0 <= self.shift_size < self.window_size
|
516 |
+
), "shift_size must in 0-window_size"
|
517 |
+
|
518 |
+
self.norm1 = norm_layer(dim)
|
519 |
+
self.attn = WindowAttention(
|
520 |
+
dim,
|
521 |
+
window_size=to_2tuple(self.window_size),
|
522 |
+
num_heads=num_heads,
|
523 |
+
qkv_bias=qkv_bias,
|
524 |
+
qk_scale=qk_scale,
|
525 |
+
attn_drop=attn_drop,
|
526 |
+
proj_drop=drop,
|
527 |
+
)
|
528 |
+
|
529 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
530 |
+
if self.norm_before_mlp == "ln":
|
531 |
+
self.norm2 = nn.LayerNorm(dim)
|
532 |
+
elif self.norm_before_mlp == "bn":
|
533 |
+
self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(
|
534 |
+
1, 2
|
535 |
+
)
|
536 |
+
else:
|
537 |
+
raise NotImplementedError
|
538 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
539 |
+
self.mlp = Mlp(
|
540 |
+
in_features=dim,
|
541 |
+
hidden_features=mlp_hidden_dim,
|
542 |
+
act_layer=act_layer,
|
543 |
+
drop=drop,
|
544 |
+
)
|
545 |
+
|
546 |
+
if self.shift_size > 0:
|
547 |
+
# calculate attention mask for SW-MSA
|
548 |
+
H, W = self.input_resolution
|
549 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
550 |
+
h_slices = (
|
551 |
+
slice(0, -self.window_size),
|
552 |
+
slice(-self.window_size, -self.shift_size),
|
553 |
+
slice(-self.shift_size, None),
|
554 |
+
)
|
555 |
+
w_slices = (
|
556 |
+
slice(0, -self.window_size),
|
557 |
+
slice(-self.window_size, -self.shift_size),
|
558 |
+
slice(-self.shift_size, None),
|
559 |
+
)
|
560 |
+
cnt = 0
|
561 |
+
for h in h_slices:
|
562 |
+
for w in w_slices:
|
563 |
+
img_mask[:, h, w, :] = cnt
|
564 |
+
cnt += 1
|
565 |
+
|
566 |
+
mask_windows = window_partition(
|
567 |
+
img_mask, self.window_size
|
568 |
+
) # nW, window_size, window_size, 1
|
569 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
570 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
571 |
+
attn_mask = attn_mask.masked_fill(
|
572 |
+
attn_mask != 0, float(-100.0)
|
573 |
+
).masked_fill(attn_mask == 0, float(0.0))
|
574 |
+
else:
|
575 |
+
attn_mask = None
|
576 |
+
|
577 |
+
self.register_buffer("attn_mask", attn_mask)
|
578 |
+
|
579 |
+
def forward(self, x):
|
580 |
+
# pdb.set_trace()
|
581 |
+
H, W = self.input_resolution
|
582 |
+
# print("H: ", H)
|
583 |
+
# print("W: ", W)
|
584 |
+
# pdb.set_trace()
|
585 |
+
B, L, C = x.shape
|
586 |
+
# assert L == H * W, "input feature has wrong size"
|
587 |
+
|
588 |
+
shortcut = x
|
589 |
+
x = self.norm1(x)
|
590 |
+
x = x.view(B, H, W, C)
|
591 |
+
|
592 |
+
# cyclic shift
|
593 |
+
if self.shift_size > 0:
|
594 |
+
shifted_x = torch.roll(
|
595 |
+
x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
|
596 |
+
)
|
597 |
+
else:
|
598 |
+
shifted_x = x
|
599 |
+
|
600 |
+
# partition windows
|
601 |
+
x_windows = window_partition(
|
602 |
+
shifted_x, self.window_size
|
603 |
+
) # nW*B, window_size, window_size, C
|
604 |
+
x_windows = x_windows.view(
|
605 |
+
-1, self.window_size * self.window_size, C
|
606 |
+
) # nW*B, window_size*window_size, C
|
607 |
+
|
608 |
+
# W-MSA/SW-MSA
|
609 |
+
attn_windows, attn = self.attn(
|
610 |
+
x_windows, mask=self.attn_mask
|
611 |
+
) # nW*B, window_size*window_size, C
|
612 |
+
|
613 |
+
# merge windows
|
614 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
615 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
616 |
+
|
617 |
+
# reverse cyclic shift
|
618 |
+
if self.shift_size > 0:
|
619 |
+
x = torch.roll(
|
620 |
+
shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
|
621 |
+
)
|
622 |
+
else:
|
623 |
+
x = shifted_x
|
624 |
+
x = x.view(B, H * W, C)
|
625 |
+
|
626 |
+
# FFN
|
627 |
+
x = shortcut + self.drop_path(x)
|
628 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
629 |
+
|
630 |
+
return x, attn
|
631 |
+
|
632 |
+
def extra_repr(self):
|
633 |
+
return (
|
634 |
+
f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
|
635 |
+
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
636 |
+
)
|
637 |
+
|
638 |
+
|
639 |
+
class PatchMerging(nn.Module):
|
640 |
+
r"""Patch Merging Layer.
|
641 |
+
Args:
|
642 |
+
input_resolution (tuple[int]): Resolution of input feature.
|
643 |
+
dim (int): Number of input channels.
|
644 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
645 |
+
"""
|
646 |
+
|
647 |
+
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
648 |
+
super().__init__()
|
649 |
+
self.input_resolution = input_resolution
|
650 |
+
self.dim = dim
|
651 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
652 |
+
self.norm = norm_layer(4 * dim)
|
653 |
+
|
654 |
+
def forward(self, x):
|
655 |
+
"""
|
656 |
+
x: B, H*W, C
|
657 |
+
"""
|
658 |
+
H, W = self.input_resolution
|
659 |
+
B, L, C = x.shape
|
660 |
+
assert L == H * W, "input feature has wrong size"
|
661 |
+
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
662 |
+
|
663 |
+
x = x.view(B, H, W, C)
|
664 |
+
|
665 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
666 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
667 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
668 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
669 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
670 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
671 |
+
|
672 |
+
x = self.norm(x)
|
673 |
+
x = self.reduction(x)
|
674 |
+
|
675 |
+
return x
|
676 |
+
|
677 |
+
def extra_repr(self):
|
678 |
+
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
679 |
+
|
680 |
+
|
681 |
+
class BasicLayer(nn.Module):
|
682 |
+
"""A basic Swin Transformer layer for one stage.
|
683 |
+
Args:
|
684 |
+
dim (int): Number of input channels.
|
685 |
+
input_resolution (tuple[int]): Input resolution.
|
686 |
+
depth (int): Number of blocks.
|
687 |
+
num_heads (int): Number of attention heads.
|
688 |
+
window_size (int): Local window size.
|
689 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
690 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
691 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
692 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
693 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
694 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
695 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
696 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
697 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
698 |
+
"""
|
699 |
+
|
700 |
+
def __init__(
|
701 |
+
self,
|
702 |
+
dim,
|
703 |
+
input_resolution,
|
704 |
+
depth,
|
705 |
+
num_heads,
|
706 |
+
window_size,
|
707 |
+
mlp_ratio=4.0,
|
708 |
+
qkv_bias=True,
|
709 |
+
qk_scale=None,
|
710 |
+
drop=0.0,
|
711 |
+
attn_drop=0.0,
|
712 |
+
drop_path=0.0,
|
713 |
+
norm_layer=nn.LayerNorm,
|
714 |
+
downsample=None,
|
715 |
+
use_checkpoint=False,
|
716 |
+
norm_before_mlp="ln",
|
717 |
+
):
|
718 |
+
super().__init__()
|
719 |
+
self.dim = dim
|
720 |
+
self.input_resolution = input_resolution
|
721 |
+
self.depth = depth
|
722 |
+
self.use_checkpoint = use_checkpoint
|
723 |
+
|
724 |
+
# build blocks
|
725 |
+
self.blocks = nn.ModuleList(
|
726 |
+
[
|
727 |
+
SwinTransformerBlock(
|
728 |
+
dim=dim,
|
729 |
+
input_resolution=input_resolution,
|
730 |
+
num_heads=num_heads,
|
731 |
+
window_size=window_size,
|
732 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
733 |
+
mlp_ratio=mlp_ratio,
|
734 |
+
qkv_bias=qkv_bias,
|
735 |
+
qk_scale=qk_scale,
|
736 |
+
drop=drop,
|
737 |
+
attn_drop=attn_drop,
|
738 |
+
drop_path=drop_path[i]
|
739 |
+
if isinstance(drop_path, list)
|
740 |
+
else drop_path,
|
741 |
+
norm_layer=norm_layer,
|
742 |
+
norm_before_mlp=norm_before_mlp,
|
743 |
+
)
|
744 |
+
for i in range(depth)
|
745 |
+
]
|
746 |
+
)
|
747 |
+
|
748 |
+
# patch merging layer
|
749 |
+
if downsample is not None:
|
750 |
+
self.downsample = downsample(
|
751 |
+
input_resolution, dim=dim, norm_layer=norm_layer
|
752 |
+
)
|
753 |
+
else:
|
754 |
+
self.downsample = None
|
755 |
+
|
756 |
+
def forward(self, x):
|
757 |
+
attns = []
|
758 |
+
for blk in self.blocks:
|
759 |
+
if self.use_checkpoint:
|
760 |
+
x = checkpoint.checkpoint(blk, x)
|
761 |
+
else:
|
762 |
+
x, attn = blk(x)
|
763 |
+
if not self.training:
|
764 |
+
attns.append(attn.unsqueeze(0))
|
765 |
+
if self.downsample is not None:
|
766 |
+
x = self.downsample(x)
|
767 |
+
if not self.training:
|
768 |
+
attn = torch.cat(attns, dim=0)
|
769 |
+
attn = torch.mean(attn, dim=0)
|
770 |
+
return x, attn
|
771 |
+
|
772 |
+
def extra_repr(self):
|
773 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
774 |
+
|
775 |
+
|
776 |
+
# The Core of HTSAT
|
777 |
+
class HTSAT_Swin_Transformer(nn.Module):
|
778 |
+
r"""HTSAT based on the Swin Transformer
|
779 |
+
Args:
|
780 |
+
spec_size (int | tuple(int)): Input Spectrogram size. Default 256
|
781 |
+
patch_size (int | tuple(int)): Patch size. Default: 4
|
782 |
+
path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4
|
783 |
+
in_chans (int): Number of input image channels. Default: 1 (mono)
|
784 |
+
num_classes (int): Number of classes for classification head. Default: 527
|
785 |
+
embed_dim (int): Patch embedding dimension. Default: 96
|
786 |
+
depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.
|
787 |
+
num_heads (tuple(int)): Number of attention heads in different layers.
|
788 |
+
window_size (int): Window size. Default: 8
|
789 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
790 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
791 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
792 |
+
drop_rate (float): Dropout rate. Default: 0
|
793 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0
|
794 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
795 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
796 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
797 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
798 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
799 |
+
config (module): The configuration Module from config.py
|
800 |
+
"""
|
801 |
+
|
802 |
+
def __init__(
|
803 |
+
self,
|
804 |
+
spec_size=256,
|
805 |
+
patch_size=4,
|
806 |
+
patch_stride=(4, 4),
|
807 |
+
in_chans=1,
|
808 |
+
num_classes=527,
|
809 |
+
embed_dim=96,
|
810 |
+
depths=[2, 2, 6, 2],
|
811 |
+
num_heads=[4, 8, 16, 32],
|
812 |
+
window_size=8,
|
813 |
+
mlp_ratio=4.0,
|
814 |
+
qkv_bias=True,
|
815 |
+
qk_scale=None,
|
816 |
+
drop_rate=0.0,
|
817 |
+
attn_drop_rate=0.0,
|
818 |
+
drop_path_rate=0.1,
|
819 |
+
norm_layer=nn.LayerNorm,
|
820 |
+
ape=False,
|
821 |
+
patch_norm=True,
|
822 |
+
use_checkpoint=False,
|
823 |
+
norm_before_mlp="ln",
|
824 |
+
config=None,
|
825 |
+
enable_fusion=False,
|
826 |
+
fusion_type="None",
|
827 |
+
**kwargs,
|
828 |
+
):
|
829 |
+
super(HTSAT_Swin_Transformer, self).__init__()
|
830 |
+
|
831 |
+
self.config = config
|
832 |
+
self.spec_size = spec_size
|
833 |
+
self.patch_stride = patch_stride
|
834 |
+
self.patch_size = patch_size
|
835 |
+
self.window_size = window_size
|
836 |
+
self.embed_dim = embed_dim
|
837 |
+
self.depths = depths
|
838 |
+
self.ape = ape
|
839 |
+
self.in_chans = in_chans
|
840 |
+
self.num_classes = num_classes
|
841 |
+
self.num_heads = num_heads
|
842 |
+
self.num_layers = len(self.depths)
|
843 |
+
self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))
|
844 |
+
|
845 |
+
self.drop_rate = drop_rate
|
846 |
+
self.attn_drop_rate = attn_drop_rate
|
847 |
+
self.drop_path_rate = drop_path_rate
|
848 |
+
|
849 |
+
self.qkv_bias = qkv_bias
|
850 |
+
self.qk_scale = None
|
851 |
+
|
852 |
+
self.patch_norm = patch_norm
|
853 |
+
self.norm_layer = norm_layer if self.patch_norm else None
|
854 |
+
self.norm_before_mlp = norm_before_mlp
|
855 |
+
self.mlp_ratio = mlp_ratio
|
856 |
+
|
857 |
+
self.use_checkpoint = use_checkpoint
|
858 |
+
|
859 |
+
self.enable_fusion = enable_fusion
|
860 |
+
self.fusion_type = fusion_type
|
861 |
+
|
862 |
+
# process mel-spec ; used only once
|
863 |
+
self.freq_ratio = self.spec_size // self.config.mel_bins
|
864 |
+
window = "hann"
|
865 |
+
center = True
|
866 |
+
pad_mode = "reflect"
|
867 |
+
ref = 1.0
|
868 |
+
amin = 1e-10
|
869 |
+
top_db = None
|
870 |
+
self.interpolate_ratio = 32 # Downsampled ratio
|
871 |
+
# Spectrogram extractor
|
872 |
+
self.spectrogram_extractor = Spectrogram(
|
873 |
+
n_fft=config.window_size,
|
874 |
+
hop_length=config.hop_size,
|
875 |
+
win_length=config.window_size,
|
876 |
+
window=window,
|
877 |
+
center=center,
|
878 |
+
pad_mode=pad_mode,
|
879 |
+
freeze_parameters=True,
|
880 |
+
)
|
881 |
+
# Logmel feature extractor
|
882 |
+
self.logmel_extractor = LogmelFilterBank(
|
883 |
+
sr=config.sample_rate,
|
884 |
+
n_fft=config.window_size,
|
885 |
+
n_mels=config.mel_bins,
|
886 |
+
fmin=config.fmin,
|
887 |
+
fmax=config.fmax,
|
888 |
+
ref=ref,
|
889 |
+
amin=amin,
|
890 |
+
top_db=top_db,
|
891 |
+
freeze_parameters=True,
|
892 |
+
)
|
893 |
+
# Spec augmenter
|
894 |
+
self.spec_augmenter = SpecAugmentation(
|
895 |
+
time_drop_width=64,
|
896 |
+
time_stripes_num=2,
|
897 |
+
freq_drop_width=8,
|
898 |
+
freq_stripes_num=2,
|
899 |
+
) # 2 2
|
900 |
+
self.bn0 = nn.BatchNorm2d(self.config.mel_bins)
|
901 |
+
|
902 |
+
# split spctrogram into non-overlapping patches
|
903 |
+
self.patch_embed = PatchEmbed(
|
904 |
+
img_size=self.spec_size,
|
905 |
+
patch_size=self.patch_size,
|
906 |
+
in_chans=self.in_chans,
|
907 |
+
embed_dim=self.embed_dim,
|
908 |
+
norm_layer=self.norm_layer,
|
909 |
+
patch_stride=patch_stride,
|
910 |
+
enable_fusion=self.enable_fusion,
|
911 |
+
fusion_type=self.fusion_type,
|
912 |
+
)
|
913 |
+
|
914 |
+
num_patches = self.patch_embed.num_patches
|
915 |
+
patches_resolution = self.patch_embed.grid_size
|
916 |
+
self.patches_resolution = patches_resolution
|
917 |
+
|
918 |
+
# absolute position embedding
|
919 |
+
if self.ape:
|
920 |
+
self.absolute_pos_embed = nn.Parameter(
|
921 |
+
torch.zeros(1, num_patches, self.embed_dim)
|
922 |
+
)
|
923 |
+
trunc_normal_(self.absolute_pos_embed, std=0.02)
|
924 |
+
|
925 |
+
self.pos_drop = nn.Dropout(p=self.drop_rate)
|
926 |
+
|
927 |
+
# stochastic depth
|
928 |
+
dpr = [
|
929 |
+
x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))
|
930 |
+
] # stochastic depth decay rule
|
931 |
+
|
932 |
+
# build layers
|
933 |
+
self.layers = nn.ModuleList()
|
934 |
+
for i_layer in range(self.num_layers):
|
935 |
+
layer = BasicLayer(
|
936 |
+
dim=int(self.embed_dim * 2**i_layer),
|
937 |
+
input_resolution=(
|
938 |
+
patches_resolution[0] // (2**i_layer),
|
939 |
+
patches_resolution[1] // (2**i_layer),
|
940 |
+
),
|
941 |
+
depth=self.depths[i_layer],
|
942 |
+
num_heads=self.num_heads[i_layer],
|
943 |
+
window_size=self.window_size,
|
944 |
+
mlp_ratio=self.mlp_ratio,
|
945 |
+
qkv_bias=self.qkv_bias,
|
946 |
+
qk_scale=self.qk_scale,
|
947 |
+
drop=self.drop_rate,
|
948 |
+
attn_drop=self.attn_drop_rate,
|
949 |
+
drop_path=dpr[
|
950 |
+
sum(self.depths[:i_layer]) : sum(self.depths[: i_layer + 1])
|
951 |
+
],
|
952 |
+
norm_layer=self.norm_layer,
|
953 |
+
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
954 |
+
use_checkpoint=use_checkpoint,
|
955 |
+
norm_before_mlp=self.norm_before_mlp,
|
956 |
+
)
|
957 |
+
self.layers.append(layer)
|
958 |
+
|
959 |
+
self.norm = self.norm_layer(self.num_features)
|
960 |
+
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
961 |
+
self.maxpool = nn.AdaptiveMaxPool1d(1)
|
962 |
+
|
963 |
+
SF = (
|
964 |
+
self.spec_size
|
965 |
+
// (2 ** (len(self.depths) - 1))
|
966 |
+
// self.patch_stride[0]
|
967 |
+
// self.freq_ratio
|
968 |
+
)
|
969 |
+
self.tscam_conv = nn.Conv2d(
|
970 |
+
in_channels=self.num_features,
|
971 |
+
out_channels=self.num_classes,
|
972 |
+
kernel_size=(SF, 3),
|
973 |
+
padding=(0, 1),
|
974 |
+
)
|
975 |
+
self.head = nn.Linear(num_classes, num_classes)
|
976 |
+
|
977 |
+
if (self.enable_fusion) and (
|
978 |
+
self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]
|
979 |
+
):
|
980 |
+
self.mel_conv1d = nn.Sequential(
|
981 |
+
nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2),
|
982 |
+
nn.BatchNorm1d(64),
|
983 |
+
)
|
984 |
+
if self.fusion_type == "daf_1d":
|
985 |
+
self.fusion_model = DAF()
|
986 |
+
elif self.fusion_type == "aff_1d":
|
987 |
+
self.fusion_model = AFF(channels=64, type="1D")
|
988 |
+
elif self.fusion_type == "iaff_1d":
|
989 |
+
self.fusion_model = iAFF(channels=64, type="1D")
|
990 |
+
|
991 |
+
self.apply(self._init_weights)
|
992 |
+
|
993 |
+
def _init_weights(self, m):
|
994 |
+
if isinstance(m, nn.Linear):
|
995 |
+
trunc_normal_(m.weight, std=0.02)
|
996 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
997 |
+
nn.init.constant_(m.bias, 0)
|
998 |
+
elif isinstance(m, nn.LayerNorm):
|
999 |
+
nn.init.constant_(m.bias, 0)
|
1000 |
+
nn.init.constant_(m.weight, 1.0)
|
1001 |
+
|
1002 |
+
@torch.jit.ignore
|
1003 |
+
def no_weight_decay(self):
|
1004 |
+
return {"absolute_pos_embed"}
|
1005 |
+
|
1006 |
+
@torch.jit.ignore
|
1007 |
+
def no_weight_decay_keywords(self):
|
1008 |
+
return {"relative_position_bias_table"}
|
1009 |
+
|
1010 |
+
def forward_features(self, x, longer_idx=None):
|
1011 |
+
# A deprecated optimization for using a hierarchical output from different blocks
|
1012 |
+
|
1013 |
+
frames_num = x.shape[2]
|
1014 |
+
x = self.patch_embed(x, longer_idx=longer_idx)
|
1015 |
+
if self.ape:
|
1016 |
+
x = x + self.absolute_pos_embed
|
1017 |
+
x = self.pos_drop(x)
|
1018 |
+
for i, layer in enumerate(self.layers):
|
1019 |
+
x, attn = layer(x)
|
1020 |
+
# for x
|
1021 |
+
x = self.norm(x)
|
1022 |
+
B, N, C = x.shape
|
1023 |
+
SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
|
1024 |
+
ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
|
1025 |
+
x = x.permute(0, 2, 1).contiguous().reshape(B, C, SF, ST)
|
1026 |
+
B, C, F, T = x.shape
|
1027 |
+
# group 2D CNN
|
1028 |
+
c_freq_bin = F // self.freq_ratio
|
1029 |
+
x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
|
1030 |
+
x = x.permute(0, 1, 3, 2, 4).contiguous().reshape(B, C, c_freq_bin, -1)
|
1031 |
+
# get latent_output
|
1032 |
+
fine_grained_latent_output = torch.mean(x, dim=2)
|
1033 |
+
fine_grained_latent_output = interpolate(
|
1034 |
+
fine_grained_latent_output.permute(0, 2, 1).contiguous(),
|
1035 |
+
8 * self.patch_stride[1],
|
1036 |
+
)
|
1037 |
+
|
1038 |
+
latent_output = self.avgpool(torch.flatten(x, 2))
|
1039 |
+
latent_output = torch.flatten(latent_output, 1)
|
1040 |
+
|
1041 |
+
# display the attention map, if needed
|
1042 |
+
|
1043 |
+
x = self.tscam_conv(x)
|
1044 |
+
x = torch.flatten(x, 2) # B, C, T
|
1045 |
+
|
1046 |
+
fpx = interpolate(
|
1047 |
+
torch.sigmoid(x).permute(0, 2, 1).contiguous(), 8 * self.patch_stride[1]
|
1048 |
+
)
|
1049 |
+
|
1050 |
+
x = self.avgpool(x)
|
1051 |
+
x = torch.flatten(x, 1)
|
1052 |
+
|
1053 |
+
output_dict = {
|
1054 |
+
"framewise_output": fpx, # already sigmoided
|
1055 |
+
"clipwise_output": torch.sigmoid(x),
|
1056 |
+
"fine_grained_embedding": fine_grained_latent_output,
|
1057 |
+
"embedding": latent_output,
|
1058 |
+
}
|
1059 |
+
|
1060 |
+
return output_dict
|
1061 |
+
|
1062 |
+
def crop_wav(self, x, crop_size, spe_pos=None):
|
1063 |
+
time_steps = x.shape[2]
|
1064 |
+
tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)
|
1065 |
+
for i in range(len(x)):
|
1066 |
+
if spe_pos is None:
|
1067 |
+
crop_pos = random.randint(0, time_steps - crop_size - 1)
|
1068 |
+
else:
|
1069 |
+
crop_pos = spe_pos
|
1070 |
+
tx[i][0] = x[i, 0, crop_pos : crop_pos + crop_size, :]
|
1071 |
+
return tx
|
1072 |
+
|
1073 |
+
# Reshape the wavform to a img size, if you want to use the pretrained swin transformer model
|
1074 |
+
def reshape_wav2img(self, x):
|
1075 |
+
B, C, T, F = x.shape
|
1076 |
+
target_T = int(self.spec_size * self.freq_ratio)
|
1077 |
+
target_F = self.spec_size // self.freq_ratio
|
1078 |
+
assert (
|
1079 |
+
T <= target_T and F <= target_F
|
1080 |
+
), "the wav size should less than or equal to the swin input size"
|
1081 |
+
# to avoid bicubic zero error
|
1082 |
+
if T < target_T:
|
1083 |
+
x = nn.functional.interpolate(
|
1084 |
+
x, (target_T, x.shape[3]), mode="bicubic", align_corners=True
|
1085 |
+
)
|
1086 |
+
if F < target_F:
|
1087 |
+
x = nn.functional.interpolate(
|
1088 |
+
x, (x.shape[2], target_F), mode="bicubic", align_corners=True
|
1089 |
+
)
|
1090 |
+
x = x.permute(0, 1, 3, 2).contiguous()
|
1091 |
+
x = x.reshape(
|
1092 |
+
x.shape[0],
|
1093 |
+
x.shape[1],
|
1094 |
+
x.shape[2],
|
1095 |
+
self.freq_ratio,
|
1096 |
+
x.shape[3] // self.freq_ratio,
|
1097 |
+
)
|
1098 |
+
# print(x.shape)
|
1099 |
+
x = x.permute(0, 1, 3, 2, 4).contiguous()
|
1100 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])
|
1101 |
+
return x
|
1102 |
+
|
1103 |
+
# Repeat the wavform to a img size, if you want to use the pretrained swin transformer model
|
1104 |
+
def repeat_wat2img(self, x, cur_pos):
|
1105 |
+
B, C, T, F = x.shape
|
1106 |
+
target_T = int(self.spec_size * self.freq_ratio)
|
1107 |
+
target_F = self.spec_size // self.freq_ratio
|
1108 |
+
assert (
|
1109 |
+
T <= target_T and F <= target_F
|
1110 |
+
), "the wav size should less than or equal to the swin input size"
|
1111 |
+
# to avoid bicubic zero error
|
1112 |
+
if T < target_T:
|
1113 |
+
x = nn.functional.interpolate(
|
1114 |
+
x, (target_T, x.shape[3]), mode="bicubic", align_corners=True
|
1115 |
+
)
|
1116 |
+
if F < target_F:
|
1117 |
+
x = nn.functional.interpolate(
|
1118 |
+
x, (x.shape[2], target_F), mode="bicubic", align_corners=True
|
1119 |
+
)
|
1120 |
+
x = x.permute(0, 1, 3, 2).contiguous() # B C F T
|
1121 |
+
x = x[:, :, :, cur_pos : cur_pos + self.spec_size]
|
1122 |
+
x = x.repeat(repeats=(1, 1, 4, 1))
|
1123 |
+
return x
|
1124 |
+
|
1125 |
+
def forward(
|
1126 |
+
self, x: torch.Tensor, mixup_lambda=None, infer_mode=False, device=None
|
1127 |
+
): # out_feat_keys: List[str] = None):
|
1128 |
+
if self.enable_fusion and x["longer"].sum() == 0:
|
1129 |
+
# if no audio is longer than 10s, then randomly select one audio to be longer
|
1130 |
+
x["longer"][torch.randint(0, x["longer"].shape[0], (1,))] = True
|
1131 |
+
|
1132 |
+
if not self.enable_fusion:
|
1133 |
+
x = x["waveform"].to(device=device, non_blocking=True)
|
1134 |
+
x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
|
1135 |
+
x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
|
1136 |
+
x = x.transpose(1, 3)
|
1137 |
+
x = self.bn0(x)
|
1138 |
+
x = x.transpose(1, 3)
|
1139 |
+
if self.training:
|
1140 |
+
x = self.spec_augmenter(x)
|
1141 |
+
|
1142 |
+
if self.training and mixup_lambda is not None:
|
1143 |
+
x = do_mixup(x, mixup_lambda)
|
1144 |
+
|
1145 |
+
x = self.reshape_wav2img(x)
|
1146 |
+
output_dict = self.forward_features(x)
|
1147 |
+
else:
|
1148 |
+
longer_list = x["longer"].to(device=device, non_blocking=True)
|
1149 |
+
x = x["mel_fusion"].to(device=device, non_blocking=True)
|
1150 |
+
x = x.transpose(1, 3)
|
1151 |
+
x = self.bn0(x)
|
1152 |
+
x = x.transpose(1, 3)
|
1153 |
+
longer_list_idx = torch.where(longer_list)[0]
|
1154 |
+
if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]:
|
1155 |
+
new_x = x[:, 0:1, :, :].clone().contiguous()
|
1156 |
+
if len(longer_list_idx) > 0:
|
1157 |
+
# local processing
|
1158 |
+
fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous()
|
1159 |
+
FB, FC, FT, FF = fusion_x_local.size()
|
1160 |
+
fusion_x_local = fusion_x_local.view(FB * FC, FT, FF)
|
1161 |
+
fusion_x_local = torch.permute(
|
1162 |
+
fusion_x_local, (0, 2, 1)
|
1163 |
+
).contiguous()
|
1164 |
+
fusion_x_local = self.mel_conv1d(fusion_x_local)
|
1165 |
+
fusion_x_local = fusion_x_local.view(
|
1166 |
+
FB, FC, FF, fusion_x_local.size(-1)
|
1167 |
+
)
|
1168 |
+
fusion_x_local = (
|
1169 |
+
torch.permute(fusion_x_local, (0, 2, 1, 3))
|
1170 |
+
.contiguous()
|
1171 |
+
.flatten(2)
|
1172 |
+
)
|
1173 |
+
if fusion_x_local.size(-1) < FT:
|
1174 |
+
fusion_x_local = torch.cat(
|
1175 |
+
[
|
1176 |
+
fusion_x_local,
|
1177 |
+
torch.zeros(
|
1178 |
+
(FB, FF, FT - fusion_x_local.size(-1)),
|
1179 |
+
device=device,
|
1180 |
+
),
|
1181 |
+
],
|
1182 |
+
dim=-1,
|
1183 |
+
)
|
1184 |
+
else:
|
1185 |
+
fusion_x_local = fusion_x_local[:, :, :FT]
|
1186 |
+
# 1D fusion
|
1187 |
+
new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous()
|
1188 |
+
new_x[longer_list_idx] = self.fusion_model(
|
1189 |
+
new_x[longer_list_idx], fusion_x_local
|
1190 |
+
)
|
1191 |
+
x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :]
|
1192 |
+
else:
|
1193 |
+
x = new_x
|
1194 |
+
|
1195 |
+
elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]:
|
1196 |
+
x = x # no change
|
1197 |
+
|
1198 |
+
if self.training:
|
1199 |
+
x = self.spec_augmenter(x)
|
1200 |
+
if self.training and mixup_lambda is not None:
|
1201 |
+
x = do_mixup(x, mixup_lambda)
|
1202 |
+
|
1203 |
+
x = self.reshape_wav2img(x)
|
1204 |
+
output_dict = self.forward_features(x, longer_idx=longer_list_idx)
|
1205 |
+
|
1206 |
+
# if infer_mode:
|
1207 |
+
# # in infer mode. we need to handle different length audio input
|
1208 |
+
# frame_num = x.shape[2]
|
1209 |
+
# target_T = int(self.spec_size * self.freq_ratio)
|
1210 |
+
# repeat_ratio = math.floor(target_T / frame_num)
|
1211 |
+
# x = x.repeat(repeats=(1,1,repeat_ratio,1))
|
1212 |
+
# x = self.reshape_wav2img(x)
|
1213 |
+
# output_dict = self.forward_features(x)
|
1214 |
+
# else:
|
1215 |
+
# if x.shape[2] > self.freq_ratio * self.spec_size:
|
1216 |
+
# if self.training:
|
1217 |
+
# x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size)
|
1218 |
+
# x = self.reshape_wav2img(x)
|
1219 |
+
# output_dict = self.forward_features(x)
|
1220 |
+
# else:
|
1221 |
+
# # Change: Hard code here
|
1222 |
+
# overlap_size = (x.shape[2] - 1) // 4
|
1223 |
+
# output_dicts = []
|
1224 |
+
# crop_size = (x.shape[2] - 1) // 2
|
1225 |
+
# for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size):
|
1226 |
+
# tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos)
|
1227 |
+
# tx = self.reshape_wav2img(tx)
|
1228 |
+
# output_dicts.append(self.forward_features(tx))
|
1229 |
+
# clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
|
1230 |
+
# framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
|
1231 |
+
# for d in output_dicts:
|
1232 |
+
# clipwise_output += d["clipwise_output"]
|
1233 |
+
# framewise_output += d["framewise_output"]
|
1234 |
+
# clipwise_output = clipwise_output / len(output_dicts)
|
1235 |
+
# framewise_output = framewise_output / len(output_dicts)
|
1236 |
+
# output_dict = {
|
1237 |
+
# 'framewise_output': framewise_output,
|
1238 |
+
# 'clipwise_output': clipwise_output
|
1239 |
+
# }
|
1240 |
+
# else: # this part is typically used, and most easy one
|
1241 |
+
# x = self.reshape_wav2img(x)
|
1242 |
+
# output_dict = self.forward_features(x)
|
1243 |
+
# x = self.head(x)
|
1244 |
+
|
1245 |
+
# We process the data in the dataloader part, in that here we only consider the input_T < fixed_T
|
1246 |
+
|
1247 |
+
return output_dict
|
1248 |
+
|
1249 |
+
|
1250 |
+
def create_htsat_model(audio_cfg, enable_fusion=False, fusion_type="None"):
|
1251 |
+
try:
|
1252 |
+
assert audio_cfg.model_name in [
|
1253 |
+
"tiny",
|
1254 |
+
"base",
|
1255 |
+
"large",
|
1256 |
+
], "model name for HTS-AT is wrong!"
|
1257 |
+
if audio_cfg.model_name == "tiny":
|
1258 |
+
model = HTSAT_Swin_Transformer(
|
1259 |
+
spec_size=256,
|
1260 |
+
patch_size=4,
|
1261 |
+
patch_stride=(4, 4),
|
1262 |
+
num_classes=audio_cfg.class_num,
|
1263 |
+
embed_dim=96,
|
1264 |
+
depths=[2, 2, 6, 2],
|
1265 |
+
num_heads=[4, 8, 16, 32],
|
1266 |
+
window_size=8,
|
1267 |
+
config=audio_cfg,
|
1268 |
+
enable_fusion=enable_fusion,
|
1269 |
+
fusion_type=fusion_type,
|
1270 |
+
)
|
1271 |
+
elif audio_cfg.model_name == "base":
|
1272 |
+
model = HTSAT_Swin_Transformer(
|
1273 |
+
spec_size=256,
|
1274 |
+
patch_size=4,
|
1275 |
+
patch_stride=(4, 4),
|
1276 |
+
num_classes=audio_cfg.class_num,
|
1277 |
+
embed_dim=128,
|
1278 |
+
depths=[2, 2, 12, 2],
|
1279 |
+
num_heads=[4, 8, 16, 32],
|
1280 |
+
window_size=8,
|
1281 |
+
config=audio_cfg,
|
1282 |
+
enable_fusion=enable_fusion,
|
1283 |
+
fusion_type=fusion_type,
|
1284 |
+
)
|
1285 |
+
elif audio_cfg.model_name == "large":
|
1286 |
+
model = HTSAT_Swin_Transformer(
|
1287 |
+
spec_size=256,
|
1288 |
+
patch_size=4,
|
1289 |
+
patch_stride=(4, 4),
|
1290 |
+
num_classes=audio_cfg.class_num,
|
1291 |
+
embed_dim=256,
|
1292 |
+
depths=[2, 2, 12, 2],
|
1293 |
+
num_heads=[4, 8, 16, 32],
|
1294 |
+
window_size=8,
|
1295 |
+
config=audio_cfg,
|
1296 |
+
enable_fusion=enable_fusion,
|
1297 |
+
fusion_type=fusion_type,
|
1298 |
+
)
|
1299 |
+
|
1300 |
+
return model
|
1301 |
+
except:
|
1302 |
+
raise RuntimeError(
|
1303 |
+
f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough."
|
1304 |
+
)
|
audiosr/clap/open_clip/loss.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.distributed.nn
|
3 |
+
from torch import distributed as dist, nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
import numpy as np
|
6 |
+
from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score
|
7 |
+
|
8 |
+
try:
|
9 |
+
import horovod.torch as hvd
|
10 |
+
except ImportError:
|
11 |
+
hvd = None
|
12 |
+
|
13 |
+
|
14 |
+
def gather_features(
|
15 |
+
audio_features,
|
16 |
+
text_features,
|
17 |
+
audio_features_mlp=None,
|
18 |
+
text_features_mlp=None,
|
19 |
+
local_loss=False,
|
20 |
+
gather_with_grad=False,
|
21 |
+
rank=0,
|
22 |
+
world_size=1,
|
23 |
+
use_horovod=False,
|
24 |
+
mlp_loss=False,
|
25 |
+
):
|
26 |
+
if use_horovod:
|
27 |
+
assert hvd is not None, "Please install horovod"
|
28 |
+
if gather_with_grad:
|
29 |
+
all_audio_features = hvd.allgather(audio_features)
|
30 |
+
all_text_features = hvd.allgather(text_features)
|
31 |
+
if mlp_loss:
|
32 |
+
all_audio_features_mlp = hvd.allgather(audio_features_mlp)
|
33 |
+
all_text_features_mlp = hvd.allgather(text_features_mlp)
|
34 |
+
else:
|
35 |
+
with torch.no_grad():
|
36 |
+
all_audio_features = hvd.allgather(audio_features)
|
37 |
+
all_text_features = hvd.allgather(text_features)
|
38 |
+
if mlp_loss:
|
39 |
+
all_audio_features_mlp = hvd.allgather(audio_features_mlp)
|
40 |
+
all_text_features_mlp = hvd.allgather(text_features_mlp)
|
41 |
+
if not local_loss:
|
42 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
43 |
+
gathered_audio_features = list(
|
44 |
+
all_audio_features.chunk(world_size, dim=0)
|
45 |
+
)
|
46 |
+
gathered_text_features = list(
|
47 |
+
all_text_features.chunk(world_size, dim=0)
|
48 |
+
)
|
49 |
+
gathered_audio_features[rank] = audio_features
|
50 |
+
gathered_text_features[rank] = text_features
|
51 |
+
all_audio_features = torch.cat(gathered_audio_features, dim=0)
|
52 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
53 |
+
if mlp_loss:
|
54 |
+
gathered_audio_features_mlp = list(
|
55 |
+
all_audio_features_mlp.chunk(world_size, dim=0)
|
56 |
+
)
|
57 |
+
gathered_text_features_mlp = list(
|
58 |
+
all_text_features_mlp.chunk(world_size, dim=0)
|
59 |
+
)
|
60 |
+
gathered_audio_features_mlp[rank] = audio_features_mlp
|
61 |
+
gathered_text_features_mlp[rank] = text_features_mlp
|
62 |
+
all_audio_features_mlp = torch.cat(
|
63 |
+
gathered_audio_features_mlp, dim=0
|
64 |
+
)
|
65 |
+
all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
|
66 |
+
else:
|
67 |
+
# We gather tensors from all gpus
|
68 |
+
if gather_with_grad:
|
69 |
+
all_audio_features = torch.cat(
|
70 |
+
torch.distributed.nn.all_gather(audio_features), dim=0
|
71 |
+
)
|
72 |
+
all_text_features = torch.cat(
|
73 |
+
torch.distributed.nn.all_gather(text_features), dim=0
|
74 |
+
)
|
75 |
+
if mlp_loss:
|
76 |
+
all_audio_features_mlp = torch.cat(
|
77 |
+
torch.distributed.nn.all_gather(audio_features_mlp), dim=0
|
78 |
+
)
|
79 |
+
all_text_features_mlp = torch.cat(
|
80 |
+
torch.distributed.nn.all_gather(text_features_mlp), dim=0
|
81 |
+
)
|
82 |
+
else:
|
83 |
+
gathered_audio_features = [
|
84 |
+
torch.zeros_like(audio_features) for _ in range(world_size)
|
85 |
+
]
|
86 |
+
gathered_text_features = [
|
87 |
+
torch.zeros_like(text_features) for _ in range(world_size)
|
88 |
+
]
|
89 |
+
dist.all_gather(gathered_audio_features, audio_features)
|
90 |
+
dist.all_gather(gathered_text_features, text_features)
|
91 |
+
if mlp_loss:
|
92 |
+
gathered_audio_features_mlp = [
|
93 |
+
torch.zeros_like(audio_features_mlp) for _ in range(world_size)
|
94 |
+
]
|
95 |
+
gathered_text_features_mlp = [
|
96 |
+
torch.zeros_like(text_features_mlp) for _ in range(world_size)
|
97 |
+
]
|
98 |
+
dist.all_gather(gathered_audio_features_mlp, audio_features_mlp)
|
99 |
+
dist.all_gather(gathered_text_features_mlp, text_features_mlp)
|
100 |
+
if not local_loss:
|
101 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
102 |
+
gathered_audio_features[rank] = audio_features
|
103 |
+
gathered_text_features[rank] = text_features
|
104 |
+
if mlp_loss:
|
105 |
+
gathered_audio_features_mlp[rank] = audio_features_mlp
|
106 |
+
gathered_text_features_mlp[rank] = text_features_mlp
|
107 |
+
|
108 |
+
all_audio_features = torch.cat(gathered_audio_features, dim=0)
|
109 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
110 |
+
if mlp_loss:
|
111 |
+
all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0)
|
112 |
+
all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
|
113 |
+
if mlp_loss:
|
114 |
+
return (
|
115 |
+
all_audio_features,
|
116 |
+
all_text_features,
|
117 |
+
all_audio_features_mlp,
|
118 |
+
all_text_features_mlp,
|
119 |
+
)
|
120 |
+
else:
|
121 |
+
return all_audio_features, all_text_features
|
122 |
+
|
123 |
+
|
124 |
+
class ClipLoss(nn.Module):
|
125 |
+
def __init__(
|
126 |
+
self,
|
127 |
+
local_loss=False,
|
128 |
+
gather_with_grad=False,
|
129 |
+
cache_labels=False,
|
130 |
+
rank=0,
|
131 |
+
world_size=1,
|
132 |
+
use_horovod=False,
|
133 |
+
mlp_loss=False,
|
134 |
+
weight_loss_kappa=0,
|
135 |
+
):
|
136 |
+
super().__init__()
|
137 |
+
self.local_loss = local_loss
|
138 |
+
self.gather_with_grad = gather_with_grad
|
139 |
+
self.cache_labels = cache_labels
|
140 |
+
self.rank = rank
|
141 |
+
self.world_size = world_size
|
142 |
+
self.use_horovod = use_horovod
|
143 |
+
self.mlp_loss = mlp_loss
|
144 |
+
self.weighted_loss = bool(weight_loss_kappa != 0)
|
145 |
+
self.weight_loss_kappa = weight_loss_kappa
|
146 |
+
# cache state
|
147 |
+
self.prev_num_logits = 0
|
148 |
+
self.labels = {}
|
149 |
+
|
150 |
+
def forward(
|
151 |
+
self,
|
152 |
+
audio_features,
|
153 |
+
text_features,
|
154 |
+
logit_scale_a,
|
155 |
+
logit_scale_t=None,
|
156 |
+
audio_features_mlp=None,
|
157 |
+
text_features_mlp=None,
|
158 |
+
):
|
159 |
+
device = audio_features.device
|
160 |
+
if self.mlp_loss:
|
161 |
+
if self.world_size > 1:
|
162 |
+
(
|
163 |
+
all_audio_features,
|
164 |
+
all_text_features,
|
165 |
+
all_audio_features_mlp,
|
166 |
+
all_text_features_mlp,
|
167 |
+
) = gather_features(
|
168 |
+
audio_features=audio_features,
|
169 |
+
text_features=text_features,
|
170 |
+
audio_features_mlp=audio_features_mlp,
|
171 |
+
text_features_mlp=text_features_mlp,
|
172 |
+
local_loss=self.local_loss,
|
173 |
+
gather_with_grad=self.gather_with_grad,
|
174 |
+
rank=self.rank,
|
175 |
+
world_size=self.world_size,
|
176 |
+
use_horovod=self.use_horovod,
|
177 |
+
mlp_loss=self.mlp_loss,
|
178 |
+
)
|
179 |
+
if self.local_loss:
|
180 |
+
a_logits_per_audio = (
|
181 |
+
logit_scale_a * audio_features @ all_text_features_mlp.T
|
182 |
+
)
|
183 |
+
a_logits_per_text = (
|
184 |
+
logit_scale_a * text_features_mlp @ all_audio_features.T
|
185 |
+
)
|
186 |
+
t_logits_per_audio = (
|
187 |
+
logit_scale_t * audio_features_mlp @ all_text_features.T
|
188 |
+
)
|
189 |
+
t_logits_per_text = (
|
190 |
+
logit_scale_t * text_features @ all_audio_features_mlp.T
|
191 |
+
)
|
192 |
+
else:
|
193 |
+
a_logits_per_audio = (
|
194 |
+
logit_scale_a * all_audio_features @ all_text_features_mlp.T
|
195 |
+
)
|
196 |
+
a_logits_per_text = a_logits_per_audio.T
|
197 |
+
t_logits_per_audio = (
|
198 |
+
logit_scale_t * all_audio_features_mlp @ all_text_features.T
|
199 |
+
)
|
200 |
+
t_logits_per_text = t_logits_per_audio.T
|
201 |
+
else:
|
202 |
+
a_logits_per_audio = (
|
203 |
+
logit_scale_a * audio_features @ text_features_mlp.T
|
204 |
+
)
|
205 |
+
a_logits_per_text = logit_scale_a * text_features_mlp @ audio_features.T
|
206 |
+
t_logits_per_audio = (
|
207 |
+
logit_scale_t * audio_features_mlp @ text_features.T
|
208 |
+
)
|
209 |
+
t_logits_per_text = logit_scale_t * text_features @ audio_features_mlp.T
|
210 |
+
|
211 |
+
# calculated ground-truth and cache if enabled
|
212 |
+
num_logits = a_logits_per_audio.shape[0]
|
213 |
+
if self.prev_num_logits != num_logits or device not in self.labels:
|
214 |
+
labels = torch.arange(num_logits, device=device, dtype=torch.long)
|
215 |
+
if self.world_size > 1 and self.local_loss:
|
216 |
+
labels = labels + num_logits * self.rank
|
217 |
+
if self.cache_labels:
|
218 |
+
self.labels[device] = labels
|
219 |
+
self.prev_num_logits = num_logits
|
220 |
+
else:
|
221 |
+
labels = self.labels[device]
|
222 |
+
|
223 |
+
if not self.weighted_loss:
|
224 |
+
total_loss = (
|
225 |
+
F.cross_entropy(a_logits_per_audio, labels)
|
226 |
+
+ F.cross_entropy(a_logits_per_text, labels)
|
227 |
+
+ F.cross_entropy(t_logits_per_audio, labels)
|
228 |
+
+ F.cross_entropy(t_logits_per_text, labels)
|
229 |
+
) / 4
|
230 |
+
else:
|
231 |
+
audio_weight = (audio_features @ audio_features.T).detach()
|
232 |
+
audio_weight = (
|
233 |
+
torch.exp(
|
234 |
+
torch.sum(audio_weight, axis=1)
|
235 |
+
/ (self.weight_loss_kappa * len(audio_weight))
|
236 |
+
)
|
237 |
+
).detach()
|
238 |
+
text_weight = (text_features @ text_features.T).detach()
|
239 |
+
text_weight = (
|
240 |
+
torch.exp(
|
241 |
+
torch.sum(text_weight, axis=1)
|
242 |
+
/ (self.weight_loss_kappa * len(text_features))
|
243 |
+
)
|
244 |
+
).detach()
|
245 |
+
total_loss = (
|
246 |
+
F.cross_entropy(a_logits_per_audio, labels, weight=audio_weight)
|
247 |
+
+ F.cross_entropy(a_logits_per_text, labels, weight=audio_weight)
|
248 |
+
+ F.cross_entropy(t_logits_per_audio, labels, weight=text_weight)
|
249 |
+
+ F.cross_entropy(t_logits_per_text, labels, weight=text_weight)
|
250 |
+
) / 4
|
251 |
+
else:
|
252 |
+
if self.world_size > 1:
|
253 |
+
all_audio_features, all_text_features = gather_features(
|
254 |
+
audio_features=audio_features,
|
255 |
+
text_features=text_features,
|
256 |
+
local_loss=self.local_loss,
|
257 |
+
gather_with_grad=self.gather_with_grad,
|
258 |
+
rank=self.rank,
|
259 |
+
world_size=self.world_size,
|
260 |
+
use_horovod=self.use_horovod,
|
261 |
+
mlp_loss=self.mlp_loss,
|
262 |
+
)
|
263 |
+
|
264 |
+
if self.local_loss:
|
265 |
+
logits_per_audio = (
|
266 |
+
logit_scale_a * audio_features @ all_text_features.T
|
267 |
+
)
|
268 |
+
logits_per_text = (
|
269 |
+
logit_scale_a * text_features @ all_audio_features.T
|
270 |
+
)
|
271 |
+
else:
|
272 |
+
logits_per_audio = (
|
273 |
+
logit_scale_a * all_audio_features @ all_text_features.T
|
274 |
+
)
|
275 |
+
logits_per_text = logits_per_audio.T
|
276 |
+
else:
|
277 |
+
logits_per_audio = logit_scale_a * audio_features @ text_features.T
|
278 |
+
logits_per_text = logit_scale_a * text_features @ audio_features.T
|
279 |
+
|
280 |
+
# calculated ground-truth and cache if enabled
|
281 |
+
num_logits = logits_per_audio.shape[0]
|
282 |
+
if self.prev_num_logits != num_logits or device not in self.labels:
|
283 |
+
labels = torch.arange(num_logits, device=device, dtype=torch.long)
|
284 |
+
if self.world_size > 1 and self.local_loss:
|
285 |
+
labels = labels + num_logits * self.rank
|
286 |
+
if self.cache_labels:
|
287 |
+
self.labels[device] = labels
|
288 |
+
self.prev_num_logits = num_logits
|
289 |
+
else:
|
290 |
+
labels = self.labels[device]
|
291 |
+
if not self.weighted_loss:
|
292 |
+
total_loss = (
|
293 |
+
F.cross_entropy(logits_per_audio, labels)
|
294 |
+
+ F.cross_entropy(logits_per_text, labels)
|
295 |
+
) / 2
|
296 |
+
else:
|
297 |
+
audio_weight = (all_audio_features @ all_audio_features.T).detach()
|
298 |
+
audio_weight = (
|
299 |
+
torch.exp(
|
300 |
+
torch.sum(audio_weight, axis=1)
|
301 |
+
/ (self.weight_loss_kappa * len(all_audio_features))
|
302 |
+
)
|
303 |
+
).detach()
|
304 |
+
text_weight = (all_text_features @ all_text_features.T).detach()
|
305 |
+
text_weight = (
|
306 |
+
torch.exp(
|
307 |
+
torch.sum(text_weight, axis=1)
|
308 |
+
/ (self.weight_loss_kappa * len(all_text_features))
|
309 |
+
)
|
310 |
+
).detach()
|
311 |
+
total_loss = (
|
312 |
+
F.cross_entropy(logits_per_audio, labels, weight=text_weight)
|
313 |
+
+ F.cross_entropy(logits_per_text, labels, weight=audio_weight)
|
314 |
+
) / 2
|
315 |
+
return total_loss
|
316 |
+
|
317 |
+
|
318 |
+
def lp_gather_features(pred, target, world_size=1, use_horovod=False):
|
319 |
+
if use_horovod:
|
320 |
+
assert hvd is not None, "Please install horovod"
|
321 |
+
with torch.no_grad():
|
322 |
+
all_preds = hvd.allgather(pred)
|
323 |
+
all_targets = hvd.allgath(target)
|
324 |
+
else:
|
325 |
+
gathered_preds = [torch.zeros_like(pred) for _ in range(world_size)]
|
326 |
+
gathered_targets = [torch.zeros_like(target) for _ in range(world_size)]
|
327 |
+
|
328 |
+
dist.all_gather(gathered_preds, pred)
|
329 |
+
dist.all_gather(gathered_targets, target)
|
330 |
+
all_preds = torch.cat(gathered_preds, dim=0)
|
331 |
+
all_targets = torch.cat(gathered_targets, dim=0)
|
332 |
+
|
333 |
+
return all_preds, all_targets
|
334 |
+
|
335 |
+
|
336 |
+
def get_map(pred, target):
|
337 |
+
pred = torch.sigmoid(pred).numpy()
|
338 |
+
target = target.numpy()
|
339 |
+
return np.mean(average_precision_score(target, pred, average=None))
|
340 |
+
|
341 |
+
|
342 |
+
def get_acc(pred, target):
|
343 |
+
pred = torch.argmax(pred, 1).numpy()
|
344 |
+
target = torch.argmax(target, 1).numpy()
|
345 |
+
return accuracy_score(target, pred)
|
346 |
+
|
347 |
+
|
348 |
+
def get_mauc(pred, target):
|
349 |
+
pred = torch.sigmoid(pred).numpy()
|
350 |
+
target = target.numpy()
|
351 |
+
return np.mean(roc_auc_score(target, pred, average=None))
|
352 |
+
|
353 |
+
|
354 |
+
class LPMetrics(object):
|
355 |
+
def __init__(self, metric_names=["map", "acc", "mauc"]):
|
356 |
+
self.metrics = []
|
357 |
+
for name in metric_names:
|
358 |
+
self.metrics.append(self.get_metric(name))
|
359 |
+
self.metric_names = metric_names
|
360 |
+
|
361 |
+
def get_metric(self, name):
|
362 |
+
if name == "map":
|
363 |
+
return get_map
|
364 |
+
elif name == "acc":
|
365 |
+
return get_acc
|
366 |
+
elif name == "mauc":
|
367 |
+
return get_mauc
|
368 |
+
else:
|
369 |
+
raise ValueError(f"the metric should be at least one of [map, acc, mauc]")
|
370 |
+
|
371 |
+
def evaluate_mertics(self, pred, target):
|
372 |
+
metric_dict = {}
|
373 |
+
for i in range(len(self.metric_names)):
|
374 |
+
metric_dict[self.metric_names[i]] = self.metrics[i](pred, target)
|
375 |
+
return metric_dict
|
376 |
+
|
377 |
+
|
378 |
+
def calc_celoss(pred, target):
|
379 |
+
target = torch.argmax(target, 1).long()
|
380 |
+
return nn.CrossEntropyLoss()(pred, target)
|
381 |
+
|
382 |
+
|
383 |
+
class LPLoss(nn.Module):
|
384 |
+
def __init__(self, loss_name):
|
385 |
+
super().__init__()
|
386 |
+
if loss_name == "bce":
|
387 |
+
self.loss_func = nn.BCEWithLogitsLoss()
|
388 |
+
elif loss_name == "ce":
|
389 |
+
self.loss_func = calc_celoss
|
390 |
+
elif loss_name == "mse":
|
391 |
+
self.loss_func = nn.MSELoss()
|
392 |
+
else:
|
393 |
+
raise ValueError(f"the loss func should be at least one of [bce, ce, mse]")
|
394 |
+
|
395 |
+
def forward(self, pred, target):
|
396 |
+
loss = self.loss_func(pred, target)
|
397 |
+
return loss
|
audiosr/clap/open_clip/model.py
ADDED
@@ -0,0 +1,931 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" CLAP Model
|
2 |
+
|
3 |
+
Adapted from CLIP: https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
4 |
+
Adapted to the Audio Task.
|
5 |
+
"""
|
6 |
+
|
7 |
+
from collections import OrderedDict
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from typing import Tuple, Union, Callable, Optional
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from torch import nn
|
15 |
+
|
16 |
+
import logging
|
17 |
+
from .utils import freeze_batch_norm_2d
|
18 |
+
|
19 |
+
from .pann_model import create_pann_model
|
20 |
+
from .htsat import create_htsat_model
|
21 |
+
from transformers import BertModel, RobertaModel, BartModel, RobertaConfig
|
22 |
+
|
23 |
+
|
24 |
+
class MLPLayers(nn.Module):
|
25 |
+
def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1):
|
26 |
+
super(MLPLayers, self).__init__()
|
27 |
+
self.nonlin = nonlin
|
28 |
+
self.dropout = dropout
|
29 |
+
|
30 |
+
sequence = []
|
31 |
+
for u0, u1 in zip(units[:-1], units[1:]):
|
32 |
+
sequence.append(nn.Linear(u0, u1))
|
33 |
+
sequence.append(self.nonlin)
|
34 |
+
sequence.append(nn.Dropout(self.dropout))
|
35 |
+
sequence = sequence[:-2]
|
36 |
+
|
37 |
+
self.sequential = nn.Sequential(*sequence)
|
38 |
+
|
39 |
+
def forward(self, X):
|
40 |
+
X = self.sequential(X)
|
41 |
+
return X
|
42 |
+
|
43 |
+
|
44 |
+
class Bottleneck(nn.Module):
|
45 |
+
expansion = 4
|
46 |
+
|
47 |
+
def __init__(self, inplanes, planes, stride=1):
|
48 |
+
super().__init__()
|
49 |
+
|
50 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
51 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
52 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
53 |
+
|
54 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
55 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
56 |
+
|
57 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
58 |
+
|
59 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
60 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
61 |
+
|
62 |
+
self.relu = nn.ReLU(inplace=True)
|
63 |
+
self.downsample = None
|
64 |
+
self.stride = stride
|
65 |
+
|
66 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
67 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
68 |
+
self.downsample = nn.Sequential(
|
69 |
+
OrderedDict(
|
70 |
+
[
|
71 |
+
("-1", nn.AvgPool2d(stride)),
|
72 |
+
(
|
73 |
+
"0",
|
74 |
+
nn.Conv2d(
|
75 |
+
inplanes,
|
76 |
+
planes * self.expansion,
|
77 |
+
1,
|
78 |
+
stride=1,
|
79 |
+
bias=False,
|
80 |
+
),
|
81 |
+
),
|
82 |
+
("1", nn.BatchNorm2d(planes * self.expansion)),
|
83 |
+
]
|
84 |
+
)
|
85 |
+
)
|
86 |
+
|
87 |
+
def forward(self, x: torch.Tensor):
|
88 |
+
identity = x
|
89 |
+
|
90 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
91 |
+
out = self.relu(self.bn2(self.conv2(out)))
|
92 |
+
out = self.avgpool(out)
|
93 |
+
out = self.bn3(self.conv3(out))
|
94 |
+
|
95 |
+
if self.downsample is not None:
|
96 |
+
identity = self.downsample(x)
|
97 |
+
|
98 |
+
out += identity
|
99 |
+
out = self.relu(out)
|
100 |
+
return out
|
101 |
+
|
102 |
+
|
103 |
+
class AttentionPool2d(nn.Module):
|
104 |
+
def __init__(
|
105 |
+
self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
|
106 |
+
):
|
107 |
+
super().__init__()
|
108 |
+
self.positional_embedding = nn.Parameter(
|
109 |
+
torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5
|
110 |
+
)
|
111 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
112 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
113 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
114 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
115 |
+
self.num_heads = num_heads
|
116 |
+
|
117 |
+
def forward(self, x):
|
118 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(
|
119 |
+
2, 0, 1
|
120 |
+
) # NCHW -> (HW)NC
|
121 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
122 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
123 |
+
x, _ = F.multi_head_attention_forward(
|
124 |
+
query=x,
|
125 |
+
key=x,
|
126 |
+
value=x,
|
127 |
+
embed_dim_to_check=x.shape[-1],
|
128 |
+
num_heads=self.num_heads,
|
129 |
+
q_proj_weight=self.q_proj.weight,
|
130 |
+
k_proj_weight=self.k_proj.weight,
|
131 |
+
v_proj_weight=self.v_proj.weight,
|
132 |
+
in_proj_weight=None,
|
133 |
+
in_proj_bias=torch.cat(
|
134 |
+
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
|
135 |
+
),
|
136 |
+
bias_k=None,
|
137 |
+
bias_v=None,
|
138 |
+
add_zero_attn=False,
|
139 |
+
dropout_p=0,
|
140 |
+
out_proj_weight=self.c_proj.weight,
|
141 |
+
out_proj_bias=self.c_proj.bias,
|
142 |
+
use_separate_proj_weight=True,
|
143 |
+
training=self.training,
|
144 |
+
need_weights=False,
|
145 |
+
)
|
146 |
+
|
147 |
+
return x[0]
|
148 |
+
|
149 |
+
|
150 |
+
class ModifiedResNet(nn.Module):
|
151 |
+
"""
|
152 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
153 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
154 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
155 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
156 |
+
"""
|
157 |
+
|
158 |
+
def __init__(self, layers, output_dim, heads, image_size=224, width=64):
|
159 |
+
super().__init__()
|
160 |
+
self.output_dim = output_dim
|
161 |
+
self.image_size = image_size
|
162 |
+
|
163 |
+
# the 3-layer stem
|
164 |
+
self.conv1 = nn.Conv2d(
|
165 |
+
3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
|
166 |
+
)
|
167 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
168 |
+
self.conv2 = nn.Conv2d(
|
169 |
+
width // 2, width // 2, kernel_size=3, padding=1, bias=False
|
170 |
+
)
|
171 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
172 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
173 |
+
self.bn3 = nn.BatchNorm2d(width)
|
174 |
+
self.avgpool = nn.AvgPool2d(2)
|
175 |
+
self.relu = nn.ReLU(inplace=True)
|
176 |
+
|
177 |
+
# residual layers
|
178 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
179 |
+
self.layer1 = self._make_layer(width, layers[0])
|
180 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
181 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
182 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
183 |
+
|
184 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
185 |
+
self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
|
186 |
+
|
187 |
+
self.init_parameters()
|
188 |
+
|
189 |
+
def _make_layer(self, planes, blocks, stride=1):
|
190 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
191 |
+
|
192 |
+
self._inplanes = planes * Bottleneck.expansion
|
193 |
+
for _ in range(1, blocks):
|
194 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
195 |
+
|
196 |
+
return nn.Sequential(*layers)
|
197 |
+
|
198 |
+
def init_parameters(self):
|
199 |
+
if self.attnpool is not None:
|
200 |
+
std = self.attnpool.c_proj.in_features**-0.5
|
201 |
+
nn.init.normal_(self.attnpool.q_proj.weight, std=std)
|
202 |
+
nn.init.normal_(self.attnpool.k_proj.weight, std=std)
|
203 |
+
nn.init.normal_(self.attnpool.v_proj.weight, std=std)
|
204 |
+
nn.init.normal_(self.attnpool.c_proj.weight, std=std)
|
205 |
+
|
206 |
+
for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
|
207 |
+
for name, param in resnet_block.named_parameters():
|
208 |
+
if name.endswith("bn3.weight"):
|
209 |
+
nn.init.zeros_(param)
|
210 |
+
|
211 |
+
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
212 |
+
assert (
|
213 |
+
unlocked_groups == 0
|
214 |
+
), "partial locking not currently supported for this model"
|
215 |
+
for param in self.parameters():
|
216 |
+
param.requires_grad = False
|
217 |
+
if freeze_bn_stats:
|
218 |
+
freeze_batch_norm_2d(self)
|
219 |
+
|
220 |
+
def stem(self, x):
|
221 |
+
for conv, bn in [
|
222 |
+
(self.conv1, self.bn1),
|
223 |
+
(self.conv2, self.bn2),
|
224 |
+
(self.conv3, self.bn3),
|
225 |
+
]:
|
226 |
+
x = self.relu(bn(conv(x)))
|
227 |
+
x = self.avgpool(x)
|
228 |
+
return x
|
229 |
+
|
230 |
+
def forward(self, x):
|
231 |
+
x = self.stem(x)
|
232 |
+
x = self.layer1(x)
|
233 |
+
x = self.layer2(x)
|
234 |
+
x = self.layer3(x)
|
235 |
+
x = self.layer4(x)
|
236 |
+
x = self.attnpool(x)
|
237 |
+
|
238 |
+
return x
|
239 |
+
|
240 |
+
|
241 |
+
class LayerNorm(nn.LayerNorm):
|
242 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
243 |
+
|
244 |
+
def forward(self, x: torch.Tensor):
|
245 |
+
orig_type = x.dtype
|
246 |
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
247 |
+
return x.to(orig_type)
|
248 |
+
|
249 |
+
|
250 |
+
class QuickGELU(nn.Module):
|
251 |
+
# NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
|
252 |
+
def forward(self, x: torch.Tensor):
|
253 |
+
return x * torch.sigmoid(1.702 * x)
|
254 |
+
|
255 |
+
|
256 |
+
class ResidualAttentionBlock(nn.Module):
|
257 |
+
def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU):
|
258 |
+
super().__init__()
|
259 |
+
|
260 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
261 |
+
self.ln_1 = LayerNorm(d_model)
|
262 |
+
self.mlp = nn.Sequential(
|
263 |
+
OrderedDict(
|
264 |
+
[
|
265 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
266 |
+
("gelu", act_layer()),
|
267 |
+
("c_proj", nn.Linear(d_model * 4, d_model)),
|
268 |
+
]
|
269 |
+
)
|
270 |
+
)
|
271 |
+
self.ln_2 = LayerNorm(d_model)
|
272 |
+
|
273 |
+
def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
274 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
|
275 |
+
|
276 |
+
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
277 |
+
x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
|
278 |
+
x = x + self.mlp(self.ln_2(x))
|
279 |
+
return x
|
280 |
+
|
281 |
+
|
282 |
+
class Transformer(nn.Module):
|
283 |
+
def __init__(
|
284 |
+
self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU
|
285 |
+
):
|
286 |
+
super().__init__()
|
287 |
+
self.width = width
|
288 |
+
self.layers = layers
|
289 |
+
self.resblocks = nn.ModuleList(
|
290 |
+
[
|
291 |
+
ResidualAttentionBlock(width, heads, act_layer=act_layer)
|
292 |
+
for _ in range(layers)
|
293 |
+
]
|
294 |
+
)
|
295 |
+
|
296 |
+
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
297 |
+
for r in self.resblocks:
|
298 |
+
x = r(x, attn_mask=attn_mask)
|
299 |
+
return x
|
300 |
+
|
301 |
+
|
302 |
+
class VisualTransformer(nn.Module):
|
303 |
+
def __init__(
|
304 |
+
self,
|
305 |
+
image_size: int,
|
306 |
+
patch_size: int,
|
307 |
+
width: int,
|
308 |
+
layers: int,
|
309 |
+
heads: int,
|
310 |
+
output_dim: int,
|
311 |
+
act_layer: Callable = nn.GELU,
|
312 |
+
):
|
313 |
+
super().__init__()
|
314 |
+
self.image_size = image_size
|
315 |
+
self.output_dim = output_dim
|
316 |
+
self.conv1 = nn.Conv2d(
|
317 |
+
in_channels=3,
|
318 |
+
out_channels=width,
|
319 |
+
kernel_size=patch_size,
|
320 |
+
stride=patch_size,
|
321 |
+
bias=False,
|
322 |
+
)
|
323 |
+
|
324 |
+
scale = width**-0.5
|
325 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
326 |
+
self.positional_embedding = nn.Parameter(
|
327 |
+
scale * torch.randn((image_size // patch_size) ** 2 + 1, width)
|
328 |
+
)
|
329 |
+
self.ln_pre = LayerNorm(width)
|
330 |
+
|
331 |
+
self.text_branch = Transformer(width, layers, heads, act_layer=act_layer)
|
332 |
+
|
333 |
+
self.ln_post = LayerNorm(width)
|
334 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
335 |
+
|
336 |
+
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
337 |
+
assert (
|
338 |
+
unlocked_groups == 0
|
339 |
+
), "partial locking not currently supported for this model"
|
340 |
+
for param in self.parameters():
|
341 |
+
param.requires_grad = False
|
342 |
+
|
343 |
+
def forward(self, x: torch.Tensor):
|
344 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
345 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
346 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
347 |
+
x = torch.cat(
|
348 |
+
[
|
349 |
+
self.class_embedding.to(x.dtype)
|
350 |
+
+ torch.zeros(
|
351 |
+
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
|
352 |
+
),
|
353 |
+
x,
|
354 |
+
],
|
355 |
+
dim=1,
|
356 |
+
) # shape = [*, grid ** 2 + 1, width]
|
357 |
+
x = x + self.positional_embedding.to(x.dtype)
|
358 |
+
x = self.ln_pre(x)
|
359 |
+
|
360 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
361 |
+
x = self.text_branch(x)
|
362 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
363 |
+
|
364 |
+
x = self.ln_post(x[:, 0, :])
|
365 |
+
|
366 |
+
if self.proj is not None:
|
367 |
+
x = x @ self.proj
|
368 |
+
|
369 |
+
return x
|
370 |
+
|
371 |
+
|
372 |
+
@dataclass
|
373 |
+
class CLAPVisionCfg:
|
374 |
+
layers: Union[Tuple[int, int, int, int], int] = 12
|
375 |
+
width: int = 768
|
376 |
+
patch_size: int = 16
|
377 |
+
image_size: Union[Tuple[int, int], int] = 224
|
378 |
+
timm_model_name: str = (
|
379 |
+
None # a valid model name overrides layers, width, patch_size
|
380 |
+
)
|
381 |
+
timm_model_pretrained: bool = (
|
382 |
+
False # use (imagenet) pretrained weights for named model
|
383 |
+
)
|
384 |
+
timm_pool: str = (
|
385 |
+
"avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
|
386 |
+
)
|
387 |
+
timm_proj: str = (
|
388 |
+
"linear" # linear projection for timm model output ('linear', 'mlp', '')
|
389 |
+
)
|
390 |
+
|
391 |
+
|
392 |
+
# Audio Config Class
|
393 |
+
@dataclass
|
394 |
+
class CLAPAudioCfp:
|
395 |
+
model_type: str = "PANN"
|
396 |
+
model_name: str = "Cnn14"
|
397 |
+
sample_rate: int = 48000
|
398 |
+
# Param
|
399 |
+
audio_length: int = 1024
|
400 |
+
window_size: int = 1024
|
401 |
+
hop_size: int = 1024
|
402 |
+
fmin: int = 50
|
403 |
+
fmax: int = 14000
|
404 |
+
class_num: int = 527
|
405 |
+
mel_bins: int = 64
|
406 |
+
clip_samples: int = 480000
|
407 |
+
|
408 |
+
|
409 |
+
@dataclass
|
410 |
+
class CLAPTextCfg:
|
411 |
+
context_length: int
|
412 |
+
vocab_size: int
|
413 |
+
width: int
|
414 |
+
heads: int
|
415 |
+
layers: int
|
416 |
+
model_type: str
|
417 |
+
|
418 |
+
|
419 |
+
class CLAP(nn.Module):
|
420 |
+
def __init__(
|
421 |
+
self,
|
422 |
+
embed_dim: int,
|
423 |
+
audio_cfg: CLAPAudioCfp,
|
424 |
+
text_cfg: CLAPTextCfg,
|
425 |
+
quick_gelu: bool = False,
|
426 |
+
enable_fusion: bool = False,
|
427 |
+
fusion_type: str = "None",
|
428 |
+
joint_embed_shape: int = 512,
|
429 |
+
mlp_act: str = "relu",
|
430 |
+
):
|
431 |
+
super().__init__()
|
432 |
+
if isinstance(audio_cfg, dict):
|
433 |
+
audio_cfg = CLAPAudioCfp(**audio_cfg)
|
434 |
+
if isinstance(text_cfg, dict):
|
435 |
+
text_cfg = CLAPTextCfg(**text_cfg)
|
436 |
+
|
437 |
+
self.audio_cfg = audio_cfg
|
438 |
+
self.text_cfg = text_cfg
|
439 |
+
self.enable_fusion = enable_fusion
|
440 |
+
self.fusion_type = fusion_type
|
441 |
+
self.joint_embed_shape = joint_embed_shape
|
442 |
+
self.mlp_act = mlp_act
|
443 |
+
|
444 |
+
self.context_length = text_cfg.context_length
|
445 |
+
|
446 |
+
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
|
447 |
+
# memory efficient in recent PyTorch releases (>= 1.10).
|
448 |
+
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
|
449 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
450 |
+
|
451 |
+
if mlp_act == "relu":
|
452 |
+
mlp_act_layer = nn.ReLU()
|
453 |
+
elif mlp_act == "gelu":
|
454 |
+
mlp_act_layer = nn.GELU()
|
455 |
+
else:
|
456 |
+
raise NotImplementedError
|
457 |
+
|
458 |
+
# audio branch
|
459 |
+
# audio branch parameters
|
460 |
+
if audio_cfg.model_type == "PANN":
|
461 |
+
self.audio_branch = create_pann_model(audio_cfg, enable_fusion, fusion_type)
|
462 |
+
elif audio_cfg.model_type == "HTSAT":
|
463 |
+
self.audio_branch = create_htsat_model(
|
464 |
+
audio_cfg, enable_fusion, fusion_type
|
465 |
+
)
|
466 |
+
else:
|
467 |
+
logging.error(f"Model config for {audio_cfg.model_type} not found")
|
468 |
+
raise RuntimeError(f"Model config for {audio_cfg.model_type} not found.")
|
469 |
+
|
470 |
+
# text branch
|
471 |
+
# text branch parameters
|
472 |
+
if text_cfg.model_type == "transformer":
|
473 |
+
self.text_branch = Transformer(
|
474 |
+
width=text_cfg.width,
|
475 |
+
layers=text_cfg.layers,
|
476 |
+
heads=text_cfg.heads,
|
477 |
+
act_layer=act_layer,
|
478 |
+
)
|
479 |
+
self.vocab_size = text_cfg.vocab_size
|
480 |
+
self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width)
|
481 |
+
self.positional_embedding = nn.Parameter(
|
482 |
+
torch.empty(self.context_length, text_cfg.width)
|
483 |
+
)
|
484 |
+
self.ln_final = LayerNorm(text_cfg.width)
|
485 |
+
self.text_transform = MLPLayers(
|
486 |
+
units=[
|
487 |
+
self.joint_embed_shape,
|
488 |
+
self.joint_embed_shape,
|
489 |
+
self.joint_embed_shape,
|
490 |
+
],
|
491 |
+
dropout=0.1,
|
492 |
+
)
|
493 |
+
self.text_projection = nn.Sequential(
|
494 |
+
nn.Linear(text_cfg.width, self.joint_embed_shape),
|
495 |
+
mlp_act_layer,
|
496 |
+
nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
|
497 |
+
)
|
498 |
+
elif text_cfg.model_type == "bert":
|
499 |
+
self.text_branch = BertModel.from_pretrained("bert-base-uncased")
|
500 |
+
self.text_transform = MLPLayers(
|
501 |
+
units=[
|
502 |
+
self.joint_embed_shape,
|
503 |
+
self.joint_embed_shape,
|
504 |
+
self.joint_embed_shape,
|
505 |
+
],
|
506 |
+
dropout=0.1,
|
507 |
+
)
|
508 |
+
self.text_projection = nn.Sequential(
|
509 |
+
nn.Linear(768, self.joint_embed_shape),
|
510 |
+
mlp_act_layer,
|
511 |
+
nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
|
512 |
+
)
|
513 |
+
elif text_cfg.model_type == "roberta":
|
514 |
+
self.text_branch = RobertaModel(
|
515 |
+
RobertaConfig.from_pretrained("roberta-base")
|
516 |
+
)
|
517 |
+
self.text_transform = MLPLayers(
|
518 |
+
units=[
|
519 |
+
self.joint_embed_shape,
|
520 |
+
self.joint_embed_shape,
|
521 |
+
self.joint_embed_shape,
|
522 |
+
],
|
523 |
+
dropout=0.1,
|
524 |
+
)
|
525 |
+
self.text_projection = nn.Sequential(
|
526 |
+
nn.Linear(768, self.joint_embed_shape),
|
527 |
+
mlp_act_layer,
|
528 |
+
nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
|
529 |
+
)
|
530 |
+
elif text_cfg.model_type == "bart":
|
531 |
+
self.text_branch = BartModel.from_pretrained("facebook/bart-base")
|
532 |
+
self.text_transform = MLPLayers(
|
533 |
+
units=[
|
534 |
+
self.joint_embed_shape,
|
535 |
+
self.joint_embed_shape,
|
536 |
+
self.joint_embed_shape,
|
537 |
+
],
|
538 |
+
dropout=0.1,
|
539 |
+
)
|
540 |
+
self.text_projection = nn.Sequential(
|
541 |
+
nn.Linear(768, self.joint_embed_shape),
|
542 |
+
mlp_act_layer,
|
543 |
+
nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
|
544 |
+
)
|
545 |
+
else:
|
546 |
+
logging.error(f"Model config for {text_cfg.model_type} not found")
|
547 |
+
raise RuntimeError(f"Model config for {text_cfg.model_type} not found.")
|
548 |
+
self.text_branch_type = text_cfg.model_type
|
549 |
+
# text branch parameters
|
550 |
+
|
551 |
+
# audio branch parameters
|
552 |
+
self.audio_transform = MLPLayers(
|
553 |
+
units=[
|
554 |
+
self.joint_embed_shape,
|
555 |
+
self.joint_embed_shape,
|
556 |
+
self.joint_embed_shape,
|
557 |
+
],
|
558 |
+
dropout=0.1,
|
559 |
+
)
|
560 |
+
|
561 |
+
# below here is text branch parameters
|
562 |
+
|
563 |
+
# ============================================================================================================
|
564 |
+
self.audio_projection = nn.Sequential(
|
565 |
+
nn.Linear(embed_dim, self.joint_embed_shape),
|
566 |
+
mlp_act_layer,
|
567 |
+
nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
|
568 |
+
)
|
569 |
+
|
570 |
+
self.logit_scale_a = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
571 |
+
self.logit_scale_t = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
572 |
+
self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)
|
573 |
+
|
574 |
+
self.init_text_branch_parameters()
|
575 |
+
|
576 |
+
def init_text_branch_parameters(self):
|
577 |
+
if self.text_branch_type == "transformer":
|
578 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
579 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
580 |
+
proj_std = (self.text_branch.width**-0.5) * (
|
581 |
+
(2 * self.text_branch.layers) ** -0.5
|
582 |
+
)
|
583 |
+
attn_std = self.text_branch.width**-0.5
|
584 |
+
fc_std = (2 * self.text_branch.width) ** -0.5
|
585 |
+
for block in self.text_branch.resblocks:
|
586 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
587 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
588 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
589 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
590 |
+
if self.text_branch_type == "bert" or self.text_branch_type == "roberta":
|
591 |
+
self.text_branch.embeddings.word_embeddings.weight.shape[-1]
|
592 |
+
elif self.text_branch_type == "bart":
|
593 |
+
self.text_branch.shared.weight.shape[-1]
|
594 |
+
else:
|
595 |
+
self.text_branch.width
|
596 |
+
nn.init.constant_(self.logit_scale_a, np.log(1 / 0.07))
|
597 |
+
nn.init.constant_(self.logit_scale_t, np.log(1 / 0.07))
|
598 |
+
|
599 |
+
# deprecated
|
600 |
+
# if hasattr(self.visual, 'init_parameters'):
|
601 |
+
# self.visual.init_parameters()
|
602 |
+
|
603 |
+
# if self.text_projection is not None:
|
604 |
+
# nn.init.normal_(self.text_projection, std=width**-0.5)
|
605 |
+
|
606 |
+
def build_attention_mask(self):
|
607 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
608 |
+
# pytorch uses additive attention mask; fill with -inf
|
609 |
+
mask = torch.empty(self.context_length, self.context_length)
|
610 |
+
mask.fill_(float("-inf"))
|
611 |
+
mask.triu_(1) # zero out the lower diagonal
|
612 |
+
return mask
|
613 |
+
|
614 |
+
def encode_audio(self, audio, device):
|
615 |
+
return self.audio_branch(
|
616 |
+
audio, mixup_lambda=None, device=device
|
617 |
+
) # mix lambda needs to add
|
618 |
+
|
619 |
+
# def list_of_dict_of_tensor2dict_of_tensor(self, x, device):
|
620 |
+
# tmp = {}
|
621 |
+
# for k in x[0].keys():
|
622 |
+
# tmp[k] = []
|
623 |
+
# for i in range(len(x)):
|
624 |
+
# tmp[k].append(x[i][k][:77])
|
625 |
+
# for k in x[0].keys():
|
626 |
+
# tmp[k] = torch.tensor(tmp[k]).to(device=device, non_blocking=True)
|
627 |
+
# return tmp
|
628 |
+
|
629 |
+
def encode_text(self, text, device):
|
630 |
+
if self.text_branch_type == "transformer":
|
631 |
+
text = text.to(device=device, non_blocking=True)
|
632 |
+
x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
|
633 |
+
|
634 |
+
x = x + self.positional_embedding
|
635 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
636 |
+
x = self.text_branch(x, attn_mask=self.attn_mask)
|
637 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
638 |
+
x = self.ln_final(x)
|
639 |
+
|
640 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
641 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
642 |
+
x = self.text_projection(x[torch.arange(x.shape[0]), text.argmax(dim=-1)])
|
643 |
+
elif self.text_branch_type == "bert":
|
644 |
+
# text = self.list_of_dict_of_tensor2dict_of_tensor(text, device)
|
645 |
+
# text = BatchEncoding(text)
|
646 |
+
x = self.text_branch(
|
647 |
+
input_ids=text["input_ids"].to(device=device, non_blocking=True),
|
648 |
+
attention_mask=text["attention_mask"].to(
|
649 |
+
device=device, non_blocking=True
|
650 |
+
),
|
651 |
+
token_type_ids=text["token_type_ids"].to(
|
652 |
+
device=device, non_blocking=True
|
653 |
+
),
|
654 |
+
)["pooler_output"]
|
655 |
+
x = self.text_projection(x)
|
656 |
+
elif self.text_branch_type == "roberta":
|
657 |
+
x = self.text_branch(
|
658 |
+
input_ids=text["input_ids"].to(device=device, non_blocking=True),
|
659 |
+
attention_mask=text["attention_mask"].to(
|
660 |
+
device=device, non_blocking=True
|
661 |
+
),
|
662 |
+
)["pooler_output"]
|
663 |
+
x = self.text_projection(x)
|
664 |
+
elif self.text_branch_type == "bart":
|
665 |
+
x = torch.mean(
|
666 |
+
self.text_branch(
|
667 |
+
input_ids=text["input_ids"].to(device=device, non_blocking=True),
|
668 |
+
attention_mask=text["attention_mask"].to(
|
669 |
+
device=device, non_blocking=True
|
670 |
+
),
|
671 |
+
)["encoder_last_hidden_state"],
|
672 |
+
axis=1,
|
673 |
+
)
|
674 |
+
x = self.text_projection(x)
|
675 |
+
else:
|
676 |
+
logging.error(f"Model type {self.text_branch_type} not found")
|
677 |
+
raise RuntimeError(f"Model type {self.text_branch_type} not found.")
|
678 |
+
return x
|
679 |
+
|
680 |
+
def forward(self, audio, text, device=None):
|
681 |
+
"""Forward audio and text into the CLAP
|
682 |
+
|
683 |
+
Parameters
|
684 |
+
----------
|
685 |
+
audio: torch.Tensor (batch_size, audio_length)
|
686 |
+
the time-domain audio input / the batch of mel_spec and longer list.
|
687 |
+
text: torch.Tensor () // need to add
|
688 |
+
the text token input
|
689 |
+
"""
|
690 |
+
if device is None:
|
691 |
+
if audio is not None:
|
692 |
+
device = audio.device
|
693 |
+
elif text is not None:
|
694 |
+
device = text.device
|
695 |
+
if audio is None and text is None:
|
696 |
+
# a hack to get the logit scale
|
697 |
+
return self.logit_scale_a.exp(), self.logit_scale_t.exp()
|
698 |
+
elif audio is None:
|
699 |
+
return self.encode_text(text, device=device)
|
700 |
+
elif text is None:
|
701 |
+
return self.audio_projection(
|
702 |
+
self.encode_audio(audio, device=device)["embedding"]
|
703 |
+
)
|
704 |
+
audio_features = self.audio_projection(
|
705 |
+
self.encode_audio(audio, device=device)["embedding"]
|
706 |
+
)
|
707 |
+
audio_features = F.normalize(audio_features, dim=-1)
|
708 |
+
|
709 |
+
text_features = self.encode_text(text, device=device)
|
710 |
+
# print("text_features", text_features)
|
711 |
+
# print("text_features.shape", text_features.shape)
|
712 |
+
# print("text_features.type", type(text_features))
|
713 |
+
text_features = F.normalize(text_features, dim=-1)
|
714 |
+
|
715 |
+
audio_features_mlp = self.audio_transform(audio_features)
|
716 |
+
text_features_mlp = self.text_transform(text_features)
|
717 |
+
# Four outputs: audio features (basic & MLP), text features (basic & MLP)
|
718 |
+
return (
|
719 |
+
audio_features,
|
720 |
+
text_features,
|
721 |
+
audio_features_mlp,
|
722 |
+
text_features_mlp,
|
723 |
+
self.logit_scale_a.exp(),
|
724 |
+
self.logit_scale_t.exp(),
|
725 |
+
)
|
726 |
+
|
727 |
+
def get_logit_scale(self):
|
728 |
+
return self.logit_scale_a.exp(), self.logit_scale_t.exp()
|
729 |
+
|
730 |
+
def get_text_embedding(self, data):
|
731 |
+
"""Get the text embedding from the model
|
732 |
+
|
733 |
+
Parameters
|
734 |
+
----------
|
735 |
+
data: torch.Tensor
|
736 |
+
a tensor of text embedding
|
737 |
+
|
738 |
+
Returns
|
739 |
+
----------
|
740 |
+
text_embed: torch.Tensor
|
741 |
+
a tensor of text_embeds (N, D)
|
742 |
+
|
743 |
+
"""
|
744 |
+
device = next(self.parameters()).device
|
745 |
+
for k in data:
|
746 |
+
data[k] = data[k].to(device)
|
747 |
+
text_embeds = self.encode_text(data, device=device)
|
748 |
+
text_embeds = F.normalize(text_embeds, dim=-1)
|
749 |
+
|
750 |
+
return text_embeds
|
751 |
+
|
752 |
+
def get_audio_embedding(self, data):
|
753 |
+
"""Get the audio embedding from the model
|
754 |
+
|
755 |
+
Parameters
|
756 |
+
----------
|
757 |
+
data: a list of dict
|
758 |
+
the audio input dict list from 'get_audio_feature' method
|
759 |
+
|
760 |
+
Returns
|
761 |
+
----------
|
762 |
+
audio_embed: torch.Tensor
|
763 |
+
a tensor of audio_embeds (N, D)
|
764 |
+
|
765 |
+
"""
|
766 |
+
device = next(self.parameters()).device
|
767 |
+
# input_dict = {}
|
768 |
+
# keys = data[0].keys()
|
769 |
+
# for k in keys:
|
770 |
+
# input_dict[k] = torch.cat([d[k].unsqueeze(0) for d in data], dim=0).to(
|
771 |
+
# device
|
772 |
+
# )
|
773 |
+
audio_embeds = self.audio_projection(
|
774 |
+
self.encode_audio(data, device=device)["embedding"]
|
775 |
+
)
|
776 |
+
audio_embeds = F.normalize(audio_embeds, dim=-1)
|
777 |
+
|
778 |
+
return audio_embeds
|
779 |
+
|
780 |
+
def audio_infer(self, audio, hopsize=None, device=None):
|
781 |
+
"""Forward one audio and produce the audio embedding
|
782 |
+
|
783 |
+
Parameters
|
784 |
+
----------
|
785 |
+
audio: (audio_length)
|
786 |
+
the time-domain audio input, notice that it must be only one input
|
787 |
+
hopsize: int
|
788 |
+
the overlap hopsize as the sliding window
|
789 |
+
|
790 |
+
Returns
|
791 |
+
----------
|
792 |
+
output_dict: {
|
793 |
+
key: [n, (embedding_shape)] if "HTS-AT"
|
794 |
+
or
|
795 |
+
key: [(embedding_shape)] if "PANN"
|
796 |
+
}
|
797 |
+
the list of key values of the audio branch
|
798 |
+
|
799 |
+
"""
|
800 |
+
|
801 |
+
assert not self.training, "the inference mode must be run at eval stage"
|
802 |
+
output_dict = {}
|
803 |
+
# PANN
|
804 |
+
if self.audio_cfg.model_type == "PANN":
|
805 |
+
audio_input = audio.unsqueeze(dim=0)
|
806 |
+
output_dict[key] = self.encode_audio(audio_input, device=device)[
|
807 |
+
key
|
808 |
+
].squeeze(dim=0)
|
809 |
+
elif self.audio_cfg.model_type == "HTSAT":
|
810 |
+
# repeat
|
811 |
+
audio_len = len(audio)
|
812 |
+
k = self.audio_cfg.clip_samples // audio_len
|
813 |
+
if k > 1:
|
814 |
+
audio = audio.repeat(k)
|
815 |
+
audio_len = len(audio)
|
816 |
+
|
817 |
+
if hopsize is None:
|
818 |
+
hopsize = min(hopsize, audio_len)
|
819 |
+
|
820 |
+
if audio_len > self.audio_cfg.clip_samples:
|
821 |
+
audio_input = [
|
822 |
+
audio[pos : pos + self.audio_cfg.clip_samples].clone()
|
823 |
+
for pos in range(
|
824 |
+
0, audio_len - self.audio_cfg.clip_samples, hopsize
|
825 |
+
)
|
826 |
+
]
|
827 |
+
audio_input.append(audio[-self.audio_cfg.clip_samples :].clone())
|
828 |
+
audio_input = torch.stack(audio_input)
|
829 |
+
output_dict[key] = self.encode_audio(audio_input, device=device)[key]
|
830 |
+
else:
|
831 |
+
audio_input = audio.unsqueeze(dim=0)
|
832 |
+
output_dict[key] = self.encode_audio(audio_input, device=device)[
|
833 |
+
key
|
834 |
+
].squeeze(dim=0)
|
835 |
+
|
836 |
+
return output_dict
|
837 |
+
|
838 |
+
|
839 |
+
def convert_weights_to_fp16(model: nn.Module):
|
840 |
+
"""Convert applicable model parameters to fp16"""
|
841 |
+
|
842 |
+
def _convert_weights_to_fp16(l):
|
843 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
844 |
+
l.weight.data = l.weight.data.half()
|
845 |
+
if l.bias is not None:
|
846 |
+
l.bias.data = l.bias.data.half()
|
847 |
+
|
848 |
+
if isinstance(l, nn.MultiheadAttention):
|
849 |
+
for attr in [
|
850 |
+
*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
|
851 |
+
"in_proj_bias",
|
852 |
+
"bias_k",
|
853 |
+
"bias_v",
|
854 |
+
]:
|
855 |
+
tensor = getattr(l, attr)
|
856 |
+
if tensor is not None:
|
857 |
+
tensor.data = tensor.data.half()
|
858 |
+
|
859 |
+
for name in ["text_projection", "proj"]:
|
860 |
+
if hasattr(l, name):
|
861 |
+
attr = getattr(l, name)
|
862 |
+
if attr is not None:
|
863 |
+
attr.data = attr.data.half()
|
864 |
+
|
865 |
+
model.apply(_convert_weights_to_fp16)
|
866 |
+
|
867 |
+
|
868 |
+
# Ignore the state dict of the vision part
|
869 |
+
def build_model_from_openai_state_dict(
|
870 |
+
state_dict: dict, model_cfg, enable_fusion: bool = False, fusion_type: str = "None"
|
871 |
+
):
|
872 |
+
embed_dim = model_cfg["embed_dim"]
|
873 |
+
audio_cfg = model_cfg["audio_cfg"]
|
874 |
+
text_cfg = model_cfg["text_cfg"]
|
875 |
+
state_dict["positional_embedding"].shape[0]
|
876 |
+
state_dict["token_embedding.weight"].shape[0]
|
877 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
878 |
+
transformer_width // 64
|
879 |
+
transformer_layers = len(
|
880 |
+
set(
|
881 |
+
k.split(".")[2]
|
882 |
+
for k in state_dict
|
883 |
+
if k.startswith(f"transformer.resblocks")
|
884 |
+
)
|
885 |
+
)
|
886 |
+
|
887 |
+
audio_cfg = CLAPAudioCfp(**audio_cfg)
|
888 |
+
text_cfg = CLAPTextCfg(**text_cfg)
|
889 |
+
|
890 |
+
model = CLAP(
|
891 |
+
embed_dim,
|
892 |
+
audio_cfg=audio_cfg,
|
893 |
+
text_cfg=text_cfg,
|
894 |
+
quick_gelu=True, # OpenAI models were trained with QuickGELU
|
895 |
+
enable_fusion=enable_fusion,
|
896 |
+
fusion_type=fusion_type,
|
897 |
+
)
|
898 |
+
state_dict["logit_scale_a"] = state_dict["logit_scale"]
|
899 |
+
state_dict["logit_scale_t"] = state_dict["logit_scale"]
|
900 |
+
pop_keys = list(state_dict.keys())[::]
|
901 |
+
# pop the visual branch saved weights
|
902 |
+
for key in pop_keys:
|
903 |
+
if key.startswith("visual."):
|
904 |
+
state_dict.pop(key, None)
|
905 |
+
|
906 |
+
for key in ["logit_scale", "input_resolution", "context_length", "vocab_size"]:
|
907 |
+
state_dict.pop(key, None)
|
908 |
+
|
909 |
+
# not use fp16
|
910 |
+
# convert_weights_to_fp16(model)
|
911 |
+
model.load_state_dict(state_dict, strict=False)
|
912 |
+
return model.eval()
|
913 |
+
|
914 |
+
|
915 |
+
def trace_model(model, batch_size=256, device=torch.device("cpu")):
|
916 |
+
model.eval()
|
917 |
+
audio_length = model.audio_cfg.audio_length
|
918 |
+
example_audio = torch.ones((batch_size, audio_length), device=device)
|
919 |
+
example_text = torch.zeros(
|
920 |
+
(batch_size, model.context_length), dtype=torch.int, device=device
|
921 |
+
)
|
922 |
+
model = torch.jit.trace_module(
|
923 |
+
model,
|
924 |
+
inputs=dict(
|
925 |
+
forward=(example_audio, example_text),
|
926 |
+
encode_text=(example_text,),
|
927 |
+
encode_image=(example_audio,),
|
928 |
+
),
|
929 |
+
)
|
930 |
+
model.audio_cfg.audio_length = audio_length # Question: what does this do?
|
931 |
+
return model
|
audiosr/clap/open_clip/model_configs/HTSAT-base.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"audio_cfg": {
|
4 |
+
"audio_length": 1024,
|
5 |
+
"clip_samples": 480000,
|
6 |
+
"mel_bins": 64,
|
7 |
+
"sample_rate": 48000,
|
8 |
+
"window_size": 1024,
|
9 |
+
"hop_size": 480,
|
10 |
+
"fmin": 50,
|
11 |
+
"fmax": 14000,
|
12 |
+
"class_num": 527,
|
13 |
+
"model_type": "HTSAT",
|
14 |
+
"model_name": "base"
|
15 |
+
},
|
16 |
+
"text_cfg": {
|
17 |
+
"context_length": 77,
|
18 |
+
"vocab_size": 49408,
|
19 |
+
"width": 512,
|
20 |
+
"heads": 8,
|
21 |
+
"layers": 12
|
22 |
+
}
|
23 |
+
}
|
audiosr/clap/open_clip/model_configs/HTSAT-large.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 2048,
|
3 |
+
"audio_cfg": {
|
4 |
+
"audio_length": 1024,
|
5 |
+
"clip_samples": 480000,
|
6 |
+
"mel_bins": 64,
|
7 |
+
"sample_rate": 48000,
|
8 |
+
"window_size": 1024,
|
9 |
+
"hop_size": 480,
|
10 |
+
"fmin": 50,
|
11 |
+
"fmax": 14000,
|
12 |
+
"class_num": 527,
|
13 |
+
"model_type": "HTSAT",
|
14 |
+
"model_name": "large"
|
15 |
+
},
|
16 |
+
"text_cfg": {
|
17 |
+
"context_length": 77,
|
18 |
+
"vocab_size": 49408,
|
19 |
+
"width": 512,
|
20 |
+
"heads": 8,
|
21 |
+
"layers": 12
|
22 |
+
}
|
23 |
+
}
|
audiosr/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"audio_cfg": {
|
4 |
+
"audio_length": 1024,
|
5 |
+
"clip_samples": 480000,
|
6 |
+
"mel_bins": 64,
|
7 |
+
"sample_rate": 48000,
|
8 |
+
"window_size": 1536,
|
9 |
+
"hop_size": 480,
|
10 |
+
"fmin": 50,
|
11 |
+
"fmax": 14000,
|
12 |
+
"class_num": 527,
|
13 |
+
"model_type": "HTSAT",
|
14 |
+
"model_name": "tiny"
|
15 |
+
},
|
16 |
+
"text_cfg": {
|
17 |
+
"context_length": 77,
|
18 |
+
"vocab_size": 49408,
|
19 |
+
"width": 512,
|
20 |
+
"heads": 8,
|
21 |
+
"layers": 12
|
22 |
+
}
|
23 |
+
}
|
audiosr/clap/open_clip/model_configs/HTSAT-tiny.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"audio_cfg": {
|
4 |
+
"audio_length": 1024,
|
5 |
+
"clip_samples": 480000,
|
6 |
+
"mel_bins": 64,
|
7 |
+
"sample_rate": 48000,
|
8 |
+
"window_size": 1024,
|
9 |
+
"hop_size": 480,
|
10 |
+
"fmin": 50,
|
11 |
+
"fmax": 14000,
|
12 |
+
"class_num": 527,
|
13 |
+
"model_type": "HTSAT",
|
14 |
+
"model_name": "tiny"
|
15 |
+
},
|
16 |
+
"text_cfg": {
|
17 |
+
"context_length": 77,
|
18 |
+
"vocab_size": 49408,
|
19 |
+
"width": 512,
|
20 |
+
"heads": 8,
|
21 |
+
"layers": 12
|
22 |
+
}
|
23 |
+
}
|
audiosr/clap/open_clip/model_configs/PANN-10.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"audio_cfg": {
|
4 |
+
"audio_length": 1024,
|
5 |
+
"clip_samples": 480000,
|
6 |
+
"mel_bins": 64,
|
7 |
+
"sample_rate": 48000,
|
8 |
+
"window_size": 1024,
|
9 |
+
"hop_size": 480,
|
10 |
+
"fmin": 50,
|
11 |
+
"fmax": 14000,
|
12 |
+
"class_num": 527,
|
13 |
+
"model_type": "PANN",
|
14 |
+
"model_name": "Cnn10"
|
15 |
+
},
|
16 |
+
"text_cfg": {
|
17 |
+
"context_length": 77,
|
18 |
+
"vocab_size": 49408,
|
19 |
+
"width": 512,
|
20 |
+
"heads": 8,
|
21 |
+
"layers": 12
|
22 |
+
}
|
23 |
+
}
|
audiosr/clap/open_clip/model_configs/PANN-14-fmax-18k.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 2048,
|
3 |
+
"audio_cfg": {
|
4 |
+
"audio_length": 1024,
|
5 |
+
"clip_samples": 480000,
|
6 |
+
"mel_bins": 64,
|
7 |
+
"sample_rate": 48000,
|
8 |
+
"window_size": 1024,
|
9 |
+
"hop_size": 480,
|
10 |
+
"fmin": 50,
|
11 |
+
"fmax": 18000,
|
12 |
+
"class_num": 527,
|
13 |
+
"model_type": "PANN",
|
14 |
+
"model_name": "Cnn14"
|
15 |
+
},
|
16 |
+
"text_cfg": {
|
17 |
+
"context_length": 77,
|
18 |
+
"vocab_size": 49408,
|
19 |
+
"width": 512,
|
20 |
+
"heads": 8,
|
21 |
+
"layers": 12
|
22 |
+
}
|
23 |
+
}
|
audiosr/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 2048,
|
3 |
+
"audio_cfg": {
|
4 |
+
"audio_length": 1024,
|
5 |
+
"clip_samples": 960000,
|
6 |
+
"mel_bins": 64,
|
7 |
+
"sample_rate": 48000,
|
8 |
+
"window_size": 1024,
|
9 |
+
"hop_size": 360,
|
10 |
+
"fmin": 50,
|
11 |
+
"fmax": 8000,
|
12 |
+
"class_num": 527,
|
13 |
+
"model_type": "PANN",
|
14 |
+
"model_name": "Cnn14"
|
15 |
+
},
|
16 |
+
"text_cfg": {
|
17 |
+
"context_length": 77,
|
18 |
+
"vocab_size": 49408,
|
19 |
+
"width": 512,
|
20 |
+
"heads": 8,
|
21 |
+
"layers": 12
|
22 |
+
}
|
23 |
+
}
|
audiosr/clap/open_clip/model_configs/PANN-14-tiny-transformer.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 2048,
|
3 |
+
"audio_cfg": {
|
4 |
+
"audio_length": 1024,
|
5 |
+
"clip_samples": 480000,
|
6 |
+
"mel_bins": 64,
|
7 |
+
"sample_rate": 48000,
|
8 |
+
"window_size": 1024,
|
9 |
+
"hop_size": 480,
|
10 |
+
"fmin": 50,
|
11 |
+
"fmax": 14000,
|
12 |
+
"class_num": 527,
|
13 |
+
"model_type": "PANN",
|
14 |
+
"model_name": "Cnn14"
|
15 |
+
},
|
16 |
+
"text_cfg": {
|
17 |
+
"context_length": 77,
|
18 |
+
"vocab_size": 49408,
|
19 |
+
"width": 512,
|
20 |
+
"heads": 8,
|
21 |
+
"layers": 4
|
22 |
+
}
|
23 |
+
}
|
audiosr/clap/open_clip/model_configs/PANN-14-win-1536.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 2048,
|
3 |
+
"audio_cfg": {
|
4 |
+
"audio_length": 1024,
|
5 |
+
"clip_samples": 480000,
|
6 |
+
"mel_bins": 64,
|
7 |
+
"sample_rate": 48000,
|
8 |
+
"window_size": 1536,
|
9 |
+
"hop_size": 480,
|
10 |
+
"fmin": 50,
|
11 |
+
"fmax": 14000,
|
12 |
+
"class_num": 527,
|
13 |
+
"model_type": "PANN",
|
14 |
+
"model_name": "Cnn14"
|
15 |
+
},
|
16 |
+
"text_cfg": {
|
17 |
+
"context_length": 77,
|
18 |
+
"vocab_size": 49408,
|
19 |
+
"width": 512,
|
20 |
+
"heads": 8,
|
21 |
+
"layers": 12
|
22 |
+
}
|
23 |
+
}
|
audiosr/clap/open_clip/model_configs/PANN-14.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 2048,
|
3 |
+
"audio_cfg": {
|
4 |
+
"audio_length": 1024,
|
5 |
+
"clip_samples": 480000,
|
6 |
+
"mel_bins": 64,
|
7 |
+
"sample_rate": 48000,
|
8 |
+
"window_size": 1024,
|
9 |
+
"hop_size": 480,
|
10 |
+
"fmin": 50,
|
11 |
+
"fmax": 14000,
|
12 |
+
"class_num": 527,
|
13 |
+
"model_type": "PANN",
|
14 |
+
"model_name": "Cnn14"
|
15 |
+
},
|
16 |
+
"text_cfg": {
|
17 |
+
"context_length": 77,
|
18 |
+
"vocab_size": 49408,
|
19 |
+
"width": 512,
|
20 |
+
"heads": 8,
|
21 |
+
"layers": 12
|
22 |
+
}
|
23 |
+
}
|
audiosr/clap/open_clip/model_configs/PANN-6.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"audio_cfg": {
|
4 |
+
"audio_length": 1024,
|
5 |
+
"clip_samples": 480000,
|
6 |
+
"mel_bins": 64,
|
7 |
+
"sample_rate": 48000,
|
8 |
+
"window_size": 1024,
|
9 |
+
"hop_size": 480,
|
10 |
+
"fmin": 50,
|
11 |
+
"fmax": 14000,
|
12 |
+
"class_num": 527,
|
13 |
+
"model_type": "PANN",
|
14 |
+
"model_name": "Cnn6"
|
15 |
+
},
|
16 |
+
"text_cfg": {
|
17 |
+
"context_length": 77,
|
18 |
+
"vocab_size": 49408,
|
19 |
+
"width": 512,
|
20 |
+
"heads": 8,
|
21 |
+
"layers": 12
|
22 |
+
}
|
23 |
+
}
|
audiosr/clap/open_clip/model_configs/RN101-quickgelu.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"quick_gelu": true,
|
4 |
+
"vision_cfg": {
|
5 |
+
"image_size": 224,
|
6 |
+
"layers": [
|
7 |
+
3,
|
8 |
+
4,
|
9 |
+
23,
|
10 |
+
3
|
11 |
+
],
|
12 |
+
"width": 64,
|
13 |
+
"patch_size": null
|
14 |
+
},
|
15 |
+
"text_cfg": {
|
16 |
+
"context_length": 77,
|
17 |
+
"vocab_size": 49408,
|
18 |
+
"width": 512,
|
19 |
+
"heads": 8,
|
20 |
+
"layers": 12
|
21 |
+
}
|
22 |
+
}
|
audiosr/clap/open_clip/model_configs/RN101.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": [
|
6 |
+
3,
|
7 |
+
4,
|
8 |
+
23,
|
9 |
+
3
|
10 |
+
],
|
11 |
+
"width": 64,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 512,
|
18 |
+
"heads": 8,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
audiosr/clap/open_clip/model_configs/RN50-quickgelu.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"quick_gelu": true,
|
4 |
+
"vision_cfg": {
|
5 |
+
"image_size": 224,
|
6 |
+
"layers": [
|
7 |
+
3,
|
8 |
+
4,
|
9 |
+
6,
|
10 |
+
3
|
11 |
+
],
|
12 |
+
"width": 64,
|
13 |
+
"patch_size": null
|
14 |
+
},
|
15 |
+
"text_cfg": {
|
16 |
+
"context_length": 77,
|
17 |
+
"vocab_size": 49408,
|
18 |
+
"width": 512,
|
19 |
+
"heads": 8,
|
20 |
+
"layers": 12
|
21 |
+
}
|
22 |
+
}
|
audiosr/clap/open_clip/model_configs/RN50.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": [
|
6 |
+
3,
|
7 |
+
4,
|
8 |
+
6,
|
9 |
+
3
|
10 |
+
],
|
11 |
+
"width": 64,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 512,
|
18 |
+
"heads": 8,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
audiosr/clap/open_clip/model_configs/RN50x16.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 384,
|
5 |
+
"layers": [
|
6 |
+
6,
|
7 |
+
8,
|
8 |
+
18,
|
9 |
+
8
|
10 |
+
],
|
11 |
+
"width": 96,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 768,
|
18 |
+
"heads": 12,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
audiosr/clap/open_clip/model_configs/RN50x4.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 640,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 288,
|
5 |
+
"layers": [
|
6 |
+
4,
|
7 |
+
6,
|
8 |
+
10,
|
9 |
+
6
|
10 |
+
],
|
11 |
+
"width": 80,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 640,
|
18 |
+
"heads": 10,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
audiosr/clap/open_clip/model_configs/ViT-B-16.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 768,
|
7 |
+
"patch_size": 16
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 512,
|
13 |
+
"heads": 8,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
audiosr/clap/open_clip/model_configs/ViT-B-32-quickgelu.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"quick_gelu": true,
|
4 |
+
"vision_cfg": {
|
5 |
+
"image_size": 224,
|
6 |
+
"layers": 12,
|
7 |
+
"width": 768,
|
8 |
+
"patch_size": 32
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 512,
|
14 |
+
"heads": 8,
|
15 |
+
"layers": 12
|
16 |
+
}
|
17 |
+
}
|
audiosr/clap/open_clip/model_configs/ViT-B-32.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 768,
|
7 |
+
"patch_size": 32
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 512,
|
13 |
+
"heads": 8,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
audiosr/clap/open_clip/model_configs/ViT-L-14.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 24,
|
6 |
+
"width": 1024,
|
7 |
+
"patch_size": 14
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 768,
|
13 |
+
"heads": 12,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|