|
6 | 6 |
|
7 | 7 | import logging |
8 | 8 | import operator |
| 9 | + |
| 10 | +from collections import deque |
9 | 11 | from typing import Any |
10 | 12 |
|
11 | 13 | import executorch.backends.vulkan.utils as utils |
@@ -332,81 +334,111 @@ def trace_node_users_to_constrain_repset( # noqa: C901 |
332 | 334 | search_depth: list[int] | None = None, |
333 | 335 | ) -> utils.TensorRepSet: |
334 | 336 | """ |
335 | | - For an ambiguous repset, try to constrain the repset by tracing the required |
336 | | - repsets of the users of `origin_node`. The idea is to try to find a representation |
337 | | - that can be used the longest without needing user nodes to insert a transition |
338 | | - for its arguments. |
| 337 | + BFS over downstream users to constrain an ambiguous repset. Explores all |
| 338 | + immediate users at each level before going deeper, so that nearby constrained |
| 339 | + ops (e.g. linear requiring width_packed) are discovered before the search |
| 340 | + budget is spent on a single deep branch. |
339 | 341 | """ |
340 | | - # Optionally limit the total number of nodes explored to improve export |
341 | | - # time. search_depth is a mutable list so that all branches of a fan-out |
342 | | - # share a single counter, preventing exponential blowup. |
343 | 342 | if self.max_trace_search_depth is not None: |
344 | 343 | if search_depth is None: |
345 | 344 | search_depth = [self.max_trace_search_depth] |
346 | | - search_depth[0] -= 1 |
347 | | - if search_depth[0] <= 0: |
| 345 | + |
| 346 | + queue: deque[torch.fx.Node] = deque() |
| 347 | + queue.append(origin_node) |
| 348 | + |
| 349 | + while queue: |
| 350 | + if repset.is_constrained(): |
348 | 351 | return repset |
349 | 352 |
|
350 | | - users_to_trace = origin_node.users |
| 353 | + if self.max_trace_search_depth is not None: |
| 354 | + search_depth[0] -= 1 |
| 355 | + if search_depth[0] <= 0: |
| 356 | + return repset |
| 357 | + |
| 358 | + node = queue.popleft() |
| 359 | + |
| 360 | + users_to_trace = node.users |
| 361 | + |
| 362 | + sync_outs_repr = True |
| 363 | + if self.is_valid_op_node(node): |
| 364 | + sync_outs_repr = self.get_node_cached_repsets(node).sync_outs_repr |
351 | 365 |
|
352 | | - sync_outs_repr = True |
353 | | - if self.is_valid_op_node(origin_node): |
354 | | - sync_outs_repr = self.get_node_cached_repsets(origin_node).sync_outs_repr |
| 366 | + if utils.num_tensors_in_node(node) > 1 and not sync_outs_repr: |
| 367 | + users_to_trace = [] |
| 368 | + for usage_node in node.users: |
| 369 | + if ( |
| 370 | + usage_node.target == operator.getitem |
| 371 | + and usage_node.args[1] == 1 |
| 372 | + ): |
| 373 | + users_to_trace.append(usage_node) |
355 | 374 |
|
356 | | - if utils.num_tensors_in_node(origin_node) > 1 and not sync_outs_repr: |
357 | | - users_to_trace = [] |
358 | | - for usage_node in origin_node.users: |
359 | | - if usage_node.target == operator.getitem and usage_node.args[1] == 1: |
360 | | - users_to_trace.append(usage_node) |
| 375 | + for usage_node in users_to_trace: |
| 376 | + if repset.is_constrained(): |
| 377 | + return repset |
361 | 378 |
|
362 | | - for usage_node in users_to_trace: |
363 | | - arg_i_in_user = None |
364 | | - for i in range(len(usage_node.args)): |
365 | | - if origin_node == usage_node.args[i]: |
366 | | - arg_i_in_user = i |
367 | | - break |
| 379 | + arg_i_in_user = None |
| 380 | + for i in range(len(usage_node.args)): |
| 381 | + if node == usage_node.args[i]: |
| 382 | + arg_i_in_user = i |
| 383 | + break |
368 | 384 |
|
369 | | - if arg_i_in_user is not None: |
370 | | - repset = self.constrain_repset_with_user( |
371 | | - usage_node, arg_i_in_user, repset, search_depth |
| 385 | + if arg_i_in_user is None: |
| 386 | + continue |
| 387 | + |
| 388 | + if not self.is_valid_op_node(usage_node): |
| 389 | + continue |
| 390 | + |
| 391 | + cur_node_repsets = self.get_node_cached_repsets(usage_node) |
| 392 | + req_arg_repset = cur_node_repsets.get_arg_repset(arg_i_in_user) |
| 393 | + |
| 394 | + if not req_arg_repset.any_in_common(repset): |
| 395 | + continue |
| 396 | + |
| 397 | + repset = repset.make_intersect(req_arg_repset) |
| 398 | + |
| 399 | + repset_propagates_to_output = ( |
| 400 | + cur_node_repsets.sync_primary_io_repr |
| 401 | + and ( |
| 402 | + cur_node_repsets.sync_args_repr |
| 403 | + or arg_i_in_user == cur_node_repsets.primary_arg_idx |
| 404 | + ) |
372 | 405 | ) |
373 | 406 |
|
374 | | - if repset.is_constrained(): |
375 | | - return repset |
| 407 | + if repset_propagates_to_output: |
| 408 | + queue.append(usage_node) |
376 | 409 |
|
377 | 410 | return repset |
378 | 411 |
|
379 | 412 | def constrain_op_arg_repset(self, arg_i: int, op_repsets: utils.OpRepSets) -> None: |
380 | 413 | """ |
381 | 414 | Attempts to constrain the repset of the argument at index `arg_i` of the op |
382 | | - associated with `op_repsets`. Does this with two stages: |
383 | | -
|
384 | | - 1. First, account for any existing representation that has already been determined |
385 | | - for the argument. If no existing representation has been determined, then use |
386 | | - the output repset of the operator that produces the argument. |
387 | | - 2. Then, try to trace through the users of the argument to find a representation |
388 | | - that can be used for as long as possible without needing a transition. |
| 415 | + associated with `op_repsets`. Prefers downstream consumers' layout requirements |
| 416 | + over the upstream source's existing layout, falling back to the source only when |
| 417 | + downstream tracing does not fully constrain the repset. |
389 | 418 | """ |
390 | | - # If forcing fp16, then try to use texture storage whenever possible. This is |
391 | | - # a temporary stopgap measure until all buffer implementations properly account |
392 | | - # for potential overflow of fp16 representation range when doing math in fp16. |
393 | 419 | if self.force_fp16: |
394 | 420 | op_repsets.try_constrain_with_arg_repset(arg_i, utils.ANY_TEXTURE) |
395 | 421 |
|
396 | | - arg_source_repset = self.get_arg_tensor_source_repset(op_repsets.op_node, arg_i) |
397 | | - op_repsets.try_constrain_with_arg_repset(arg_i, arg_source_repset) |
398 | | - |
399 | | - arg_repset = op_repsets.get_arg_repset(arg_i) |
400 | | - if arg_repset.is_constrained(): |
401 | | - return |
402 | | - |
| 422 | + # First, trace downstream users to discover what layout they prefer. |
403 | 423 | arg_node = op_repsets.op_node.args[arg_i] |
404 | | - |
405 | 424 | if isinstance(arg_node, list): |
406 | 425 | arg_node = arg_node[0] |
407 | 426 |
|
408 | | - arg_repset = self.trace_node_users_to_constrain_repset(arg_node, arg_repset) |
409 | | - op_repsets.try_constrain_with_arg_repset(arg_i, arg_repset) |
| 427 | + arg_repset = op_repsets.get_arg_repset(arg_i) |
| 428 | + if not arg_repset.is_constrained(): |
| 429 | + downstream_repset = self.trace_node_users_to_constrain_repset( |
| 430 | + arg_node, arg_repset |
| 431 | + ) |
| 432 | + op_repsets.try_constrain_with_arg_repset(arg_i, downstream_repset) |
| 433 | + |
| 434 | + # Fall back to the upstream source's existing layout only if downstream |
| 435 | + # tracing did not fully constrain the repset. |
| 436 | + arg_repset = op_repsets.get_arg_repset(arg_i) |
| 437 | + if not arg_repset.is_constrained(): |
| 438 | + arg_source_repset = self.get_arg_tensor_source_repset( |
| 439 | + op_repsets.op_node, arg_i |
| 440 | + ) |
| 441 | + op_repsets.try_constrain_with_arg_repset(arg_i, arg_source_repset) |
410 | 442 |
|
411 | 443 | def constrain_op_out_repset(self, op_repsets: utils.OpRepSets) -> None: |
412 | 444 | """ |
|
0 commit comments