|
5 | 5 | import functools |
6 | 6 | import itertools |
7 | 7 | import operator |
| 8 | +from warnings import warn |
8 | 9 | from time import time |
9 | 10 | import pdb |
10 | 11 |
|
@@ -414,27 +415,50 @@ def get_all_params(self): |
414 | 415 | params = sum([l.get_params() for l in layers], []) |
415 | 416 | return unique(params) |
416 | 417 |
|
417 | | - def load_weights_from(self, source): |
| 418 | + def get_all_params_values(self): |
| 419 | + return_value = OrderedDict() |
| 420 | + for name, layer in self.layers_.items(): |
| 421 | + return_value[name] = [p.get_value() for p in layer.get_params()] |
| 422 | + return return_value |
| 423 | + |
| 424 | + def load_params_from(self, source): |
418 | 425 | self.initialize() |
419 | 426 |
|
420 | 427 | if isinstance(source, str): |
421 | | - source = np.load(source) |
| 428 | + with open(source, 'rb') as f: |
| 429 | + source = pickle.load(f) |
422 | 430 |
|
423 | 431 | if isinstance(source, NeuralNet): |
424 | | - source = source.get_all_params() |
| 432 | + source = source.get_all_params_values() |
425 | 433 |
|
426 | | - source_weights = [ |
427 | | - w.get_value() if hasattr(w, 'get_value') else w for w in source] |
| 434 | + for key, values in source.items(): |
| 435 | + layer = self.layers_.get(key) |
| 436 | + if layer is not None: |
| 437 | + for p1, p2v in zip(layer.get_params(), values): |
| 438 | + if p1.get_value().shape == p2v.shape: |
| 439 | + p1.set_value(p2v) |
428 | 440 |
|
429 | | - for w1, w2 in zip(source_weights, self.get_all_params()): |
430 | | - if w1.shape != w2.get_value().shape: |
431 | | - continue |
432 | | - w2.set_value(w1) |
| 441 | + def save_params_to(self, fname): |
| 442 | + params = self.get_all_params_values() |
| 443 | + with open(fname, 'wb') as f: |
| 444 | + pickle.dump(params, f, -1) |
| 445 | + |
| 446 | + def load_weights_from(self, source): |
| 447 | + warn("The 'load_weights_from' method will be removed in nolearn 0.6. " |
| 448 | + "Please use 'load_params_from' instead.") |
| 449 | + |
| 450 | + if isinstance(source, list): |
| 451 | + raise ValueError( |
| 452 | + "Loading weights from a list of parameter values is no " |
| 453 | + "longer supported. Please send me something like the " |
| 454 | + "return value of 'net.get_all_param_values()' instead.") |
| 455 | + |
| 456 | + return self.load_params_from(source) |
433 | 457 |
|
434 | 458 | def save_weights_to(self, fname): |
435 | | - weights = [w.get_value() for w in self.get_all_params()] |
436 | | - with open(fname, 'wb') as f: |
437 | | - pickle.dump(weights, f, -1) |
| 459 | + warn("The 'save_weights_to' method will be removed in nolearn 0.6. " |
| 460 | + "Please use 'save_params_to' instead.") |
| 461 | + return self.save_params_to(fname) |
438 | 462 |
|
439 | 463 | def __getstate__(self): |
440 | 464 | state = dict(self.__dict__) |
|
0 commit comments