Skip to content

Commit bf9ab57

Browse files
dbellicoso-bdaiexploy-bot
authored andcommitted
Allow tensor proxy to handle kwargs (#80)
### What change is being made Fix todo. ### Why this change is being made Missing feature. ### Tested Added unit test. GitOrigin-RevId: 7f7acb063249b5491dfd7ec2c10c6b47ee56832e
1 parent d1e2bad commit bf9ab57

2 files changed

Lines changed: 15 additions & 3 deletions

File tree

python/exploy/exporter/core/tensor_proxy.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,12 +159,13 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
159159
# Convert all `TensorProxy`` elements in `args` to `torch.Tensor` objects.
160160
new_args = args_to_tensor(args=args)
161161

162-
# todo: Convert TensorProxy to Tensor in kwargs.
162+
# Convert TensorProxy to Tensor in kwargs.
163+
new_kwargs = {k: args_to_tensor(args=(v,))[0] for k, v in kwargs.items()}
163164

164165
return (
165-
func.__wrapped__(*new_args, **kwargs)
166+
func.__wrapped__(*new_args, **new_kwargs)
166167
if hasattr(func, "__wrapped__")
167-
else func(*new_args, **kwargs)
168+
else func(*new_args, **new_kwargs)
168169
)
169170

170171
def __repr__(self):

python/exploy/exporter/core/tests/test_tensor_proxy.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,17 @@ def test_torch_cat(self):
366366
expected = torch.cat([tensor1, tensor2], dim=0)
367367
assert torch.equal(result, expected)
368368

369+
def test_torch_cat_with_kwargs(self):
370+
"""Test using TensorProxy with torch.cat using keyword arguments."""
371+
tensor1 = torch.rand((2, 3))
372+
tensor2 = torch.rand((2, 3))
373+
proxy1 = TensorProxy(tensor1, split_dim=0)
374+
proxy2 = TensorProxy(tensor2, split_dim=0)
375+
376+
result = torch.cat(tensors=[proxy1, proxy2], dim=0)
377+
expected = torch.cat([tensor1, tensor2], dim=0)
378+
assert torch.equal(result, expected)
379+
369380
def test_repr(self):
370381
"""Test string representation."""
371382
tensor = torch.rand((2, 3, 4))

0 commit comments

Comments
 (0)