File size: 705 Bytes
ac91715 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
from typing import Dict, List, Union
import numpy as np
import torch
def detach_to_numpy(data: Union[List, Dict, torch.Tensor]) -> Union[List, Dict, torch.Tensor]:
"""
Recursively detach elements in data
"""
if isinstance(data, torch.Tensor):
return data.cpu().detach().numpy() # pytype: disable=attribute-error
elif isinstance(data, np.ndarray):
return data
elif isinstance(data, list):
return [detach_to_numpy(d) for d in data]
elif isinstance(data, dict):
for k in data.keys():
data[k] = detach_to_numpy(data[k])
return data
else:
raise ValueError("data should be tensor, numpy array, dict, or list.")
|