11# SPDX-License-Identifier: LGPL-3.0-or-later
22import logging
3- import os
43from collections import (
54 defaultdict ,
65)
7- from concurrent .futures import (
8- ThreadPoolExecutor ,
9- )
106from typing import (
117 Any ,
128 Callable ,
4339def make_stat_input (
4440 datasets : list [Any ], dataloaders : list [Any ], nbatches : int
4541) -> dict [str , Any ]:
46- """Pack data for statistics in parallel .
42+ """Pack data for statistics.
4743
4844 Args:
4945 - dataset: A list of dataset to analyze.
@@ -53,83 +49,49 @@ def make_stat_input(
5349 -------
5450 - a list of dicts, each of which contains data from a system
5551 """
56- log .info (f"Packing data for statistics from { len (datasets )} systems" )
57- dataloader_lens = [len (dl ) for dl in dataloaders ]
58- args_list = [
59- (dataloaders [i ], nbatches , dataloader_lens [i ]) for i in range (len (datasets ))
60- ]
61-
6252 lst = []
63- # I/O intensive, set a larger number of workers
64- with ThreadPoolExecutor (min (128 , (os .cpu_count () or 1 ) * 6 )) as executor :
65- lst = list (executor .map (_process_one_dataset , args_list ))
66- log .info ("Finished packing data." )
67- return lst
68-
69-
70- def _process_one_dataset (args : tuple [Any , int , int ]) -> dict [str , Any ]:
71- """
72- Helper function to process a single dataset's dataloader for statistics.
73- Designed to be called in parallel by a ThreadPoolExecutor.
74-
75- Parameters
76- ----------
77- args : tuple(Any, int, int)
78- A tuple containing (dataloader, nbatches, dataloader_len)
79-
80- Returns
81- -------
82- dict[str, Any]
83- The processed sys_stat dictionary for one dataset.
84- """
85- dataloader , nbatches , dataloader_len = args
86- sys_stat = {}
87-
88- with torch .device ("cpu" ):
89- iterator = iter (dataloader )
90- numb_batches = min (nbatches , dataloader_len )
91-
92- for _ in range (numb_batches ):
93- try :
94- stat_data = next (iterator )
95- except StopIteration :
96- iterator = iter (dataloader )
97- stat_data = next (iterator )
98-
99- if (
100- "find_fparam" in stat_data
101- and "fparam" in stat_data
102- and stat_data ["find_fparam" ] == 0.0
103- ):
104- # for model using default fparam
105- stat_data .pop ("fparam" )
106- stat_data .pop ("find_fparam" )
107-
108- for dd in stat_data :
109- if stat_data [dd ] is None :
110- sys_stat [dd ] = None
111- elif isinstance (stat_data [dd ], torch .Tensor ):
112- if dd not in sys_stat :
113- sys_stat [dd ] = []
114- sys_stat [dd ].append (stat_data [dd ])
115- elif isinstance (stat_data [dd ], np .float32 ):
116- sys_stat [dd ] = stat_data [dd ]
117- else :
118- pass
119-
120- for key in sys_stat :
121- if isinstance (sys_stat [key ], np .float32 ):
122- pass
123- elif isinstance (sys_stat [key ], list ):
124- if len (sys_stat [key ]) == 0 or sys_stat [key ][0 ] is None :
53+ log .info (f"Packing data for statistics from { len (datasets )} systems" )
54+ for i in range (len (datasets )):
55+ sys_stat = {}
56+ with torch .device ("cpu" ):
57+ iterator = iter (dataloaders [i ])
58+ numb_batches = min (nbatches , len (dataloaders [i ]))
59+ for _ in range (numb_batches ):
60+ try :
61+ stat_data = next (iterator )
62+ except StopIteration :
63+ iterator = iter (dataloaders [i ])
64+ stat_data = next (iterator )
65+ if (
66+ "find_fparam" in stat_data
67+ and "fparam" in stat_data
68+ and stat_data ["find_fparam" ] == 0.0
69+ ):
70+ # for model using default fparam
71+ stat_data .pop ("fparam" )
72+ stat_data .pop ("find_fparam" )
73+ for dd in stat_data :
74+ if stat_data [dd ] is None :
75+ sys_stat [dd ] = None
76+ elif isinstance (stat_data [dd ], torch .Tensor ):
77+ if dd not in sys_stat :
78+ sys_stat [dd ] = []
79+ sys_stat [dd ].append (stat_data [dd ])
80+ elif isinstance (stat_data [dd ], np .float32 ):
81+ sys_stat [dd ] = stat_data [dd ]
82+ else :
83+ pass
84+
85+ for key in sys_stat :
86+ if isinstance (sys_stat [key ], np .float32 ):
87+ pass
88+ elif sys_stat [key ] is None or sys_stat [key ][0 ] is None :
12589 sys_stat [key ] = None
126- else :
90+ elif isinstance ( stat_data [ dd ], torch . Tensor ) :
12791 sys_stat [key ] = torch .cat (sys_stat [key ], dim = 0 )
128- elif sys_stat [key ] is None :
129- pass
130-
131- dict_to_device (sys_stat )
132- return sys_stat
92+ dict_to_device (sys_stat )
93+ lst .append (sys_stat )
94+ return lst
13395
13496
13597def _restore_from_file (
0 commit comments