Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/vector/backends/awkward_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import numpy

from vector._methods import _repr_momentum_to_generic


def _recname(is_momentum: bool, dimension: int) -> str:
name = "Momentum" if is_momentum else "Vector"
Expand Down Expand Up @@ -199,6 +201,20 @@ def _check_names(
if dimension == 0:
raise TypeError(complaint1 if is_momentum else complaint2)

# Check if any remaining fieldnames would conflict with already-processed coordinates
# or with each other when mapped to generic names (e.g., "x" and "px" both map to "x")
if fieldnames:
# Check leftovers against already-processed coordinates
for fname in fieldnames:
generic = _repr_momentum_to_generic.get(fname, fname)
if generic in names:
raise TypeError(complaint1 if is_momentum else complaint2)

# Check leftovers against each other for duplicates
leftover_generics = [_repr_momentum_to_generic.get(x, x) for x in fieldnames]
if len(leftover_generics) != len(set(leftover_generics)):
raise TypeError(complaint1 if is_momentum else complaint2)

for name in fieldnames:
names.append(name)
columns.append(projectable[name])
Expand Down
165 changes: 165 additions & 0 deletions src/vector/backends/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2071,6 +2071,168 @@ def __setitem__(self, where: typing.Any, what: typing.Any) -> None:
return _setitem(self, where, what, True)


def _validate_numpy_coordinates(fieldnames: tuple[str, ...]) -> None:
"""
Validate coordinate field names using dimension-guard pattern.

This follows the same logic as _check_names in awkward_constructors to ensure
consistent validation across backends.

Raises TypeError if duplicate or conflicting coordinates are detected.
"""
complaint1 = "duplicate coordinates (through momentum-aliases): " + ", ".join(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a picky comment but it's a bit sub-optimal that we duplicate these "complaint strings" in several submodules. For better maintainability it's probably best to move them to a trival submodule and import the strings in the several places, such as here but also in awkward_constructions.py for example.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a suggestion regarding where to put these complaint strings?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe @Saransh-cpp ?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can put it in _methods.py as all the common methods/base classes and protocols live there. This exact string is used at several other places as well, so those can be refactored too (perhaps in another PR if this PR gets too big).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, all "types" of vector constructors and backend literally validate the coordinates in different places each with their own implementations. I'd rather not do this here and just a minimal working case and do a whole refactoring of validation in a separate PR. obj numpy, awkward, numba, sympy vectors all validate in different places with slightly different mechanisms each. We need a general refactoring of the coordinate presence validation that is shared amongst all of those types of vectors.

repr(x) for x in fieldnames
)
complaint2 = (
"unrecognized combination of coordinates, allowed combinations are:\n\n"
" (2D) x= y=\n"
" (2D) rho= phi=\n"
" (3D) x= y= z=\n"
" (3D) x= y= theta=\n"
" (3D) x= y= eta=\n"
" (3D) rho= phi= z=\n"
" (3D) rho= phi= theta=\n"
" (3D) rho= phi= eta=\n"
" (4D) x= y= z= t=\n"
" (4D) x= y= z= tau=\n"
" (4D) x= y= theta= t=\n"
" (4D) x= y= theta= tau=\n"
" (4D) x= y= eta= t=\n"
" (4D) x= y= eta= tau=\n"
" (4D) rho= phi= z= t=\n"
" (4D) rho= phi= z= tau=\n"
" (4D) rho= phi= theta= t=\n"
" (4D) rho= phi= theta= tau=\n"
" (4D) rho= phi= eta= t=\n"
" (4D) rho= phi= eta= tau="
)

is_momentum = False
dimension = 0
fieldnames_copy = list(fieldnames)

# 2D azimuthal coordinates
if "x" in fieldnames_copy and "y" in fieldnames_copy:
if dimension != 0:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 2
fieldnames_copy.remove("x")
fieldnames_copy.remove("y")
if "rho" in fieldnames_copy and "phi" in fieldnames_copy:
if dimension != 0:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 2
fieldnames_copy.remove("rho")
fieldnames_copy.remove("phi")
if "x" in fieldnames_copy and "py" in fieldnames_copy:
is_momentum = True
if dimension != 0:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 2
fieldnames_copy.remove("x")
fieldnames_copy.remove("py")
if "px" in fieldnames_copy and "y" in fieldnames_copy:
is_momentum = True
if dimension != 0:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 2
fieldnames_copy.remove("px")
fieldnames_copy.remove("y")
if "px" in fieldnames_copy and "py" in fieldnames_copy:
is_momentum = True
if dimension != 0:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 2
fieldnames_copy.remove("px")
fieldnames_copy.remove("py")
if "pt" in fieldnames_copy and "phi" in fieldnames_copy:
is_momentum = True
if dimension != 0:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 2
fieldnames_copy.remove("pt")
fieldnames_copy.remove("phi")

# 3D longitudinal coordinates
if "z" in fieldnames_copy:
if dimension != 2:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 3
fieldnames_copy.remove("z")
if "theta" in fieldnames_copy:
if dimension != 2:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 3
fieldnames_copy.remove("theta")
if "eta" in fieldnames_copy:
if dimension != 2:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 3
fieldnames_copy.remove("eta")
if "pz" in fieldnames_copy:
is_momentum = True
if dimension != 2:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 3
fieldnames_copy.remove("pz")

# 4D temporal coordinates
if "t" in fieldnames_copy:
if dimension != 3:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 4
fieldnames_copy.remove("t")
if "tau" in fieldnames_copy:
if dimension != 3:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 4
fieldnames_copy.remove("tau")
if "E" in fieldnames_copy:
is_momentum = True
if dimension != 3:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 4
fieldnames_copy.remove("E")
if "e" in fieldnames_copy:
is_momentum = True
if dimension != 3:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 4
fieldnames_copy.remove("e")
if "energy" in fieldnames_copy:
is_momentum = True
if dimension != 3:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 4
fieldnames_copy.remove("energy")
if "M" in fieldnames_copy:
is_momentum = True
if dimension != 3:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 4
fieldnames_copy.remove("M")
if "m" in fieldnames_copy:
is_momentum = True
if dimension != 3:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 4
fieldnames_copy.remove("m")
if "mass" in fieldnames_copy:
is_momentum = True
if dimension != 3:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 4
fieldnames_copy.remove("mass")

# Check if any remaining fieldnames would conflict with already-processed coordinates
# when mapped to generic names (e.g., pt was processed, rho shouldn't remain)
if fieldnames_copy:
# Map all original fieldnames to generic names to detect conflicts
generic_names = [_repr_momentum_to_generic.get(x, x) for x in fieldnames]
if len(generic_names) != len(set(generic_names)):
raise TypeError(complaint1 if is_momentum else complaint2)


def array(*args: typing.Any, **kwargs: typing.Any) -> VectorNumpy:
"""
Constructs a NumPy array of vectors, whose type is determined by the dtype
Expand Down Expand Up @@ -2138,6 +2300,9 @@ def array(*args: typing.Any, **kwargs: typing.Any) -> VectorNumpy:

is_momentum = any(x in _repr_momentum_to_generic for x in names)

# Validate coordinates using dimension-guard pattern (same as awkward _check_names)
_validate_numpy_coordinates(names)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vector.array is just a wrapper around individual constructors of Vector/MomentumNumpy*D, which can be used to construct vectors (unlike the Awkward backend). Hence, it would be better if we move this check to the __array_finalize__ method of each class:

def __array_finalize__(self, obj: typing.Any) -> None:

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I vaguely remember that I had some issue with __array_finalize__ regarding when is it being ran when I was making these edits but I will take a look again.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've moved calling this function inside the __array_finalize__ of the 2d, 3d, and 4d classes. The tests pass so I'll need to play with it a bit more in a terminal in case there's something I'm not remembering here.

Copy link
Copy Markdown
Contributor Author

@ikrommyd ikrommyd Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh..now I think I remember. Look at this

is_momentum = any(x in _repr_momentum_to_generic for x in names)
if any(x in ("t", "E", "e", "energy", "tau", "M", "m", "mass") for x in names):
cls = MomentumNumpy4D if is_momentum else VectorNumpy4D
elif any(x in ("z", "pz", "theta", "eta") for x in names):
cls = MomentumNumpy3D if is_momentum else VectorNumpy3D
else:
cls = MomentumNumpy2D if is_momentum else VectorNumpy2D
return cls(*args, **kwargs)

It makes a decision whether it is momentum or not and which class constructor to use based on the field names. That's why I think I wanted it in the vector.array implementation and not in the finalize. I think it's best to validate the fields first and then decide which constructor to use. I think it's best to move it back to where it was. I have reverted the change.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it needs to be in both places because classes like MomentumNumpy4D can be directly instantiated. And therefore this works fine and it shouldn't.

In [1]: import vector
   ...:
   ...: vec = vector.MomentumNumpy4D([(1.1, 2.1, 3.1, 4.1, 4.1), (1.2, 2.2, 3.2, 4.2, 4.2), (1.3, 2.3, 3.3, 4.3, 4.3), (1.4, 2.4, 3.4, 4.4, 4.4), (1.5, 2.5, 3.5, 4.5, 4.5)],
   ...:
   ...:               dtype=[('x', float), ('y', float), ('z', float), ('e', float), ('mass', float)])
   ...:
   ...: vec
Out[1]:
MomentumNumpy4D([(1.1, 2.1, 3.1, 4.1, 4.1), (1.2, 2.2, 3.2, 4.2, 4.2),
                 (1.3, 2.3, 3.3, 4.3, 4.3), (1.4, 2.4, 3.4, 4.4, 4.4),
                 (1.5, 2.5, 3.5, 4.5, 4.5)],
                dtype=[('x', '<f8'), ('y', '<f8'), ('z', '<f8'), ('t', '<f8'), ('tau', '<f8')])

So I think we need to validate in vector.array and in __array_finalize__ too because both are valid instantiation methods.


if any(x in ("t", "E", "e", "energy", "tau", "M", "m", "mass") for x in names):
cls = MomentumNumpy4D if is_momentum else VectorNumpy4D
elif any(x in ("z", "pz", "theta", "eta") for x in names):
Expand Down
4 changes: 2 additions & 2 deletions src/vector/backends/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -3211,7 +3211,7 @@ def obj(**coordinates: float) -> VectorObject:
if "E" in coordinates:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also have the and "t" not in generic_coordinates condition?

Copy link
Copy Markdown
Contributor Author

@ikrommyd ikrommyd Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this unnecessary? This is the first if condition where generic_coordinates might be populated with t. So and "t" not in generic_coordinates is a useless check no? Same goes for tau in your other comment below.

Copy link
Copy Markdown
Contributor Author

@ikrommyd ikrommyd Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can add it for "visual OCD reasons" but it's a check that's always going to be false right?

is_momentum = True
generic_coordinates["t"] = coordinates.pop("E")
if "e" in coordinates:
if "e" in coordinates and "t" not in generic_coordinates:
is_momentum = True
generic_coordinates["t"] = coordinates.pop("e")
if "energy" in coordinates and "t" not in generic_coordinates:
Expand All @@ -3220,7 +3220,7 @@ def obj(**coordinates: float) -> VectorObject:
if "M" in coordinates:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto, for tau

is_momentum = True
generic_coordinates["tau"] = coordinates.pop("M")
if "m" in coordinates:
if "m" in coordinates and "tau" not in generic_coordinates:
is_momentum = True
generic_coordinates["tau"] = coordinates.pop("m")
if "mass" in coordinates and "tau" not in generic_coordinates:
Expand Down
Loading