Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import torch | |
import torch.nn as nn | |
class HHGCodecEmbedding(nn.Module): | |
def __init__(self, out_channels, codebook_path:str, freeze=True): | |
super().__init__() | |
# (2, 128, 128) | |
codebook = torch.from_numpy(np.load(codebook_path).copy()) | |
assert codebook.shape[0] == 2 and codebook.shape[1] == 128 | |
self.codebook_dim = codebook.shape[2] | |
self.codebook = torch.nn.ModuleList([ | |
torch.nn.Embedding.from_pretrained(codebook[i], freeze=freeze) | |
for i in range(codebook.shape[0])] | |
) | |
if self.codebook_dim * 2 != out_channels: | |
self.proj = nn.Linear(self.codebook_dim * 2, out_channels) | |
else: | |
self.proj = nn.Identity() | |
def forward(self, tokens): | |
token_embs = torch.cat([ | |
self.codebook[0](tokens % 128), | |
self.codebook[1](tokens // 128) | |
], dim=-1) | |
token_embs = self.proj(token_embs) | |
return token_embs | |