Skip to content

Modernizing HEALPix UNet#1156

Open
yikwill wants to merge 23 commits into
ai2cm:mainfrom
yikwill:feature/hpx_unet_updates
Open

Modernizing HEALPix UNet#1156
yikwill wants to merge 23 commits into
ai2cm:mainfrom
yikwill:feature/hpx_unet_updates

Conversation

@yikwill
Copy link
Copy Markdown
Contributor

@yikwill yikwill commented May 11, 2026

This PR is a large refactor to the HEALPix model which brings it up to speed with the DLESyM codebase and also removes a lot of dead code relating to recurrence.

Changes:

  • Removes all code related to recurrence (e.g., unused recurrent blocks) as this was never implemented. We will leave adding recurrence to the model for a separate project.

  • Removes HEALPixRecUNet. The naming is confusing given the lack of recurrence, and its forward function tries to handle things that should be abstracted to the stepper like prognostic/diagnostic split and residual prediction. Instead, we add HEALPixUNet whose forward method is a simple pass of the encoder and decoder.

  • Moves padding layers to separate healpix_paddings.py file. Adds HEALPixPaddingIsolatitude padding layer to resolve HEALPix grid imprinting issues.

  • Deprecates enable_healpixpad option in favor of hpx_padding_mode.

  • Adds layers DealiasedDownsample and SmoothedInterpolateConv to address checkerboarding artifacts.

  • Moves all downsampling, upsampling, and conv blocks to healpix_blocks.py. Previously, the up/downsampling layers lived in healpix_activations.py for no clear reason.

  • Tests added

  • If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated

yikwill and others added 6 commits May 11, 2026 11:25
Subset of wy/healpix_updates: model/healpix modules, hpx registry, exports,
and registry tests only (no experiment configs or trainer timing changes).

Co-authored-by: Cursor <cursoragent@cursor.com>
Copy link
Copy Markdown
Contributor

@mcgibbon mcgibbon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just giving this a quick look with some comments/pointers.

A PR of this size isn't really something I can review for e.g. correctness, and I'd also want to do a quick agent review so it can look more thoroughly than I can at so many changes.

That would mean this code would be going in as kind of "unverified", until you show scientific results that demonstrate it's correct. This is perhaps the right path for this particular code, since the existing version wasn't "working" scientifically, but it's something to keep in mind. Once your model is indeed "working" and we have results we don't want to degrade, you'd need to separate changes out into individual PRs so a reviewer like me could understand that change and how it's tested.

Comment thread fme/ace/models/healpix/healpix_decoder.py Outdated
hpx_padding_mode=hpx_padding_mode,
nside=face_nside,
)
conv_module = conv_cfg.build()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't do it in this PR, too much going on, but generally speaking it would be nice to move "build" type arguments out of the dataclass, and instead as arguments to .build. Specifically, the number of input and output channels comes to mind. Anything that you wouldn't want the user to configure in a yaml file, but should instead be determined at build time from other variables in the runtime.

The fact that you have to .replace attributes of something we've specified in a yaml is a bit scary in this respect, those things should probably not be configurable rather than being silently ignored.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still an important issue. I am specifically in the mode of reviewing this PR with few/no required changes so that these very large diffs can get merged in quickly (though it has been two weeks this PR has been open), with these issues getting fixed after.

The new (optional) comments are nice-to-haves that would make the code easier to configure, but this comment is a "I smell danger, the code may do something unexpected and undetectable" type of issue.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. I'll address this issue in another PR along with all your (optional) comments I left open.

Comment thread fme/ace/models/healpix/healpix_encoder.py Outdated
Comment thread fme/ace/models/healpix/healpix_encoder.py Outdated
Comment thread fme/ace/models/healpix/healpix_encoder.py
Comment thread fme/ace/models/healpix/healpix_encoder.py Outdated
Comment thread fme/ace/registry/hpx.py Outdated
Comment thread fme/ace/test_train.py
Comment thread fme/ace/test_train.py Outdated

with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".yaml") as f_train:
f_train.write(yaml.dump(dataclasses.asdict(train_config)))
f_train.write(yaml.safe_dump(dataclasses.asdict(train_config)))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we indeed safe_dump these? I thought there was a reason we didn't. In any case, please separate out things like this into their own PRs to make sure whoever looks at it understands what it's doing, I could easily have missed this and it has implications for the rest of the codebase.

