@@ -74,28 +74,18 @@ def test_to_batches(ds: pd.Dataset):
7474
7575def test_use_threads_configures_worker_pool (monkeypatch : pytest .MonkeyPatch ):
7676 current_workers = 3
77- calls : list [tuple [ str , int ] ] = []
77+ calls : list [int | None ] = []
7878
79- def fake_worker_count () -> int :
79+ def fake_worker_threads () -> int :
8080 return current_workers
8181
82- def fake_set_worker_threads (count : int ) -> None :
82+ def fake_set_worker_threads (count : int | None ) -> None :
8383 nonlocal current_workers
84- calls .append (( "set" , count ) )
85- current_workers = count
84+ calls .append (count )
85+ current_workers = 11 if count is None else count
8686
87- def fake_set_worker_threads_to_available_parallelism () -> None :
88- nonlocal current_workers
89- calls .append (("available" , current_workers ))
90- current_workers = 11
91-
92- monkeypatch .setattr (vx_dataset , "_worker_count" , fake_worker_count )
87+ monkeypatch .setattr (vx_dataset , "_worker_threads" , fake_worker_threads )
9388 monkeypatch .setattr (vx_dataset , "_set_worker_threads" , fake_set_worker_threads )
94- monkeypatch .setattr (
95- vx_dataset ,
96- "_set_worker_threads_to_available_parallelism" ,
97- fake_set_worker_threads_to_available_parallelism ,
98- )
9989
10090 with vx_dataset ._temporary_worker_threads (True ): # pyright: ignore[reportPrivateUsage]
10191 assert current_workers == 11
@@ -106,7 +96,7 @@ def fake_set_worker_threads_to_available_parallelism() -> None:
10696 assert current_workers == 0
10797
10898 assert current_workers == 3
109- assert calls == [( "available" , 3 ), ( "set" , 3 ), ( "set" , 0 ), ( "set" , 3 ) ]
99+ assert calls == [None , 3 , 0 , 3 ]
110100
111101 calls .clear ()
112102 reader = pa .RecordBatchReader .from_batches (
@@ -121,7 +111,7 @@ def fake_set_worker_threads_to_available_parallelism() -> None:
121111
122112 assert [batch .to_pylist () for batch in batches ] == [[{"x" : 1 }], [{"x" : 2 }]]
123113 assert current_workers == 3
124- assert calls == [( "available" , 3 ), ( "set" , 3 ) ]
114+ assert calls == [None , 3 ]
125115
126116
127117@pytest .mark .parametrize ("batch_size" , [1234 , 8192 , 1 << 31 ])
0 commit comments