|
| 1 | +# -------------------------------------------------------------------------------------------- |
| 2 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 3 | +# Licensed under the MIT License. See License.txt in the project root for license information. |
| 4 | +# -------------------------------------------------------------------------------------------- |
| 5 | + |
| 6 | +import json |
| 7 | +from collections import OrderedDict |
| 8 | +from knack.util import todict |
| 9 | +from knack.log import get_logger |
| 10 | + |
| 11 | +from .vendored_sdks.resourcegraph.models import ResultTruncated |
| 12 | +from .vendored_sdks.resourcegraph.models import QueryRequest, QueryRequestOptions, QueryResponse, ResultFormat, Error |
| 13 | +from azure.cli.core._profile import Profile |
| 14 | +from azure.core.exceptions import HttpResponseError |
| 15 | +from azure.cli.core.azclierror import BadRequestError, AzureInternalError |
| 16 | + |
| 17 | + |
| 18 | +__SUBSCRIPTION_LIMIT = 1000 |
| 19 | +__MANAGEMENT_GROUP_LIMIT = 10 |
| 20 | +__logger = get_logger(__name__) |
| 21 | + |
| 22 | + |
| 23 | +def build_arg_query(resource_groups, tags): |
| 24 | + # type: (list[str], list[str]) -> str |
| 25 | + |
| 26 | + query = "Resources" |
| 27 | + if resource_groups is not None and len(resource_groups) > 0: |
| 28 | + query += " | where resourceGroup in ({0})".format(','.join(f"'{item}'" for item in resource_groups.split(','))) |
| 29 | + |
| 30 | + if tags is not None: |
| 31 | + tagquery = [] |
| 32 | + for tag in tags.split(','): |
| 33 | + tag = tag.strip() |
| 34 | + if not tag: # Skip empty tags |
| 35 | + continue |
| 36 | + |
| 37 | + if '=' in tag: |
| 38 | + # Tag with a value (TagA=ValueA) |
| 39 | + tag_name, tag_value = tag.split('=', 1) |
| 40 | + # Escape single quotes in the value |
| 41 | + tag_value = tag_value.replace("'", "''") |
| 42 | + tagquery.append(f"tags['{tag_name}'] == '{tag_value}'") |
| 43 | + else: |
| 44 | + # Tag without a value. We don't support those. |
| 45 | + pass |
| 46 | + |
| 47 | + if tagquery: # Only proceed if tagquery has items |
| 48 | + query += " | where " + " and ".join(tagquery) |
| 49 | + |
| 50 | + return query |
| 51 | + |
| 52 | + |
| 53 | +def execute_arg_query( |
| 54 | + client, graph_query, first, skip, subscriptions, management_groups, allow_partial_scopes, skip_token): |
| 55 | + |
| 56 | + mgs_list = management_groups |
| 57 | + if mgs_list is not None and len(mgs_list) > __MANAGEMENT_GROUP_LIMIT: |
| 58 | + mgs_list = mgs_list[:__MANAGEMENT_GROUP_LIMIT] |
| 59 | + warning_message = "The query included more management groups than allowed. "\ |
| 60 | + "Only the first {0} management groups were included for the results. "\ |
| 61 | + "To use more than {0} management groups, "\ |
| 62 | + "see the docs for examples: "\ |
| 63 | + "https://aka.ms/arg-error-toomanysubs".format(__MANAGEMENT_GROUP_LIMIT) |
| 64 | + __logger.warning(warning_message) |
| 65 | + |
| 66 | + subs_list = None |
| 67 | + if mgs_list is None: |
| 68 | + subs_list = subscriptions or _get_cached_subscriptions() |
| 69 | + if subs_list is not None and len(subs_list) > __SUBSCRIPTION_LIMIT: |
| 70 | + subs_list = subs_list[:__SUBSCRIPTION_LIMIT] |
| 71 | + warning_message = "The query included more subscriptions than allowed. "\ |
| 72 | + "Only the first {0} subscriptions were included for the results. "\ |
| 73 | + "To use more than {0} subscriptions, "\ |
| 74 | + "see the docs for examples: "\ |
| 75 | + "https://aka.ms/arg-error-toomanysubs".format(__SUBSCRIPTION_LIMIT) |
| 76 | + __logger.warning(warning_message) |
| 77 | + |
| 78 | + response = None |
| 79 | + try: |
| 80 | + result_truncated = False |
| 81 | + |
| 82 | + request_options = QueryRequestOptions( |
| 83 | + top=first, |
| 84 | + skip=skip, |
| 85 | + skip_token=skip_token, |
| 86 | + result_format=ResultFormat.object_array, |
| 87 | + allow_partial_scopes=allow_partial_scopes |
| 88 | + ) |
| 89 | + |
| 90 | + request = QueryRequest( |
| 91 | + query=graph_query, |
| 92 | + subscriptions=subs_list, |
| 93 | + management_groups=mgs_list, |
| 94 | + options=request_options) |
| 95 | + response = client.resources(request) # type: QueryResponse |
| 96 | + if response.result_truncated == ResultTruncated.true: |
| 97 | + result_truncated = True |
| 98 | + |
| 99 | + if result_truncated and first is not None and len(response.data) < first: |
| 100 | + __logger.warning("Unable to paginate the results of the query. " |
| 101 | + "Some resources may be missing from the results. " |
| 102 | + "To rewrite the query and enable paging, " |
| 103 | + "see the docs for an example: https://aka.ms/arg-results-truncated") |
| 104 | + |
| 105 | + except HttpResponseError as ex: |
| 106 | + if ex.model.error.code == 'BadRequest': |
| 107 | + raise BadRequestError(json.dumps(_to_dict(ex.model.error), indent=4)) from ex |
| 108 | + |
| 109 | + raise AzureInternalError(json.dumps(_to_dict(ex.model.error), indent=4)) from ex |
| 110 | + |
| 111 | + result_dict = dict() |
| 112 | + result_dict['data'] = response.data |
| 113 | + result_dict['count'] = response.count |
| 114 | + result_dict['total_records'] = response.total_records |
| 115 | + result_dict['skip_token'] = response.skip_token |
| 116 | + |
| 117 | + return result_dict |
| 118 | + |
| 119 | + |
| 120 | +def _get_cached_subscriptions(): |
| 121 | + # type: () -> list[str] |
| 122 | + |
| 123 | + cached_subs = Profile().load_cached_subscriptions() |
| 124 | + return [sub['id'] for sub in cached_subs] |
| 125 | + |
| 126 | + |
| 127 | +def _to_dict(obj): |
| 128 | + if isinstance(obj, Error): |
| 129 | + return _to_dict(todict(obj)) |
| 130 | + |
| 131 | + if isinstance(obj, dict): |
| 132 | + result = OrderedDict() |
| 133 | + |
| 134 | + # Complex objects should be displayed last |
| 135 | + sorted_keys = sorted(obj.keys(), key=lambda k: (isinstance(obj[k], dict), isinstance(obj[k], list), k)) |
| 136 | + for key in sorted_keys: |
| 137 | + if obj[key] is None or obj[key] == [] or obj[key] == {}: |
| 138 | + continue |
| 139 | + |
| 140 | + result[key] = _to_dict(obj[key]) |
| 141 | + return result |
| 142 | + |
| 143 | + if isinstance(obj, list): |
| 144 | + return [_to_dict(v) for v in obj] |
| 145 | + |
| 146 | + return obj |
0 commit comments