1414from collections import OrderedDict , namedtuple
1515import itertools
1616import torch
17- import torch .nn as nn
1817from typing import Optional , Iterator , Tuple , Dict , Mapping , List , Any , overload , TypeVar
19- import inspect
2018import warnings
2119
2220_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
@@ -95,6 +93,7 @@ def remove_from(*dicts_or_sets) -> None:
9593 self .__dict__ ,
9694 self ._parameters ,
9795 self ._buffers ,
96+ self ._non_persistent_buffers_set ,
9897 )
9998 modules [name ] = value
10099 elif modules is not None and name in modules :
@@ -113,7 +112,63 @@ def remove_from(*dicts_or_sets) -> None:
113112 )
114113 buffers [name ] = value
115114 else :
116- super .__setattr__ (name , value )
115+ super ().__setattr__ (name , value )
116+
117+ def register_buffer (self , name : str , tensor : Optional [torch .tensor ], persistent : bool = True ) -> None :
118+ r"""Adds a buffer to the module.
119+
120+ This is typically used to register a buffer that should not to be
121+ considered a model parameter. For example, BatchNorm's ``running_mean``
122+ is not a parameter, but is part of the module's state. Buffers, by
123+ default, are persistent and will be saved alongside parameters. This
124+ behavior can be changed by setting :attr:`persistent` to ``False``. The
125+ only difference between a persistent buffer and a non-persistent buffer
126+ is that the latter will not be a part of this module's
127+ :attr:`state_dict`.
128+
129+ Buffers can be accessed as attributes using given names.
130+
131+ Args:
132+ name (str): name of the buffer. The buffer can be accessed
133+ from this module using the given name
134+ tensor (Tensor or None): buffer to be registered. If ``None``, then operations
135+ that run on buffers, such as :attr:`cuda`, are ignored. If ``None``,
136+ the buffer is **not** included in the module's :attr:`state_dict`.
137+ persistent (bool): whether the buffer is part of this module's
138+ :attr:`state_dict`.
139+
140+ Example::
141+
142+ >>> # xdoctest: +SKIP("undefined vars")
143+ >>> self.register_buffer('running_mean', torch.zeros(num_features))
144+
145+ """
146+ if persistent is False and isinstance (self , torch .jit .ScriptModule ):
147+ raise RuntimeError ("ScriptModule does not support non-persistent buffers" )
148+
149+ if '_buffers' not in self .__dict__ :
150+ raise AttributeError (
151+ "cannot assign buffer before Module.__init__() call" )
152+ elif not isinstance (name , str ):
153+ raise TypeError ("buffer name should be a string. "
154+ "Got {}" .format (torch .typename (name )))
155+ elif '.' in name :
156+ raise KeyError ("buffer name can't contain \" .\" " )
157+ elif name == '' :
158+ raise KeyError ("buffer name can't be empty string \" \" " )
159+ elif hasattr (self , name ) and name not in self ._buffers :
160+ raise KeyError ("attribute '{}' already exists" .format (name ))
161+ elif tensor is not None and not isinstance (tensor , torch .Tensor ):
162+ raise TypeError ("cannot assign '{}' object to buffer '{}' "
163+ "(torch Tensor or None required)"
164+ .format (torch .typename (tensor ), name ))
165+ else :
166+ self ._buffers [name ] = tensor
167+ if persistent :
168+ self ._non_persistent_buffers_set .discard (name )
169+ else :
170+ self ._non_persistent_buffers_set .add (name )
171+
117172
118173 def register_parameter (self , name : str , param : Optional [torch .nn .Parameter ]) -> None :
119174 r"""Add a parameter to the module.
0 commit comments