Skip to content

Add support for einsum operation to pytorch parser (requires 1116)#1273

Merged
calad0i merged 3 commits into
fastmachinelearning:mainfrom
JanFSchulte:einsum-torch
Jun 3, 2025
Merged

Add support for einsum operation to pytorch parser (requires 1116)#1273
calad0i merged 3 commits into
fastmachinelearning:mainfrom
JanFSchulte:einsum-torch

Conversation

@JanFSchulte

Copy link
Copy Markdown
Contributor

Builds on Chang's keras v3 PR #1116 and exposes the new einsum implementation through the pytorch parser. pytorch doesn't have an an equivalent to EinsumDense but allowing the use of einsum operations in some custom model would still be useful.

Type of change

  • New feature (non-breaking change which adds functionality)

Tests

Added 2 use cases (outer product and batch matrix multiplication to the pytests, works without issues.

Checklist

  • I have read the guidelines for contributing.
  • I have commented my code, particularly in hard-to-understand areas.
  • I have made corresponding changes to the documentation.
  • My changes generate no new warnings.
  • I have installed and run pre-commit on the files I edited or added.
  • I have added tests that prove my fix is effective or that my feature works.

@JanFSchulte JanFSchulte added the please test Trigger testing by creating local PR branch label Apr 15, 2025
@JanFSchulte

Copy link
Copy Markdown
Contributor Author

Now that #1116 is merged, this can also be added.

@JanFSchulte JanFSchulte added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Jun 3, 2025
Comment thread hls4ml/converters/pytorch/core.py Outdated

a = torch.randn(input_shapes_tmp[0])
b = torch.randn(input_shapes_tmp[1])
layer['out_shape'] = tuple(torch.einsum(layer['equation'], a, b).shape)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is preferred to use hls4ml.utils.einsum_utils._validate_einsum_expr here, something like:

inp_shape0 = input_shapes_tmp[0][1:]
inp_shape1 = input_shapes_tmp[1][1:]

layer['equation'], layer['out_shape'] = _validate_einsum_expr(node.args[0], inp_shape0, inp_shape1

In this way wildcard dimensions (...) will be fixed to indices and explicit torch call will not be required.
Otherwise all looks good.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion, implemented.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, actually I didn't expect the batch dimension when implementing this, and I explicitly striped off the batch dim in the keras impl. However, indeed setting batch dimension = 1 is functionally equivalent and should always work.

Merging after test pass.

@calad0i calad0i added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Jun 3, 2025
@JanFSchulte JanFSchulte added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Jun 3, 2025
@calad0i calad0i merged commit 6cdf842 into fastmachinelearning:main Jun 3, 2025
6 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

please test Trigger testing by creating local PR branch

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants