Skip to content

Commit dfb32ae

Browse files
committed
A new 'load_params_from' is more clever about loading weights.
Specifically, it will match layers by name and copy weights if names match, instead of the brute-force approach using list indices, which doesn't work reliably in some cases. Aims to supercede #77. Still needs better testing and possibly a way to override some of the matching.
1 parent ea54a6d commit dfb32ae

2 files changed

Lines changed: 44 additions & 13 deletions

File tree

nolearn/lasagne/base.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import functools
66
import itertools
77
import operator
8+
from warnings import warn
89
from time import time
910
import pdb
1011

@@ -414,27 +415,50 @@ def get_all_params(self):
414415
params = sum([l.get_params() for l in layers], [])
415416
return unique(params)
416417

417-
def load_weights_from(self, source):
418+
def get_all_params_values(self):
419+
return_value = OrderedDict()
420+
for name, layer in self.layers_.items():
421+
return_value[name] = [p.get_value() for p in layer.get_params()]
422+
return return_value
423+
424+
def load_params_from(self, source):
418425
self.initialize()
419426

420427
if isinstance(source, str):
421-
source = np.load(source)
428+
with open(source, 'rb') as f:
429+
source = pickle.load(f)
422430

423431
if isinstance(source, NeuralNet):
424-
source = source.get_all_params()
432+
source = source.get_all_params_values()
425433

426-
source_weights = [
427-
w.get_value() if hasattr(w, 'get_value') else w for w in source]
434+
for key, values in source.items():
435+
layer = self.layers_.get(key)
436+
if layer is not None:
437+
for p1, p2v in zip(layer.get_params(), values):
438+
if p1.get_value().shape == p2v.shape:
439+
p1.set_value(p2v)
428440

429-
for w1, w2 in zip(source_weights, self.get_all_params()):
430-
if w1.shape != w2.get_value().shape:
431-
continue
432-
w2.set_value(w1)
441+
def save_params_to(self, fname):
442+
params = self.get_all_params_values()
443+
with open(fname, 'wb') as f:
444+
pickle.dump(params, f, -1)
445+
446+
def load_weights_from(self, source):
447+
warn("The 'load_weights_from' method will be removed in nolearn 0.6. "
448+
"Please use 'load_params_from' instead.")
449+
450+
if isinstance(source, list):
451+
raise ValueError(
452+
"Loading weights from a list of parameter values is no "
453+
"longer supported. Please send me something like the "
454+
"return value of 'net.get_all_param_values()' instead.")
455+
456+
return self.load_params_from(source)
433457

434458
def save_weights_to(self, fname):
435-
weights = [w.get_value() for w in self.get_all_params()]
436-
with open(fname, 'wb') as f:
437-
pickle.dump(weights, f, -1)
459+
warn("The 'save_weights_to' method will be removed in nolearn 0.6. "
460+
"Please use 'save_params_to' instead.")
461+
return self.save_params_to(fname)
438462

439463
def __getstate__(self):
440464
state = dict(self.__dict__)

nolearn/tests/test_lasagne.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,16 @@ def on_epoch_finished(nn, train_history):
118118

119119
# Use load_weights_from to initialize an untrained model:
120120
nn3 = clone(nn_def)
121-
nn3.load_weights_from(nn2)
121+
nn3.load_params_from(nn2)
122122
assert np.array_equal(nn3.predict(X_test), y_pred)
123123

124+
# Use save_params_to and load_params_from with a path:
125+
path = '/tmp/test_lasagne_functional_mnist.params'
126+
nn.save_params_to(path)
127+
nn4 = clone(nn_def)
128+
nn4.load_params_from(path)
129+
assert np.array_equal(nn4.predict(X_test), y_pred)
130+
124131

125132
def test_lasagne_functional_grid_search(mnist, monkeypatch):
126133
# Make sure that we can satisfy the grid search interface.

0 commit comments

Comments
 (0)