Skip to content

Commit fa0cd08

Browse files
authored
Change tf trainer import (#2271)
* Remove tensorflow import for module * Fix import
1 parent 90b6090 commit fa0cd08

1 file changed

Lines changed: 7 additions & 21 deletions

File tree

src/garage/trainer.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,7 @@
1111
from garage.experiment.experiment import dump_json
1212
from garage.experiment.snapshotter import Snapshotter
1313

14-
# pylint: disable=no-name-in-module
15-
16-
tf = False
17-
try:
18-
import tensorflow as tf
19-
except ImportError:
20-
pass
14+
tf = None
2115

2216

2317
class ExperimentStats:
@@ -536,6 +530,7 @@ class NotSetupError(Exception):
536530
"""Raise when an experiment is about to run without setup."""
537531

538532

533+
# pylint: disable=no-member
539534
class TFTrainer(Trainer):
540535
"""This class implements a trainer for TensorFlow algorithms.
541536
@@ -590,6 +585,11 @@ class TFTrainer(Trainer):
590585
"""
591586

592587
def __init__(self, snapshot_config, sess=None):
588+
# pylint: disable=import-outside-toplevel
589+
import tensorflow
590+
# pylint: disable=global-statement
591+
global tf
592+
tf = tensorflow
593593
super().__init__(snapshot_config=snapshot_config)
594594
self.sess = sess or tf.compat.v1.Session()
595595
self.sess_entered = False
@@ -663,17 +663,3 @@ def initialize_tf_vars(self):
663663
v for v in tf.compat.v1.global_variables()
664664
if v.name.split(':')[0] in uninited_set
665665
]))
666-
667-
668-
class __FakeTFTrainer:
669-
# noqa: E501; pylint: disable=missing-param-doc,too-few-public-methods,no-method-argument
670-
"""Raises an ImportError for environments without TensorFlow."""
671-
672-
def __init__(*args, **kwargs):
673-
raise ImportError(
674-
'TFTrainer requires TensorFlow. To use it, please install '
675-
'TensorFlow.')
676-
677-
678-
if not tf:
679-
TFTrainer = __FakeTFTrainer # noqa: F811

0 commit comments

Comments
 (0)