Skip to content

Commit b72e6c7

Browse files
committed
feat: add register buffer in InfiniCoreModule
1 parent 5875295 commit b72e6c7

2 files changed

Lines changed: 58 additions & 3 deletions

File tree

python/infinicore/nn/modules/module.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@
1414
from collections import OrderedDict, namedtuple
1515
import itertools
1616
import torch
17-
import torch.nn as nn
1817
from typing import Optional, Iterator, Tuple, Dict, Mapping, List, Any, overload, TypeVar
19-
import inspect
2018
import warnings
2119

2220
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
@@ -95,6 +93,7 @@ def remove_from(*dicts_or_sets) -> None:
9593
self.__dict__,
9694
self._parameters,
9795
self._buffers,
96+
self._non_persistent_buffers_set,
9897
)
9998
modules[name] = value
10099
elif modules is not None and name in modules:
@@ -113,7 +112,63 @@ def remove_from(*dicts_or_sets) -> None:
113112
)
114113
buffers[name] = value
115114
else:
116-
super.__setattr__(name, value)
115+
super().__setattr__(name, value)
116+
117+
def register_buffer(self, name: str, tensor: Optional[torch.tensor], persistent: bool = True) -> None:
118+
r"""Adds a buffer to the module.
119+
120+
This is typically used to register a buffer that should not to be
121+
considered a model parameter. For example, BatchNorm's ``running_mean``
122+
is not a parameter, but is part of the module's state. Buffers, by
123+
default, are persistent and will be saved alongside parameters. This
124+
behavior can be changed by setting :attr:`persistent` to ``False``. The
125+
only difference between a persistent buffer and a non-persistent buffer
126+
is that the latter will not be a part of this module's
127+
:attr:`state_dict`.
128+
129+
Buffers can be accessed as attributes using given names.
130+
131+
Args:
132+
name (str): name of the buffer. The buffer can be accessed
133+
from this module using the given name
134+
tensor (Tensor or None): buffer to be registered. If ``None``, then operations
135+
that run on buffers, such as :attr:`cuda`, are ignored. If ``None``,
136+
the buffer is **not** included in the module's :attr:`state_dict`.
137+
persistent (bool): whether the buffer is part of this module's
138+
:attr:`state_dict`.
139+
140+
Example::
141+
142+
>>> # xdoctest: +SKIP("undefined vars")
143+
>>> self.register_buffer('running_mean', torch.zeros(num_features))
144+
145+
"""
146+
if persistent is False and isinstance(self, torch.jit.ScriptModule):
147+
raise RuntimeError("ScriptModule does not support non-persistent buffers")
148+
149+
if '_buffers' not in self.__dict__:
150+
raise AttributeError(
151+
"cannot assign buffer before Module.__init__() call")
152+
elif not isinstance(name, str):
153+
raise TypeError("buffer name should be a string. "
154+
"Got {}".format(torch.typename(name)))
155+
elif '.' in name:
156+
raise KeyError("buffer name can't contain \".\"")
157+
elif name == '':
158+
raise KeyError("buffer name can't be empty string \"\"")
159+
elif hasattr(self, name) and name not in self._buffers:
160+
raise KeyError("attribute '{}' already exists".format(name))
161+
elif tensor is not None and not isinstance(tensor, torch.Tensor):
162+
raise TypeError("cannot assign '{}' object to buffer '{}' "
163+
"(torch Tensor or None required)"
164+
.format(torch.typename(tensor), name))
165+
else:
166+
self._buffers[name] = tensor
167+
if persistent:
168+
self._non_persistent_buffers_set.discard(name)
169+
else:
170+
self._non_persistent_buffers_set.add(name)
171+
117172

118173
def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None:
119174
r"""Add a parameter to the module.

torch_model.safetensors

504 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)