Modernizing HEALPix UNet#1156
Conversation
Subset of wy/healpix_updates: model/healpix modules, hpx registry, exports, and registry tests only (no experiment configs or trainer timing changes). Co-authored-by: Cursor <cursoragent@cursor.com>
…rences to recurrence
mcgibbon
left a comment
There was a problem hiding this comment.
Just giving this a quick look with some comments/pointers.
A PR of this size isn't really something I can review for e.g. correctness, and I'd also want to do a quick agent review so it can look more thoroughly than I can at so many changes.
That would mean this code would be going in as kind of "unverified", until you show scientific results that demonstrate it's correct. This is perhaps the right path for this particular code, since the existing version wasn't "working" scientifically, but it's something to keep in mind. Once your model is indeed "working" and we have results we don't want to degrade, you'd need to separate changes out into individual PRs so a reviewer like me could understand that change and how it's tested.
| hpx_padding_mode=hpx_padding_mode, | ||
| nside=face_nside, | ||
| ) | ||
| conv_module = conv_cfg.build() |
There was a problem hiding this comment.
Don't do it in this PR, too much going on, but generally speaking it would be nice to move "build" type arguments out of the dataclass, and instead as arguments to .build. Specifically, the number of input and output channels comes to mind. Anything that you wouldn't want the user to configure in a yaml file, but should instead be determined at build time from other variables in the runtime.
The fact that you have to .replace attributes of something we've specified in a yaml is a bit scary in this respect, those things should probably not be configurable rather than being silently ignored.
There was a problem hiding this comment.
This is still an important issue. I am specifically in the mode of reviewing this PR with few/no required changes so that these very large diffs can get merged in quickly (though it has been two weeks this PR has been open), with these issues getting fixed after.
The new (optional) comments are nice-to-haves that would make the code easier to configure, but this comment is a "I smell danger, the code may do something unexpected and undetectable" type of issue.
There was a problem hiding this comment.
Got it. I'll address this issue in another PR along with all your (optional) comments I left open.
|
|
||
| with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".yaml") as f_train: | ||
| f_train.write(yaml.dump(dataclasses.asdict(train_config))) | ||
| f_train.write(yaml.safe_dump(dataclasses.asdict(train_config))) |
There was a problem hiding this comment.
Can we indeed safe_dump these? I thought there was a reason we didn't. In any case, please separate out things like this into their own PRs to make sure whoever looks at it understands what it's doing, I could easily have missed this and it has implications for the rest of the codebase.
If safe_dump indeed does work here, we might want to use safe_dump throughout more of our code. However, we might still decide to use dump to avoid authors of later changes needing to keep this safe (in particular, if we can't dump checkpoints safely, is there any point to dumping/loading anything safely?).
There was a problem hiding this comment.
I changed this to be consistent with the use of safe_load in other parts of this file. However, I don't fully understand the difference between dump and safe_dump, so I agree that this should be its own PR if safe_dump works. Reverting these to dump for now.
| if derived_forcings is None: | ||
| derived_forcings = DerivedForcingsConfig() | ||
| if nettype == "HEALPixRecUNet": | ||
| if nettype == "HEALPixUNet": |
There was a problem hiding this comment.
Comment: We're breaking the ability for any existing healpix checkpoints to be loaded, which I think is fine.
|
Tests appear to be failing due to earth2grid missing, which I didn't expect to see outside of the "no healpix" tests. They also require CUDA which makes the CPU tests fail, but even GPU is failing. Also, make sure to install pre-commit hooks ( |
…sing errors with amp)
|
Thanks for the comments. I've addressed them and the PR is ready for re-review. I think the tests are failing because the GitHub workflow is outdated for HEALPix. Looking at the workflow file I see
Not sure if this is something I can change since I'm not a maintainer. |
Hm yes I think the issue was that extras don't let you specify a commit from a certain git repo as a requirement. I agree about your suggestion for how to update the github CI yaml for the workflow to work. It seems like perhaps we weren't actually using earth2grid, before your PR, or at least there was a fallback.
You're actually the only person who can change it, as owner of the PR (unless you've given permission to push to your repo, but it's still cleaner if you do it). Just go into the Pre-commit also still has many failures, make sure you've installed the hooks and do a |
mcgibbon
left a comment
There was a problem hiding this comment.
You can fix the nit: if you like, but the Suggestion (optional)'s should be left for another PR so we can get these changes merged.
Assuming tests are passing.
| enable_nhwc: bool = False | ||
| enable_healpixpad: bool = False | ||
| block_type: Literal["ConvGRUBlock", "ConvLSTMBlock"] = "ConvGRUBlock" | ||
| hpx_padding_mode: Literal["earth2grid", "karlbauer", "isolatitude"] = "earth2grid" |
There was a problem hiding this comment.
Suggestion (optional): Some of these settings feel like they might be "global" settings that you're needing to duplicate on each configuration class. For example, you probably want to use the same hpx_padding_mode on all layers always. You could configure this as a higher level, and have it be passed as an argument to .build instead, as a way of sharing the configuration across levels/layers.
| kernel_size: int = 1 | ||
| block_type: Literal["MaxPool", "AvgPool", "DealiasedDownsample"] | ||
| pooling: int = 2 | ||
| enable_nhwc: bool = False |
There was a problem hiding this comment.
Suggestion (optional): Some of these settings like enable_nhwc and stride feel like they aren't meant to be changed. We hard-code the memory formats of our inputs, and you don't have any options to change/permute them, so enable_nhwc could probably be removed. stride feels like something determined at code-time rather than configuration time, or at best you probably want pooling == stride.
| block_type: Literal[ | ||
| "TransposedConvUpsample", | ||
| "SmoothedInterpolateConv", | ||
| "Interpolate", | ||
| ] | ||
| in_channels: int = 3 | ||
| out_channels: int = 1 | ||
| stride: int = 2 | ||
| kernel_size: int = 3 | ||
| dilation: int = 1 | ||
| upsample_mode: str = "nearest" | ||
| activation: Optional[CappedGELUConfig] = None | ||
| enable_nhwc: bool = False | ||
| hpx_padding_mode: Literal["earth2grid", "karlbauer", "isolatitude"] = "earth2grid" | ||
| nside: Optional[int] = None | ||
| nside_after: Optional[int] = None | ||
| align_corners: bool = False | ||
| scale_factor: Optional[int] = None | ||
| mode: Optional[str] = None |
There was a problem hiding this comment.
Issue (optional): There are several configuration options that only apply to specific configurations, and are ignored otherwise. You could make it easier to reason about these groups of settings if you split this configuration into multiple classes, each with its own specific e.g. block_type: Literal["Interpolate"] input (that is required, with no default and only one valid value) that dacite uses to select the correct configuration during load time.
This same pattern applies to some of the other configuration classes as well.
| hpx_padding_mode=hpx_padding_mode, | ||
| nside=face_nside, | ||
| ) | ||
| conv_module = conv_cfg.build() |
There was a problem hiding this comment.
This is still an important issue. I am specifically in the mode of reviewing this PR with few/no required changes so that these very large diffs can get merged in quickly (though it has been two weeks this PR has been open), with these issues getting fixed after.
The new (optional) comments are nice-to-haves that would make the code easier to configure, but this comment is a "I smell danger, the code may do something unexpected and undetectable" type of issue.
| This file contains padding and convolution classes to perform according operations on the twelve faces of the HEALPix. | ||
|
|
||
|
|
||
| HEALPix Face order 3D array representation | ||
| ----------------- | ||
| -------------------------- //\\ //\\ //\\ //\\ | | | | | | ||
| || 0 | 1 | 2 | 3 || // \\// \\// \\// \\ |0 |1 |2 |3 | | ||
| |\\ //\\ //\\ //\\ //| /\\0 //\\1 //\\2 //\\3 // ----------------- | ||
| | \\// \\// \\// \\// | // \\// \\// \\// \\// | | | | | | ||
| |4//\\5 //\\6 //\\7 //\\4| \\4//\\5 //\\6 //\\7 //\\ |4 |5 |6 |7 | | ||
| |// \\// \\// \\// \\| \\/ \\// \\// \\// \\ ----------------- | ||
| || 8 | 9 | 10 | 11 | \\8 //\\9 //\\10//\\11// | | | | | | ||
| -------------------------- \\// \\// \\// \\// |8 |9 |10 |11 | | ||
| ----------------- | ||
| "\\" are top and bottom, whereas | ||
| "//" are left and right borders | ||
|
|
||
|
|
||
| Details on the HEALPix can be found at https://iopscience.iop.org/article/10.1086/427976 |
There was a problem hiding this comment.
Are we changing the layout we're using? If not, we should keep this in a comment somewhere (maybe here). If so, what are we moving to, and how is the layout controlled?
There was a problem hiding this comment.
Not changing the layout, just moved this comment to healpix_paddings.py. I feel like this comment is most helpful there, but I'm happy to keep it here if you think otherwise.
|
|
||
|
|
||
| class HEALPixPadding(nn.Module): | ||
| class HEALPixLayer(th.nn.Module): |
There was a problem hiding this comment.
Suggestion (optional): Revert replacements of nn. with th.nn throughout this file.
| from collections.abc import Sequence | ||
| from typing import Literal | ||
|
|
||
| import torch as th |
There was a problem hiding this comment.
It's a pre-existing issue in the healpix code so I'm not going to ask you to fix it in this PR, but I always read th as torch-harmonics, it would be nice to use this as "torch" like we do in the rest of the repo.
| self.scale_factor = scale_factor | ||
| self.mode = mode | ||
| self.trim_size = trim_size | ||
| self.interp = th.nn.functional.interpolate |
There was a problem hiding this comment.
Suggestion (optional): Unusual to assign a function to an attribute so it can be called in one place (I only detected the call in forward) and never overridden - you could instead call th.nn.functional.interpolate directly where you're currently calling self.interp.
|
Some additional comments from AI review above which you should consider fixing, but I think are better to do after you get tests passing and merge in the current code. |
This PR is a large refactor to the HEALPix model which brings it up to speed with the DLESyM codebase and also removes a lot of dead code relating to recurrence.
Changes:
Removes all code related to recurrence (e.g., unused recurrent blocks) as this was never implemented. We will leave adding recurrence to the model for a separate project.
Removes HEALPixRecUNet. The naming is confusing given the lack of recurrence, and its forward function tries to handle things that should be abstracted to the stepper like prognostic/diagnostic split and residual prediction. Instead, we add HEALPixUNet whose forward method is a simple pass of the encoder and decoder.
Moves padding layers to separate
healpix_paddings.pyfile. AddsHEALPixPaddingIsolatitudepadding layer to resolve HEALPix grid imprinting issues.Deprecates
enable_healpixpadoption in favor ofhpx_padding_mode.Adds layers
DealiasedDownsampleandSmoothedInterpolateConvto address checkerboarding artifacts.Moves all downsampling, upsampling, and conv blocks to
healpix_blocks.py. Previously, the up/downsampling layers lived inhealpix_activations.pyfor no clear reason.Tests added
If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated