88import h5py
99
1010
11- class UKDataset :
11+ class UKSingleYearDataset :
1212 person : pd .DataFrame
1313 benunit : pd .DataFrame
1414 household : pd .DataFrame
@@ -61,6 +61,7 @@ def __init__(
6161
6262 self .data_format = "arrays"
6363 self .tables = (self .person , self .benunit , self .household )
64+ self .table_names = ("person" , "benunit" , "household" )
6465
6566 def save (self , file_path : str ):
6667 with pd .HDFStore (file_path ) as f :
@@ -80,10 +81,11 @@ def load(self):
8081 return data
8182
8283 def copy (self ):
83- return UKDataset (
84+ return UKSingleYearDataset (
8485 person = self .person .copy (),
8586 benunit = self .benunit .copy (),
8687 household = self .household .copy (),
88+ fiscal_year = self .time_period ,
8789 )
8890
8991 def validate (self ):
@@ -110,9 +112,120 @@ def from_simulation(
110112 input_variables , period = fiscal_year
111113 )
112114
113- return UKDataset (
115+ return UKSingleYearDataset (
114116 person = entity_dfs ["person" ],
115117 benunit = entity_dfs ["benunit" ],
116118 household = entity_dfs ["household" ],
117119 fiscal_year = fiscal_year ,
118120 )
121+
122+
123+ class UKMultiYearDataset :
124+ def __init__ (
125+ self ,
126+ file_path : str = None ,
127+ datasets : list [UKSingleYearDataset ] | None = None ,
128+ ):
129+ if datasets is not None :
130+ self .datasets = {}
131+ for dataset in datasets :
132+ if not isinstance (dataset , UKSingleYearDataset ):
133+ raise TypeError (
134+ "All items in datasets must be of type UKSingleYearDataset."
135+ )
136+ year = int (dataset .time_period [:4 ])
137+ self .datasets [year ] = dataset
138+
139+ if file_path is not None :
140+ UKSingleYearDataset .validate_file_path (file_path )
141+ with pd .HDFStore (file_path ) as f :
142+ self .datasets = {}
143+ for year in f .keys ():
144+ if year .startswith ("/person/" ):
145+ fiscal_year = int (year .split ("/" )[2 ])
146+ person_df = f [year ]
147+ benunit_df = f [f"/benunit/{ fiscal_year } " ]
148+ household_df = f [f"/household/{ fiscal_year } " ]
149+ self .datasets [fiscal_year ] = UKSingleYearDataset (
150+ person = person_df ,
151+ benunit = benunit_df ,
152+ household = household_df ,
153+ fiscal_year = fiscal_year ,
154+ )
155+
156+ self .data_format = "time_period_arrays"
157+ self .time_period = list (sorted (self .datasets .keys ()))[0 ]
158+
159+ def get_year (self , fiscal_year : int ) -> UKSingleYearDataset :
160+ if fiscal_year in self .datasets :
161+ return self .datasets [fiscal_year ]
162+ else :
163+ raise ValueError (f"No dataset found for year { fiscal_year } ." )
164+
165+ def __getitem__ (self , fiscal_year : int ):
166+ return self .get_year (fiscal_year )
167+
168+ def save (self , file_path : str ):
169+ Path (file_path ).unlink (
170+ missing_ok = True
171+ ) # Remove existing file if it exists
172+ with pd .HDFStore (file_path ) as f :
173+ for year , dataset in self .datasets .items ():
174+ f .put (
175+ f"person/{ year } " ,
176+ dataset .person ,
177+ format = "table" ,
178+ data_columns = True ,
179+ )
180+ f .put (
181+ f"benunit/{ year } " ,
182+ dataset .benunit ,
183+ format = "table" ,
184+ data_columns = True ,
185+ )
186+ f .put (
187+ f"household/{ year } " ,
188+ dataset .household ,
189+ format = "table" ,
190+ data_columns = True ,
191+ )
192+ f .put (
193+ f"time_period/{ year } " ,
194+ pd .Series ([year ]),
195+ format = "table" ,
196+ data_columns = True ,
197+ )
198+
199+ def copy (self ):
200+ new_datasets = {
201+ year : dataset .copy () for year , dataset in self .datasets .items ()
202+ }
203+ return UKMultiYearDataset (datasets = list (new_datasets .values ()))
204+
205+ @staticmethod
206+ def validate_file_path (file_path : str ):
207+ if not file_path .endswith (".h5" ):
208+ raise ValueError (
209+ "File path must end with '.h5' for UKMultiYearDataset."
210+ )
211+ if not Path (file_path ).exists ():
212+ raise FileNotFoundError (f"File not found: { file_path } " )
213+
214+ # Check if the file contains datasets for multiple years
215+ with h5py .File (file_path , "r" ) as f :
216+ if not any (key .startswith ("/person/" ) for key in f .keys ()):
217+ raise ValueError ("No person dataset found in the file." )
218+ if not any (key .startswith ("/benunit/" ) for key in f .keys ()):
219+ raise ValueError ("No benunit dataset found in the file." )
220+ if not any (key .startswith ("/household/" ) for key in f .keys ()):
221+ raise ValueError ("No household dataset found in the file." )
222+
223+ def load (self ):
224+ data = {}
225+ for year , dataset in self .datasets .items ():
226+ for df in (dataset .person , dataset .benunit , dataset .household ):
227+ for col in df .columns :
228+ if col not in data :
229+ data [col ] = {}
230+ data [col ][year ] = df [col ].values
231+ return data
0 commit comments