|
11 | 11 | from garage.experiment.experiment import dump_json |
12 | 12 | from garage.experiment.snapshotter import Snapshotter |
13 | 13 |
|
14 | | -# pylint: disable=no-name-in-module |
15 | | - |
16 | | -tf = False |
17 | | -try: |
18 | | - import tensorflow as tf |
19 | | -except ImportError: |
20 | | - pass |
| 14 | +tf = None |
21 | 15 |
|
22 | 16 |
|
23 | 17 | class ExperimentStats: |
@@ -536,6 +530,7 @@ class NotSetupError(Exception): |
536 | 530 | """Raise when an experiment is about to run without setup.""" |
537 | 531 |
|
538 | 532 |
|
| 533 | +# pylint: disable=no-member |
539 | 534 | class TFTrainer(Trainer): |
540 | 535 | """This class implements a trainer for TensorFlow algorithms. |
541 | 536 |
|
@@ -590,6 +585,11 @@ class TFTrainer(Trainer): |
590 | 585 | """ |
591 | 586 |
|
592 | 587 | def __init__(self, snapshot_config, sess=None): |
| 588 | + # pylint: disable=import-outside-toplevel |
| 589 | + import tensorflow |
| 590 | + # pylint: disable=global-statement |
| 591 | + global tf |
| 592 | + tf = tensorflow |
593 | 593 | super().__init__(snapshot_config=snapshot_config) |
594 | 594 | self.sess = sess or tf.compat.v1.Session() |
595 | 595 | self.sess_entered = False |
@@ -663,17 +663,3 @@ def initialize_tf_vars(self): |
663 | 663 | v for v in tf.compat.v1.global_variables() |
664 | 664 | if v.name.split(':')[0] in uninited_set |
665 | 665 | ])) |
666 | | - |
667 | | - |
668 | | -class __FakeTFTrainer: |
669 | | - # noqa: E501; pylint: disable=missing-param-doc,too-few-public-methods,no-method-argument |
670 | | - """Raises an ImportError for environments without TensorFlow.""" |
671 | | - |
672 | | - def __init__(*args, **kwargs): |
673 | | - raise ImportError( |
674 | | - 'TFTrainer requires TensorFlow. To use it, please install ' |
675 | | - 'TensorFlow.') |
676 | | - |
677 | | - |
678 | | -if not tf: |
679 | | - TFTrainer = __FakeTFTrainer # noqa: F811 |
0 commit comments