@@ -11,13 +11,13 @@ class Graph:
1111 size = 0
1212
1313 @classmethod
14- def add_node (cls , node ):
14+ def _add_node (cls , node ):
1515 '''添加图节点'''
1616 cls .node_list .append (node )
1717 cls .size += 1
1818
1919 @classmethod
20- def free_node (cls , node ):
20+ def _free_node (cls , node ):
2121 node .last .clear ()
2222
2323 index = cls .node_list .index (node )
@@ -87,7 +87,7 @@ def __init__(
8787 with self .device :
8888 self .grad = self .xp .zeros (self .shape , dtype = dtype )
8989 self .last : list [Tensor ] = list ()
90- Graph .add_node (self )
90+ Graph ._add_node (self )
9191 else :
9292 self .grad = None
9393
@@ -249,72 +249,66 @@ def __abs__(self):
249249 def __getitem__ (self , key ):
250250 return _get_slice (self , key )
251251
252- def __setitem__ (self , key , value ):
253- if is_grad_enable () and self .requires_grad :
252+ def _inplace (self , * others : Tensor , func ):
253+ if self .requires_grad and is_grad_enable () :
254254 raise ValueError (
255255 "In-place operation is forbidden in node requires grad." )
256- if isinstance (key , Tensor ):
257- key = key .data
258256
259- with self . device :
260- self . data [ key ] = value . data if isinstance ( value , Tensor ) else value
257+ others = tuple ( other . data if isinstance ( other , Tensor ) else other
258+ for other in others )
261259
262- def __inplace (self , other , func ):
263- if self .requires_grad :
264- raise ValueError (
265- "In-place operation is forbidden in node requires grad." )
266- if isinstance (other , Tensor ):
267- other = other .data
268260 with self .device :
269- self . data [...] = func (self . data , other )
261+ func (* others )
270262 return self
271263
264+ def __setitem__ (self , key , value ):
265+ return self ._inplace (key , value , func = self .data .__setitem__ )
266+
272267 def __iadd__ (self , other ):
273- return self .__inplace (other , lambda x , y : x + y )
268+ return self ._inplace (other , func = self . data . __iadd__ )
274269
275270 def __isub__ (self , other ):
276- return self .__inplace (other , lambda x , y : x - y )
271+ return self ._inplace (other , func = self . data . __isub__ )
277272
278273 def __imul__ (self , other ):
279- return self .__inplace (other , lambda x , y : x * y )
274+ return self ._inplace (other , func = self . data . __imul__ )
280275
281276 def __itruediv__ (self , other ):
282- return self .__inplace (other , lambda x , y : x / y )
277+ return self ._inplace (other , func = self . data . __itruediv__ )
283278
284279 def __imatmul__ (self , other ):
285- return self .__inplace (other , lambda x , y : x @ y )
286-
287- def __compare (self , other , func ):
288- if isinstance (other , Tensor ):
289- other = other .data
280+ return self ._inplace (other , func = self .data .__imatmul__ )
290281
282+ def _compare (self , other , func ):
291283 with self .device :
292- return Tensor (func (self .data , other ), self .xp .bool_ , None ,
293- self .device , False )
284+ return Tensor (
285+ func (self .data ,
286+ other .data if isinstance (other , Tensor ) else other ),
287+ self .xp .bool_ , None , self .device , False )
294288
295289 @no_grad ()
296290 def eq (self , other ):
297- return self .__compare (other , lambda x , y : x == y )
291+ return self ._compare (other , lambda x , y : x == y )
298292
299293 @no_grad ()
300294 def ne (self , other ):
301- return self .__compare (other , lambda x , y : x != y )
295+ return self ._compare (other , lambda x , y : x != y )
302296
303297 @no_grad ()
304298 def __lt__ (self , other ):
305- return self .__compare (other , lambda x , y : x < y )
299+ return self ._compare (other , lambda x , y : x < y )
306300
307301 @no_grad ()
308302 def __le__ (self , other ):
309- return self .__compare (other , lambda x , y : x <= y )
303+ return self ._compare (other , lambda x , y : x <= y )
310304
311305 @no_grad ()
312306 def __gt__ (self , other ):
313- return not self .__le__ (other )
307+ return self ._compare (other , lambda x , y : x > y )
314308
315309 @no_grad ()
316310 def __ge__ (self , other ):
317- return not self .__lt__ (other )
311+ return self ._compare (other , lambda x , y : x >= y )
318312
319313 def backward (self , retain_graph : bool = False ):
320314 '''
@@ -364,7 +358,7 @@ def backward(self, retain_graph: bool = False):
364358
365359 # if not retain graph and node is not leaf, free it
366360 if not retain_graph and not node .is_leaf :
367- Graph .free_node (node )
361+ Graph ._free_node (node )
368362
369363 def _build_edge (self , node : Tensor ):
370364 node .last .append (self )
@@ -397,8 +391,8 @@ def to(self, device):
397391 def cpu (self ):
398392 return self .to ("cpu" )
399393
400- def cuda (self ):
401- return self .to ("cuda:0 " )
394+ def cuda (self , id : int = 0 ):
395+ return self .to (f "cuda:{ id } " )
402396
403397 @property
404398 def xp (self ):
@@ -809,7 +803,7 @@ def grad_fn(self, x: Tensor, grad) -> np.ndarray:
809803class minimum (_BinaryOperator ):
810804
811805 def forward_ (self , x : Tensor , y : Tensor ) -> np .ndarray :
812- return self .xp .minimum (x , y )
806+ return self .xp .minimum (x . data , y . data )
813807
814808 def grad_fn (self , x : Tensor , grad ) -> np .ndarray :
815809 return (self .data == x ) * grad
@@ -983,3 +977,29 @@ def grad_fn(self, x, grad: np.ndarray):
983977 slc = [slice (None )] * grad .ndim
984978 slc [self .axis ] = slice (start , end )
985979 return grad [tuple (slc )]
980+
981+
982+ class sigmoid (_UnaryOperator ):
983+ '''Sigmoid运算, 我们前向传播避免了溢出问题'''
984+
985+ def forward_ (self , x : Tensor ) -> np .ndarray :
986+ sigmoid = self .xp .zeros (x .shape , dtype = x .dtype )
987+ sigmoid [x .data > 0 ] = 1 / (1 + self .xp .exp (- x .data [x .data > 0 ]))
988+ sigmoid [x .data <= 0 ] = 1 - 1 / (1 + self .xp .exp (x .data [x .data <= 0 ]))
989+ return sigmoid
990+
991+ def grad_fn (self , x : Tensor , grad : np .ndarray ) -> np .ndarray :
992+ return self .data * (1 - self .data ) * grad
993+
994+
995+ class tanh (_UnaryOperator ):
996+ '''Tanh运算, 我们前向传播避免了溢出问题'''
997+
998+ def forward_ (self , x : Tensor ) -> np .ndarray :
999+ tanh = self .xp .zeros (x .shape , dtype = x .dtype )
1000+ tanh [x .data > 0 ] = 2 / (1 + self .xp .exp (- 2 * x .data [x .data > 0 ])) - 1
1001+ tanh [x .data <= 0 ] = 1 - 2 / (1 + self .xp .exp (2 * x .data [x .data <= 0 ]))
1002+ return tanh
1003+
1004+ def grad_fn (self , x : Tensor , grad : np .ndarray ) -> np .ndarray :
1005+ return (1 - self .data ** 2 ) * grad
0 commit comments