edge-flux-sayan-4 / src /ghanta.py
loveisgone's picture
Upload folder using huggingface_hub
be08337 verified
import torch
from typing import Tuple, Callable
def hacer_nada(x: torch.Tensor, modo: str = None):
return x
def brujeria_mps(entrada, dim, indice):
if entrada.shape[-1] == 1:
return torch.gather(entrada.unsqueeze(-1), dim - 1 if dim < 0 else dim, indice.unsqueeze(-1)).squeeze(-1)
else:
return torch.gather(entrada, dim, indice)
def emparejamiento_suave_aleatorio_2d(
metrica: torch.Tensor,
ancho: int,
alto: int,
paso_x: int,
paso_y: int,
radio: int,
sin_aleatoriedad: bool = False,
generador: torch.Generator = None
) -> Tuple[Callable, Callable]:
lote, num_nodos, _ = metrica.shape
if radio <= 0:
return hacer_nada, hacer_nada
recopilar = brujeria_mps if metrica.device.type == "mps" else torch.gather
with torch.no_grad():
alto_paso_y, ancho_paso_x = alto // paso_y, ancho // paso_x
if sin_aleatoriedad:
indice_aleatorio = torch.zeros(alto_paso_y, ancho_paso_x, 1, device=metrica.device, dtype=torch.int64)
else:
indice_aleatorio = torch.randint(paso_y * paso_x, size=(alto_paso_y, ancho_paso_x, 1), device=generador.device, generator=generador).to(metrica.device)
vista_buffer_indice = torch.zeros(alto_paso_y, ancho_paso_x, paso_y * paso_x, device=metrica.device, dtype=torch.int64)
vista_buffer_indice.scatter_(dim=2, index=indice_aleatorio, src=-torch.ones_like(indice_aleatorio, dtype=indice_aleatorio.dtype))
vista_buffer_indice = vista_buffer_indice.view(alto_paso_y, ancho_paso_x, paso_y, paso_x).transpose(1, 2).reshape(alto_paso_y * paso_y, ancho_paso_x * paso_x)
if (alto_paso_y * paso_y) < alto or (ancho_paso_x * paso_x) < ancho:
buffer_indice = torch.zeros(alto, ancho, device=metrica.device, dtype=torch.int64)
buffer_indice[:(alto_paso_y * paso_y), :(ancho_paso_x * paso_x)] = vista_buffer_indice
else:
buffer_indice = vista_buffer_indice
indice_aleatorio = buffer_indice.reshape(1, -1, 1).argsort(dim=1)
del buffer_indice, vista_buffer_indice
num_destino = alto_paso_y * ancho_paso_x
indices_a = indice_aleatorio[:, num_destino:, :]
indices_b = indice_aleatorio[:, :num_destino, :]
def dividir(x):
canales = x.shape[-1]
origen = recopilar(x, dim=1, index=indices_a.expand(lote, num_nodos - num_destino, canales))
destino = recopilar(x, dim=1, index=indices_b.expand(lote, num_destino, canales))
return origen, destino
metrica = metrica / metrica.norm(dim=-1, keepdim=True)
a, b = dividir(metrica)
puntuaciones = a @ b.transpose(-1, -2)
radio = min(a.shape[1], radio)
nodo_max, nodo_indice = puntuaciones.max(dim=-1)
indice_borde = nodo_max.argsort(dim=-1, descending=True)[..., None]
indice_no_emparejado = indice_borde[..., radio:, :]
indice_origen = indice_borde[..., :radio, :]
indice_destino = recopilar(nodo_indice[..., None], dim=-2, index=indice_origen)
def fusionar(x: torch.Tensor, modo="mean") -> torch.Tensor:
origen, destino = dividir(x)
n, t1, c = origen.shape
no_emparejado = recopilar(origen, dim=-2, index=indice_no_emparejado.expand(n, t1 - radio, c))
origen = recopilar(origen, dim=-2, index=indice_origen.expand(n, radio, c))
destino = destino.scatter_reduce(-2, indice_destino.expand(n, radio, c), origen, reduce=modo)
return torch.cat([no_emparejado, destino], dim=1)
def desfusionar(x: torch.Tensor) -> torch.Tensor:
longitud_no_emparejado = indice_no_emparejado.shape[1]
no_emparejado, destino = x[..., :longitud_no_emparejado, :], x[..., longitud_no_emparejado:, :]
_, _, c = no_emparejado.shape
origen = recopilar(destino, dim=-2, index=indice_destino.expand(lote, radio, c))
salida = torch.zeros(lote, num_nodos, c, device=x.device, dtype=x.dtype)
salida.scatter_(dim=-2, index=indices_b.expand(lote, num_destino, c), src=destino)
salida.scatter_(dim=-2, index=recopilar(indices_a.expand(lote, indices_a.shape[1], 1), dim=1, index=indice_no_emparejado).expand(lote, longitud_no_emparejado, c), src=no_emparejado)
salida.scatter_(dim=-2, index=recopilar(indices_a.expand(lote, indices_a.shape[1], 1), dim=1, index=indice_origen).expand(lote, radio, c), src=origen)
return salida
return fusionar, desfusionar