|
import torch |
|
from typing import Tuple, Callable |
|
def hacer_nada(x: torch.Tensor, modo: str = None): |
|
return x |
|
def brujeria_mps(entrada, dim, indice): |
|
if entrada.shape[-1] == 1: |
|
return torch.gather(entrada.unsqueeze(-1), dim - 1 if dim < 0 else dim, indice.unsqueeze(-1)).squeeze(-1) |
|
else: |
|
return torch.gather(entrada, dim, indice) |
|
def emparejamiento_suave_aleatorio_2d( |
|
metrica: torch.Tensor, |
|
ancho: int, |
|
alto: int, |
|
paso_x: int, |
|
paso_y: int, |
|
radio: int, |
|
sin_aleatoriedad: bool = False, |
|
generador: torch.Generator = None |
|
) -> Tuple[Callable, Callable]: |
|
lote, num_nodos, _ = metrica.shape |
|
if radio <= 0: |
|
return hacer_nada, hacer_nada |
|
recopilar = brujeria_mps if metrica.device.type == "mps" else torch.gather |
|
with torch.no_grad(): |
|
alto_paso_y, ancho_paso_x = alto // paso_y, ancho // paso_x |
|
if sin_aleatoriedad: |
|
indice_aleatorio = torch.zeros(alto_paso_y, ancho_paso_x, 1, device=metrica.device, dtype=torch.int64) |
|
else: |
|
indice_aleatorio = torch.randint(paso_y * paso_x, size=(alto_paso_y, ancho_paso_x, 1), device=generador.device, generator=generador).to(metrica.device) |
|
vista_buffer_indice = torch.zeros(alto_paso_y, ancho_paso_x, paso_y * paso_x, device=metrica.device, dtype=torch.int64) |
|
vista_buffer_indice.scatter_(dim=2, index=indice_aleatorio, src=-torch.ones_like(indice_aleatorio, dtype=indice_aleatorio.dtype)) |
|
vista_buffer_indice = vista_buffer_indice.view(alto_paso_y, ancho_paso_x, paso_y, paso_x).transpose(1, 2).reshape(alto_paso_y * paso_y, ancho_paso_x * paso_x) |
|
if (alto_paso_y * paso_y) < alto or (ancho_paso_x * paso_x) < ancho: |
|
buffer_indice = torch.zeros(alto, ancho, device=metrica.device, dtype=torch.int64) |
|
buffer_indice[:(alto_paso_y * paso_y), :(ancho_paso_x * paso_x)] = vista_buffer_indice |
|
else: |
|
buffer_indice = vista_buffer_indice |
|
indice_aleatorio = buffer_indice.reshape(1, -1, 1).argsort(dim=1) |
|
del buffer_indice, vista_buffer_indice |
|
num_destino = alto_paso_y * ancho_paso_x |
|
indices_a = indice_aleatorio[:, num_destino:, :] |
|
indices_b = indice_aleatorio[:, :num_destino, :] |
|
def dividir(x): |
|
canales = x.shape[-1] |
|
origen = recopilar(x, dim=1, index=indices_a.expand(lote, num_nodos - num_destino, canales)) |
|
destino = recopilar(x, dim=1, index=indices_b.expand(lote, num_destino, canales)) |
|
return origen, destino |
|
metrica = metrica / metrica.norm(dim=-1, keepdim=True) |
|
a, b = dividir(metrica) |
|
puntuaciones = a @ b.transpose(-1, -2) |
|
radio = min(a.shape[1], radio) |
|
nodo_max, nodo_indice = puntuaciones.max(dim=-1) |
|
indice_borde = nodo_max.argsort(dim=-1, descending=True)[..., None] |
|
indice_no_emparejado = indice_borde[..., radio:, :] |
|
indice_origen = indice_borde[..., :radio, :] |
|
indice_destino = recopilar(nodo_indice[..., None], dim=-2, index=indice_origen) |
|
def fusionar(x: torch.Tensor, modo="mean") -> torch.Tensor: |
|
origen, destino = dividir(x) |
|
n, t1, c = origen.shape |
|
no_emparejado = recopilar(origen, dim=-2, index=indice_no_emparejado.expand(n, t1 - radio, c)) |
|
origen = recopilar(origen, dim=-2, index=indice_origen.expand(n, radio, c)) |
|
destino = destino.scatter_reduce(-2, indice_destino.expand(n, radio, c), origen, reduce=modo) |
|
return torch.cat([no_emparejado, destino], dim=1) |
|
def desfusionar(x: torch.Tensor) -> torch.Tensor: |
|
longitud_no_emparejado = indice_no_emparejado.shape[1] |
|
no_emparejado, destino = x[..., :longitud_no_emparejado, :], x[..., longitud_no_emparejado:, :] |
|
_, _, c = no_emparejado.shape |
|
origen = recopilar(destino, dim=-2, index=indice_destino.expand(lote, radio, c)) |
|
salida = torch.zeros(lote, num_nodos, c, device=x.device, dtype=x.dtype) |
|
salida.scatter_(dim=-2, index=indices_b.expand(lote, num_destino, c), src=destino) |
|
salida.scatter_(dim=-2, index=recopilar(indices_a.expand(lote, indices_a.shape[1], 1), dim=1, index=indice_no_emparejado).expand(lote, longitud_no_emparejado, c), src=no_emparejado) |
|
salida.scatter_(dim=-2, index=recopilar(indices_a.expand(lote, indices_a.shape[1], 1), dim=1, index=indice_origen).expand(lote, radio, c), src=origen) |
|
return salida |
|
return fusionar, desfusionar |