|
399 | 399 | " params: Optional[torch.Tensor] = None, \n", |
400 | 400 | " silent: bool = True,\n", |
401 | 401 | " n_jobs: int = 1,\n", |
402 | | - " filter_errs: bool = True) -> tuple[Sequence[any], int]:\n", |
| 402 | + " filter_errs: bool = True,\n", |
| 403 | + " return_tensors: bool = False,\n", |
| 404 | + " ) -> tuple[Sequence[any], int] | tuple[Sequence[any], int, torch.Tensor]:\n", |
403 | 405 | " tensors = tensors.cpu()\n", |
404 | 406 | "\n", |
405 | 407 | " if exists(params):\n", |
|
426 | 428 | " backend_obj_list = [pot_qc for pot_qc in pot_qcs if exists(pot_qc)]\n", |
427 | 429 | " err_cnt = sum(1 for pot_qc in pot_qcs if not_exists(pot_qc))\n", |
428 | 430 | " assert len(backend_obj_list) + err_cnt == len(pot_qcs)\n", |
| 431 | + "\n", |
| 432 | + " if return_tensors:\n", |
| 433 | + " tensors = tensors[torch.tensor([exists(pot_qc) for pot_qc in pot_qcs])]\n", |
| 434 | + " \n", |
429 | 435 | " else:\n", |
430 | 436 | " backend_obj_list = pot_qcs\n", |
431 | 437 | " err_cnt = None\n", |
432 | | - " \n", |
| 438 | + "\n", |
| 439 | + " if return_tensors:\n", |
| 440 | + " return backend_obj_list, err_cnt, tensors\n", |
433 | 441 | " return backend_obj_list, err_cnt" |
434 | 442 | ] |
435 | 443 | }, |
|
0 commit comments