Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,527 Bytes
d711508 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Adapted from https://botorch.org/api/_modules/botorch/utils/torch.html
# TODO: To be removed once (if) https://github.com/pytorch/pytorch/pull/37385 lands
from __future__ import annotations
import collections
from collections import OrderedDict
import torch
from torch.nn import Module
class BufferDict(Module):
r"""
Holds buffers in a dictionary.
BufferDict can be indexed like a regular Python dictionary, but buffers it contains are properly registered, and
will be visible by all Module methods. `torch.nn.BufferDict` is an **ordered** dictionary that respects
* the order of insertion, and
* in `torch.nn.BufferDict.update`, the order of the merged `OrderedDict` or another `torch.nn.BufferDict` (the
argument to `torch.nn.BufferDict.update`).
Note that `torch.nn.BufferDict.update` with other unordered mapping types (e.g., Python's plain `dict`) does not
preserve the order of the merged mapping.
Args:
buffers (iterable, optional):
a mapping (dictionary) of (string : `torch.Tensor`) or an iterable of key-value pairs of type (string,
`torch.Tensor`)
```python
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.buffers = nn.BufferDict({"left": torch.randn(5, 10), "right": torch.randn(5, 10)})
def forward(self, x, choice):
x = self.buffers[choice].mm(x)
return x
```
"""
def __init__(self, buffers=None, persistent: bool = False):
r"""
Args:
buffers (`dict`):
A mapping (dictionary) from string to `torch.Tensor`, or an iterable of key-value pairs of type
(string, `torch.Tensor`).
"""
super().__init__()
if buffers is not None:
self.update(buffers)
self.persistent = persistent
def __getitem__(self, key):
return self._buffers[key]
def __setitem__(self, key, buffer):
self.register_buffer(key, buffer, persistent=self.persistent)
def __delitem__(self, key):
del self._buffers[key]
def __len__(self):
return len(self._buffers)
def __iter__(self):
return iter(self._buffers.keys())
def __contains__(self, key):
return key in self._buffers
def clear(self):
"""Remove all items from the BufferDict."""
self._buffers.clear()
def pop(self, key):
r"""Remove key from the BufferDict and return its buffer.
Args:
key (`str`):
Key to pop from the BufferDict
"""
v = self[key]
del self[key]
return v
def keys(self):
r"""Return an iterable of the BufferDict keys."""
return self._buffers.keys()
def items(self):
r"""Return an iterable of the BufferDict key/value pairs."""
return self._buffers.items()
def values(self):
r"""Return an iterable of the BufferDict values."""
return self._buffers.values()
def update(self, buffers):
r"""
Update the `torch.nn.BufferDict` with the key-value pairs from a mapping or an iterable, overwriting existing
keys.
Note:
If `buffers` is an `OrderedDict`, a `torch.nn.BufferDict`, or an iterable of key-value pairs, the order of
new elements in it is preserved.
Args:
buffers (iterable):
a mapping (dictionary) from string to `torch.Tensor`, or an iterable of key-value pairs of type
(string, `torch.Tensor`).
"""
if not isinstance(buffers, collections.abc.Iterable):
raise TypeError(
"BuffersDict.update should be called with an "
"iterable of key/value pairs, but got " + type(buffers).__name__
)
if isinstance(buffers, collections.abc.Mapping):
if isinstance(buffers, (OrderedDict, BufferDict)):
for key, buffer in buffers.items():
self[key] = buffer
else:
for key, buffer in sorted(buffers.items()):
self[key] = buffer
else:
for j, p in enumerate(buffers):
if not isinstance(p, collections.abc.Iterable):
raise TypeError(
"BufferDict update sequence element "
"#" + str(j) + " should be Iterable; is" + type(p).__name__
)
if not len(p) == 2:
raise ValueError(
"BufferDict update sequence element "
"#" + str(j) + " has length " + str(len(p)) + "; 2 is required"
)
self[p[0]] = p[1]
def extra_repr(self):
child_lines = []
for k, p in self._buffers.items():
size_str = "x".join(str(size) for size in p.size())
device_str = "" if not p.is_cuda else f" (GPU {p.get_device()})"
parastr = f"Buffer containing: [{torch.typename(p)} of size {size_str}{device_str}]"
child_lines.append(" (" + k + "): " + parastr)
tmpstr = "\n".join(child_lines)
return tmpstr
def __call__(self, input):
raise RuntimeError("BufferDict should not be called.")
|