diff --git a/experimental/shoshin/configs/waterbirds_resnet_config_reweighting.py b/experimental/shoshin/configs/waterbirds_resnet_config_reweighting.py index af8bbe12b..38098c30c 100644 --- a/experimental/shoshin/configs/waterbirds_resnet_config_reweighting.py +++ b/experimental/shoshin/configs/waterbirds_resnet_config_reweighting.py @@ -32,6 +32,7 @@ def get_config() -> ml_collections.ConfigDict: model.dropout_rate = 0.2 config.train_bias = False + config.train_weight_ensemble = True reweighting = config.reweighting reweighting.do_reweighting = True