Skip to content

Commit 061fd04

Browse files
authored
Merge pull request #80 from p-lambda/dev
v1.2.2
2 parents 1d06a18 + 88ba842 commit 061fd04

9 files changed

Lines changed: 76 additions & 48 deletions

File tree

README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ pip install wilds
2929
If you have already installed it, please check that you have the latest version:
3030
```bash
3131
python -c "import wilds; print(wilds.__version__)"
32-
# This should print "1.2.1". If it doesn't, update by running:
32+
# This should print "1.2.2". If it doesn't, update by running:
3333
pip install -U wilds
3434
```
3535

@@ -50,7 +50,10 @@ pip install -e .
5050
- torch>=1.7.0
5151
- torch-scatter>=2.0.5
5252
- torch-geometric>=1.6.1
53+
- torchvision>=0.8.2
5354
- tqdm>=4.53.0
55+
- scikit-learn>=0.20.0
56+
- scipy>=1.5.4
5457

5558
Running `pip install wilds` or `pip install -e .` will automatically check for and install all of these requirements
5659
except for the `torch-scatter` and `torch-geometric` packages, which require a [quick manual install](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html#installation-via-binaries).
@@ -63,9 +66,8 @@ These scripts are not part of the installed WILDS package. To use them, you shou
6366
git clone git@github.com:p-lambda/wilds.git
6467
```
6568

66-
To run these scripts, you will need to install these additional dependencies:
69+
To run these scripts, you will also need to install this additional dependency:
6770

68-
- torchvision>=0.8.2
6971
- transformers>=3.5.0
7072

7173
All baseline experiments in the paper were run on Python 3.8.5 and CUDA 10.1.

examples/configs/utils.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
from configs.algorithm import algorithm_defaults
23
from configs.model import model_defaults
34
from configs.scheduler import scheduler_defaults
@@ -7,41 +8,44 @@
78
def populate_defaults(config):
89
"""Populates hyperparameters with defaults implied by choices
910
of other hyperparameters."""
11+
12+
orig_config = copy.deepcopy(config)
1013
assert config.dataset is not None, 'dataset must be specified'
1114
assert config.algorithm is not None, 'algorithm must be specified'
15+
1216
# implied defaults from choice of dataset
1317
config = populate_config(
14-
config,
18+
config,
1519
dataset_defaults[config.dataset]
1620
)
1721

1822
# implied defaults from choice of split
1923
if config.dataset in split_defaults and config.split_scheme in split_defaults[config.dataset]:
2024
config = populate_config(
21-
config,
25+
config,
2226
split_defaults[config.dataset][config.split_scheme]
2327
)
24-
28+
2529
# implied defaults from choice of algorithm
2630
config = populate_config(
27-
config,
31+
config,
2832
algorithm_defaults[config.algorithm]
2933
)
3034

3135
# implied defaults from choice of loader
3236
config = populate_config(
33-
config,
37+
config,
3438
loader_defaults
3539
)
3640
# implied defaults from choice of model
3741
if config.model: config = populate_config(
38-
config,
42+
config,
3943
model_defaults[config.model],
4044
)
41-
45+
4246
# implied defaults from choice of scheduler
4347
if config.scheduler: config = populate_config(
44-
config,
48+
config,
4549
scheduler_defaults[config.scheduler]
4650
)
4751

@@ -52,12 +56,22 @@ def populate_defaults(config):
5256

5357
# basic checks
5458
required_fields = [
55-
'split_scheme', 'train_loader', 'uniform_over_groups', 'batch_size', 'eval_loader', 'model', 'loss_function',
59+
'split_scheme', 'train_loader', 'uniform_over_groups', 'batch_size', 'eval_loader', 'model', 'loss_function',
5660
'val_metric', 'val_metric_decreasing', 'n_epochs', 'optimizer', 'lr', 'weight_decay',
57-
]
61+
]
5862
for field in required_fields:
5963
assert getattr(config, field) is not None, f"Must manually specify {field} for this setup."
6064

65+
# data loader validations
66+
# we only raise this error if the train_loader is standard, and
67+
# n_groups_per_batch or distinct_groups are
68+
# specified by the user (instead of populated as a default)
69+
if config.train_loader == 'standard':
70+
if orig_config.n_groups_per_batch is not None:
71+
raise ValueError("n_groups_per_batch cannot be specified if the data loader is 'standard'. Consider using a 'group' data loader instead.")
72+
if orig_config.distinct_groups is not None:
73+
raise ValueError("distinct_groups cannot be specified if the data loader is 'standard'. Consider using a 'group' data loader instead.")
74+
6175
return config
6276

6377
def populate_config(config, template: dict, force_compatibility=False):
@@ -78,7 +92,7 @@ def populate_config(config, template: dict, force_compatibility=False):
7892
d_config[key] = val
7993
elif d_config[key] != val and force_compatibility:
8094
raise ValueError(f"Argument {key} must be set to {val}")
81-
95+
8296
else: # config[key] expected to be a kwarg dict
8397
for kwargs_key, kwargs_val in val.items():
8498
if kwargs_key not in d_config[key] or d_config[key][kwargs_key] is None:

examples/models/detection/fasterrcnn.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from torchvision.models.utils import load_state_dict_from_url
2828
from torchvision.ops import misc as misc_nn_ops
2929
from torchvision.ops import MultiScaleRoIAlign
30-
from torchvision.models.detection import _utils as det_utils
3130
from torchvision.models.detection.anchor_utils import AnchorGenerator
3231
from torchvision.models.detection.generalized_rcnn import GeneralizedRCNN
3332
from torchvision.models.detection.faster_rcnn import TwoMLPHead
@@ -127,11 +126,11 @@ def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets)
127126
sampled_neg_inds = torch.where(torch.cat(sampled_neg_inds, dim=0))[0]
128127
sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)
129128

130-
box_loss.append(det_utils.smooth_l1_loss(
129+
box_loss.append(F.smooth_l1_loss(
131130
pred_bbox_deltas_[sampled_pos_inds],
132131
regression_targets_[sampled_pos_inds],
133132
beta=1 / 9,
134-
size_average=False,
133+
reduction='sum',
135134
) / (sampled_inds.numel()))
136135

137136
objectness_loss.append(F.binary_cross_entropy_with_logits(
@@ -226,11 +225,11 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
226225

227226
box_regression_ = box_regression_.reshape(N, -1, 4)
228227

229-
box_loss_ = det_utils.smooth_l1_loss(
228+
box_loss_ = F.smooth_l1_loss(
230229
box_regression_[sampled_pos_inds_subset, labels_pos],
231230
regression_targets_[sampled_pos_inds_subset],
232231
beta=1 / 9,
233-
size_average=False,
232+
reduction='sum',
234233
)
235234
box_loss.append(box_loss_ / labels_.numel())
236235

examples/run_expt.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,9 @@ def main():
155155
split_scheme=config.split_scheme,
156156
**config.dataset_kwargs)
157157

158-
# To implement data augmentation (i.e., have different transforms
159-
# at training time vs. test time), modify these two lines:
158+
# To modify data augmentation, modify the following code block.
159+
# If you want to use transforms that modify both `x` and `y`,
160+
# set `do_transform_y` to True when initializing the `WILDSSubset` below.
160161
train_transform = initialize_transform(
161162
transform_name=config.transform,
162163
config=config,

examples/transforms.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77

88
def initialize_transform(transform_name, config, dataset, is_training):
99
"""
10-
Transforms should take in a single (x, y)
11-
and return (transformed_x, transformed_y).
10+
By default, transforms should take in `x` and return `transformed_x`.
11+
For transforms that take in `(x, y)` and return `(transformed_x, transformed_y)`,
12+
set `do_transform_y` to True when initializing the WILDSSubset.
1213
"""
1314
if transform_name is None:
1415
return None
@@ -25,11 +26,6 @@ def initialize_transform(transform_name, config, dataset, is_training):
2526
else:
2627
raise ValueError(f"{transform_name} not recognized")
2728

28-
def transform_input_only(input_transform):
29-
def transform(x, y):
30-
return input_transform(x), y
31-
return transform
32-
3329
def initialize_bert_transform(config):
3430
assert 'bert' in config.model
3531
assert config.max_token_length is not None
@@ -55,7 +51,7 @@ def transform(text):
5551
dim=2)
5652
x = torch.squeeze(x, dim=0) # First shape dim is always 1
5753
return x
58-
return transform_input_only(transform)
54+
return transform
5955

6056
def getBertTokenizer(model):
6157
if model == 'bert-base-uncased':
@@ -79,7 +75,7 @@ def initialize_image_base_transform(config, dataset):
7975
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
8076
]
8177
transform = transforms.Compose(transform_steps)
82-
return transform_input_only(transform)
78+
return transform
8379

8480
def initialize_image_resize_and_center_crop_transform(config, dataset):
8581
"""
@@ -98,7 +94,7 @@ def initialize_image_resize_and_center_crop_transform(config, dataset):
9894
transforms.ToTensor(),
9995
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
10096
])
101-
return transform_input_only(transform)
97+
return transform
10298

10399
def initialize_poverty_transform(is_training):
104100
if is_training:
@@ -115,7 +111,7 @@ def transform_rgb(img):
115111
img[:3] = rgb_transform(img[:3][[2,1,0]])[[2,1,0]]
116112
return img
117113
transform = transforms.Lambda(lambda x: transform_rgb(x))
118-
return transform_input_only(transform)
114+
return transform
119115
else:
120116
return None
121117

@@ -148,4 +144,4 @@ def random_rotation(x: torch.Tensor) -> torch.Tensor:
148144
t_standardize,
149145
]
150146
transform = transforms.Compose(transforms_ls)
151-
return transform_input_only(transform)
147+
return transform

setup.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,16 @@
2222
long_description_content_type="text/markdown",
2323
install_requires = [
2424
'numpy>=1.19.1',
25+
'ogb>=1.2.6',
26+
'outdated>=0.2.0',
2527
'pandas>=1.1.0',
26-
'scikit-learn>=0.20.0',
2728
'pillow>=7.2.0',
29+
'pytz>=2020.4',
2830
'torch>=1.7.0',
29-
'ogb>=1.2.6',
31+
'torchvision>=0.8.2',
3032
'tqdm>=4.53.0',
31-
'outdated>=0.2.0',
32-
'pytz>=2020.4',
33+
'scikit-learn>=0.20.0',
34+
'scipy>=1.5.4'
3335
],
3436
license='MIT',
3537
packages=setuptools.find_packages(exclude=['dataset_preprocessing', 'examples', 'examples.models', 'examples.models.bert']),

wilds/common/metrics/all_metrics.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1+
import numpy as np
12
import torch
23
import torch.nn as nn
4+
import torch.nn.functional as F
35
from torchvision.ops.boxes import box_iou
46
from torchvision.models.detection._utils import Matcher
57
from torchvision.ops import nms, box_convert
6-
import numpy as np
7-
import torch.nn.functional as F
88
from wilds.common.metrics.metric import Metric, ElementwiseMetric, MultiTaskMetric
99
from wilds.common.metrics.loss import ElementwiseLoss
1010
from wilds.common.utils import avg_over_groups, minimum, maximum, get_counts
@@ -243,12 +243,17 @@ def _accuracy(self, src_boxes,pred_boxes , iou_threshold):
243243
total_pred = len(pred_boxes)
244244
if total_gt > 0 and total_pred > 0:
245245
# Define the matcher and distance matrix based on iou
246-
matcher = Matcher(iou_threshold,iou_threshold,allow_low_quality_matches=False)
247-
match_quality_matrix = box_iou(src_boxes,pred_boxes)
246+
matcher = Matcher(
247+
iou_threshold,
248+
iou_threshold,
249+
allow_low_quality_matches=False)
250+
match_quality_matrix = box_iou(
251+
src_boxes,
252+
pred_boxes)
248253
results = matcher(match_quality_matrix)
249254
true_positive = torch.count_nonzero(results.unique() != -1)
250255
matched_elements = results[results > -1]
251-
#in Matcher, a pred element can be matched only twice
256+
# in Matcher, a pred element can be matched only twice
252257
false_positive = (
253258
torch.count_nonzero(results == -1) +
254259
(len(matched_elements) - len(matched_elements.unique()))

wilds/datasets/wilds_dataset.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -433,11 +433,16 @@ def standard_group_eval(metric, grouper, y_pred, y_true, metadata, aggregate=Tru
433433

434434

435435
class WILDSSubset(WILDSDataset):
436-
def __init__(self, dataset, indices, transform):
436+
def __init__(self, dataset, indices, transform, do_transform_y=False):
437437
"""
438-
This acts like torch.utils.data.Subset, but on WILDSDatasets.
439-
We pass in transform explicitly because it can potentially vary at
440-
training vs. test time, if we're using data augmentation.
438+
This acts like `torch.utils.data.Subset`, but on `WILDSDatasets`.
439+
We pass in `transform` (which is used for data augmentation) explicitly
440+
because it can potentially vary on the training vs. test subsets.
441+
442+
`do_transform_y` (bool): When this is false (the default),
443+
`self.transform ` acts only on `x`.
444+
Set this to true if `self.transform` should
445+
operate on `(x,y)` instead of just `x`.
441446
"""
442447
self.dataset = dataset
443448
self.indices = indices
@@ -449,11 +454,15 @@ def __init__(self, dataset, indices, transform):
449454
if hasattr(dataset, attr_name):
450455
setattr(self, attr_name, getattr(dataset, attr_name))
451456
self.transform = transform
457+
self.do_transform_y = do_transform_y
452458

453459
def __getitem__(self, idx):
454460
x, y, metadata = self.dataset[self.indices[idx]]
455461
if self.transform is not None:
456-
x, y = self.transform(x, y)
462+
if self.do_transform_y:
463+
x, y = self.transform(x, y)
464+
else:
465+
x = self.transform(x)
457466
return x, y, metadata
458467

459468
def __len__(self):

wilds/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55
from threading import Thread
66

7-
__version__ = '1.2.1'
7+
__version__ = '1.2.2'
88

99
try:
1010
os.environ['OUTDATED_IGNORE'] = '1'

0 commit comments

Comments
 (0)