diff --git a/openwisp_controller/connection/api/serializers.py b/openwisp_controller/connection/api/serializers.py index c837e1e10..1f5a891b5 100644 --- a/openwisp_controller/connection/api/serializers.py +++ b/openwisp_controller/connection/api/serializers.py @@ -31,6 +31,19 @@ class CommandSerializer(ValidatedDeviceFieldSerializer): pk_field=serializers.UUIDField(format='hex_verbose'), ) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # show only connections and command types available for the device + if device_id := self.context.get('device_id'): + self.fields['connection'].queryset = self.fields[ + 'connection' + ].queryset.filter(device_id=device_id) + device = Device.objects.only('organization_id', 'id').get(pk=device_id) + # filter command types based on the device's organization + self.fields['type'].choices = Command.get_org_allowed_commands( + device.organization_id + ) + def to_representation(self, instance): repr = super().to_representation(instance) repr['type'] = instance.get_type_display() diff --git a/openwisp_controller/connection/base/models.py b/openwisp_controller/connection/base/models.py index 6baa9cc07..e472eac5f 100644 --- a/openwisp_controller/connection/base/models.py +++ b/openwisp_controller/connection/base/models.py @@ -437,10 +437,17 @@ class Meta: ordering = ('created',) @classmethod - def get_org_choices(self, organization_id=None): - return ORGANIZATION_ENABLED_COMMANDS.get( + def get_org_allowed_commands(self, organization_id=None): + """ + Returns a list of allowed commands for the given organization + """ + allowed_commands = ORGANIZATION_ENABLED_COMMANDS.get( str(organization_id), ORGANIZATION_ENABLED_COMMANDS.get('__all__') ) + commands_map = dict(COMMAND_CHOICES) + return [ + (cmd, commands_map[cmd]) for cmd in allowed_commands if cmd in commands_map + ] @classmethod def get_org_schema(self, organization_id=None): @@ -459,19 +466,26 @@ def __str__(self): def clean(self): self._verify_command_type_allowed() + self._verify_connection() try: jsonschema.Draft4Validator(self._schema).validate(self.input) except SchemaError as e: raise ValidationError({'input': e.message}) + def _verify_connection(self): + """Raises validation error if device has no connection and credentials.""" + if self.device and not self.device.deviceconnection_set.exists(): + raise ValidationError({'device': _('Device has no credentials assigned.')}) + def _verify_command_type_allowed(self): """Raises validation error if command type is not allowed.""" # if device is not set, skip to avoid uncaught exception # (standard model validation will kick in) if not hasattr(self, 'device'): return - if self.type not in self.get_org_choices( - organization_id=self.device.organization_id + + if self.type not in dict( + self.get_org_allowed_commands(organization_id=self.device.organization_id) ): raise ValidationError( { diff --git a/openwisp_controller/connection/tests/test_api.py b/openwisp_controller/connection/tests/test_api.py index b97ce6da2..dfe7f90d0 100644 --- a/openwisp_controller/connection/tests/test_api.py +++ b/openwisp_controller/connection/tests/test_api.py @@ -175,6 +175,22 @@ def test_command_attributes(self, payload): self.assertEqual(response.status_code, 201) test_command_attributes(self, payload) + # for ensuring that only related connections are shown + def test_available_connections(self): + device = self._create_device( + name='default.test.device2', mac_address='12:23:34:45:56:67' + ) + self._create_config(device=device) + credentials_2 = self._create_credentials(name='Test Credentials 2') + device_conn2 = self._create_device_connection( + device=device, credentials=credentials_2 + ) + url = self._get_path('device_command_list', self.device_id) + response = self.client.get(url, {'format': 'api'}) + self.assertEqual(response.status_code, 200) + self.assertContains(response, str(self.device_conn.id)) + self.assertNotContains(response, device_conn2.id) + def test_command_details_api(self): command_obj = self._create_command(device_conn=self.device_conn) url = self._get_path('device_command_details', self.device_id, command_obj.id) @@ -338,10 +354,30 @@ def test_non_existent_command(self): ) self.assertEqual(response.status_code, 400) self.assertIn( - '"custom" command is not available for this organization', - response.data['input'][0], + '"custom" is not a valid choice.', + response.data['type'][0], ) + def test_create_command_without_connection(self): + device = self._create_device( + name='default.test.device2', mac_address='11:22:33:44:55:66' + ) + url = self._get_path('device_command_list', device.pk) + payload = { + 'type': 'custom', + 'input': {'command': 'echo test'}, + } + response = self.client.post( + url, + data=json.dumps(payload), + content_type='application/json', + ) + self.assertEqual(response.status_code, 400) + self.assertIn( + 'Device has no credentials assigned.', + response.data['device'][0], + ) + class TestConnectionApi( TestAdminMixin, AuthenticationMixin, TestCase, CreateConnectionsMixin diff --git a/openwisp_controller/connection/tests/test_models.py b/openwisp_controller/connection/tests/test_models.py index 1653f6cff..4cdd5d5d6 100644 --- a/openwisp_controller/connection/tests/test_models.py +++ b/openwisp_controller/connection/tests/test_models.py @@ -563,6 +563,18 @@ def test_command_validation(self): ], ) + with self.subTest('Test command creation without device connection'): + device = dc.device + device.deviceconnection_set.all().delete() + with self.assertRaises(ValidationError) as context_manager: + command.full_clean() + exception = context_manager.exception + self.assertIn('device', exception.message_dict) + self.assertEqual( + exception.message_dict['device'], + ['Device has no credentials assigned.'], + ) + @tag('skip_prod') def test_enabled_command(self): self.assertEqual( @@ -786,7 +798,7 @@ def _command_assertions(destination_address, mocked_exec_command): @mock.patch(_connect_path) @mock.patch.dict(COMMANDS, {}) - @mock.patch.dict(ORGANIZATION_ENABLED_COMMANDS, {'__all__': ('restart_network')}) + @mock.patch.dict(ORGANIZATION_ENABLED_COMMANDS, {'__all__': ('restart_network',)}) @mock.patch(_exec_command_path) def test_execute_user_registered_command_without_input( self, mocked_exec_command, connect_mocked