If safe_dump indeed does work here, we might want to use safe_dump throughout more of our code. However, we might still decide to use dump to avoid authors of later changes needing to keep this safe (in particular, if we can't dump checkpoints safely, is there any point to dumping/loading anything safely?).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed this to be consistent with the use of safe_load in other parts of this file. However, I don't fully understand the difference between dump and safe_dump, so I agree that this should be its own PR if safe_dump works. Reverting these to dump for now.

Comment thread fme/ace/test_train.py
if derived_forcings is None:
derived_forcings = DerivedForcingsConfig()
if nettype == "HEALPixRecUNet":
if nettype == "HEALPixUNet":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment: We're breaking the ability for any existing healpix checkpoints to be loaded, which I think is fine.

@mcgibbon
Copy link
Copy Markdown
Contributor

Tests appear to be failing due to earth2grid missing, which I didn't expect to see outside of the "no healpix" tests. They also require CUDA which makes the CPU tests fail, but even GPU is failing. Also, make sure to install pre-commit hooks (pre-commit install) so they run before you commit.

@yikwill yikwill marked this pull request as ready for review May 26, 2026 17:33
@yikwill
Copy link
Copy Markdown
Contributor Author

yikwill commented May 26, 2026

Thanks for the comments. I've addressed them and the PR is ready for re-review. I think the tests are failing because the GitHub workflow is outdated for HEALPix. Looking at the workflow file I see uv pip install --no-build-isolation -c constraints.txt -e .[dev,healpix,graphcast] but healpix is not an optional dependency in pyproject.toml. Matching the current Makefile, I think the workflow should do

uv pip install -c constraints.txt -e .[dev,docs,graphcast]
uv pip install --no-build-isolation -c constraints.txt -r requirements-healpix.txt

Not sure if this is something I can change since I'm not a maintainer.

@mcgibbon
Copy link
Copy Markdown
Contributor

Thanks for the comments. I've addressed them and the PR is ready for re-review. I think the tests are failing because the GitHub workflow is outdated for HEALPix. Looking at the workflow file I see uv pip install --no-build-isolation -c constraints.txt -e .[dev,healpix,graphcast] but healpix is not an optional dependency in pyproject.toml. Matching the current Makefile, I think the workflow should do

uv pip install -c constraints.txt -e .[dev,docs,graphcast] uv pip install --no-build-isolation -c constraints.txt -r requirements-healpix.txt

Hm yes I think the issue was that extras don't let you specify a commit from a certain git repo as a requirement. I agree about your suggestion for how to update the github CI yaml for the workflow to work. It seems like perhaps we weren't actually using earth2grid, before your PR, or at least there was a fallback.

Not sure if this is something I can change since I'm not a maintainer.

You're actually the only person who can change it, as owner of the PR (unless you've given permission to push to your repo, but it's still cleaner if you do it). Just go into the .github folder in your branch and make the necessary changes to the CI workflows. Then when we run your PR, it will run the updated workflows.

Pre-commit also still has many failures, make sure you've installed the hooks and do a pre-commit run --all-files locally.

Copy link
Copy Markdown
Contributor

@mcgibbon mcgibbon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can fix the nit: if you like, but the Suggestion (optional)'s should be left for another PR so we can get these changes merged.

Assuming tests are passing.

enable_nhwc: bool = False
enable_healpixpad: bool = False
block_type: Literal["ConvGRUBlock", "ConvLSTMBlock"] = "ConvGRUBlock"
hpx_padding_mode: Literal["earth2grid", "karlbauer", "isolatitude"] = "earth2grid"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion (optional): Some of these settings feel like they might be "global" settings that you're needing to duplicate on each configuration class. For example, you probably want to use the same hpx_padding_mode on all layers always. You could configure this as a higher level, and have it be passed as an argument to .build instead, as a way of sharing the configuration across levels/layers.

kernel_size: int = 1
block_type: Literal["MaxPool", "AvgPool", "DealiasedDownsample"]
pooling: int = 2
enable_nhwc: bool = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion (optional): Some of these settings like enable_nhwc and stride feel like they aren't meant to be changed. We hard-code the memory formats of our inputs, and you don't have any options to change/permute them, so enable_nhwc could probably be removed. stride feels like something determined at code-time rather than configuration time, or at best you probably want pooling == stride.

Comment on lines +142 to +160
block_type: Literal[
"TransposedConvUpsample",
"SmoothedInterpolateConv",
"Interpolate",
]
in_channels: int = 3
out_channels: int = 1
stride: int = 2
kernel_size: int = 3
dilation: int = 1
upsample_mode: str = "nearest"
activation: Optional[CappedGELUConfig] = None
enable_nhwc: bool = False
hpx_padding_mode: Literal["earth2grid", "karlbauer", "isolatitude"] = "earth2grid"
nside: Optional[int] = None
nside_after: Optional[int] = None
align_corners: bool = False
scale_factor: Optional[int] = None
mode: Optional[str] = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue (optional): There are several configuration options that only apply to specific configurations, and are ignored otherwise. You could make it easier to reason about these groups of settings if you split this configuration into multiple classes, each with its own specific e.g. block_type: Literal["Interpolate"] input (that is required, with no default and only one valid value) that dacite uses to select the correct configuration during load time.

This same pattern applies to some of the other configuration classes as well.

Comment thread fme/ace/models/healpix/healpix_blocks.py Outdated
hpx_padding_mode=hpx_padding_mode,
nside=face_nside,
)
conv_module = conv_cfg.build()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still an important issue. I am specifically in the mode of reviewing this PR with few/no required changes so that these very large diffs can get merged in quickly (though it has been two weeks this PR has been open), with these issues getting fixed after.

