-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMyDataset.py
More file actions
85 lines (81 loc) · 3.2 KB
/
MyDataset.py
File metadata and controls
85 lines (81 loc) · 3.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
'''
Author: xiaoniu
Date: 2026-01-06 16:46:34
LastEditors: xiaoniu
LastEditTime: 2026-01-06 17:01:47
Description: Provides a custom dataset class for loading
'''
import numpy as np
import torch
import os
from torch.utils.data import Dataset as dataset
from torch.utils.data import DataLoader
import torch.nn.functional as F
import xarray as xr
import random,re
from wrappers import SRImplicitDownsampled
class MyDataset(dataset):
def __init__(self, data_dir, start_year, end_year, var_name ='t2m',get_name=False):
self.start_year = start_year
self.end_year = end_year
self.variable = var_name
self.get_name = get_name
self.scale_factor = 0.00145877124883646
self.add_offset = 270.095568832149
self.file_list = self._get_file_list()
print(f"Found {len(self.file_list)} files for years {self.start_year}-{self.end_year}")
self.mean,self.std = self._mean_std()
print(f"Mean: {self.mean}, Std: {self.std}")
def _get_file_list(self):
file_list = []
for year in range(self.start_year, self.end_year + 1):
year_dir = os.path.join(data_dir, str(year))
if not os.path.exists(year_dir):
continue
for root,dirs,files in os.walk(year_dir):
for file in files:
if file.endswith('.nc'):
file_list.append(os.path.join(root, file))
return file_list
def _mean_std(self, sample=False):
if sample:
sampled_files = random.sample(self.file_list, min(2000, len(self.file_list)))
else:
sampled_files = self.file_list
data_all = []
for file in sampled_files:
ds = xr.open_dataset(file)
data = ds[self.variable].values
data = data * self.scale_factor + self.add_offset
data_all.append(data)
ds.close()
data_all = np.array(data_all)
print('Data shape for mean/std calculation:', data_all.shape)
mean = np.mean(data_all)
std = np.std(data_all)
return torch.tensor(mean, dtype=torch.float32), torch.tensor(std, dtype=torch.float32)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
file_path = self.file_list[index]
ds = xr.open_dataset(file_path)
data = ds[self.variable].values
ds.close()
data = data * self.scale_factor + self.add_offset
data = torch.from_numpy(data).float()
data = (data - self.mean) / self.std
if self.get_name:
return data, file_path
else:
return data
def SR_dataset(data_dir,start_year,end_year,scale_max=8,scale_min=2,inp_size=90,sample_q=5000):
my_dataset = MyDataset(data_dir, start_year, end_year)
SR_dataset = SRImplicitDownsampled(dataset=my_dataset, scale_max=scale_max, scale_min=scale_min, inp_size=inp_size, sample_q=sample_q)
return SR_dataset
if __name__ == "__main__":
data_dir = '/path/to/your/data'
dataset = MyDataset(data_dir, start_year=2000, end_year=2020, var_name='t2m')
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2)
for batch in dataloader:
print(batch.shape)
break