diff --git a/src/azure-cli/azure/cli/command_modules/rdbms/_params.py b/src/azure-cli/azure/cli/command_modules/rdbms/_params.py index fc386940e8e..830fe34ba10 100644 --- a/src/azure-cli/azure/cli/command_modules/rdbms/_params.py +++ b/src/azure-cli/azure/cli/command_modules/rdbms/_params.py @@ -261,11 +261,24 @@ def _flexible_server_params(command_group): arg_group='Authentication' ) + database_name_create_arg_type = CLIArgumentType( + metavar='NAME', + options_list=['--database-name', '-d'], + id_part='child_name_1', + help='The name of the database to be created when provisioning the database server. ' + 'Database name must begin with a letter (a-z) or underscore (_). Subsequent characters ' + 'in a name can be letters, digits (0-9), or underscores. Database name length must be less ' + 'than 32 characters.', + local_context_attribute=LocalContextAttribute( + name='database_name', + actions=[LocalContextAction.SET], + scopes=['{} flexible-server'.format(command_group)])) + database_name_arg_type = CLIArgumentType( metavar='NAME', options_list=['--database-name', '-d'], id_part='child_name_1', - help='The name of the database to be created when provisioning the database server', + help='The name of the database', local_context_attribute=LocalContextAttribute( name='database_name', actions=[LocalContextAction.GET, LocalContextAction.SET], @@ -641,7 +654,7 @@ def _flexible_server_params(command_group): c.argument('zone', zone_arg_type) c.argument('tags', tags_type) c.argument('standby_availability_zone', arg_type=standby_availability_zone_arg_type) - c.argument('database_name', arg_type=database_name_arg_type) + c.argument('database_name', arg_type=database_name_create_arg_type) c.argument('yes', arg_type=yes_arg_type) with self.argument_context('{} flexible-server list'.format(command_group)) as c: @@ -824,6 +837,10 @@ def _flexible_server_params(command_group): argument_context_string = '{} flexible-server db {}'.format(command_group, scope) with self.argument_context(argument_context_string) as c: c.argument('server_name', options_list=['--server-name', '-s'], arg_type=server_name_arg_type) + + for scope in ['delete', 'list', 'show', 'update']: + argument_context_string = '{} flexible-server db {}'.format(command_group, scope) + with self.argument_context(argument_context_string) as c: c.argument('database_name', arg_type=database_name_arg_type) with self.argument_context('{} flexible-server db list'.format(command_group)) as c: @@ -832,6 +849,7 @@ def _flexible_server_params(command_group): with self.argument_context('{} flexible-server db create'.format(command_group)) as c: c.argument('charset', help='The charset of the database. The default value is UTF8') c.argument('collation', help='The collation of the database.') + c.argument('database_name', arg_type=database_name_create_arg_type) with self.argument_context('{} flexible-server db delete'.format(command_group)) as c: c.argument('yes', arg_type=yes_arg_type) diff --git a/src/azure-cli/azure/cli/command_modules/rdbms/flexible_server_custom_postgres.py b/src/azure-cli/azure/cli/command_modules/rdbms/flexible_server_custom_postgres.py index cec391cf8dd..2aba3f0c2df 100644 --- a/src/azure-cli/azure/cli/command_modules/rdbms/flexible_server_custom_postgres.py +++ b/src/azure-cli/azure/cli/command_modules/rdbms/flexible_server_custom_postgres.py @@ -41,7 +41,7 @@ from .validators import pg_arguments_validator, validate_server_name, validate_and_format_restore_point_in_time, \ validate_postgres_replica, validate_georestore_network, pg_byok_validator, validate_migration_runtime_server, \ validate_resource_group, check_resource_group, validate_citus_cluster, cluster_byok_validator, validate_backup_name, \ - validate_virtual_endpoint_name_availability + validate_virtual_endpoint_name_availability, validate_database_name logger = get_logger(__name__) DEFAULT_DB_NAME = 'flexibleserverdb' @@ -87,6 +87,7 @@ def flexible_server_create(cmd, client, pg_arguments_validator(db_context, server_name=server_name, + database_name=database_name, location=location, tier=tier, sku_name=sku_name, storage_gb=storage_gb, @@ -891,6 +892,7 @@ def _create_database(db_context, cmd, resource_group_name, server_name, database def database_create_func(cmd, client, resource_group_name, server_name, database_name=None, charset=None, collation=None): + validate_database_name(database_name) validate_resource_group(resource_group_name) validate_citus_cluster(cmd, resource_group_name, server_name) diff --git a/src/azure-cli/azure/cli/command_modules/rdbms/validators.py b/src/azure-cli/azure/cli/command_modules/rdbms/validators.py index efee756a851..d997fdb99d0 100644 --- a/src/azure-cli/azure/cli/command_modules/rdbms/validators.py +++ b/src/azure-cli/azure/cli/command_modules/rdbms/validators.py @@ -305,15 +305,16 @@ def _mysql_iops_validator(iops, auto_io_scaling, instance): logger.warning("The server has enabled the auto scale iops. So the iops will be ignored.") -def pg_arguments_validator(db_context, location, tier, sku_name, storage_gb, server_name=None, zone=None, - standby_availability_zone=None, high_availability=None, subnet=None, public_access=None, - version=None, instance=None, geo_redundant_backup=None, +def pg_arguments_validator(db_context, location, tier, sku_name, storage_gb, server_name=None, database_name=None, + zone=None, standby_availability_zone=None, high_availability=None, subnet=None, + public_access=None, version=None, instance=None, geo_redundant_backup=None, byok_identity=None, byok_key=None, backup_byok_identity=None, backup_byok_key=None, auto_grow=None, performance_tier=None, storage_type=None, iops=None, throughput=None, create_cluster=None, cluster_size=None, password_auth=None, active_directory_auth=None, microsoft_entra_auth=None, admin_name=None, admin_id=None, admin_type=None): validate_server_name(db_context, server_name, 'Microsoft.DBforPostgreSQL/flexibleServers') + validate_database_name(database_name) is_create = not instance if is_create: list_location_capability_info = get_postgres_location_capability_info( @@ -989,3 +990,10 @@ def validate_backup_name(backup_name): # check if backup_name exceeds 128 characters if len(backup_name) > 128: raise CLIError('Backup name cannot exceed 128 characters.') + + +def validate_database_name(database_name): + if database_name is not None and not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]{0,30}$', database_name): + raise ValidationError("Database name must begin with a letter (a-z) or underscore (_). " + "Subsequent characters in a name can be letters, digits (0-9), or underscores. " + "Database name length must be less than 32 characters.")