The new (optional) comments are nice-to-haves that would make the code easier to configure, but this comment is a "I smell danger, the code may do something unexpected and undetectable" type of issue.

Comment on lines -19 to -37
This file contains padding and convolution classes to perform according operations on the twelve faces of the HEALPix.


HEALPix Face order 3D array representation
-----------------
-------------------------- //\\ //\\ //\\ //\\ | | | | |
|| 0 | 1 | 2 | 3 || // \\// \\// \\// \\ |0 |1 |2 |3 |
|\\ //\\ //\\ //\\ //| /\\0 //\\1 //\\2 //\\3 // -----------------
| \\// \\// \\// \\// | // \\// \\// \\// \\// | | | | |
|4//\\5 //\\6 //\\7 //\\4| \\4//\\5 //\\6 //\\7 //\\ |4 |5 |6 |7 |
|// \\// \\// \\// \\| \\/ \\// \\// \\// \\ -----------------
|| 8 | 9 | 10 | 11 | \\8 //\\9 //\\10//\\11// | | | | |
-------------------------- \\// \\// \\// \\// |8 |9 |10 |11 |
-----------------
"\\" are top and bottom, whereas
"//" are left and right borders


Details on the HEALPix can be found at https://iopscience.iop.org/article/10.1086/427976
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we changing the layout we're using? If not, we should keep this in a comment somewhere (maybe here). If so, what are we moving to, and how is the layout controlled?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not changing the layout, just moved this comment to healpix_paddings.py. I feel like this comment is most helpful there, but I'm happy to keep it here if you think otherwise.



class HEALPixPadding(nn.Module):
class HEALPixLayer(th.nn.Module):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion (optional): Revert replacements of nn. with th.nn throughout this file.

from collections.abc import Sequence
from typing import Literal

import torch as th
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a pre-existing issue in the healpix code so I'm not going to ask you to fix it in this PR, but I always read th as torch-harmonics, it would be nice to use this as "torch" like we do in the rest of the repo.

Comment thread fme/ace/models/healpix/healpix_blocks.py Outdated
Comment thread fme/ace/models/healpix/healpix_layers.py Outdated
self.scale_factor = scale_factor
self.mode = mode
self.trim_size = trim_size
self.interp = th.nn.functional.interpolate
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion (optional): Unusual to assign a function to an attribute so it can be called in one place (I only detected the call in forward) and never overridden - you could instead call th.nn.functional.interpolate directly where you're currently calling self.interp.

@mcgibbon
Copy link
Copy Markdown
Contributor

Some additional comments from AI review above which you should consider fixing, but I think are better to do after you get tests passing and merge in the current code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants