-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Expand file tree
/
Copy path__init__.py
More file actions
159 lines (137 loc) · 6 KB
/
__init__.py
File metadata and controls
159 lines (137 loc) · 6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Copyright 2018 The TensorFlow Probability Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tools for probabilistic reasoning in TensorFlow."""
import functools
import sys
import types
from tensorflow_probability.python.internal import all_util
from tensorflow_probability.python.internal import lazy_loader
# pylint: disable=g-import-not-at-top
def _validate_tf_environment(package):
"""Check TF version and (depending on package) warn about TensorFloat32.
Args:
package: Python `str` indicating which package is being imported. Used for
package-dependent warning about TensorFloat32.
Raises:
ImportError: if either tensorflow is not importable or its version is
inadequate.
"""
try:
import tensorflow as tf
except (ImportError, ModuleNotFoundError) as err:
# Raise same type of error, but with more informative error message.
# Using print will lead to a message above the stacktrace and can easily
# be overlooked. For Python 3.11+ may switch to "err.add_note(...)"
raise type(err)('\n\nFailed to import TensorFlow. Please note that TensorFlow '
'is not installed by default when you install TensorFlow Probability. '
'This is so that users can decide whether to install the GPU-enabled '
'TensorFlow package. To use TensorFlow Probability, please install '
'the most recent version of TensorFlow, by following instructions at '
'https://tensorflow.org/install.\n\n') from err
import distutils.version
#
# Update this whenever we need to depend on a newer TensorFlow release.
#
required_tensorflow_version = '2.18'
# required_tensorflow_version = '1.15' # Needed internally -- DisableOnExport
if (distutils.version.LooseVersion(tf.__version__) <
distutils.version.LooseVersion(required_tensorflow_version)):
raise ImportError(
'This version of TensorFlow Probability requires TensorFlow '
'version >= {required}; Detected an installation of version {present}. '
'Please upgrade TensorFlow to proceed.'.format(
required=required_tensorflow_version,
present=tf.__version__))
if (package == 'mcmc' and
tf.config.experimental.tensor_float_32_execution_enabled()):
# Must import here, because symbols get pruned to __all__.
import warnings
warnings.warn(
'TensorFloat-32 matmul/conv are enabled for NVIDIA Ampere+ GPUs. The '
'resulting loss of precision may hinder MCMC convergence. To turn off, '
'run `tf.config.experimental.enable_tensor_float_32_execution(False)`. '
'For more detail, see https://github.com/tensorflow/community/pull/287.'
)
if required_tensorflow_version[0] == '2':
try:
import tf_keras # pylint: disable=unused-import
except (ImportError, ModuleNotFoundError) as err:
# Raise same type of error, but with more informative error message.
# Using print will lead to a message above the stacktrace and can easily
# be overlooked. For Python 3.11+ may switch to "err.add_note(...)"
raise type(err)('\n\nFailed to import TF-Keras. '
'Please note that TF-Keras is not '
'installed by default when you install TensorFlow Probability. '
'This is so that JAX-only users do not have to install TensorFlow '
'or TF-Keras. To use TensorFlow Probability with TensorFlow, '
'please install the tf-keras or tf-keras-nightly package.\n'
'This can be be done through installing the '
'tensorflow-probability[tf] extra or directly installing tf-keras '
'(i.e. when using conda/mamba).\n\n') from err
# Declare these explicitly to appease pytype, which otherwise misses them,
# presumably due to lazy loading.
bijectors: types.ModuleType
debugging: types.ModuleType
distributions: types.ModuleType
experimental: types.ModuleType
glm: types.ModuleType
layers: types.ModuleType
math: types.ModuleType
mcmc: types.ModuleType
monte_carlo: types.ModuleType
optimizer: types.ModuleType
random: types.ModuleType
stats: types.ModuleType
sts: types.ModuleType
util: types.ModuleType
vi: types.ModuleType
_lazy_load = [
'bijectors',
'debugging',
'distributions',
'glm',
'math',
'mcmc',
'monte_carlo',
'optimizer',
'random',
'stats',
'sts',
'util',
'vi',
]
# If TensorFlow is already imported, we should non-lazily load modules which
# include registrations (e.g., Keras layer registrations and CompositeTensor
# registrations) -- which must be loaded when deserializing tensorflow
# saved models.
_maybe_nonlazy_load = [
'experimental',
'layers',
]
def _tf_loaded():
return 'compat' in dir(sys.modules.get('tensorflow', None))
# To start with, lazy-load everything. Later we may replace some of the
# lazy-loaded modules by forcing a load.
for pkg_name in _lazy_load + _maybe_nonlazy_load:
globals()[pkg_name] = lazy_loader.LazyLoader(
pkg_name, globals(), 'tensorflow_probability.python.{}'.format(pkg_name),
# These checks need to happen before lazy-loading, since the modules
# themselves will try to import tensorflow, too.
on_first_access=functools.partial(_validate_tf_environment, pkg_name))
if _tf_loaded():
# Non-lazy load of packages that register with tensorflow or keras.
for pkg_name in _maybe_nonlazy_load:
dir(globals()[pkg_name]) # Forces loading the package from its lazy loader.
all_util.remove_undocumented(__name__, _lazy_load + _maybe_nonlazy_load)