-
Notifications
You must be signed in to change notification settings - Fork 37
feat: improve errors for invalid combinations of arguments in vector constructor methods #659
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
0f2c1c0
b04d5cf
96d9c7c
a500b6f
0c26342
b44c6b1
9ec0f9c
7a718f2
0b97590
d1edbe5
0fa491e
bfe3f0b
489dccf
a25bec3
368f655
840da50
f0b25eb
f1d3d94
75e5d54
84af22a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -2071,6 +2071,168 @@ def __setitem__(self, where: typing.Any, what: typing.Any) -> None: | |||||||||||||||||||||||
| return _setitem(self, where, what, True) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def _validate_numpy_coordinates(fieldnames: tuple[str, ...]) -> None: | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| Validate coordinate field names using dimension-guard pattern. | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| This follows the same logic as _check_names in awkward_constructors to ensure | ||||||||||||||||||||||||
| consistent validation across backends. | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Raises TypeError if duplicate or conflicting coordinates are detected. | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| complaint1 = "duplicate coordinates (through momentum-aliases): " + ", ".join( | ||||||||||||||||||||||||
| repr(x) for x in fieldnames | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| complaint2 = ( | ||||||||||||||||||||||||
| "unrecognized combination of coordinates, allowed combinations are:\n\n" | ||||||||||||||||||||||||
| " (2D) x= y=\n" | ||||||||||||||||||||||||
| " (2D) rho= phi=\n" | ||||||||||||||||||||||||
| " (3D) x= y= z=\n" | ||||||||||||||||||||||||
| " (3D) x= y= theta=\n" | ||||||||||||||||||||||||
| " (3D) x= y= eta=\n" | ||||||||||||||||||||||||
| " (3D) rho= phi= z=\n" | ||||||||||||||||||||||||
| " (3D) rho= phi= theta=\n" | ||||||||||||||||||||||||
| " (3D) rho= phi= eta=\n" | ||||||||||||||||||||||||
| " (4D) x= y= z= t=\n" | ||||||||||||||||||||||||
| " (4D) x= y= z= tau=\n" | ||||||||||||||||||||||||
| " (4D) x= y= theta= t=\n" | ||||||||||||||||||||||||
| " (4D) x= y= theta= tau=\n" | ||||||||||||||||||||||||
| " (4D) x= y= eta= t=\n" | ||||||||||||||||||||||||
| " (4D) x= y= eta= tau=\n" | ||||||||||||||||||||||||
| " (4D) rho= phi= z= t=\n" | ||||||||||||||||||||||||
| " (4D) rho= phi= z= tau=\n" | ||||||||||||||||||||||||
| " (4D) rho= phi= theta= t=\n" | ||||||||||||||||||||||||
| " (4D) rho= phi= theta= tau=\n" | ||||||||||||||||||||||||
| " (4D) rho= phi= eta= t=\n" | ||||||||||||||||||||||||
| " (4D) rho= phi= eta= tau=" | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| is_momentum = False | ||||||||||||||||||||||||
| dimension = 0 | ||||||||||||||||||||||||
| fieldnames_copy = list(fieldnames) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # 2D azimuthal coordinates | ||||||||||||||||||||||||
| if "x" in fieldnames_copy and "y" in fieldnames_copy: | ||||||||||||||||||||||||
| if dimension != 0: | ||||||||||||||||||||||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||||||||||||||||||||||
| dimension = 2 | ||||||||||||||||||||||||
| fieldnames_copy.remove("x") | ||||||||||||||||||||||||
| fieldnames_copy.remove("y") | ||||||||||||||||||||||||
| if "rho" in fieldnames_copy and "phi" in fieldnames_copy: | ||||||||||||||||||||||||
| if dimension != 0: | ||||||||||||||||||||||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||||||||||||||||||||||
| dimension = 2 | ||||||||||||||||||||||||
| fieldnames_copy.remove("rho") | ||||||||||||||||||||||||
| fieldnames_copy.remove("phi") | ||||||||||||||||||||||||
| if "x" in fieldnames_copy and "py" in fieldnames_copy: | ||||||||||||||||||||||||
| is_momentum = True | ||||||||||||||||||||||||
| if dimension != 0: | ||||||||||||||||||||||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||||||||||||||||||||||
| dimension = 2 | ||||||||||||||||||||||||
| fieldnames_copy.remove("x") | ||||||||||||||||||||||||
| fieldnames_copy.remove("py") | ||||||||||||||||||||||||
| if "px" in fieldnames_copy and "y" in fieldnames_copy: | ||||||||||||||||||||||||
| is_momentum = True | ||||||||||||||||||||||||
| if dimension != 0: | ||||||||||||||||||||||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||||||||||||||||||||||
| dimension = 2 | ||||||||||||||||||||||||
| fieldnames_copy.remove("px") | ||||||||||||||||||||||||
| fieldnames_copy.remove("y") | ||||||||||||||||||||||||
| if "px" in fieldnames_copy and "py" in fieldnames_copy: | ||||||||||||||||||||||||
| is_momentum = True | ||||||||||||||||||||||||
| if dimension != 0: | ||||||||||||||||||||||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||||||||||||||||||||||
| dimension = 2 | ||||||||||||||||||||||||
| fieldnames_copy.remove("px") | ||||||||||||||||||||||||
| fieldnames_copy.remove("py") | ||||||||||||||||||||||||
| if "pt" in fieldnames_copy and "phi" in fieldnames_copy: | ||||||||||||||||||||||||
| is_momentum = True | ||||||||||||||||||||||||
| if dimension != 0: | ||||||||||||||||||||||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||||||||||||||||||||||
| dimension = 2 | ||||||||||||||||||||||||
| fieldnames_copy.remove("pt") | ||||||||||||||||||||||||
| fieldnames_copy.remove("phi") | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # 3D longitudinal coordinates | ||||||||||||||||||||||||
| if "z" in fieldnames_copy: | ||||||||||||||||||||||||
| if dimension != 2: | ||||||||||||||||||||||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||||||||||||||||||||||
| dimension = 3 | ||||||||||||||||||||||||
| fieldnames_copy.remove("z") | ||||||||||||||||||||||||
| if "theta" in fieldnames_copy: | ||||||||||||||||||||||||
| if dimension != 2: | ||||||||||||||||||||||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||||||||||||||||||||||
| dimension = 3 | ||||||||||||||||||||||||
| fieldnames_copy.remove("theta") | ||||||||||||||||||||||||
| if "eta" in fieldnames_copy: | ||||||||||||||||||||||||
| if dimension != 2: | ||||||||||||||||||||||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||||||||||||||||||||||
| dimension = 3 | ||||||||||||||||||||||||
| fieldnames_copy.remove("eta") | ||||||||||||||||||||||||
| if "pz" in fieldnames_copy: | ||||||||||||||||||||||||
| is_momentum = True | ||||||||||||||||||||||||
| if dimension != 2: | ||||||||||||||||||||||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||||||||||||||||||||||
| dimension = 3 | ||||||||||||||||||||||||
| fieldnames_copy.remove("pz") | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # 4D temporal coordinates | ||||||||||||||||||||||||
| if "t" in fieldnames_copy: | ||||||||||||||||||||||||
| if dimension != 3: | ||||||||||||||||||||||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||||||||||||||||||||||
| dimension = 4 | ||||||||||||||||||||||||
| fieldnames_copy.remove("t") | ||||||||||||||||||||||||
| if "tau" in fieldnames_copy: | ||||||||||||||||||||||||
| if dimension != 3: | ||||||||||||||||||||||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||||||||||||||||||||||
| dimension = 4 | ||||||||||||||||||||||||
| fieldnames_copy.remove("tau") | ||||||||||||||||||||||||
| if "E" in fieldnames_copy: | ||||||||||||||||||||||||
| is_momentum = True | ||||||||||||||||||||||||
| if dimension != 3: | ||||||||||||||||||||||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||||||||||||||||||||||
| dimension = 4 | ||||||||||||||||||||||||
| fieldnames_copy.remove("E") | ||||||||||||||||||||||||
| if "e" in fieldnames_copy: | ||||||||||||||||||||||||
| is_momentum = True | ||||||||||||||||||||||||
| if dimension != 3: | ||||||||||||||||||||||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||||||||||||||||||||||
| dimension = 4 | ||||||||||||||||||||||||
| fieldnames_copy.remove("e") | ||||||||||||||||||||||||
| if "energy" in fieldnames_copy: | ||||||||||||||||||||||||
| is_momentum = True | ||||||||||||||||||||||||
| if dimension != 3: | ||||||||||||||||||||||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||||||||||||||||||||||
| dimension = 4 | ||||||||||||||||||||||||
| fieldnames_copy.remove("energy") | ||||||||||||||||||||||||
| if "M" in fieldnames_copy: | ||||||||||||||||||||||||
| is_momentum = True | ||||||||||||||||||||||||
| if dimension != 3: | ||||||||||||||||||||||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||||||||||||||||||||||
| dimension = 4 | ||||||||||||||||||||||||
| fieldnames_copy.remove("M") | ||||||||||||||||||||||||
| if "m" in fieldnames_copy: | ||||||||||||||||||||||||
| is_momentum = True | ||||||||||||||||||||||||
| if dimension != 3: | ||||||||||||||||||||||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||||||||||||||||||||||
| dimension = 4 | ||||||||||||||||||||||||
| fieldnames_copy.remove("m") | ||||||||||||||||||||||||
| if "mass" in fieldnames_copy: | ||||||||||||||||||||||||
| is_momentum = True | ||||||||||||||||||||||||
| if dimension != 3: | ||||||||||||||||||||||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||||||||||||||||||||||
| dimension = 4 | ||||||||||||||||||||||||
| fieldnames_copy.remove("mass") | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # Check if any remaining fieldnames would conflict with already-processed coordinates | ||||||||||||||||||||||||
| # when mapped to generic names (e.g., pt was processed, rho shouldn't remain) | ||||||||||||||||||||||||
| if fieldnames_copy: | ||||||||||||||||||||||||
| # Map all original fieldnames to generic names to detect conflicts | ||||||||||||||||||||||||
| generic_names = [_repr_momentum_to_generic.get(x, x) for x in fieldnames] | ||||||||||||||||||||||||
| if len(generic_names) != len(set(generic_names)): | ||||||||||||||||||||||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def array(*args: typing.Any, **kwargs: typing.Any) -> VectorNumpy: | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| Constructs a NumPy array of vectors, whose type is determined by the dtype | ||||||||||||||||||||||||
|
|
@@ -2138,6 +2300,9 @@ def array(*args: typing.Any, **kwargs: typing.Any) -> VectorNumpy: | |||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| is_momentum = any(x in _repr_momentum_to_generic for x in names) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # Validate coordinates using dimension-guard pattern (same as awkward _check_names) | ||||||||||||||||||||||||
| _validate_numpy_coordinates(names) | ||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
vector/src/vector/backends/numpy.py Line 1168 in e374915
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I vaguely remember that I had some issue with
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've moved calling this function inside the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh..now I think I remember. Look at this vector/src/vector/backends/numpy.py Lines 2160 to 2169 in 776f405
It makes a decision whether it is momentum or not and which class constructor to use based on the field names. That's why I think I wanted it in the vector.array implementation and not in the finalize. I think it's best to validate the fields first and then decide which constructor to use. I think it's best to move it back to where it was. I have reverted the change.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, it needs to be in both places because classes like So I think we need to validate in |
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| if any(x in ("t", "E", "e", "energy", "tau", "M", "m", "mass") for x in names): | ||||||||||||||||||||||||
| cls = MomentumNumpy4D if is_momentum else VectorNumpy4D | ||||||||||||||||||||||||
| elif any(x in ("z", "pz", "theta", "eta") for x in names): | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3211,7 +3211,7 @@ def obj(**coordinates: float) -> VectorObject: | |
| if "E" in coordinates: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this also have the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't this unnecessary? This is the first if condition where
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess we can add it for "visual OCD reasons" but it's a check that's always going to be false right? |
||
| is_momentum = True | ||
| generic_coordinates["t"] = coordinates.pop("E") | ||
| if "e" in coordinates: | ||
| if "e" in coordinates and "t" not in generic_coordinates: | ||
| is_momentum = True | ||
| generic_coordinates["t"] = coordinates.pop("e") | ||
| if "energy" in coordinates and "t" not in generic_coordinates: | ||
|
|
@@ -3220,7 +3220,7 @@ def obj(**coordinates: float) -> VectorObject: | |
| if "M" in coordinates: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto, for |
||
| is_momentum = True | ||
| generic_coordinates["tau"] = coordinates.pop("M") | ||
| if "m" in coordinates: | ||
| if "m" in coordinates and "tau" not in generic_coordinates: | ||
| is_momentum = True | ||
| generic_coordinates["tau"] = coordinates.pop("m") | ||
| if "mass" in coordinates and "tau" not in generic_coordinates: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a picky comment but it's a bit sub-optimal that we duplicate these "complaint strings" in several submodules. For better maintainability it's probably best to move them to a trival submodule and import the strings in the several places, such as here but also in awkward_constructions.py for example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a suggestion regarding where to put these complaint strings?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe @Saransh-cpp ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can put it in
_methods.pyas all the common methods/base classes and protocols live there. This exact string is used at several other places as well, so those can be refactored too (perhaps in another PR if this PR gets too big).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Honestly, all "types" of vector constructors and backend literally validate the coordinates in different places each with their own implementations. I'd rather not do this here and just a minimal working case and do a whole refactoring of validation in a separate PR. obj numpy, awkward, numba, sympy vectors all validate in different places with slightly different mechanisms each. We need a general refactoring of the coordinate presence validation that is shared amongst all of those types of vectors.