|
| 1 | +%%% @doc Regression contract for context thread affinity with real |
| 2 | +%%% ML libraries. |
| 3 | +%%% |
| 4 | +%%% v3.0 fixed numpy / torch / tensorflow segfaults caused by the |
| 5 | +%%% executor pool moving calls across OS threads. The fix is per-context |
| 6 | +%%% worker pthreads with stable thread affinity. `py_thread_affinity_SUITE' |
| 7 | +%%% checks the threading.get_native_id invariants in isolation; |
| 8 | +%%% this suite drives actual numpy and tensorflow operations through |
| 9 | +%%% exec / eval / call paths to confirm the libraries' thread-local |
| 10 | +%%% state survives the round-trip. |
| 11 | +%%% |
| 12 | +%%% Skip-on-missing for both libraries: cases that need an unavailable |
| 13 | +%%% module return {skip, ...} from init_per_testcase. TensorFlow is |
| 14 | +%%% always skipped on CI (too heavy to install). Owngil cases additionally |
| 15 | +%%% skip on Python <3.14. |
| 16 | +-module(py_ml_libs_SUITE). |
| 17 | + |
| 18 | +-include_lib("common_test/include/ct.hrl"). |
| 19 | + |
| 20 | +-export([ |
| 21 | + all/0, |
| 22 | + init_per_suite/1, |
| 23 | + end_per_suite/1, |
| 24 | + init_per_testcase/2, |
| 25 | + end_per_testcase/2 |
| 26 | +]). |
| 27 | + |
| 28 | +-export([ |
| 29 | + numpy_basic_ops/1, |
| 30 | + numpy_call_thread_affinity/1, |
| 31 | + numpy_parallel_processes/1, |
| 32 | + numpy_owngil_basic/1, |
| 33 | + tensorflow_basic_ops/1, |
| 34 | + tensorflow_call_thread_affinity/1 |
| 35 | +]). |
| 36 | + |
| 37 | +all() -> |
| 38 | + [ |
| 39 | + numpy_basic_ops, |
| 40 | + numpy_call_thread_affinity, |
| 41 | + numpy_parallel_processes, |
| 42 | + numpy_owngil_basic, |
| 43 | + tensorflow_basic_ops, |
| 44 | + tensorflow_call_thread_affinity |
| 45 | + ]. |
| 46 | + |
| 47 | +init_per_suite(Config) -> |
| 48 | + {ok, _} = application:ensure_all_started(erlang_python), |
| 49 | + %% Suppress TensorFlow's chatty C++ logging at import time. |
| 50 | + %% Must be set before TF is imported in any context. |
| 51 | + os:putenv("TF_CPP_MIN_LOG_LEVEL", "3"), |
| 52 | + Config. |
| 53 | + |
| 54 | +end_per_suite(_Config) -> |
| 55 | + ok = application:stop(erlang_python), |
| 56 | + ok. |
| 57 | + |
| 58 | +init_per_testcase(TC, Config) -> |
| 59 | + Mode = case TC of |
| 60 | + numpy_owngil_basic -> owngil; |
| 61 | + _ -> worker |
| 62 | + end, |
| 63 | + case Mode of |
| 64 | + owngil -> |
| 65 | + case py_nif:owngil_supported() of |
| 66 | + false -> |
| 67 | + {skip, "owngil mode requires Python 3.14+"}; |
| 68 | + true -> |
| 69 | + setup_case(TC, Mode, Config) |
| 70 | + end; |
| 71 | + worker -> |
| 72 | + setup_case(TC, Mode, Config) |
| 73 | + end. |
| 74 | + |
| 75 | +end_per_testcase(_TC, Config) -> |
| 76 | + case proplists:get_value(ctx, Config) of |
| 77 | + undefined -> ok; |
| 78 | + Ctx -> py_context:stop(Ctx) |
| 79 | + end. |
| 80 | + |
| 81 | +%%% --------------------------------------------------------------------------- |
| 82 | +%%% Helpers |
| 83 | +%%% --------------------------------------------------------------------------- |
| 84 | + |
| 85 | +setup_case(TC, Mode, Config) -> |
| 86 | + case py_context:new(#{mode => Mode}) of |
| 87 | + {ok, Ctx} -> |
| 88 | + case require_module(Ctx, required_module(TC)) of |
| 89 | + ok -> |
| 90 | + [{ctx, Ctx} | Config]; |
| 91 | + {skip, _} = Skip -> |
| 92 | + py_context:stop(Ctx), |
| 93 | + Skip |
| 94 | + end; |
| 95 | + {error, Reason} -> |
| 96 | + ct:fail({context_create_failed, Mode, Reason}) |
| 97 | + end. |
| 98 | + |
| 99 | +required_module(numpy_basic_ops) -> "numpy"; |
| 100 | +required_module(numpy_call_thread_affinity) -> "numpy"; |
| 101 | +required_module(numpy_parallel_processes) -> "numpy"; |
| 102 | +required_module(numpy_owngil_basic) -> "numpy"; |
| 103 | +required_module(tensorflow_basic_ops) -> "tensorflow"; |
| 104 | +required_module(tensorflow_call_thread_affinity) -> "tensorflow". |
| 105 | + |
| 106 | +%% Reflect the import status into a Python variable so we can |
| 107 | +%% distinguish "module not installed" (skip) from any other error |
| 108 | +%% (let it bubble up). A native-extension crash that surfaces as a |
| 109 | +%% non-ImportError must not silently turn into a skip. |
| 110 | +require_module(Ctx, Mod) -> |
| 111 | + Code = iolist_to_binary([ |
| 112 | + "try:\n", |
| 113 | + " import ", Mod, "\n", |
| 114 | + " _import_status = 'ok'\n", |
| 115 | + "except ImportError:\n", |
| 116 | + " _import_status = 'not_found'\n" |
| 117 | + ]), |
| 118 | + ok = py_context:exec(Ctx, Code), |
| 119 | + {ok, Status} = py_context:eval(Ctx, <<"_import_status">>, #{}), |
| 120 | + case Status of |
| 121 | + <<"ok">> -> |
| 122 | + ok; |
| 123 | + <<"not_found">> -> |
| 124 | + {skip, "Python module " ++ Mod ++ " not available"} |
| 125 | + end. |
| 126 | + |
| 127 | +native_id(Ctx) -> |
| 128 | + {ok, Tid} = py_context:eval(Ctx, |
| 129 | + <<"__import__('threading').get_native_id()">>, #{}), |
| 130 | + Tid. |
| 131 | + |
| 132 | +%%% --------------------------------------------------------------------------- |
| 133 | +%%% numpy cases |
| 134 | +%%% --------------------------------------------------------------------------- |
| 135 | + |
| 136 | +numpy_basic_ops(Config) -> |
| 137 | + Ctx = ?config(ctx, Config), |
| 138 | + %% Define a numpy-backed function and exercise both call and eval |
| 139 | + %% paths so a thread-state regression in either direction crashes. |
| 140 | + ok = py_context:exec(Ctx, << |
| 141 | + "import numpy as np\n" |
| 142 | + "def vec_dot_self(xs):\n" |
| 143 | + " v = np.array(xs, dtype=np.float64)\n" |
| 144 | + " return float(np.dot(v, v))\n" |
| 145 | + >>), |
| 146 | + {ok, 30.0} = py_context:call(Ctx, '__main__', vec_dot_self, |
| 147 | + [[1.0, 2.0, 3.0, 4.0]]), |
| 148 | + ok = py_context:exec(Ctx, |
| 149 | + <<"_w = vec_dot_self([10.0, 0.0, 0.0])">>), |
| 150 | + {ok, 100.0} = py_context:eval(Ctx, <<"_w">>, #{}), |
| 151 | + ok. |
| 152 | + |
| 153 | +numpy_call_thread_affinity(Config) -> |
| 154 | + Ctx = ?config(ctx, Config), |
| 155 | + ok = py_context:exec(Ctx, << |
| 156 | + "import numpy as np\n" |
| 157 | + "import threading\n" |
| 158 | + "def numpy_with_tid(xs):\n" |
| 159 | + " v = np.array(xs, dtype=np.float64)\n" |
| 160 | + " return (threading.get_native_id(), float(np.sum(v)))\n" |
| 161 | + >>), |
| 162 | + Results = [py_context:call(Ctx, '__main__', numpy_with_tid, |
| 163 | + [[float(I), float(I + 1), float(I + 2)]]) |
| 164 | + || I <- lists:seq(1, 50)], |
| 165 | + Tids = [Tid || {ok, {Tid, _Sum}} <- Results], |
| 166 | + Sums = [Sum || {ok, {_Tid, Sum}} <- Results], |
| 167 | + 50 = length(Tids), |
| 168 | + [SingleTid] = lists:usort(Tids), |
| 169 | + true = is_integer(SingleTid), |
| 170 | + %% Spot-check a few sums. |
| 171 | + Expected = [3.0 * I + 3.0 || I <- lists:seq(1, 50)], |
| 172 | + Expected = Sums, |
| 173 | + ok. |
| 174 | + |
| 175 | +numpy_parallel_processes(Config) -> |
| 176 | + Ctx = ?config(ctx, Config), |
| 177 | + ok = py_context:exec(Ctx, << |
| 178 | + "import numpy as np\n" |
| 179 | + "import threading\n" |
| 180 | + "def numpy_dot_with_tid(xs, ys):\n" |
| 181 | + " a = np.array(xs, dtype=np.float64)\n" |
| 182 | + " b = np.array(ys, dtype=np.float64)\n" |
| 183 | + " return (threading.get_native_id(), float(np.dot(a, b)))\n" |
| 184 | + >>), |
| 185 | + Parent = self(), |
| 186 | + N = 8, |
| 187 | + Pids = [spawn_link(fun() -> |
| 188 | + Xs = [float(K * J) || J <- lists:seq(1, 4)], |
| 189 | + Ys = [float(K + J) || J <- lists:seq(1, 4)], |
| 190 | + R = py_context:call(Ctx, '__main__', numpy_dot_with_tid, |
| 191 | + [Xs, Ys]), |
| 192 | + Parent ! {result, K, R} |
| 193 | + end) || K <- lists:seq(1, N)], |
| 194 | + Results = [receive {result, K, R} -> {K, R} after 5000 -> ct:fail(timeout) end |
| 195 | + || _ <- Pids], |
| 196 | + %% All calls converged on one thread. |
| 197 | + Tids = [Tid || {_K, {ok, {Tid, _Sum}}} <- Results], |
| 198 | + N = length(Tids), |
| 199 | + [SingleTid] = lists:usort(Tids), |
| 200 | + true = is_integer(SingleTid), |
| 201 | + %% Each result matches the expected dot product. |
| 202 | + lists:foreach(fun({K, {ok, {_Tid, Got}}}) -> |
| 203 | + Xs = [float(K * J) || J <- lists:seq(1, 4)], |
| 204 | + Ys = [float(K + J) || J <- lists:seq(1, 4)], |
| 205 | + Expected = lists:sum([X * Y || {X, Y} <- lists:zip(Xs, Ys)]), |
| 206 | + true = abs(Got - Expected) < 1.0e-9 |
| 207 | + end, Results), |
| 208 | + %% Drain time + mailbox sanity (no orphan results). |
| 209 | + timer:sleep(50), |
| 210 | + {messages, []} = erlang:process_info(self(), messages), |
| 211 | + ok. |
| 212 | + |
| 213 | +numpy_owngil_basic(Config) -> |
| 214 | + Ctx = ?config(ctx, Config), |
| 215 | + %% Same shape as numpy_basic_ops but inside an OWN_GIL subinterpreter. |
| 216 | + %% Numpy in OWN_GIL was the original v3.0 motivator on Python 3.14. |
| 217 | + ok = py_context:exec(Ctx, << |
| 218 | + "import numpy as np\n" |
| 219 | + "def vec_dot_self(xs):\n" |
| 220 | + " v = np.array(xs, dtype=np.float64)\n" |
| 221 | + " return float(np.dot(v, v))\n" |
| 222 | + >>), |
| 223 | + {ok, 30.0} = py_context:call(Ctx, '__main__', vec_dot_self, |
| 224 | + [[1.0, 2.0, 3.0, 4.0]]), |
| 225 | + {ok, 25.0} = py_context:call(Ctx, '__main__', vec_dot_self, [[5.0]]), |
| 226 | + %% Confirm the thread is stable across owngil calls. |
| 227 | + Tid1 = native_id(Ctx), |
| 228 | + Tid2 = native_id(Ctx), |
| 229 | + Tid1 = Tid2, |
| 230 | + ok. |
| 231 | + |
| 232 | +%%% --------------------------------------------------------------------------- |
| 233 | +%%% tensorflow cases |
| 234 | +%%% --------------------------------------------------------------------------- |
| 235 | + |
| 236 | +tensorflow_basic_ops(Config) -> |
| 237 | + Ctx = ?config(ctx, Config), |
| 238 | + ok = py_context:exec(Ctx, << |
| 239 | + "import tensorflow as tf\n" |
| 240 | + "def matmul_22():\n" |
| 241 | + " a = tf.constant([[1.0, 2.0], [3.0, 4.0]])\n" |
| 242 | + " return tf.linalg.matmul(a, a).numpy().tolist()\n" |
| 243 | + >>), |
| 244 | + {ok, [[7.0, 10.0], [15.0, 22.0]]} = |
| 245 | + py_context:call(Ctx, '__main__', matmul_22, []), |
| 246 | + ok. |
| 247 | + |
| 248 | +tensorflow_call_thread_affinity(Config) -> |
| 249 | + Ctx = ?config(ctx, Config), |
| 250 | + ok = py_context:exec(Ctx, << |
| 251 | + "import tensorflow as tf\n" |
| 252 | + "import threading\n" |
| 253 | + "def tf_sum_with_tid(xs):\n" |
| 254 | + " t = tf.constant(xs, dtype=tf.float64)\n" |
| 255 | + " return (threading.get_native_id(),\n" |
| 256 | + " float(tf.math.reduce_sum(t).numpy()))\n" |
| 257 | + >>), |
| 258 | + Results = [py_context:call(Ctx, '__main__', tf_sum_with_tid, |
| 259 | + [[float(I), float(I + 1)]]) |
| 260 | + || I <- lists:seq(1, 20)], |
| 261 | + Tids = [Tid || {ok, {Tid, _Sum}} <- Results], |
| 262 | + 20 = length(Tids), |
| 263 | + [SingleTid] = lists:usort(Tids), |
| 264 | + true = is_integer(SingleTid), |
| 265 | + ok. |
0 commit comments