diff --git a/agent/installer.go b/agent/installer.go index fbd63a59d8b..84a8c2c4f5c 100644 --- a/agent/installer.go +++ b/agent/installer.go @@ -88,7 +88,7 @@ func registerInstallerCommands(rootCmd *cobra.Command) { installCmd.Flags().String("preferred-identity", "", "Preferred device identity") installCmd.Flags().Uint("keepalive-interval", 30, "Keepalive interval in seconds") installCmd.MarkFlagRequired("server-address") //nolint:errcheck - installCmd.MarkFlagRequired("tenant-id") //nolint:errcheck + installCmd.MarkFlagRequired("tenant-id") //nolint:errcheck rootCmd.AddCommand(installCmd) @@ -169,7 +169,7 @@ func writeAgentEnvFile(cfg installerConfig) error { fmt.Fprintf(&buf, "SHELLHUB_KEEPALIVE_INTERVAL=%d\n", cfg.KeepaliveInterval) } - return os.WriteFile(agentEnvFile, buf.Bytes(), 0600) + return os.WriteFile(agentEnvFile, buf.Bytes(), 0o600) } func writeAgentServiceFile(binaryPath string) error { @@ -183,7 +183,7 @@ func writeAgentServiceFile(binaryPath string) error { return err } - return os.WriteFile(agentServiceFile, buf.Bytes(), 0644) + return os.WriteFile(agentServiceFile, buf.Bytes(), 0o644) } func agentUninstall() error { diff --git a/api/routes/device.go b/api/routes/device.go index c63e2cfc249..9019c9f3a21 100644 --- a/api/routes/device.go +++ b/api/routes/device.go @@ -21,6 +21,8 @@ const ( LookupDeviceURL = "/device/lookup" UpdateDeviceStatusURL = "/devices/:uid/:status" UpdateDevice = "/devices/:uid" + GetDeviceSettingsURL = "/devices/:uid/settings" + UpdateDeviceSettingsURL = "/devices/:uid/settings" SetDeviceCustomFieldURL = "/devices/:uid/custom_fields/:key" DeleteDeviceCustomFieldURL = "/devices/:uid/custom_fields/:key" ) @@ -263,6 +265,10 @@ func (h *Handler) UpdateDevice(c gateway.Context) error { return err } + if c.Tenant() != nil { + req.TenantID = c.Tenant().ID + } + if err := h.service.UpdateDevice(c.Ctx(), req); err != nil { return err } @@ -270,6 +276,51 @@ func (h *Handler) UpdateDevice(c gateway.Context) error { return c.NoContent(http.StatusOK) } +func (h *Handler) GetDeviceSettings(c gateway.Context) error { + req := new(requests.DeviceGetSettings) + + if err := c.Bind(req); err != nil { + return err + } + + if c.Tenant() != nil { + req.TenantID = c.Tenant().ID + } + + if err := c.Validate(req); err != nil { + return err + } + + settings, err := h.service.GetDeviceSettings(c.Ctx(), req) + if err != nil { + return err + } + + return c.JSON(http.StatusOK, settings) +} + +func (h *Handler) UpdateDeviceSettings(c gateway.Context) error { + req := new(requests.DeviceUpdateSettings) + + if err := c.Bind(req); err != nil { + return err + } + + if c.Tenant() != nil { + req.TenantID = c.Tenant().ID + } + + if err := c.Validate(req); err != nil { + return err + } + + if err := h.service.UpdateDeviceSettings(c.Ctx(), req); err != nil { + return err + } + + return c.NoContent(http.StatusOK) +} + func (h *Handler) SetDeviceCustomField(c gateway.Context) error { req := new(requests.DeviceSetCustomField) diff --git a/api/routes/device_test.go b/api/routes/device_test.go index e08d0af3ae7..bfb5e1591b1 100644 --- a/api/routes/device_test.go +++ b/api/routes/device_test.go @@ -662,3 +662,111 @@ func TestUpdateDevice(t *testing.T) { }) } } + +func TestGetDeviceSettings(t *testing.T) { + mock := new(mocks.Service) + + settings := &models.SSHSettings{ + AllowPassword: false, + AllowPublicKey: true, + } + + mock. + On("GetDeviceSettings", gomock.Anything, &requests.DeviceGetSettings{ + TenantID: "00000000-0000-4000-0000-000000000000", + DeviceParam: requests.DeviceParam{UID: "1234"}, + }). + Return(settings, nil). + Once() + + req := httptest.NewRequest(http.MethodGet, "/api/devices/1234/settings", nil) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Role", authorizer.RoleOwner.String()) + req.Header.Set("X-Tenant-ID", "00000000-0000-4000-0000-000000000000") + rec := httptest.NewRecorder() + + e := NewRouter(mock) + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Result().StatusCode) + + var body *models.SSHSettings + err := json.NewDecoder(rec.Result().Body).Decode(&body) + require.NoError(t, err) + assert.Equal(t, settings, body) +} + +func TestUpdateDeviceSettings(t *testing.T) { + mock := new(mocks.Service) + + cases := []struct { + description string + req requests.DeviceUpdateSettings + requiredMocks func() + expectedStatus int + }{ + { + description: "fails when device settings update cannot find device", + req: requests.DeviceUpdateSettings{ + TenantID: "00000000-0000-4000-0000-000000000000", + DeviceParam: requests.DeviceParam{ + UID: "1234", + }, + }, + requiredMocks: func() { + mock.On("UpdateDeviceSettings", gomock.Anything, &requests.DeviceUpdateSettings{ + TenantID: "00000000-0000-4000-0000-000000000000", + DeviceParam: requests.DeviceParam{UID: "1234"}, + }).Return(svc.ErrNotFound).Once() + }, + expectedStatus: http.StatusNotFound, + }, + { + description: "success when updating a device setting", + req: requests.DeviceUpdateSettings{ + TenantID: "00000000-0000-4000-0000-000000000000", + DeviceParam: requests.DeviceParam{ + UID: "1234", + }, + SSHSettingsUpdate: requests.SSHSettingsUpdate{ + AllowPassword: func() *bool { + v := false + + return &v + }(), + }, + }, + requiredMocks: func() { + v := false + mock.On("UpdateDeviceSettings", gomock.Anything, &requests.DeviceUpdateSettings{ + TenantID: "00000000-0000-4000-0000-000000000000", + DeviceParam: requests.DeviceParam{UID: "1234"}, + SSHSettingsUpdate: requests.SSHSettingsUpdate{ + AllowPassword: &v, + }, + }).Return(nil).Once() + }, + expectedStatus: http.StatusOK, + }, + } + + for _, tc := range cases { + t.Run(tc.description, func(t *testing.T) { + tc.requiredMocks() + + jsonData, err := json.Marshal(tc.req) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPatch, fmt.Sprintf("/api/devices/%s/settings", tc.req.UID), strings.NewReader(string(jsonData))) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Role", authorizer.RoleOwner.String()) + req.Header.Set("X-Tenant-ID", "00000000-0000-4000-0000-000000000000") + rec := httptest.NewRecorder() + + e := NewRouter(mock) + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectedStatus, rec.Result().StatusCode) + }) + } +} diff --git a/api/routes/routes.go b/api/routes/routes.go index 1049a19108a..8114e432475 100644 --- a/api/routes/routes.go +++ b/api/routes/routes.go @@ -143,8 +143,10 @@ func NewRouter(service services.Service, opts ...Option) *echo.Echo { publicAPI.GET(GetDeviceListURL, routesmiddleware.Authorize(gateway.Handler(handler.GetDeviceList))) publicAPI.GET(GetDeviceURL, routesmiddleware.Authorize(gateway.Handler(handler.GetDevice))) + publicAPI.GET(GetDeviceSettingsURL, routesmiddleware.Authorize(gateway.Handler(handler.GetDeviceSettings))) publicAPI.GET(ResolveDeviceURL, routesmiddleware.Authorize(gateway.Handler(handler.ResolveDevice))) publicAPI.PUT(UpdateDevice, gateway.Handler(handler.UpdateDevice), routesmiddleware.RequiresPermission(authorizer.DeviceUpdate)) + publicAPI.PATCH(UpdateDeviceSettingsURL, gateway.Handler(handler.UpdateDeviceSettings), routesmiddleware.RequiresPermission(authorizer.DeviceUpdate)) publicAPI.PATCH(RenameDeviceURL, gateway.Handler(handler.RenameDevice), routesmiddleware.RequiresPermission(authorizer.DeviceRename)) publicAPI.PATCH(UpdateDeviceStatusURL, gateway.Handler(handler.UpdateDeviceStatus), routesmiddleware.RequiresPermission(authorizer.DeviceAccept)) // TODO: DeviceWrite publicAPI.DELETE(DeleteDeviceURL, gateway.Handler(handler.DeleteDevice), routesmiddleware.RequiresPermission(authorizer.DeviceRemove)) diff --git a/api/services/device.go b/api/services/device.go index 5e7d341742b..79b7a1201c3 100644 --- a/api/services/device.go +++ b/api/services/device.go @@ -62,6 +62,8 @@ type DeviceService interface { OfflineDevice(ctx context.Context, uid models.UID) error UpdateDevice(ctx context.Context, req *requests.DeviceUpdate) error + GetDeviceSettings(ctx context.Context, req *requests.DeviceGetSettings) (*models.SSHSettings, error) + UpdateDeviceSettings(ctx context.Context, req *requests.DeviceUpdateSettings) error // UpdateDeviceStatus updates a device's status. Devices that are already accepted cannot change their status. // // When accepting, if a device with the same MAC address is already accepted within the same namespace, it @@ -398,6 +400,64 @@ func (s *service) UpdateDevice(ctx context.Context, req *requests.DeviceUpdate) return nil } +func (s *service) GetDeviceSettings(ctx context.Context, req *requests.DeviceGetSettings) (*models.SSHSettings, error) { + device, err := s.store.DeviceResolve(ctx, store.DeviceUIDResolver, req.UID, s.store.Options().InNamespace(req.TenantID)) + if err != nil { + return nil, NewErrDeviceNotFound(models.UID(req.UID), err) + } + + if device.SSH == nil { + return models.DefaultSSHSettings(), nil + } + + return device.SSH, nil +} + +func (s *service) UpdateDeviceSettings(ctx context.Context, req *requests.DeviceUpdateSettings) error { + device, err := s.store.DeviceResolve(ctx, store.DeviceUIDResolver, req.UID, s.store.Options().InNamespace(req.TenantID)) + if err != nil { + return NewErrDeviceNotFound(models.UID(req.UID), err) + } + + if device.SSH == nil { + device.SSH = models.DefaultSSHSettings() + } + + if req.AllowPassword != nil { + device.SSH.AllowPassword = *req.AllowPassword + } + if req.AllowPublicKey != nil { + device.SSH.AllowPublicKey = *req.AllowPublicKey + } + if req.AllowRoot != nil { + device.SSH.AllowRoot = *req.AllowRoot + } + if req.AllowEmptyPasswords != nil { + device.SSH.AllowEmptyPasswords = *req.AllowEmptyPasswords + } + if req.AllowTTY != nil { + device.SSH.AllowTTY = *req.AllowTTY + } + if req.AllowTCPForwarding != nil { + device.SSH.AllowTCPForwarding = *req.AllowTCPForwarding + } + if req.AllowWebEndpoints != nil { + device.SSH.AllowWebEndpoints = *req.AllowWebEndpoints + } + if req.AllowSFTP != nil { + device.SSH.AllowSFTP = *req.AllowSFTP + } + if req.AllowAgentForwarding != nil { + device.SSH.AllowAgentForwarding = *req.AllowAgentForwarding + } + + if err := s.store.DeviceUpdateSettings(ctx, req.UID, device.SSH); err != nil { + return err + } + + return nil +} + // maxCustomFieldsPerDevice is the upper bound on the number of custom_fields entries // per device. Enforced server-side to prevent storage abuse. const maxCustomFieldsPerDevice = 20 @@ -451,6 +511,10 @@ func (s *service) mergeDevice(ctx context.Context, tenantID string, oldDevice *m } log.WithFields(logFields).Debug("updating new device name to preserve old device identity") + if oldDevice.SSH != nil { + newDevice.SSH = oldDevice.SSH + } + newDevice.Name = oldDevice.Name if err := s.store.DeviceUpdate(ctx, newDevice); err != nil { log.WithError(err).WithFields(logFields).Error("failed to update new device name") diff --git a/api/services/device_test.go b/api/services/device_test.go index 9425d8fc545..9662794949f 100644 --- a/api/services/device_test.go +++ b/api/services/device_test.go @@ -1749,6 +1749,17 @@ func TestUpdateDeviceStatus(t *testing.T) { TenantID: "00000000-0000-0000-0000-000000000000", Status: models.DeviceStatusAccepted, Identity: &models.DeviceIdentity{MAC: "aa:bb:cc:dd:ee:ff"}, + SSH: &models.SSHSettings{ + AllowPassword: false, + AllowPublicKey: true, + AllowRoot: false, + AllowEmptyPasswords: true, + AllowTTY: false, + AllowTCPForwarding: true, + AllowWebEndpoints: false, + AllowSFTP: true, + AllowAgentForwarding: false, + }, } mergedDevice := &models.Device{ UID: "new-device", @@ -1756,6 +1767,17 @@ func TestUpdateDeviceStatus(t *testing.T) { TenantID: "00000000-0000-0000-0000-000000000000", Status: models.DeviceStatusPending, Identity: &models.DeviceIdentity{MAC: "aa:bb:cc:dd:ee:ff"}, + SSH: &models.SSHSettings{ + AllowPassword: false, + AllowPublicKey: true, + AllowRoot: false, + AllowEmptyPasswords: true, + AllowTTY: false, + AllowTCPForwarding: true, + AllowWebEndpoints: false, + AllowSFTP: true, + AllowAgentForwarding: false, + }, } finalDevice := &models.Device{ UID: "new-device", @@ -1764,6 +1786,17 @@ func TestUpdateDeviceStatus(t *testing.T) { Status: models.DeviceStatusAccepted, StatusUpdatedAt: now, Identity: &models.DeviceIdentity{MAC: "aa:bb:cc:dd:ee:ff"}, + SSH: &models.SSHSettings{ + AllowPassword: false, + AllowPublicKey: true, + AllowRoot: false, + AllowEmptyPasswords: true, + AllowTTY: false, + AllowTCPForwarding: true, + AllowWebEndpoints: false, + AllowSFTP: true, + AllowAgentForwarding: false, + }, } storeMock. @@ -2410,6 +2443,115 @@ func TestDeviceUpdate(t *testing.T) { } } +func TestGetDeviceSettings(t *testing.T) { + storeMock := new(storemock.Store) + queryOptionsMock := new(storemock.QueryOptions) + storeMock.On("Options").Return(queryOptionsMock) + + service := NewService(storeMock, privateKey, publicKey, storecache.NewNullCache(), clientMock) + + t.Run("returns defaults when device settings are nil", func(t *testing.T) { + ctx := context.Background() + req := &requests.DeviceGetSettings{ + TenantID: "00000000-0000-0000-0000-000000000000", + DeviceParam: requests.DeviceParam{UID: "device-id"}, + } + + queryOptionsMock.On("InNamespace", req.TenantID).Return(nil).Once() + storeMock. + On("DeviceResolve", ctx, store.DeviceUIDResolver, req.UID, mock.AnythingOfType("store.QueryOption")). + Return(&models.Device{UID: req.UID}, nil). + Once() + + settings, err := service.GetDeviceSettings(ctx, req) + require.NoError(t, err) + assert.Equal(t, models.DefaultSSHSettings(), settings) + }) +} + +func TestUpdateDeviceSettings(t *testing.T) { + now := time.Now() + storeMock := new(storemock.Store) + queryOptionsMock := new(storemock.QueryOptions) + storeMock.On("Options").Return(queryOptionsMock) + + cases := []struct { + description string + req *requests.DeviceUpdateSettings + requiredMocks func(ctx context.Context) + expected error + }{ + { + description: "fails when device cannot be resolved", + req: &requests.DeviceUpdateSettings{ + TenantID: "00000000-0000-0000-0000-000000000000", + DeviceParam: requests.DeviceParam{UID: "device-id"}, + }, + requiredMocks: func(ctx context.Context) { + queryOptionsMock.On("InNamespace", "00000000-0000-0000-0000-000000000000").Return(nil).Once() + storeMock. + On("DeviceResolve", ctx, store.DeviceUIDResolver, "device-id", mock.AnythingOfType("store.QueryOption")). + Return(nil, errors.New("error", "", 0)). + Once() + }, + expected: NewErrDeviceNotFound(models.UID("device-id"), errors.New("error", "", 0)), + }, + { + description: "updates a single device setting", + req: &requests.DeviceUpdateSettings{ + TenantID: "00000000-0000-0000-0000-000000000000", + DeviceParam: requests.DeviceParam{UID: "device-id"}, + SSHSettingsUpdate: requests.SSHSettingsUpdate{ + AllowPassword: func() *bool { + v := false + + return &v + }(), + }, + }, + requiredMocks: func(ctx context.Context) { + device := &models.Device{ + UID: "device-id", + Name: "device", + DisconnectedAt: &now, + SSH: models.DefaultSSHSettings(), + } + device.SSH.AllowPassword = true + updated := &models.Device{ + UID: "device-id", + Name: "device", + DisconnectedAt: &now, + SSH: models.DefaultSSHSettings(), + } + updated.SSH.AllowPassword = false + + queryOptionsMock.On("InNamespace", "00000000-0000-0000-0000-000000000000").Return(nil).Once() + storeMock. + On("DeviceResolve", ctx, store.DeviceUIDResolver, "device-id", mock.AnythingOfType("store.QueryOption")). + Return(device, nil). + Once() + storeMock. + On("DeviceUpdateSettings", ctx, "device-id", updated.SSH). + Return(nil). + Once() + }, + expected: nil, + }, + } + + service := NewService(storeMock, privateKey, publicKey, storecache.NewNullCache(), clientMock) + + for _, test := range cases { + t.Run(test.description, func(t *testing.T) { + ctx := context.Background() + test.requiredMocks(ctx) + + err := service.UpdateDeviceSettings(ctx, test.req) + assert.Equal(t, test.expected, err) + }) + } +} + func TestSetDeviceCustomField(t *testing.T) { storeMock := new(storemock.Store) queryOptionsMock := new(storemock.QueryOptions) diff --git a/api/services/mocks/services.go b/api/services/mocks/services.go index 0aa288167f9..a00b8a7a4fa 100644 --- a/api/services/mocks/services.go +++ b/api/services/mocks/services.go @@ -1515,6 +1515,52 @@ func (_m *Service) UpdateDevice(ctx context.Context, req *requests.DeviceUpdate) return r0 } +// UpdateDeviceSettings provides a mock function with given fields: ctx, req +func (_m *Service) UpdateDeviceSettings(ctx context.Context, req *requests.DeviceUpdateSettings) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for UpdateDeviceSettings") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *requests.DeviceUpdateSettings) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// GetDeviceSettings provides a mock function with given fields: ctx, req +func (_m *Service) GetDeviceSettings(ctx context.Context, req *requests.DeviceGetSettings) (*models.SSHSettings, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for GetDeviceSettings") + } + + var r0 *models.SSHSettings + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *requests.DeviceGetSettings) (*models.SSHSettings, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *requests.DeviceGetSettings) *models.SSHSettings); ok { + r0 = rf(ctx, req) + } else if ret.Get(0) != nil { + r0 = ret.Get(0).(*models.SSHSettings) + } + + if rf, ok := ret.Get(1).(func(context.Context, *requests.DeviceGetSettings) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // UpdateDeviceStatus provides a mock function with given fields: ctx, req func (_m *Service) UpdateDeviceStatus(ctx context.Context, req *requests.DeviceUpdateStatus) error { ret := _m.Called(ctx, req) diff --git a/api/services/namespace.go b/api/services/namespace.go index 77ce8dfbdbb..95c00eb2f3d 100644 --- a/api/services/namespace.go +++ b/api/services/namespace.go @@ -66,6 +66,15 @@ func (s *service) CreateNamespace(ctx context.Context, req *requests.NamespaceCr Settings: &models.NamespaceSettings{ SessionRecord: true, ConnectionAnnouncement: "", + AllowPassword: true, + AllowPublicKey: true, + AllowRoot: true, + AllowEmptyPasswords: true, + AllowTTY: true, + AllowTCPForwarding: true, + AllowWebEndpoints: true, + AllowSFTP: true, + AllowAgentForwarding: true, }, TenantID: req.TenantID, Type: models.NewDefaultType(), @@ -196,6 +205,42 @@ func (s *service) EditNamespace(ctx context.Context, req *requests.NamespaceEdit namespace.Settings.ConnectionAnnouncement = *req.Settings.ConnectionAnnouncement } + if req.Settings.AllowPassword != nil { + namespace.Settings.AllowPassword = *req.Settings.AllowPassword + } + + if req.Settings.AllowPublicKey != nil { + namespace.Settings.AllowPublicKey = *req.Settings.AllowPublicKey + } + + if req.Settings.AllowRoot != nil { + namespace.Settings.AllowRoot = *req.Settings.AllowRoot + } + + if req.Settings.AllowEmptyPasswords != nil { + namespace.Settings.AllowEmptyPasswords = *req.Settings.AllowEmptyPasswords + } + + if req.Settings.AllowTTY != nil { + namespace.Settings.AllowTTY = *req.Settings.AllowTTY + } + + if req.Settings.AllowTCPForwarding != nil { + namespace.Settings.AllowTCPForwarding = *req.Settings.AllowTCPForwarding + } + + if req.Settings.AllowWebEndpoints != nil { + namespace.Settings.AllowWebEndpoints = *req.Settings.AllowWebEndpoints + } + + if req.Settings.AllowSFTP != nil { + namespace.Settings.AllowSFTP = *req.Settings.AllowSFTP + } + + if req.Settings.AllowAgentForwarding != nil { + namespace.Settings.AllowAgentForwarding = *req.Settings.AllowAgentForwarding + } + if err := s.store.NamespaceUpdate(ctx, namespace); err != nil { return nil, err } diff --git a/api/services/namespace_test.go b/api/services/namespace_test.go index 6a13361ee02..dcc54e904cb 100644 --- a/api/services/namespace_test.go +++ b/api/services/namespace_test.go @@ -641,6 +641,15 @@ func TestCreateNamespace(t *testing.T) { Settings: &models.NamespaceSettings{ SessionRecord: true, ConnectionAnnouncement: models.DefaultAnnouncementMessage, + AllowPassword: true, + AllowPublicKey: true, + AllowRoot: true, + AllowEmptyPasswords: true, + AllowTTY: true, + AllowTCPForwarding: true, + AllowWebEndpoints: true, + AllowSFTP: true, + AllowAgentForwarding: true, }, MaxDevices: -1, }, @@ -717,6 +726,15 @@ func TestCreateNamespace(t *testing.T) { Settings: &models.NamespaceSettings{ SessionRecord: true, ConnectionAnnouncement: models.DefaultAnnouncementMessage, + AllowPassword: true, + AllowPublicKey: true, + AllowRoot: true, + AllowEmptyPasswords: true, + AllowTTY: true, + AllowTCPForwarding: true, + AllowWebEndpoints: true, + AllowSFTP: true, + AllowAgentForwarding: true, }, MaxDevices: -1, }, @@ -740,6 +758,15 @@ func TestCreateNamespace(t *testing.T) { Settings: &models.NamespaceSettings{ SessionRecord: true, ConnectionAnnouncement: models.DefaultAnnouncementMessage, + AllowPassword: true, + AllowPublicKey: true, + AllowRoot: true, + AllowEmptyPasswords: true, + AllowTTY: true, + AllowTCPForwarding: true, + AllowWebEndpoints: true, + AllowSFTP: true, + AllowAgentForwarding: true, }, MaxDevices: -1, }, @@ -814,6 +841,15 @@ func TestCreateNamespace(t *testing.T) { Settings: &models.NamespaceSettings{ SessionRecord: true, ConnectionAnnouncement: models.DefaultAnnouncementMessage, + AllowPassword: true, + AllowPublicKey: true, + AllowRoot: true, + AllowEmptyPasswords: true, + AllowTTY: true, + AllowTCPForwarding: true, + AllowWebEndpoints: true, + AllowSFTP: true, + AllowAgentForwarding: true, }, MaxDevices: -1, }, @@ -837,6 +873,15 @@ func TestCreateNamespace(t *testing.T) { Settings: &models.NamespaceSettings{ SessionRecord: true, ConnectionAnnouncement: models.DefaultAnnouncementMessage, + AllowPassword: true, + AllowPublicKey: true, + AllowRoot: true, + AllowEmptyPasswords: true, + AllowTTY: true, + AllowTCPForwarding: true, + AllowWebEndpoints: true, + AllowSFTP: true, + AllowAgentForwarding: true, }, MaxDevices: -1, }, @@ -908,6 +953,15 @@ func TestCreateNamespace(t *testing.T) { Settings: &models.NamespaceSettings{ SessionRecord: true, ConnectionAnnouncement: "", + AllowPassword: true, + AllowPublicKey: true, + AllowRoot: true, + AllowEmptyPasswords: true, + AllowTTY: true, + AllowTCPForwarding: true, + AllowWebEndpoints: true, + AllowSFTP: true, + AllowAgentForwarding: true, }, MaxDevices: -1, }, @@ -931,6 +985,15 @@ func TestCreateNamespace(t *testing.T) { Settings: &models.NamespaceSettings{ SessionRecord: true, ConnectionAnnouncement: "", + AllowPassword: true, + AllowPublicKey: true, + AllowRoot: true, + AllowEmptyPasswords: true, + AllowTTY: true, + AllowTCPForwarding: true, + AllowWebEndpoints: true, + AllowSFTP: true, + AllowAgentForwarding: true, }, MaxDevices: -1, }, @@ -1002,6 +1065,15 @@ func TestCreateNamespace(t *testing.T) { Settings: &models.NamespaceSettings{ SessionRecord: true, ConnectionAnnouncement: "", + AllowPassword: true, + AllowPublicKey: true, + AllowRoot: true, + AllowEmptyPasswords: true, + AllowTTY: true, + AllowTCPForwarding: true, + AllowWebEndpoints: true, + AllowSFTP: true, + AllowAgentForwarding: true, }, MaxDevices: 3, }, @@ -1025,6 +1097,15 @@ func TestCreateNamespace(t *testing.T) { Settings: &models.NamespaceSettings{ SessionRecord: true, ConnectionAnnouncement: "", + AllowPassword: true, + AllowPublicKey: true, + AllowRoot: true, + AllowEmptyPasswords: true, + AllowTTY: true, + AllowTCPForwarding: true, + AllowWebEndpoints: true, + AllowSFTP: true, + AllowAgentForwarding: true, }, MaxDevices: 3, }, @@ -1062,7 +1143,13 @@ func TestEditNamespace(t *testing.T) { requiredMocks func() tenantID string namespaceName string - expected Expected + settings struct { + SessionRecord *bool + ConnectionAnnouncement *string + AllowPassword *bool + AllowPublicKey *bool + } + expected Expected }{ { description: "fails when namespace does not exist", @@ -1147,10 +1234,112 @@ func TestEditNamespace(t *testing.T) { nil, }, }, + { + description: "succeeds changing AllowPassword", + tenantID: "xxxxx", + settings: struct { + SessionRecord *bool + ConnectionAnnouncement *string + AllowPassword *bool + AllowPublicKey *bool + }{ + AllowPassword: func(b bool) *bool { return &b }(true), + }, + requiredMocks: func() { + namespace := &models.Namespace{ + TenantID: "xxxxx", + Name: "oldname", + Settings: &models.NamespaceSettings{AllowPassword: false}, + } + storeMock. + On("NamespaceResolve", ctx, store.NamespaceTenantIDResolver, "xxxxx"). + Return(namespace, nil). + Once() + + expectedNamespace := *namespace + expectedNamespace.Settings.AllowPassword = true + storeMock. + On("NamespaceUpdate", ctx, &expectedNamespace). + Return(nil). + Once() + + finalNamespace := &models.Namespace{ + TenantID: "xxxxx", + Name: "oldname", + Settings: &models.NamespaceSettings{AllowPassword: true}, + } + storeMock. + On("NamespaceResolve", ctx, store.NamespaceTenantIDResolver, "xxxxx"). + Return(finalNamespace, nil). + Once() + }, + expected: Expected{ + &models.Namespace{ + TenantID: "xxxxx", + Name: "oldname", + Settings: &models.NamespaceSettings{AllowPassword: true}, + }, + nil, + }, + }, + { + description: "succeeds changing AllowPublicKey", + tenantID: "xxxxx", + settings: struct { + SessionRecord *bool + ConnectionAnnouncement *string + AllowPassword *bool + AllowPublicKey *bool + }{ + AllowPublicKey: func(b bool) *bool { return &b }(true), + }, + requiredMocks: func() { + namespace := &models.Namespace{ + TenantID: "xxxxx", + Name: "oldname", + Settings: &models.NamespaceSettings{AllowPublicKey: false}, + } + storeMock. + On("NamespaceResolve", ctx, store.NamespaceTenantIDResolver, "xxxxx"). + Return(namespace, nil). + Once() + + expectedNamespace := *namespace + expectedNamespace.Settings.AllowPublicKey = true + storeMock. + On("NamespaceUpdate", ctx, &expectedNamespace). + Return(nil). + Once() + + finalNamespace := &models.Namespace{ + TenantID: "xxxxx", + Name: "oldname", + Settings: &models.NamespaceSettings{AllowPublicKey: true}, + } + storeMock. + On("NamespaceResolve", ctx, store.NamespaceTenantIDResolver, "xxxxx"). + Return(finalNamespace, nil). + Once() + }, + expected: Expected{ + &models.Namespace{ + TenantID: "xxxxx", + Name: "oldname", + Settings: &models.NamespaceSettings{AllowPublicKey: true}, + }, + nil, + }, + }, { description: "succeeds", namespaceName: "newname", tenantID: "xxxxx", + settings: struct { + SessionRecord *bool + ConnectionAnnouncement *string + AllowPassword *bool + AllowPublicKey *bool + }{}, requiredMocks: func() { namespace := &models.Namespace{ TenantID: "xxxxx", @@ -1199,6 +1388,11 @@ func TestEditNamespace(t *testing.T) { TenantParam: requests.TenantParam{Tenant: tc.tenantID}, Name: tc.namespaceName, } + req.Settings.SessionRecord = tc.settings.SessionRecord + req.Settings.ConnectionAnnouncement = tc.settings.ConnectionAnnouncement + req.Settings.AllowPassword = tc.settings.AllowPassword + req.Settings.AllowPublicKey = tc.settings.AllowPublicKey + namespace, err := service.EditNamespace(ctx, req) assert.Equal(t, tc.expected, Expected{namespace, err}) diff --git a/api/services/setup.go b/api/services/setup.go index 1d5d99193a1..015a135c603 100644 --- a/api/services/setup.go +++ b/api/services/setup.go @@ -77,6 +77,15 @@ func (s *service) Setup(ctx context.Context, req requests.Setup) error { Settings: &models.NamespaceSettings{ SessionRecord: false, ConnectionAnnouncement: models.DefaultAnnouncementMessage, + AllowPassword: true, + AllowPublicKey: true, + AllowRoot: true, + AllowEmptyPasswords: true, + AllowTTY: true, + AllowTCPForwarding: true, + AllowWebEndpoints: true, + AllowSFTP: true, + AllowAgentForwarding: true, }, } diff --git a/api/services/setup_test.go b/api/services/setup_test.go index 7eea9a43660..ffeb898877d 100644 --- a/api/services/setup_test.go +++ b/api/services/setup_test.go @@ -197,6 +197,15 @@ func TestSetup(t *testing.T) { Settings: &models.NamespaceSettings{ SessionRecord: false, ConnectionAnnouncement: models.DefaultAnnouncementMessage, + AllowPassword: true, + AllowPublicKey: true, + AllowRoot: true, + AllowEmptyPasswords: true, + AllowTTY: true, + AllowTCPForwarding: true, + AllowWebEndpoints: true, + AllowSFTP: true, + AllowAgentForwarding: true, }, CreatedAt: now, } @@ -287,6 +296,15 @@ func TestSetup(t *testing.T) { Settings: &models.NamespaceSettings{ SessionRecord: false, ConnectionAnnouncement: models.DefaultAnnouncementMessage, + AllowPassword: true, + AllowPublicKey: true, + AllowRoot: true, + AllowEmptyPasswords: true, + AllowTTY: true, + AllowTCPForwarding: true, + AllowWebEndpoints: true, + AllowSFTP: true, + AllowAgentForwarding: true, }, CreatedAt: now, } @@ -354,6 +372,15 @@ func TestSetup(t *testing.T) { Settings: &models.NamespaceSettings{ SessionRecord: false, ConnectionAnnouncement: models.DefaultAnnouncementMessage, + AllowPassword: true, + AllowPublicKey: true, + AllowRoot: true, + AllowEmptyPasswords: true, + AllowTTY: true, + AllowTCPForwarding: true, + AllowWebEndpoints: true, + AllowSFTP: true, + AllowAgentForwarding: true, }, CreatedAt: now, } diff --git a/api/store/device.go b/api/store/device.go index d6a92506704..533a7fc0fd1 100644 --- a/api/store/device.go +++ b/api/store/device.go @@ -49,6 +49,8 @@ type DeviceStore interface { // DeviceUpdate updates a device. It returns [ErrNoDocuments] if none device is found. DeviceUpdate(ctx context.Context, device *models.Device) error + // DeviceUpdateSettings updates a device's SSH settings. It returns [ErrNoDocuments] if no device is found. + DeviceUpdateSettings(ctx context.Context, uid string, settings *models.SSHSettings) error // DeviceHeartbeat updates the last_seen timestamp and sets disconnected_at to nil for multiple devices. // It returns the number of modified devices and an error if any. DeviceHeartbeat(ctx context.Context, uids []string, lastSeen time.Time) (modifiedCount int64, err error) diff --git a/api/store/mocks/store.go b/api/store/mocks/store.go index e40b5baec4d..c796bfa2265 100644 --- a/api/store/mocks/store.go +++ b/api/store/mocks/store.go @@ -558,6 +558,24 @@ func (_m *Store) DeviceUpdate(ctx context.Context, device *models.Device) error return r0 } +// DeviceUpdateSettings provides a mock function with given fields: ctx, uid, settings +func (_m *Store) DeviceUpdateSettings(ctx context.Context, uid string, settings *models.SSHSettings) error { + ret := _m.Called(ctx, uid, settings) + + if len(ret) == 0 { + panic("no return value specified for DeviceUpdateSettings") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, *models.SSHSettings) error); ok { + r0 = rf(ctx, uid, settings) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // GetStats provides a mock function with given fields: ctx, tenantID func (_m *Store) GetStats(ctx context.Context, tenantID string) (*models.Stats, error) { ret := _m.Called(ctx, tenantID) diff --git a/api/store/pg/device.go b/api/store/pg/device.go index 756dc10f232..674cf70aca5 100644 --- a/api/store/pg/device.go +++ b/api/store/pg/device.go @@ -18,8 +18,33 @@ func (pg *Pg) DeviceCreate(ctx context.Context, device *models.Device) (string, device.CreatedAt = clock.Now() e := entity.DeviceFromModel(device) - if _, err := db.NewInsert().Model(e).Exec(ctx); err != nil { - return "", fromSQLError(err) + exec := func(db bun.IDB) error { + if _, err := db.NewInsert().Model(e).Exec(ctx); err != nil { + return fromSQLError(err) + } + + if device.SSH != nil { + settings := entity.DeviceSettingsFromModel(device.SSH, device.UID) + if _, err := db.NewInsert().Model(&settings).Exec(ctx); err != nil { + return fromSQLError(err) + } + } + + return nil + } + + if _, ok := db.(bun.Tx); ok { + if err := exec(db); err != nil { + return "", err + } + + return e.ID, nil + } + + if err := pg.WithTransaction(ctx, func(txCtx context.Context) error { + return exec(pg.GetConnection(txCtx)) + }); err != nil { + return "", err } return e.ID, nil @@ -75,6 +100,7 @@ func (pg *Pg) DeviceList(ctx context.Context, acceptable store.DeviceAcceptable, Model(&entities). Column("device.*"). Relation("Namespace"). + Relation("Settings"). Relation("Tags"). ColumnExpr(onlineExpr, onlineThreshold). ColumnExpr(deviceExprAcceptable(acceptable)) @@ -117,6 +143,7 @@ func (pg *Pg) DeviceResolve(ctx context.Context, resolver store.DeviceResolver, Where("? = ?", bun.Ident("device."+column), val). Column("device.*"). Relation("Namespace"). + Relation("Settings"). Relation("Tags"). ColumnExpr(onlineExpr, onlineThreshold) @@ -140,16 +167,64 @@ func (pg *Pg) DeviceUpdate(ctx context.Context, device *models.Device) error { d := entity.DeviceFromModel(device) d.UpdatedAt = clock.Now() - r, err := db.NewUpdate().Model(d).Where("id = ?", d.ID).Where("namespace_id = ?", d.NamespaceID).Exec(ctx) - if err != nil { - return fromSQLError(err) + exec := func(db bun.IDB) error { + r, err := db.NewUpdate().Model(d).Where("id = ?", d.ID).Where("namespace_id = ?", d.NamespaceID).Exec(ctx) + if err != nil { + return fromSQLError(err) + } + + if rowsAffected, err := r.RowsAffected(); err != nil || rowsAffected == 0 { + return store.ErrNoDocuments + } + + return nil } - if rowsAffected, err := r.RowsAffected(); err != nil || rowsAffected == 0 { - return store.ErrNoDocuments + if _, ok := db.(bun.Tx); ok { + return exec(db) } - return nil + return pg.WithTransaction(ctx, func(txCtx context.Context) error { + return exec(pg.GetConnection(txCtx)) + }) +} + +func (pg *Pg) DeviceUpdateSettings(ctx context.Context, uid string, ssh *models.SSHSettings) error { + db := pg.GetConnection(ctx) + + exec := func(db bun.IDB) error { + exists, err := db.NewSelect(). + Model((*entity.Device)(nil)). + Where("id = ?", uid). + Exists(ctx) + if err != nil { + return fromSQLError(err) + } + if !exists { + return store.ErrNoDocuments + } + + settings := entity.DeviceSettingsFromModel(ssh, uid) + settings.UpdatedAt = clock.Now() + + _, err = db.NewInsert(). + On("conflict (device_id) do update set updated_at = excluded.updated_at, allow_password = excluded.allow_password, allow_public_key = excluded.allow_public_key, allow_root = excluded.allow_root, allow_empty_passwords = excluded.allow_empty_passwords, allow_tty = excluded.allow_tty, allow_tcp_forwarding = excluded.allow_tcp_forwarding, allow_web_endpoints = excluded.allow_web_endpoints, allow_sftp = excluded.allow_sftp, allow_agent_forwarding = excluded.allow_agent_forwarding"). + Model(&settings). + Exec(ctx) + if err != nil { + return fromSQLError(err) + } + + return nil + } + + if _, ok := db.(bun.Tx); ok { + return exec(db) + } + + return pg.WithTransaction(ctx, func(txCtx context.Context) error { + return exec(pg.GetConnection(txCtx)) + }) } func (pg *Pg) DeviceSetCustomField(ctx context.Context, uid, key, value string) error { @@ -256,6 +331,13 @@ func (pg *Pg) DeviceDeleteMany(ctx context.Context, uids []string) (int64, error func (pg *Pg) deviceDeleteManyFn(ctx context.Context, uids []string) func(tx bun.Tx) (int64, error) { return func(tx bun.Tx) (int64, error) { + if _, err := tx.NewDelete(). + Model((*entity.DeviceSettings)(nil)). + Where("device_id IN (?)", bun.List(uids)). + Exec(ctx); err != nil { + return 0, fromSQLError(err) + } + r, err := tx.NewDelete().Model((*entity.Device)(nil)).Where("id IN (?)", bun.List(uids)).Exec(ctx) if err != nil { return 0, fromSQLError(err) diff --git a/api/store/pg/entity/device.go b/api/store/pg/entity/device.go index bc8ce92637a..8b126af594d 100644 --- a/api/store/pg/entity/device.go +++ b/api/store/pg/entity/device.go @@ -33,8 +33,9 @@ type Device struct { Latitude float64 `bun:"latitude,type:numeric"` CustomFields map[string]string `bun:"custom_fields,type:jsonb,nullzero,default:'{}'"` - Namespace *Namespace `bun:"rel:belongs-to,join:namespace_id=id"` - Tags []*Tag `bun:"m2m:device_tags,join:Device=Tag"` + Namespace *Namespace `bun:"rel:belongs-to,join:namespace_id=id"` + Tags []*Tag `bun:"m2m:device_tags,join:Device=Tag"` + Settings *DeviceSettings `bun:"rel:has-one,join:id=device_id"` } func DeviceFromModel(model *models.Device) *Device { @@ -152,5 +153,11 @@ func DeviceToModel(entity *Device) *models.Device { } } + if entity.Settings != nil { + device.SSH = DeviceSettingsToModel(entity.Settings) + } else { + device.SSH = models.DefaultSSHSettings() + } + return device } diff --git a/api/store/pg/entity/device_settings.go b/api/store/pg/entity/device_settings.go new file mode 100644 index 00000000000..8602507ae6e --- /dev/null +++ b/api/store/pg/entity/device_settings.go @@ -0,0 +1,77 @@ +package entity + +import ( + "time" + + "github.com/shellhub-io/shellhub/pkg/models" + "github.com/shellhub-io/shellhub/pkg/uuid" + "github.com/uptrace/bun" +) + +type DeviceSettings struct { + bun.BaseModel `bun:"table:device_settings"` + + ID string `bun:"id,pk,type:uuid,nullzero,default:gen_random_uuid()"` + DeviceID string `bun:"device_id,type:varchar,unique"` + AllowPassword bool `bun:"allow_password"` + AllowPublicKey bool `bun:"allow_public_key"` + AllowRoot bool `bun:"allow_root"` + AllowEmptyPasswords bool `bun:"allow_empty_passwords"` + AllowTTY bool `bun:"allow_tty"` + AllowTCPForwarding bool `bun:"allow_tcp_forwarding"` + AllowWebEndpoints bool `bun:"allow_web_endpoints"` + AllowSFTP bool `bun:"allow_sftp"` + AllowAgentForwarding bool `bun:"allow_agent_forwarding"` + CreatedAt time.Time `bun:"created_at"` + UpdatedAt time.Time `bun:"updated_at"` +} + +func DeviceSettingsFromModel(ssh *models.SSHSettings, deviceID string) DeviceSettings { + if ssh == nil { + return DeviceSettings{ + ID: uuid.Generate(), + DeviceID: deviceID, + AllowPassword: true, + AllowPublicKey: true, + AllowRoot: true, + AllowEmptyPasswords: true, + AllowTTY: true, + AllowTCPForwarding: true, + AllowWebEndpoints: true, + AllowSFTP: true, + AllowAgentForwarding: true, + } + } + + return DeviceSettings{ + ID: uuid.Generate(), + DeviceID: deviceID, + AllowPassword: ssh.AllowPassword, + AllowPublicKey: ssh.AllowPublicKey, + AllowRoot: ssh.AllowRoot, + AllowEmptyPasswords: ssh.AllowEmptyPasswords, + AllowTTY: ssh.AllowTTY, + AllowTCPForwarding: ssh.AllowTCPForwarding, + AllowWebEndpoints: ssh.AllowWebEndpoints, + AllowSFTP: ssh.AllowSFTP, + AllowAgentForwarding: ssh.AllowAgentForwarding, + } +} + +func DeviceSettingsToModel(settings *DeviceSettings) *models.SSHSettings { + if settings == nil { + return nil + } + + return &models.SSHSettings{ + AllowPassword: settings.AllowPassword, + AllowPublicKey: settings.AllowPublicKey, + AllowRoot: settings.AllowRoot, + AllowEmptyPasswords: settings.AllowEmptyPasswords, + AllowTTY: settings.AllowTTY, + AllowTCPForwarding: settings.AllowTCPForwarding, + AllowWebEndpoints: settings.AllowWebEndpoints, + AllowSFTP: settings.AllowSFTP, + AllowAgentForwarding: settings.AllowAgentForwarding, + } +} diff --git a/api/store/pg/entity/device_settings_test.go b/api/store/pg/entity/device_settings_test.go new file mode 100644 index 00000000000..1246ea77771 --- /dev/null +++ b/api/store/pg/entity/device_settings_test.go @@ -0,0 +1,35 @@ +package entity + +import "testing" + +func TestDeviceSettingsFromModelNilDefaultsToEnabled(t *testing.T) { + settings := DeviceSettingsFromModel(nil, "device-id") + + if !settings.AllowPassword { + t.Fatal("expected AllowPassword to default to true") + } + if !settings.AllowPublicKey { + t.Fatal("expected AllowPublicKey to default to true") + } + if !settings.AllowRoot { + t.Fatal("expected AllowRoot to default to true") + } + if !settings.AllowEmptyPasswords { + t.Fatal("expected AllowEmptyPasswords to default to true") + } + if !settings.AllowTTY { + t.Fatal("expected AllowTTY to default to true") + } + if !settings.AllowTCPForwarding { + t.Fatal("expected AllowTCPForwarding to default to true") + } + if !settings.AllowWebEndpoints { + t.Fatal("expected AllowWebEndpoints to default to true") + } + if !settings.AllowSFTP { + t.Fatal("expected AllowSFTP to default to true") + } + if !settings.AllowAgentForwarding { + t.Fatal("expected AllowAgentForwarding to default to true") + } +} diff --git a/api/store/pg/entity/device_test.go b/api/store/pg/entity/device_test.go index ae1a15304fb..64d90f7b7be 100644 --- a/api/store/pg/entity/device_test.go +++ b/api/store/pg/entity/device_test.go @@ -281,6 +281,25 @@ func TestDeviceToModel(t *testing.T) { assert.Nil(t, result.TagIDs) }, }, + { + name: "nil settings default to allow all", + entity: &Device{ + ID: "device-uid-7", + Status: "accepted", + }, + check: func(t *testing.T, result *models.Device) { + require.NotNil(t, result.SSH) + assert.True(t, result.SSH.AllowPassword) + assert.True(t, result.SSH.AllowPublicKey) + assert.True(t, result.SSH.AllowRoot) + assert.True(t, result.SSH.AllowEmptyPasswords) + assert.True(t, result.SSH.AllowTTY) + assert.True(t, result.SSH.AllowTCPForwarding) + assert.True(t, result.SSH.AllowWebEndpoints) + assert.True(t, result.SSH.AllowSFTP) + assert.True(t, result.SSH.AllowAgentForwarding) + }, + }, } for _, tt := range tests { diff --git a/api/store/pg/entity/entity.go b/api/store/pg/entity/entity.go index ce625a48ff8..7c2a2b6cd32 100644 --- a/api/store/pg/entity/entity.go +++ b/api/store/pg/entity/entity.go @@ -8,8 +8,10 @@ func Entities() []any { (*APIKey)(nil), (*Device)(nil), + (*DeviceSettings)(nil), (*Membership)(nil), (*Namespace)(nil), + (*NamespaceSettings)(nil), (*PrivateKey)(nil), (*PublicKey)(nil), (*Session)(nil), diff --git a/api/store/pg/entity/namespace.go b/api/store/pg/entity/namespace.go index 2fd4548f2cb..d114f73cf56 100644 --- a/api/store/pg/entity/namespace.go +++ b/api/store/pg/entity/namespace.go @@ -4,31 +4,49 @@ import ( "time" "github.com/shellhub-io/shellhub/pkg/models" + "github.com/shellhub-io/shellhub/pkg/uuid" "github.com/uptrace/bun" ) type Namespace struct { bun.BaseModel `bun:"table:namespaces"` - ID string `bun:"id,pk,type:uuid"` - CreatedAt time.Time `bun:"created_at"` - UpdatedAt time.Time `bun:"updated_at"` - Type string `bun:"scope"` - Name string `bun:"name"` - OwnerID string `bun:"owner_id"` // TODO: Remove this column in the future, owner should be determined by membership role - Memberships []Membership `json:"members" bun:"rel:has-many,join:id=namespace_id"` - Settings NamespaceSettings `bun:"embed:"` - DevicesAcceptedCount int64 `bun:"devices_accepted_count"` - DevicesPendingCount int64 `bun:"devices_pending_count"` - DevicesRejectedCount int64 `bun:"devices_rejected_count"` - DevicesRemovedCount int64 `bun:"devices_removed_count"` + ID string `bun:"id,pk,type:uuid"` + CreatedAt time.Time `bun:"created_at"` + UpdatedAt time.Time `bun:"updated_at"` + Type string `bun:"scope"` + Name string `bun:"name"` + OwnerID string `bun:"owner_id"` + SessionRecord bool `bun:"record_sessions"` + ConnectionAnnouncement string `bun:"connection_announcement,type:text"` + DeviceAutoAccept bool `bun:"device_auto_accept"` + Memberships []Membership `json:"members" bun:"rel:has-many,join:id=namespace_id"` + Settings *NamespaceSettings `bun:"rel:has-one,join:id=namespace_id"` + MaxDevices int `bun:"max_devices"` + DevicesAcceptedCount int64 `bun:"devices_accepted_count"` + DevicesPendingCount int64 `bun:"devices_pending_count"` + DevicesRejectedCount int64 `bun:"devices_rejected_count"` + DevicesRemovedCount int64 `bun:"devices_removed_count"` } type NamespaceSettings struct { - MaxDevices int `bun:"max_devices"` - SessionRecord bool `bun:"record_sessions"` - ConnectionAnnouncement string `bun:"connection_announcement,type:text"` - DeviceAutoAccept bool `bun:"device_auto_accept"` + bun.BaseModel `bun:"table:namespace_settings"` + + ID string `bun:"id,pk,type:uuid,nullzero,default:gen_random_uuid()"` + NamespaceID string `bun:"namespace_id,type:uuid,unique"` + SessionRecord bool `bun:"record_sessions"` + ConnectionAnnouncement string `bun:"connection_announcement,type:text"` + AllowPassword bool `bun:"allow_password"` + AllowPublicKey bool `bun:"allow_public_key"` + AllowRoot bool `bun:"allow_root"` + AllowEmptyPasswords bool `bun:"allow_empty_passwords"` + AllowTTY bool `bun:"allow_tty"` + AllowTCPForwarding bool `bun:"allow_tcp_forwarding"` + AllowWebEndpoints bool `bun:"allow_web_endpoints"` + AllowSFTP bool `bun:"allow_sftp"` + AllowAgentForwarding bool `bun:"allow_agent_forwarding"` + CreatedAt time.Time `bun:"created_at"` + UpdatedAt time.Time `bun:"updated_at"` } func NamespaceFromModel(model *models.Namespace) *Namespace { @@ -39,24 +57,40 @@ func NamespaceFromModel(model *models.Namespace) *Namespace { } namespace := &Namespace{ - ID: model.TenantID, - CreatedAt: model.CreatedAt, - Type: namespaceType, - Name: model.Name, - OwnerID: model.Owner, - DevicesAcceptedCount: model.DevicesAcceptedCount, - DevicesPendingCount: model.DevicesPendingCount, - DevicesRejectedCount: model.DevicesRejectedCount, - DevicesRemovedCount: model.DevicesRemovedCount, - Settings: NamespaceSettings{ - MaxDevices: model.MaxDevices, - }, + ID: model.TenantID, + CreatedAt: model.CreatedAt, + Type: namespaceType, + Name: model.Name, + OwnerID: model.Owner, + SessionRecord: false, + ConnectionAnnouncement: "", + DeviceAutoAccept: false, + MaxDevices: model.MaxDevices, + DevicesAcceptedCount: model.DevicesAcceptedCount, + DevicesPendingCount: model.DevicesPendingCount, + DevicesRejectedCount: model.DevicesRejectedCount, + DevicesRemovedCount: model.DevicesRemovedCount, } if model.Settings != nil { - namespace.Settings.SessionRecord = model.Settings.SessionRecord - namespace.Settings.ConnectionAnnouncement = model.Settings.ConnectionAnnouncement - namespace.Settings.DeviceAutoAccept = model.Settings.DeviceAutoAccept + namespace.SessionRecord = model.Settings.SessionRecord + namespace.ConnectionAnnouncement = model.Settings.ConnectionAnnouncement + namespace.DeviceAutoAccept = model.Settings.DeviceAutoAccept + namespace.Settings = &NamespaceSettings{ + ID: uuid.Generate(), + NamespaceID: model.TenantID, + SessionRecord: model.Settings.SessionRecord, + ConnectionAnnouncement: model.Settings.ConnectionAnnouncement, + AllowPassword: model.Settings.AllowPassword, + AllowPublicKey: model.Settings.AllowPublicKey, + AllowRoot: model.Settings.AllowRoot, + AllowEmptyPasswords: model.Settings.AllowEmptyPasswords, + AllowTTY: model.Settings.AllowTTY, + AllowTCPForwarding: model.Settings.AllowTCPForwarding, + AllowWebEndpoints: model.Settings.AllowWebEndpoints, + AllowSFTP: model.Settings.AllowSFTP, + AllowAgentForwarding: model.Settings.AllowAgentForwarding, + } } namespace.Memberships = make([]Membership, len(model.Members)) @@ -79,16 +113,28 @@ func NamespaceToModel(entity *Namespace) *models.Namespace { Owner: entity.OwnerID, CreatedAt: entity.CreatedAt, Type: models.Type(entity.Type), - MaxDevices: entity.Settings.MaxDevices, + MaxDevices: entity.MaxDevices, DevicesAcceptedCount: entity.DevicesAcceptedCount, DevicesPendingCount: entity.DevicesPendingCount, DevicesRejectedCount: entity.DevicesRejectedCount, DevicesRemovedCount: entity.DevicesRemovedCount, - Settings: &models.NamespaceSettings{ + } + + if entity.Settings != nil { + namespace.Settings = &models.NamespaceSettings{ SessionRecord: entity.Settings.SessionRecord, ConnectionAnnouncement: entity.Settings.ConnectionAnnouncement, - DeviceAutoAccept: entity.Settings.DeviceAutoAccept, - }, + DeviceAutoAccept: entity.DeviceAutoAccept, + AllowPassword: entity.Settings.AllowPassword, + AllowPublicKey: entity.Settings.AllowPublicKey, + AllowRoot: entity.Settings.AllowRoot, + AllowEmptyPasswords: entity.Settings.AllowEmptyPasswords, + AllowTTY: entity.Settings.AllowTTY, + AllowTCPForwarding: entity.Settings.AllowTCPForwarding, + AllowWebEndpoints: entity.Settings.AllowWebEndpoints, + AllowSFTP: entity.Settings.AllowSFTP, + AllowAgentForwarding: entity.Settings.AllowAgentForwarding, + } } namespace.Members = make([]models.Member, len(entity.Memberships)) diff --git a/api/store/pg/entity/namespace_test.go b/api/store/pg/entity/namespace_test.go index 00ee60d3161..3927991322b 100644 --- a/api/store/pg/entity/namespace_test.go +++ b/api/store/pg/entity/namespace_test.go @@ -49,12 +49,12 @@ func TestNamespaceFromModel(t *testing.T) { CreatedAt: now, }, expected: &Namespace{ - ID: "ns-id-1", - Name: "my-namespace", - OwnerID: "owner-id-1", - Type: "team", - Settings: NamespaceSettings{ - MaxDevices: 10, + ID: "ns-id-1", + Name: "my-namespace", + OwnerID: "owner-id-1", + Type: "team", + MaxDevices: 10, + Settings: &NamespaceSettings{ SessionRecord: true, ConnectionAnnouncement: "Welcome!", }, @@ -108,13 +108,12 @@ func TestNamespaceFromModel(t *testing.T) { Members: []models.Member{}, }, expected: &Namespace{ - ID: "ns-id-3", - Name: "no-settings", - OwnerID: "owner-id-3", - Type: "personal", - Settings: NamespaceSettings{ - MaxDevices: 15, - }, + ID: "ns-id-3", + Name: "no-settings", + OwnerID: "owner-id-3", + Type: "personal", + MaxDevices: 15, + Settings: nil, Memberships: []Membership{}, }, }, @@ -144,9 +143,11 @@ func TestNamespaceFromModel(t *testing.T) { assert.Equal(t, tt.expected.Name, result.Name) assert.Equal(t, tt.expected.OwnerID, result.OwnerID) assert.Equal(t, tt.expected.Type, result.Type) - assert.Equal(t, tt.expected.Settings.MaxDevices, result.Settings.MaxDevices) - assert.Equal(t, tt.expected.Settings.SessionRecord, result.Settings.SessionRecord) - assert.Equal(t, tt.expected.Settings.ConnectionAnnouncement, result.Settings.ConnectionAnnouncement) + assert.Equal(t, tt.expected.MaxDevices, result.MaxDevices) + if tt.expected.Settings != nil { + assert.Equal(t, tt.expected.Settings.SessionRecord, result.Settings.SessionRecord) + assert.Equal(t, tt.expected.Settings.ConnectionAnnouncement, result.Settings.ConnectionAnnouncement) + } assert.Equal(t, tt.expected.DevicesAcceptedCount, result.DevicesAcceptedCount) assert.Equal(t, tt.expected.DevicesPendingCount, result.DevicesPendingCount) assert.Equal(t, tt.expected.DevicesRejectedCount, result.DevicesRejectedCount) @@ -174,12 +175,12 @@ func TestNamespaceToModel(t *testing.T) { { name: "full fields", entity: &Namespace{ - ID: "ns-id-1", - Name: "my-namespace", - OwnerID: "owner-id-1", - Type: "team", - Settings: NamespaceSettings{ - MaxDevices: 10, + ID: "ns-id-1", + Name: "my-namespace", + OwnerID: "owner-id-1", + Type: "team", + MaxDevices: 10, + Settings: &NamespaceSettings{ SessionRecord: true, ConnectionAnnouncement: "Hello!", }, @@ -234,7 +235,7 @@ func TestNamespaceToModel(t *testing.T) { Name: "empty-ns", Owner: "owner-id-2", Type: models.TypePersonal, - Settings: &models.NamespaceSettings{}, + Settings: nil, Members: []models.Member{}, }, }, @@ -248,9 +249,13 @@ func TestNamespaceToModel(t *testing.T) { assert.Equal(t, tt.expected.Owner, result.Owner) assert.Equal(t, tt.expected.Type, result.Type) assert.Equal(t, tt.expected.MaxDevices, result.MaxDevices) - require.NotNil(t, result.Settings, "Settings should never be nil") - assert.Equal(t, tt.expected.Settings.SessionRecord, result.Settings.SessionRecord) - assert.Equal(t, tt.expected.Settings.ConnectionAnnouncement, result.Settings.ConnectionAnnouncement) + if tt.expected.Settings == nil { + assert.Nil(t, result.Settings) + } else { + require.NotNil(t, result.Settings, "Settings should not be nil") + assert.Equal(t, tt.expected.Settings.SessionRecord, result.Settings.SessionRecord) + assert.Equal(t, tt.expected.Settings.ConnectionAnnouncement, result.Settings.ConnectionAnnouncement) + } assert.Equal(t, tt.expected.DevicesAcceptedCount, result.DevicesAcceptedCount) assert.Equal(t, tt.expected.DevicesPendingCount, result.DevicesPendingCount) assert.Equal(t, tt.expected.DevicesRejectedCount, result.DevicesRejectedCount) diff --git a/api/store/pg/migrations/005_create_ssh_settings_tables.tx.down.sql b/api/store/pg/migrations/005_create_ssh_settings_tables.tx.down.sql new file mode 100644 index 00000000000..85dc1a95621 --- /dev/null +++ b/api/store/pg/migrations/005_create_ssh_settings_tables.tx.down.sql @@ -0,0 +1,17 @@ +-- Down migration 004: Restore columns and drop new tables + +-- Migrate data back from namespace_settings +UPDATE namespaces SET + record_sessions = ns.record_sessions, + connection_announcement = ns.connection_announcement +FROM namespace_settings ns +WHERE namespaces.id = ns.namespace_id; + +-- Drop tables +DROP TRIGGER IF EXISTS namespace_settings_updated_at ON namespace_settings; +DROP FUNCTION IF EXISTS update_namespace_settings_updated_at(); +DROP TABLE IF EXISTS namespace_settings; + +DROP TRIGGER IF EXISTS device_settings_updated_at ON device_settings; +DROP FUNCTION IF EXISTS update_device_settings_updated_at(); +DROP TABLE IF EXISTS device_settings; diff --git a/api/store/pg/migrations/005_create_ssh_settings_tables.tx.up.sql b/api/store/pg/migrations/005_create_ssh_settings_tables.tx.up.sql new file mode 100644 index 00000000000..fa34df860e1 --- /dev/null +++ b/api/store/pg/migrations/005_create_ssh_settings_tables.tx.up.sql @@ -0,0 +1,82 @@ +-- Migration 004: Create device_settings table +CREATE TABLE IF NOT EXISTS device_settings ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + device_id character varying NOT NULL UNIQUE REFERENCES devices(id) ON DELETE CASCADE, + allow_password BOOLEAN DEFAULT TRUE, + allow_public_key BOOLEAN DEFAULT TRUE, + allow_root BOOLEAN DEFAULT TRUE, + allow_empty_passwords BOOLEAN DEFAULT TRUE, + allow_tty BOOLEAN DEFAULT TRUE, + allow_tcp_forwarding BOOLEAN DEFAULT TRUE, + allow_web_endpoints BOOLEAN DEFAULT TRUE, + allow_sftp BOOLEAN DEFAULT TRUE, + allow_agent_forwarding BOOLEAN DEFAULT TRUE, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_device_settings_device_id ON device_settings(device_id); + +CREATE OR REPLACE FUNCTION update_device_settings_updated_at() +RETURNS TRIGGER AS $$ +BEGIN + NEW.updated_at = NOW(); + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +CREATE TRIGGER device_settings_updated_at + BEFORE UPDATE ON device_settings + FOR EACH ROW + EXECUTE FUNCTION update_device_settings_updated_at(); + +CREATE TABLE IF NOT EXISTS namespace_settings ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + namespace_id UUID NOT NULL UNIQUE REFERENCES namespaces(id) ON DELETE CASCADE, + record_sessions BOOLEAN DEFAULT TRUE, + connection_announcement TEXT DEFAULT '', + allow_password BOOLEAN DEFAULT TRUE, + allow_public_key BOOLEAN DEFAULT TRUE, + allow_root BOOLEAN DEFAULT TRUE, + allow_empty_passwords BOOLEAN DEFAULT TRUE, + allow_tty BOOLEAN DEFAULT TRUE, + allow_tcp_forwarding BOOLEAN DEFAULT TRUE, + allow_web_endpoints BOOLEAN DEFAULT TRUE, + allow_sftp BOOLEAN DEFAULT TRUE, + allow_agent_forwarding BOOLEAN DEFAULT TRUE, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_namespace_settings_namespace_id ON namespace_settings(namespace_id); + +CREATE OR REPLACE FUNCTION update_namespace_settings_updated_at() +RETURNS TRIGGER AS $$ +BEGIN + NEW.updated_at = NOW(); + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +CREATE TRIGGER namespace_settings_updated_at + BEFORE UPDATE ON namespace_settings + FOR EACH ROW + EXECUTE FUNCTION update_namespace_settings_updated_at(); + +-- Migrate data from namespaces table to namespace_settings table +INSERT INTO namespace_settings (namespace_id, record_sessions, connection_announcement, allow_password, allow_public_key, allow_root, allow_empty_passwords, allow_tty, allow_tcp_forwarding, allow_web_endpoints, allow_sftp, allow_agent_forwarding) +SELECT + id, + COALESCE(record_sessions, true), + COALESCE(connection_announcement, ''), + TRUE, + TRUE, + TRUE, + TRUE, + TRUE, + TRUE, + TRUE, + TRUE, + TRUE +FROM namespaces +ON CONFLICT (namespace_id) DO NOTHING; diff --git a/api/store/pg/namespace.go b/api/store/pg/namespace.go index f38a9e620d8..46c28b64a60 100644 --- a/api/store/pg/namespace.go +++ b/api/store/pg/namespace.go @@ -20,14 +20,38 @@ func (pg *Pg) NamespaceCreate(ctx context.Context, namespace *models.Namespace) } nsEntity := entity.NamespaceFromModel(namespace) - if _, err := db.NewInsert().Model(nsEntity).Exec(ctx); err != nil { - return "", fromSQLError(err) + exec := func(db bun.IDB) error { + if _, err := db.NewInsert().Model(nsEntity).Exec(ctx); err != nil { + return fromSQLError(err) + } + + if nsEntity.Settings != nil { + if _, err := db.NewInsert().Model(nsEntity.Settings).Exec(ctx); err != nil { + return fromSQLError(err) + } + } + + if len(nsEntity.Memberships) > 0 { + if _, err := db.NewInsert().Model(&nsEntity.Memberships).Exec(ctx); err != nil { + return fromSQLError(err) + } + } + + return nil } - if len(nsEntity.Memberships) > 0 { - if _, err := db.NewInsert().Model(&nsEntity.Memberships).Exec(ctx); err != nil { - return "", fromSQLError(err) + if _, ok := db.(bun.Tx); ok { + if err := exec(db); err != nil { + return "", err } + + return namespace.TenantID, nil + } + + if err := pg.WithTransaction(ctx, func(txCtx context.Context) error { + return exec(pg.GetConnection(txCtx)) + }); err != nil { + return "", err } return namespace.TenantID, nil @@ -75,7 +99,7 @@ func (pg *Pg) NamespaceList(ctx context.Context, opts ...store.QueryOption) ([]m db := pg.GetConnection(ctx) entities := make([]entity.Namespace, 0) - query := db.NewSelect().Model(&entities) + query := db.NewSelect().Model(&entities).Relation("Settings") var err error query, err = applyOptions(ctx, query, opts...) @@ -114,7 +138,7 @@ func (pg *Pg) NamespaceResolve(ctx context.Context, resolver store.NamespaceReso } ns := new(entity.Namespace) - query := db.NewSelect().Model(ns).Relation("Memberships.User").Where("? = ?", bun.Ident(column), val) + query := db.NewSelect().Model(ns).Relation("Memberships.User").Relation("Settings").Where("? = ?", bun.Ident(column), val) if err := query.Scan(ctx); err != nil { return nil, fromSQLError(err) } @@ -129,6 +153,7 @@ func (pg *Pg) NamespaceGetPreferred(ctx context.Context, userID string) (*models if err := db.NewSelect(). Model(ns). Relation("Memberships.User"). + Relation("Settings"). Join("JOIN users"). JoinOn("namespace.id = users.preferred_namespace_id OR namespace.id IN (SELECT namespace_id FROM memberships WHERE user_id = users.id)"). Where("users.id = ?", userID). @@ -144,28 +169,49 @@ func (pg *Pg) NamespaceGetPreferred(ctx context.Context, userID string) (*models func (pg *Pg) NamespaceUpdate(ctx context.Context, namespace *models.Namespace) error { db := pg.GetConnection(ctx) - // First check if namespace exists - exists, err := db.NewSelect().Model((*entity.Namespace)(nil)).Where("id = ?", namespace.TenantID).Exists(ctx) - if err != nil { - return fromSQLError(err) - } - if !exists { - return store.ErrNoDocuments - } - n := entity.NamespaceFromModel(namespace) n.UpdatedAt = clock.Now() - r, err := db.NewUpdate().Model(n).WherePK().Exec(ctx) - if err != nil { - return fromSQLError(err) + exec := func(db bun.IDB) error { + // First check if namespace exists. + exists, err := db.NewSelect().Model((*entity.Namespace)(nil)).Where("id = ?", namespace.TenantID).Exists(ctx) + if err != nil { + return fromSQLError(err) + } + if !exists { + return store.ErrNoDocuments + } + + r, err := db.NewUpdate().Model(n).WherePK().Exec(ctx) + if err != nil { + return fromSQLError(err) + } + + if rowsAffected, err := r.RowsAffected(); err != nil || rowsAffected == 0 { + return store.ErrNoDocuments + } + + if n.Settings != nil { + n.Settings.UpdatedAt = clock.Now() + _, err = db.NewInsert(). + On("conflict (namespace_id) do update set updated_at = excluded.updated_at, record_sessions = excluded.record_sessions, connection_announcement = excluded.connection_announcement, allow_password = excluded.allow_password, allow_public_key = excluded.allow_public_key, allow_root = excluded.allow_root, allow_empty_passwords = excluded.allow_empty_passwords, allow_tty = excluded.allow_tty, allow_tcp_forwarding = excluded.allow_tcp_forwarding, allow_web_endpoints = excluded.allow_web_endpoints, allow_sftp = excluded.allow_sftp, allow_agent_forwarding = excluded.allow_agent_forwarding"). + Model(n.Settings). + Exec(ctx) + if err != nil { + return fromSQLError(err) + } + } + + return nil } - if rowsAffected, err := r.RowsAffected(); err != nil || rowsAffected == 0 { - return store.ErrNoDocuments + if _, ok := db.(bun.Tx); ok { + return exec(db) } - return nil + return pg.WithTransaction(ctx, func(txCtx context.Context) error { + return exec(pg.GetConnection(txCtx)) + }) } func (pg *Pg) NamespaceIncrementDeviceCount(ctx context.Context, tenantID string, status models.DeviceStatus, count int64) error { @@ -327,9 +373,9 @@ func namespaceExprPreferredOrder() string { func NamespaceResolverToString(resolver store.NamespaceResolver) (string, error) { switch resolver { case store.NamespaceTenantIDResolver: - return "id", nil + return "namespace.id", nil case store.NamespaceNameResolver: - return "name", nil + return "namespace.name", nil default: return "", store.ErrResolverNotFound } diff --git a/api/store/storetest/device_tests.go b/api/store/storetest/device_tests.go index 278331ca55b..594d73242b3 100644 --- a/api/store/storetest/device_tests.go +++ b/api/store/storetest/device_tests.go @@ -644,6 +644,126 @@ func (s *Suite) TestDeviceUpdate(t *testing.T) { assert.True(t, newStatusUpdatedAt.Equal(updated.StatusUpdatedAt), "updated StatusUpdatedAt should match: expected %v, got %v", newStatusUpdatedAt, updated.StatusUpdatedAt) assert.Equal(t, models.DeviceStatusAccepted, updated.Status) }) + + t.Run("does not alter ssh settings", func(t *testing.T) { + require.NoError(t, s.provider.CleanDatabase(t)) + + tenantID := s.CreateNamespace(t) + deviceUID := s.CreateDevice(t, + WithDeviceName("original-name"), + WithTenantID(tenantID), + ) + + err := st.DeviceUpdateSettings(ctx, string(deviceUID), &models.SSHSettings{ + AllowPassword: false, + AllowPublicKey: true, + AllowRoot: true, + AllowEmptyPasswords: true, + AllowTTY: true, + AllowTCPForwarding: true, + AllowWebEndpoints: true, + AllowSFTP: true, + AllowAgentForwarding: true, + }) + require.NoError(t, err) + + err = st.DeviceUpdate(ctx, &models.Device{ + UID: string(deviceUID), + TenantID: tenantID, + Name: "updated-name", + }) + require.NoError(t, err) + + device, err := st.DeviceResolve(ctx, store.DeviceUIDResolver, string(deviceUID)) + require.NoError(t, err) + require.NotNil(t, device.SSH) + assert.Equal(t, "updated-name", device.Name) + assert.False(t, device.SSH.AllowPassword) + assert.True(t, device.SSH.AllowPublicKey) + }) +} + +func (s *Suite) TestDeviceUpdateSettings(t *testing.T) { + ctx := context.Background() + st := s.provider.Store() + + t.Run("fails when device is not found", func(t *testing.T) { + require.NoError(t, s.provider.CleanDatabase(t)) + + err := st.DeviceUpdateSettings(ctx, "nonexistent", models.DefaultSSHSettings()) + assert.ErrorIs(t, err, store.ErrNoDocuments) + }) + + t.Run("creates settings when absent", func(t *testing.T) { + require.NoError(t, s.provider.CleanDatabase(t)) + + tenantID := s.CreateNamespace(t) + deviceUID := s.CreateDevice(t, WithTenantID(tenantID)) + + err := st.DeviceUpdateSettings(ctx, string(deviceUID), &models.SSHSettings{ + AllowPassword: false, + AllowPublicKey: true, + AllowRoot: false, + AllowEmptyPasswords: true, + AllowTTY: false, + AllowTCPForwarding: true, + AllowWebEndpoints: false, + AllowSFTP: true, + AllowAgentForwarding: false, + }) + require.NoError(t, err) + + device, err := st.DeviceResolve(ctx, store.DeviceUIDResolver, string(deviceUID)) + require.NoError(t, err) + require.NotNil(t, device.SSH) + assert.False(t, device.SSH.AllowPassword) + assert.False(t, device.SSH.AllowRoot) + assert.False(t, device.SSH.AllowTTY) + assert.False(t, device.SSH.AllowWebEndpoints) + assert.False(t, device.SSH.AllowAgentForwarding) + }) + + t.Run("updates existing settings in place", func(t *testing.T) { + require.NoError(t, s.provider.CleanDatabase(t)) + + tenantID := s.CreateNamespace(t) + deviceUID := s.CreateDevice(t, WithTenantID(tenantID)) + + err := st.DeviceUpdateSettings(ctx, string(deviceUID), &models.SSHSettings{ + AllowPassword: false, + AllowPublicKey: true, + AllowRoot: true, + AllowEmptyPasswords: true, + AllowTTY: true, + AllowTCPForwarding: true, + AllowWebEndpoints: true, + AllowSFTP: true, + AllowAgentForwarding: true, + }) + require.NoError(t, err) + + err = st.DeviceUpdateSettings(ctx, string(deviceUID), &models.SSHSettings{ + AllowPassword: true, + AllowPublicKey: false, + AllowRoot: true, + AllowEmptyPasswords: false, + AllowTTY: true, + AllowTCPForwarding: false, + AllowWebEndpoints: true, + AllowSFTP: false, + AllowAgentForwarding: true, + }) + require.NoError(t, err) + + device, err := st.DeviceResolve(ctx, store.DeviceUIDResolver, string(deviceUID)) + require.NoError(t, err) + require.NotNil(t, device.SSH) + assert.True(t, device.SSH.AllowPassword) + assert.False(t, device.SSH.AllowPublicKey) + assert.False(t, device.SSH.AllowEmptyPasswords) + assert.False(t, device.SSH.AllowTCPForwarding) + assert.False(t, device.SSH.AllowSFTP) + }) } // TestDeviceHeartbeat tests device heartbeat updates diff --git a/cli/services/namespaces.go b/cli/services/namespaces.go index 6ac492892e5..ec47c185468 100644 --- a/cli/services/namespaces.go +++ b/cli/services/namespaces.go @@ -56,6 +56,15 @@ func (s *service) NamespaceCreate(ctx context.Context, input *inputs.NamespaceCr Settings: &models.NamespaceSettings{ SessionRecord: true, ConnectionAnnouncement: models.DefaultAnnouncementMessage, + AllowPassword: true, + AllowPublicKey: true, + AllowRoot: true, + AllowEmptyPasswords: true, + AllowTTY: true, + AllowTCPForwarding: true, + AllowWebEndpoints: true, + AllowSFTP: true, + AllowAgentForwarding: true, }, CreatedAt: clock.Now(), Type: models.NewDefaultType(), diff --git a/cli/services/namespaces_test.go b/cli/services/namespaces_test.go index c4c25304e57..617fe4a6c6c 100644 --- a/cli/services/namespaces_test.go +++ b/cli/services/namespaces_test.go @@ -18,6 +18,22 @@ import ( "github.com/stretchr/testify/assert" ) +func defaultNamespaceSettings(announcement string) *models.NamespaceSettings { + return &models.NamespaceSettings{ + SessionRecord: true, + ConnectionAnnouncement: announcement, + AllowPassword: true, + AllowPublicKey: true, + AllowRoot: true, + AllowEmptyPasswords: true, + AllowTTY: true, + AllowTCPForwarding: true, + AllowWebEndpoints: true, + AllowSFTP: true, + AllowAgentForwarding: true, + } +} + func TestNamespaceCreate(t *testing.T) { type Expected struct { namespace *models.Namespace @@ -86,10 +102,7 @@ func TestNamespaceCreate(t *testing.T) { AddedAt: now, }, }, - Settings: &models.NamespaceSettings{ - SessionRecord: true, - ConnectionAnnouncement: models.DefaultAnnouncementMessage, - }, + Settings: defaultNamespaceSettings(models.DefaultAnnouncementMessage), MaxDevices: MaxNumberDevicesUnlimited, CreatedAt: now, } @@ -129,10 +142,7 @@ func TestNamespaceCreate(t *testing.T) { AddedAt: now, }, }, - Settings: &models.NamespaceSettings{ - SessionRecord: true, - ConnectionAnnouncement: models.DefaultAnnouncementMessage, - }, + Settings: defaultNamespaceSettings(models.DefaultAnnouncementMessage), MaxDevices: MaxNumberDevicesUnlimited, CreatedAt: now, } @@ -150,10 +160,7 @@ func TestNamespaceCreate(t *testing.T) { AddedAt: now, }, }, - Settings: &models.NamespaceSettings{ - SessionRecord: true, - ConnectionAnnouncement: models.DefaultAnnouncementMessage, - }, + Settings: defaultNamespaceSettings(models.DefaultAnnouncementMessage), MaxDevices: MaxNumberDevicesUnlimited, CreatedAt: now, }, nil}, @@ -190,10 +197,7 @@ func TestNamespaceCreate(t *testing.T) { AddedAt: now, }, }, - Settings: &models.NamespaceSettings{ - SessionRecord: true, - ConnectionAnnouncement: models.DefaultAnnouncementMessage, - }, + Settings: defaultNamespaceSettings(models.DefaultAnnouncementMessage), MaxDevices: MaxNumberDevicesLimited, CreatedAt: now, } @@ -211,10 +215,7 @@ func TestNamespaceCreate(t *testing.T) { AddedAt: now, }, }, - Settings: &models.NamespaceSettings{ - SessionRecord: true, - ConnectionAnnouncement: models.DefaultAnnouncementMessage, - }, + Settings: defaultNamespaceSettings(models.DefaultAnnouncementMessage), MaxDevices: MaxNumberDevicesLimited, CreatedAt: now, }, nil}, @@ -251,10 +252,7 @@ func TestNamespaceCreate(t *testing.T) { AddedAt: now, }, }, - Settings: &models.NamespaceSettings{ - SessionRecord: true, - ConnectionAnnouncement: models.DefaultAnnouncementMessage, - }, + Settings: defaultNamespaceSettings(models.DefaultAnnouncementMessage), MaxDevices: MaxNumberDevicesLimited, CreatedAt: now, } @@ -272,10 +270,7 @@ func TestNamespaceCreate(t *testing.T) { AddedAt: now, }, }, - Settings: &models.NamespaceSettings{ - SessionRecord: true, - ConnectionAnnouncement: models.DefaultAnnouncementMessage, - }, + Settings: defaultNamespaceSettings(models.DefaultAnnouncementMessage), MaxDevices: MaxNumberDevicesLimited, CreatedAt: now, }, nil}, @@ -312,10 +307,7 @@ func TestNamespaceCreate(t *testing.T) { AddedAt: now, }, }, - Settings: &models.NamespaceSettings{ - SessionRecord: true, - ConnectionAnnouncement: models.DefaultAnnouncementMessage, - }, + Settings: defaultNamespaceSettings(models.DefaultAnnouncementMessage), MaxDevices: MaxNumberDevicesUnlimited, CreatedAt: now, } @@ -333,10 +325,7 @@ func TestNamespaceCreate(t *testing.T) { AddedAt: now, }, }, - Settings: &models.NamespaceSettings{ - SessionRecord: true, - ConnectionAnnouncement: models.DefaultAnnouncementMessage, - }, + Settings: defaultNamespaceSettings(models.DefaultAnnouncementMessage), MaxDevices: MaxNumberDevicesUnlimited, CreatedAt: now, }, nil}, @@ -373,10 +362,7 @@ func TestNamespaceCreate(t *testing.T) { AddedAt: now, }, }, - Settings: &models.NamespaceSettings{ - SessionRecord: true, - ConnectionAnnouncement: models.DefaultAnnouncementMessage, - }, + Settings: defaultNamespaceSettings(models.DefaultAnnouncementMessage), MaxDevices: MaxNumberDevicesUnlimited, CreatedAt: now, } @@ -394,10 +380,7 @@ func TestNamespaceCreate(t *testing.T) { AddedAt: now, }, }, - Settings: &models.NamespaceSettings{ - SessionRecord: true, - ConnectionAnnouncement: models.DefaultAnnouncementMessage, - }, + Settings: defaultNamespaceSettings(models.DefaultAnnouncementMessage), MaxDevices: MaxNumberDevicesUnlimited, CreatedAt: now, }, nil}, diff --git a/openapi/spec/components/schemas/announcementShort.yaml b/openapi/spec/components/schemas/announcementShort.yaml index 7af72e1a18d..24e2c5e84ce 100644 --- a/openapi/spec/components/schemas/announcementShort.yaml +++ b/openapi/spec/components/schemas/announcementShort.yaml @@ -12,5 +12,6 @@ required: - title - date example: + uuid: 3dd0d1f8-8246-4519-b11a-a3dd33717f65 title: Lorem ipsum dolor sit amet, consectetur adipiscing elit. date: 2017-07-21T17:32:28Z diff --git a/openapi/spec/components/schemas/device.yaml b/openapi/spec/components/schemas/device.yaml index f59151d4fdc..20ea729477c 100644 --- a/openapi/spec/components/schemas/device.yaml +++ b/openapi/spec/components/schemas/device.yaml @@ -68,6 +68,46 @@ properties: example: env: production owner: team-a + settings: + description: Device's SSH configuration settings + type: object + properties: + allow_password: + description: Allow password authentication + type: boolean + example: true + allow_public_key: + description: Allow public key authentication + type: boolean + example: true + allow_root: + description: Allow root user login + type: boolean + example: true + allow_empty_passwords: + description: Allow empty passwords + type: boolean + example: false + allow_tty: + description: Allow TTY allocation + type: boolean + example: true + allow_tcp_forwarding: + description: Allow TCP port forwarding + type: boolean + example: true + allow_web_endpoints: + description: Allow web endpoints access via HTTP proxy + type: boolean + example: true + allow_sftp: + description: Allow SFTP subsystem + type: boolean + example: true + allow_agent_forwarding: + description: Allow SSH agent forwarding + type: boolean + example: false public_url: $ref: devicePublicURL.yaml acceptable: diff --git a/openapi/spec/components/schemas/deviceSSH.yaml b/openapi/spec/components/schemas/deviceSSH.yaml new file mode 100644 index 00000000000..fd9ffeeee0e --- /dev/null +++ b/openapi/spec/components/schemas/deviceSSH.yaml @@ -0,0 +1,39 @@ +type: object +description: Device's SSH configuration settings +properties: + allow_password: + description: Allow password authentication + type: boolean + example: true + allow_public_key: + description: Allow public key authentication + type: boolean + example: true + allow_root: + description: Allow root user login + type: boolean + example: true + allow_empty_passwords: + description: Allow empty passwords + type: boolean + example: false + allow_tty: + description: Allow TTY allocation + type: boolean + example: true + allow_tcp_forwarding: + description: Allow TCP port forwarding + type: boolean + example: true + allow_web_endpoints: + description: Allow web endpoints access via HTTP proxy + type: boolean + example: true + allow_sftp: + description: Allow SFTP subsystem + type: boolean + example: true + allow_agent_forwarding: + description: Allow SSH agent forwarding + type: boolean + example: false \ No newline at end of file diff --git a/openapi/spec/components/schemas/namespaceSettings.yaml b/openapi/spec/components/schemas/namespaceSettings.yaml index b1b23bfcd6e..e68d7b06955 100644 --- a/openapi/spec/components/schemas/namespaceSettings.yaml +++ b/openapi/spec/components/schemas/namespaceSettings.yaml @@ -16,6 +16,42 @@ properties: description: When enabled, new devices connecting to the namespace are automatically accepted instead of going to a pending state. type: boolean example: false + allow_password: + description: Allow password authentication at namespace level + type: boolean + example: true + allow_public_key: + description: Allow public key authentication at namespace level + type: boolean + example: true + allow_root: + description: Allow root user login at namespace level + type: boolean + example: true + allow_empty_passwords: + description: Allow empty passwords at namespace level + type: boolean + example: false + allow_tty: + description: Allow TTY allocation at namespace level + type: boolean + example: true + allow_tcp_forwarding: + description: Allow TCP port forwarding at namespace level + type: boolean + example: true + allow_web_endpoints: + description: Allow web endpoints access via HTTP proxy at namespace level + type: boolean + example: true + allow_sftp: + description: Allow SFTP subsystem at namespace level + type: boolean + example: true + allow_agent_forwarding: + description: Allow SSH agent forwarding at namespace level + type: boolean + example: false required: - session_record - connection_announcement diff --git a/openapi/spec/openapi.yaml b/openapi/spec/openapi.yaml index 8d881cb4d8a..39c389a4f2c 100644 --- a/openapi/spec/openapi.yaml +++ b/openapi/spec/openapi.yaml @@ -118,6 +118,8 @@ paths: $ref: paths/api@devices.yaml /api/devices/{uid}: $ref: paths/api@devices@{uid}.yaml + /api/devices/{uid}/settings: + $ref: paths/api@devices@{uid}@settings.yaml /api/devices/resolve: $ref: paths/api@devices@resolve.yaml /api/devices/{uid}/{status}: diff --git a/openapi/spec/paths/api@devices@{uid}@settings.yaml b/openapi/spec/paths/api@devices@{uid}@settings.yaml new file mode 100644 index 00000000000..acca9756906 --- /dev/null +++ b/openapi/spec/paths/api@devices@{uid}@settings.yaml @@ -0,0 +1,53 @@ +parameters: + - $ref: ../components/parameters/path/deviceUIDPath.yaml +get: + operationId: getDeviceSettings + summary: Get device settings + description: Get a device's SSH settings. + tags: + - community + - devices + security: + - jwt: [] + - api-key: [] + responses: + '200': + description: Success to get device settings. + content: + application/json: + schema: + $ref: ../components/schemas/deviceSSH.yaml + '401': + $ref: ../components/responses/401.yaml + '404': + $ref: ../components/responses/404.yaml + '500': + $ref: ../components/responses/500.yaml +patch: + operationId: updateDeviceSettings + summary: Update device settings + description: Partially update a device's SSH settings. + tags: + - community + - devices + security: + - jwt: [] + - api-key: [] + requestBody: + content: + application/json: + schema: + $ref: ../components/schemas/deviceSSH.yaml + responses: + '200': + description: Success to update device settings. + '400': + $ref: ../components/responses/400.yaml + '401': + $ref: ../components/responses/401.yaml + '403': + $ref: ../components/responses/403.yaml + '404': + $ref: ../components/responses/404.yaml + '500': + $ref: ../components/responses/500.yaml diff --git a/pkg/api/requests/device.go b/pkg/api/requests/device.go index e7db3f13cf6..5933a0ee1cd 100644 --- a/pkg/api/requests/device.go +++ b/pkg/api/requests/device.go @@ -14,9 +14,33 @@ type DeviceList struct { } type DeviceUpdate struct { + TenantID string `header:"X-Tenant-ID"` + UID string `param:"uid" validate:"required"` + Name string `json:"name" validate:"device_name,omitempty"` + CustomFields *map[string]string `json:"custom_fields" validate:"omitempty,max=20,dive,keys,min=1,max=64,endkeys,max=256"` +} + +type SSHSettingsUpdate struct { + AllowPassword *bool `json:"allow_password" validate:"omitempty"` + AllowPublicKey *bool `json:"allow_public_key" validate:"omitempty"` + AllowRoot *bool `json:"allow_root" validate:"omitempty"` + AllowEmptyPasswords *bool `json:"allow_empty_passwords" validate:"omitempty"` + AllowTTY *bool `json:"allow_tty" validate:"omitempty"` + AllowTCPForwarding *bool `json:"allow_tcp_forwarding" validate:"omitempty"` + AllowWebEndpoints *bool `json:"allow_web_endpoints" validate:"omitempty"` + AllowSFTP *bool `json:"allow_sftp" validate:"omitempty"` + AllowAgentForwarding *bool `json:"allow_agent_forwarding" validate:"omitempty"` +} + +type DeviceGetSettings struct { TenantID string `header:"X-Tenant-ID"` - UID string `param:"uid" validate:"required"` - Name string `json:"name" validate:"omitempty,device_name"` + DeviceParam +} + +type DeviceUpdateSettings struct { + TenantID string `header:"X-Tenant-ID"` + DeviceParam + SSHSettingsUpdate } type DeviceSetCustomField struct { diff --git a/pkg/api/requests/namespace.go b/pkg/api/requests/namespace.go index b3eff819434..1af9849ea8b 100644 --- a/pkg/api/requests/namespace.go +++ b/pkg/api/requests/namespace.go @@ -55,6 +55,15 @@ type NamespaceEdit struct { Settings struct { SessionRecord *bool `json:"session_record" validate:"omitempty"` ConnectionAnnouncement *string `json:"connection_announcement" validate:"omitempty,min=0,max=4096"` + AllowPassword *bool `json:"allow_password" validate:"omitempty"` + AllowPublicKey *bool `json:"allow_public_key" validate:"omitempty"` + AllowRoot *bool `json:"allow_root" validate:"omitempty"` + AllowEmptyPasswords *bool `json:"allow_empty_passwords" validate:"omitempty"` + AllowTTY *bool `json:"allow_tty" validate:"omitempty"` + AllowTCPForwarding *bool `json:"allow_tcp_forwarding" validate:"omitempty"` + AllowWebEndpoints *bool `json:"allow_web_endpoints" validate:"omitempty"` + AllowSFTP *bool `json:"allow_sftp" validate:"omitempty"` + AllowAgentForwarding *bool `json:"allow_agent_forwarding" validate:"omitempty"` } `json:"settings"` } diff --git a/pkg/models/device.go b/pkg/models/device.go index f0f021089ac..9e9ade111da 100644 --- a/pkg/models/device.go +++ b/pkg/models/device.go @@ -52,6 +52,7 @@ type Device struct { CustomFields map[string]string `json:"custom_fields,omitempty" bson:"custom_fields"` Taggable `json:",inline" bson:",inline"` + SSH *SSHSettings `json:"settings" bson:"ssh,omitempty"` } type DeviceAuthRequest struct { @@ -101,6 +102,32 @@ type DeviceTag struct { Tag string `validate:"required,min=3,max=255,alphanum,ascii,excludes=/@&:"` } +type SSHSettings struct { + AllowPassword bool `json:"allow_password" bson:"allow_password"` + AllowPublicKey bool `json:"allow_public_key" bson:"allow_public_key"` + AllowRoot bool `json:"allow_root" bson:"allow_root"` + AllowEmptyPasswords bool `json:"allow_empty_passwords" bson:"allow_empty_passwords"` + AllowTTY bool `json:"allow_tty" bson:"allow_tty"` + AllowTCPForwarding bool `json:"allow_tcp_forwarding" bson:"allow_tcp_forwarding"` + AllowWebEndpoints bool `json:"allow_web_endpoints" bson:"allow_web_endpoints"` + AllowSFTP bool `json:"allow_sftp" bson:"allow_sftp"` + AllowAgentForwarding bool `json:"allow_agent_forwarding" bson:"allow_agent_forwarding"` +} + +func DefaultSSHSettings() *SSHSettings { + return &SSHSettings{ + AllowPassword: true, + AllowPublicKey: true, + AllowRoot: true, + AllowEmptyPasswords: true, + AllowTTY: true, + AllowTCPForwarding: true, + AllowWebEndpoints: true, + AllowSFTP: true, + AllowAgentForwarding: true, + } +} + func NewDeviceTag(tag string) DeviceTag { return DeviceTag{ Tag: tag, diff --git a/pkg/models/device_test.go b/pkg/models/device_test.go new file mode 100644 index 00000000000..e9c0bcd1f75 --- /dev/null +++ b/pkg/models/device_test.go @@ -0,0 +1,35 @@ +package models + +import "testing" + +func TestDefaultSSHSettings(t *testing.T) { + settings := DefaultSSHSettings() + + if !settings.AllowPassword { + t.Fatal("expected AllowPassword to default to true") + } + if !settings.AllowPublicKey { + t.Fatal("expected AllowPublicKey to default to true") + } + if !settings.AllowRoot { + t.Fatal("expected AllowRoot to default to true") + } + if !settings.AllowEmptyPasswords { + t.Fatal("expected AllowEmptyPasswords to default to true") + } + if !settings.AllowTTY { + t.Fatal("expected AllowTTY to default to true") + } + if !settings.AllowTCPForwarding { + t.Fatal("expected AllowTCPForwarding to default to true") + } + if !settings.AllowWebEndpoints { + t.Fatal("expected AllowWebEndpoints to default to true") + } + if !settings.AllowSFTP { + t.Fatal("expected AllowSFTP to default to true") + } + if !settings.AllowAgentForwarding { + t.Fatal("expected AllowAgentForwarding to default to true") + } +} diff --git a/pkg/models/namespace.go b/pkg/models/namespace.go index 6dbf072598f..73ec422f669 100644 --- a/pkg/models/namespace.go +++ b/pkg/models/namespace.go @@ -52,6 +52,15 @@ type NamespaceSettings struct { SessionRecord bool `json:"session_record" bson:"session_record,omitempty"` ConnectionAnnouncement string `json:"connection_announcement" bson:"connection_announcement"` DeviceAutoAccept bool `json:"device_auto_accept" bson:"device_auto_accept"` + AllowPassword bool `json:"allow_password" bson:"allow_password"` + AllowPublicKey bool `json:"allow_public_key" bson:"allow_public_key"` + AllowRoot bool `json:"allow_root" bson:"allow_root"` + AllowEmptyPasswords bool `json:"allow_empty_passwords" bson:"allow_empty_passwords"` + AllowTTY bool `json:"allow_tty" bson:"allow_tty"` + AllowTCPForwarding bool `json:"allow_tcp_forwarding" bson:"allow_tcp_forwarding"` + AllowWebEndpoints bool `json:"allow_web_endpoints" bson:"allow_web_endpoints"` + AllowSFTP bool `json:"allow_sftp" bson:"allow_sftp"` + AllowAgentForwarding bool `json:"allow_agent_forwarding" bson:"allow_agent_forwarding"` } // default Announcement Message for the shellhub namespace diff --git a/ssh/http/handlers.go b/ssh/http/handlers.go index 138a0ce1943..fa2cc53710e 100644 --- a/ssh/http/handlers.go +++ b/ssh/http/handlers.go @@ -106,6 +106,38 @@ func (h *Handlers) HandleHTTPProxy(c echo.Context) error { return c.JSON(http.StatusForbidden, NewMessageFromError(ErrWebEndpointForbidden)) } + // Check if device allows web endpoints + device, err := h.Client.GetDevice(c.Request().Context(), endpoint.DeviceUID) + if err != nil { + log.WithError(err).Error("failed to get device") + + return c.JSON(http.StatusForbidden, NewMessageFromError(ErrWebEndpointForbidden)) + } + + // Check namespace setting first + namespace, err := h.Client.NamespaceLookup(c.Request().Context(), endpoint.Namespace) + if err != nil { + log.WithError(err).Error("failed to get namespace") + + return c.JSON(http.StatusForbidden, NewMessageFromError(ErrWebEndpointForbidden)) + } + if namespace.Settings != nil && !namespace.Settings.AllowWebEndpoints { + log.WithFields(log.Fields{ + "namespace": endpoint.Namespace, + }).Warn("web endpoints disabled for namespace") + + return c.JSON(http.StatusForbidden, NewMessageFromError(ErrWebEndpointForbidden)) + } + + // Check device's SSH settings for AllowWebEndpoints (default: true if not set) + if device.SSH != nil && !device.SSH.AllowWebEndpoints { + log.WithFields(log.Fields{ + "device": endpoint.DeviceUID, + }).Warn("web endpoints disabled for device") + + return c.JSON(http.StatusForbidden, NewMessageFromError(ErrWebEndpointForbidden)) + } + logger := log.WithFields(log.Fields{ "request-id": requestID, "namespace": endpoint.Namespace, diff --git a/ssh/server/auth/password.go b/ssh/server/auth/password.go index 7cea216af98..d8a00b67cc0 100644 --- a/ssh/server/auth/password.go +++ b/ssh/server/auth/password.go @@ -1,6 +1,7 @@ package auth import ( + "errors" "net" gliderssh "github.com/gliderlabs/ssh" @@ -31,6 +32,12 @@ func PasswordHandler(ctx gliderssh.Context, passwd string) bool { } if err := sess.Auth(ctx, session.AuthPassword(passwd)); err != nil { + if errors.Is(err, session.ErrPasswordDisabled) { + logger.Warn("password authentication is disabled for this namespace") + + return false + } + logger.Warn("failed to authenticate on device using password") return false diff --git a/ssh/server/auth/publickey.go b/ssh/server/auth/publickey.go index 7d9fc28a6dc..c8d5e16a39f 100644 --- a/ssh/server/auth/publickey.go +++ b/ssh/server/auth/publickey.go @@ -1,6 +1,7 @@ package auth import ( + "errors" "net" gliderssh "github.com/gliderlabs/ssh" @@ -33,6 +34,12 @@ func PublicKeyHandler(ctx gliderssh.Context, publicKey gliderssh.PublicKey) bool } if err := sess.Auth(ctx, session.AuthPublicKey(publicKey)); err != nil { + if errors.Is(err, session.ErrPublicKeyDisabled) { + logger.Warn("public key authentication is disabled for this namespace") + + return false + } + logger.Warn("failed to authenticate on device using public key") return false diff --git a/ssh/server/channels/session.go b/ssh/server/channels/session.go index 80e9370fae9..846e4444faa 100644 --- a/ssh/server/channels/session.go +++ b/ssh/server/channels/session.go @@ -268,6 +268,18 @@ func DefaultSessionHandler() gliderssh.ChannelHandler { return } + denyRequest := func(msg string) bool { + logger.Warn(msg) + + if req.WantReply { + if err := req.Reply(false, nil); err != nil { + logger.WithError(err).Error("failed to deny request from client") + } + } + + return true + } + switch req.Type { case ShellRequestType: if seat, ok := sess.Seats.Get(seat); ok && seat.HasPty { @@ -278,10 +290,49 @@ func DefaultSessionHandler() gliderssh.ChannelHandler { sess.Event(req.Type, req.Payload, seat) case ExecRequestType, SubsystemRequestType: + if req.Type == SubsystemRequestType { + var subsystem struct { + Subsystem string `ssh:"subsystem"` + } + if err := gossh.Unmarshal(req.Payload, &subsystem); err != nil { + reject(nil, "failed to decode subsystem request") + + return + } + if subsystem.Subsystem == "sftp" { + // Check namespace setting first + if sess.Namespace.Settings != nil && !sess.Namespace.Settings.AllowSFTP { + if denyRequest("SFTP is disabled for this namespace") { + continue + } + } + // Check device override + if sess.Device.SSH != nil && !sess.Device.SSH.AllowSFTP { + if denyRequest("SFTP is disabled for this device") { + continue + } + } + } + } + session.Event[models.SSHCommand](sess, req.Type, req.Payload, seat) sess.Type = ExecRequestType case PtyRequestType: + // Check namespace setting first + if sess.Namespace.Settings != nil && !sess.Namespace.Settings.AllowTTY { + if denyRequest("TTY allocation is disabled for this namespace") { + continue + } + } + + // Check device override + if sess.Device.SSH != nil && !sess.Device.SSH.AllowTTY { + if denyRequest("TTY allocation is disabled for this device") { + continue + } + } + var pty models.SSHPty if err := gossh.Unmarshal(req.Payload, &pty); err != nil { @@ -300,6 +351,20 @@ func DefaultSessionHandler() gliderssh.ChannelHandler { sess.Event(req.Type, dimensions, seat) //nolint:errcheck case AuthRequestOpenSSHRequest: + // Check namespace setting first + if sess.Namespace.Settings != nil && !sess.Namespace.Settings.AllowAgentForwarding { + if denyRequest("Agent forwarding is disabled for this namespace") { + continue + } + } + + // Check device override + if sess.Device.SSH != nil && !sess.Device.SSH.AllowAgentForwarding { + if denyRequest("Agent forwarding is disabled for this device") { + continue + } + } + gliderssh.SetAgentRequested(ctx) sess.Event(req.Type, req.Payload, seat) diff --git a/ssh/server/server.go b/ssh/server/server.go index 1eaa8d6e660..e5daa3637b6 100644 --- a/ssh/server/server.go +++ b/ssh/server/server.go @@ -110,7 +110,20 @@ func NewServer(dialer *dialer.Dialer, cache cache.Cache, opts *Options) *Server channels.SessionChannel: channels.DefaultSessionHandler(), channels.DirectTCPIPChannel: channels.DefaultDirectTCPIPHandler, }, - LocalPortForwardingCallback: func(_ gliderssh.Context, _ string, _ uint32) bool { + LocalPortForwardingCallback: func(ctx gliderssh.Context, _ string, _ uint32) bool { + sess, _ := session.ObtainSession(ctx) + if sess == nil || sess.Device == nil { + return false + } + + if sess.Namespace.Settings != nil && !sess.Namespace.Settings.AllowTCPForwarding { + return false + } + + if sess.Device.SSH != nil && !sess.Device.SSH.AllowTCPForwarding { + return false + } + return true }, ReversePortForwardingCallback: func(_ gliderssh.Context, _ string, _ uint32) bool { diff --git a/ssh/session/auther.go b/ssh/session/auther.go index 0bf9b3b5c7d..a6316ec0ebb 100644 --- a/ssh/session/auther.go +++ b/ssh/session/auther.go @@ -76,11 +76,29 @@ func (*publicKeyAuth) Auth() authFunc { } func (p *publicKeyAuth) Evaluate(session *Session) error { + if session.Namespace.Settings != nil && !session.Namespace.Settings.AllowPublicKey { + return ErrPublicKeyDisabled + } + + if session.Device != nil && session.Device.SSH != nil && !session.Device.SSH.AllowPublicKey { + return ErrPublicKeyDisabled + } + + if session.Namespace.Settings != nil && !session.Namespace.Settings.AllowRoot { + if session.Target.Username == "root" { + return ErrRootDisabled + } + } + + if session.Device != nil && session.Device.SSH != nil && !session.Device.SSH.AllowRoot && session.Target.Username == "root" { + return ErrRootDisabled + } + // Versions earlier than 0.6.0 do not validate the user when receiving a public key // authentication request. This implies that requests with invalid users are // treated as "authenticated" because the connection does not raise any error. // Moreover, the agent panics after the connection ends. To avoid this, connections - // with public key are not permitted when agent version is 0.5.x or earlier + // with public key are not permitted when agent version is 0.5.x or earlier. if !sshconf.AllowPublickeyAccessBelow060 { version := session.Device.Info.Version if version != "latest" { @@ -137,7 +155,32 @@ func (p *passwordAuth) Auth() authFunc { } } -func (*passwordAuth) Evaluate(*Session) error { - // We don't need (yet) to do any evaluation when authenticating with password. +func (p *passwordAuth) Evaluate(session *Session) error { + if session.Namespace.Settings != nil && !session.Namespace.Settings.AllowPassword { + return ErrPasswordDisabled + } + + if session.Device != nil && session.Device.SSH != nil && !session.Device.SSH.AllowPassword { + return ErrPasswordDisabled + } + + if session.Namespace.Settings != nil && !session.Namespace.Settings.AllowRoot { + if session.Target.Username == "root" { + return ErrRootDisabled + } + } + + if session.Device != nil && session.Device.SSH != nil && !session.Device.SSH.AllowRoot && session.Target.Username == "root" { + return ErrRootDisabled + } + + if session.Namespace.Settings != nil && !session.Namespace.Settings.AllowEmptyPasswords && p.pwd == "" { + return ErrEmptyPasswordNotPermitted + } + + if session.Device != nil && session.Device.SSH != nil && !session.Device.SSH.AllowEmptyPasswords && p.pwd == "" { + return ErrEmptyPasswordNotPermitted + } + return nil } diff --git a/ssh/session/auther_test.go b/ssh/session/auther_test.go new file mode 100644 index 00000000000..ad435b4b98b --- /dev/null +++ b/ssh/session/auther_test.go @@ -0,0 +1,114 @@ +package session + +import ( + "crypto/rand" + "crypto/rsa" + "testing" + + "github.com/shellhub-io/shellhub/pkg/api/internalclient/mocks" + "github.com/shellhub-io/shellhub/pkg/models" + "github.com/shellhub-io/shellhub/ssh/pkg/target" + "github.com/stretchr/testify/assert" + testifymock "github.com/stretchr/testify/mock" + gossh "golang.org/x/crypto/ssh" +) + +func TestPasswordAuthEvaluate(t *testing.T) { + cases := []struct { + name string + allowPassword bool + expectedError error + }{ + { + name: "password auth enabled", + allowPassword: true, + expectedError: nil, + }, + { + name: "password auth disabled", + allowPassword: false, + expectedError: ErrPasswordDisabled, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + sess := &Session{ + Data: Data{ + Target: &target.Target{Username: "user"}, + Namespace: &models.Namespace{ + Settings: &models.NamespaceSettings{ + AllowPassword: tc.allowPassword, + }, + }, + }, + } + + auth := AuthPassword("password") + err := auth.Evaluate(sess) + + assert.Equal(t, tc.expectedError, err) + }) + } +} + +func TestPublicKeyAuthEvaluate(t *testing.T) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + + publicKey, err := gossh.NewPublicKey(&privateKey.PublicKey) + assert.NoError(t, err) + + fingerprint := gossh.FingerprintLegacyMD5(publicKey) + + cases := []struct { + name string + allowPublicKey bool + mockSetup func(*mocks.Client) + expectedError error + }{ + { + name: "public key auth enabled", + allowPublicKey: true, + mockSetup: func(m *mocks.Client) { + m.On("GetPublicKey", testifymock.Anything, fingerprint, "tenant-1").Return(nil, nil) + m.On("EvaluateKey", testifymock.Anything, fingerprint, testifymock.Anything, testifymock.Anything).Return(true, nil) + }, + expectedError: nil, + }, + { + name: "public key auth disabled", + allowPublicKey: false, + mockSetup: func(*mocks.Client) {}, + expectedError: ErrPublicKeyDisabled, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + mockClient := &mocks.Client{} + tc.mockSetup(mockClient) + + sess := &Session{ + Data: Data{ + Target: &target.Target{Username: "user"}, + Device: &models.Device{ + Info: &models.DeviceInfo{Version: "latest"}, + TenantID: "tenant-1", + }, + Namespace: &models.Namespace{ + Settings: &models.NamespaceSettings{ + AllowPublicKey: tc.allowPublicKey, + }, + }, + }, + api: mockClient, + } + + auth := AuthPublicKey(publicKey) + err := auth.Evaluate(sess) + + assert.Equal(t, tc.expectedError, err) + }) + } +} diff --git a/ssh/session/errors.go b/ssh/session/errors.go index efdae1f3212..347183e21a3 100644 --- a/ssh/session/errors.go +++ b/ssh/session/errors.go @@ -4,17 +4,22 @@ import "fmt" // Errors returned by the NewSession to the client. var ( - ErrBillingBlock = fmt.Errorf("Connection to this device is not available as your current namespace doesn't qualify for the free plan. To gain access, you'll need to contact the namespace owner to initiate an upgrade.\n\nFor a detailed estimate of costs based on your use-cases with ShellHub Cloud, visit our pricing page at https://www.shellhub.io/pricing. If you wish to upgrade immediately, navigate to https://cloud.shellhub.io/settings/billing. Your cooperation is appreciated.") //nolint:all - ErrFirewallBlock = fmt.Errorf("you cannot connect to this device because a firewall rule block your connection") - ErrFirewallConnection = fmt.Errorf("failed to communicate to the firewall") - ErrFirewallUnknown = fmt.Errorf("failed to evaluate the firewall rule") - ErrHost = fmt.Errorf("failed to get the device address") - ErrFindDevice = fmt.Errorf("failed to find the device") - ErrDial = fmt.Errorf("failed to connect to device agent, please check the device connection") - ErrInvalidVersion = fmt.Errorf("failed to parse device version") - ErrUnsuportedPublicKeyAuth = fmt.Errorf("connections using public keys are not permitted when the agent version is 0.5.x or earlier") - ErrUnexpectedAuthMethod = fmt.Errorf("failed to authenticate the session due to a unexpected method") - ErrEvaluatePublicKey = fmt.Errorf("failed to evaluate the provided public key") - ErrSeatAlreadySet = fmt.Errorf("this seat was already set") - ErrLicenseBlock = fmt.Errorf("Connection blocked: your ShellHub instance has exceeded the maximum number of devices allowed by your license. Please contact support or remove unused devices.") //nolint:all + ErrBillingBlock = fmt.Errorf("Connection to this device is not available as your current namespace doesn't qualify for the free plan. To gain access, you'll need to contact the namespace owner to initiate an upgrade.\n\nFor a detailed estimate of costs based on your use-cases with ShellHub Cloud, visit our pricing page at https://www.shellhub.io/pricing. If you wish to upgrade immediately, navigate to https://cloud.shellhub.io/settings/billing. Your cooperation is appreciated.") //nolint:all + ErrFirewallBlock = fmt.Errorf("you cannot connect to this device because a firewall rule block your connection") + ErrFirewallConnection = fmt.Errorf("failed to communicate to the firewall") + ErrFirewallUnknown = fmt.Errorf("failed to evaluate the firewall rule") + ErrHost = fmt.Errorf("failed to get the device address") + ErrFindDevice = fmt.Errorf("failed to find the device") + ErrDial = fmt.Errorf("failed to connect to device agent, please check the device connection") + ErrInvalidVersion = fmt.Errorf("failed to parse device version") + ErrUnsuportedPublicKeyAuth = fmt.Errorf("connections using public keys are not permitted when the agent version is 0.5.x or earlier") + ErrUnexpectedAuthMethod = fmt.Errorf("failed to authenticate the session due to a unexpected method") + ErrEvaluatePublicKey = fmt.Errorf("failed to evaluate the provided public key") + ErrPasswordDisabled = fmt.Errorf("password authentication is disabled for this namespace") + ErrPublicKeyDisabled = fmt.Errorf("public key authentication is disabled for this namespace") + ErrPublicKeyNotFound = fmt.Errorf("public key not found") + ErrSeatAlreadySet = fmt.Errorf("this seat was already set") + ErrLicenseBlock = fmt.Errorf("Connection blocked: your ShellHub instance has exceeded the maximum number of devices allowed by your license. Please contact support or remove unused devices.") //nolint:all + ErrRootDisabled = fmt.Errorf("root login is disabled for this device") + ErrEmptyPasswordNotPermitted = fmt.Errorf("empty passwords are not permitted for this device") ) diff --git a/ssh/session/session.go b/ssh/session/session.go index 1b9c1288f9a..56faec8645a 100644 --- a/ssh/session/session.go +++ b/ssh/session/session.go @@ -656,9 +656,18 @@ func (s *Session) Auth(ctx gliderssh.Context, auth Auth) error { } snap.save(sess, StateRegistered) + if err := sess.connect(ctx, auth.Auth()); err != nil { + return err + } - fallthrough + if err := sess.authenticate(ctx); err != nil { + return err + } case StateRegistered: + if err := auth.Evaluate(sess); err != nil { + return err + } + if err := sess.connect(ctx, auth.Auth()); err != nil { return err } diff --git a/ui/apps/console/openapi-ts.config.d.ts b/ui/apps/console/openapi-ts.config.d.ts new file mode 100644 index 00000000000..41687f89e48 --- /dev/null +++ b/ui/apps/console/openapi-ts.config.d.ts @@ -0,0 +1,3 @@ +declare const _default: Promise; +export default _default; +//# sourceMappingURL=openapi-ts.config.d.ts.map \ No newline at end of file diff --git a/ui/apps/console/src/components/common/SettingToggle.tsx b/ui/apps/console/src/components/common/SettingToggle.tsx new file mode 100644 index 00000000000..5c78f7f1c0b --- /dev/null +++ b/ui/apps/console/src/components/common/SettingToggle.tsx @@ -0,0 +1,70 @@ +import { useState } from "react"; + +const TOGGLE_STYLES = { + primary: { + on: "bg-primary/15 text-primary border border-primary/25", + off: "bg-hover-strong text-text-secondary border border-border-light", + }, + success: { + on: "bg-accent-green/15 text-accent-green border border-accent-green/25", + off: "bg-hover-strong text-text-secondary border border-border-light", + }, +} as const; + +export type SettingToggleTone = keyof typeof TOGGLE_STYLES; + +interface SettingToggleProps { + checked: boolean; + disabled?: boolean; + tone?: SettingToggleTone; + onChange: (checked: boolean) => Promise | void; +} + +export default function SettingToggle({ + checked, + disabled = false, + tone = "primary", + onChange, +}: SettingToggleProps) { + const [loading, setLoading] = useState(false); + const styles = TOGGLE_STYLES[tone]; + + const handleToggle = async (value: boolean) => { + if (loading || disabled) return; + setLoading(true); + try { + await onChange(value); + } finally { + setLoading(false); + } + }; + + return ( +
+ + +
+ ); +} diff --git a/ui/apps/console/src/components/layout/__tests__/AppLayout.test.tsx b/ui/apps/console/src/components/layout/__tests__/AppLayout.test.tsx index bec805be939..2fe3f50d23e 100644 --- a/ui/apps/console/src/components/layout/__tests__/AppLayout.test.tsx +++ b/ui/apps/console/src/components/layout/__tests__/AppLayout.test.tsx @@ -20,6 +20,14 @@ vi.mock("@/hooks/useNamespaces", () => ({ useInitRole: () => {}, })); +vi.mock("@/hooks/useChatwoot", () => ({ + ChatwootContext: { Provider: ({ children }: { children: React.ReactNode }) => <>{children} }, + useChatwoot: () => ({ + status: "unavailable", + openWidget: vi.fn(), + }), +})); + vi.mock("@/hooks/useSidebarLayout", () => ({ useSidebarLayout: () => ({ expanded: false, diff --git a/ui/apps/console/src/hooks/useDeviceMutations.ts b/ui/apps/console/src/hooks/useDeviceMutations.ts index df2e52d8018..12b933e0de2 100644 --- a/ui/apps/console/src/hooks/useDeviceMutations.ts +++ b/ui/apps/console/src/hooks/useDeviceMutations.ts @@ -5,6 +5,7 @@ import { updateDeviceStatusMutation, deleteDeviceMutation, updateDeviceMutation, + updateDeviceSettingsMutation, pullTagFromDeviceMutation, setDeviceCustomFieldMutation, deleteDeviceCustomFieldMutation, @@ -60,6 +61,14 @@ export function useDeleteDeviceCustomField() { }); } +export function useUpdateDeviceSSH() { + const invalidate = useInvalidateByIds("getDevices", "getDevice", "getStatusDevices"); + return useMutation({ + ...updateDeviceSettingsMutation(), + onSuccess: invalidate, + }); +} + export function useAddDeviceTag() { const invalidate = useInvalidateByIds("getDevices", "getDevice", "getStatusDevices", "getTags"); return useMutation({ diff --git a/ui/apps/console/src/hooks/useNamespaces.ts b/ui/apps/console/src/hooks/useNamespaces.ts index af21961ff41..8c5d7fb16c7 100644 --- a/ui/apps/console/src/hooks/useNamespaces.ts +++ b/ui/apps/console/src/hooks/useNamespaces.ts @@ -60,7 +60,7 @@ export function useNamespace(tenantId: string) { }); return { - namespace: (result.data ?? null), + namespace: result.data ?? null, isLoading: result.isLoading, error: result.error, refetch: result.refetch, diff --git a/ui/apps/console/src/pages/BannerEdit.tsx b/ui/apps/console/src/pages/BannerEdit.tsx index 607de2b0c1c..e6b4231cd4f 100644 --- a/ui/apps/console/src/pages/BannerEdit.tsx +++ b/ui/apps/console/src/pages/BannerEdit.tsx @@ -9,6 +9,7 @@ import { useAuthStore } from "../stores/authStore"; import { useHasPermission } from "../hooks/useHasPermission"; import Spinner from "@/components/common/Spinner"; import PageLoader from "@/components/common/PageLoader"; +import { normalizeNamespaceSettings } from "../utils/namespaceSettings"; const MAX_LENGTH = 4096; @@ -33,13 +34,7 @@ function BannerEditor({ ns, canEdit }: { ns: Namespace; canEdit: boolean }) { try { await editNs.mutateAsync({ path: { tenant: ns.tenant_id }, - body: { - settings: { - connection_announcement: text, - session_record: ns.settings?.session_record ?? false, - device_auto_accept: ns.settings?.device_auto_accept ?? false, - }, - }, + body: { settings: normalizeNamespaceSettings({ ...ns.settings, connection_announcement: text }) }, }); void navigate("/settings"); } catch { diff --git a/ui/apps/console/src/pages/DeviceDetails.tsx b/ui/apps/console/src/pages/DeviceDetails.tsx index 0844b5691fb..600506f5c28 100644 --- a/ui/apps/console/src/pages/DeviceDetails.tsx +++ b/ui/apps/console/src/pages/DeviceDetails.tsx @@ -8,9 +8,12 @@ import { ClockIcon, CpuChipIcon, ChevronDoubleRightIcon, + LockOpenIcon, + LockClosedIcon, } from "@heroicons/react/24/outline"; import { useDevice } from "../hooks/useDevice"; -import { useRemoveDevice } from "../hooks/useDeviceMutations"; +import { useRemoveDevice, useUpdateDeviceSSH } from "../hooks/useDeviceMutations"; +import { useHasPermission } from "../hooks/useHasPermission"; import { useNamespace } from "../hooks/useNamespaces"; import { useAuthStore } from "../stores/authStore"; import { useTerminalStore } from "../stores/terminalStore"; @@ -20,6 +23,7 @@ import ConnectDrawer from "../components/ConnectDrawer"; import ConfirmDialog from "../components/common/ConfirmDialog"; import CopyButton from "../components/common/CopyButton"; import PlatformBadge from "../components/common/PlatformBadge"; +import SettingToggle from "../components/common/SettingToggle"; import { formatDateFull, formatRelative } from "../utils/date"; import { buildSshid } from "../utils/sshid"; import RestrictedAction from "../components/common/RestrictedAction"; @@ -29,11 +33,72 @@ import InfoItem from "./devices/InfoItem"; import TagsSection from "./devices/TagsSection"; import RenameSection from "./devices/RenameSection"; import CustomFieldsSection from "./devices/CustomFieldsSection"; +import type { Device } from "../client"; /* ─── Shared styles ─── */ const LABEL = "text-2xs font-mono font-semibold uppercase tracking-label text-text-muted"; const VALUE = "text-sm text-text-primary font-medium mt-0.5"; +type DeviceSSHSettings = NonNullable; +type DeviceSSHSettingKey = keyof DeviceSSHSettings; + +const DEVICE_SSH_SETTINGS: Array<{ + key: DeviceSSHSettingKey; + title: string; + description: string; +}> = [ + { + key: "allow_password", + title: "Allow Password Authentication", + description: "Allow SSH connections using password for this device", + }, + { + key: "allow_public_key", + title: "Allow Public Key Authentication", + description: "Allow SSH connections using public key for this device", + }, + { + key: "allow_root", + title: "Allow Root Login", + description: "Allow SSH connections as root user for this device", + }, + { + key: "allow_empty_passwords", + title: "Allow Empty Passwords", + description: "Allow SSH connections with empty passwords for this device", + }, + { + key: "allow_tty", + title: "Allow TTY Allocation", + description: "Allow terminal (TTY) allocation for this device", + }, + { + key: "allow_tcp_forwarding", + title: "Allow TCP Forwarding", + description: "Allow TCP port forwarding for this device", + }, + { + key: "allow_web_endpoints", + title: "Allow Web Endpoints", + description: "Allow HTTP/HTTPS access via ShellHub proxy", + }, + { + key: "allow_sftp", + title: "Allow SFTP", + description: "Allow SFTP subsystem for this device", + }, + { + key: "allow_agent_forwarding", + title: "Allow Agent Forwarding", + description: "Allow SSH agent forwarding for this device", + }, +]; + +function ToggleStateIcon({ enabled }: { enabled: boolean }) { + return enabled + ? + : ; +} /* ─── Page ─── */ export default function DeviceDetails() { @@ -42,8 +107,11 @@ export default function DeviceDetails() { const [searchParams] = useSearchParams(); const { device, isLoading } = useDevice(uid ?? ""); const removeMutation = useRemoveDevice(); + const updateSSH = useUpdateDeviceSSH(); + const canUpdateDeviceSettings = useHasPermission("device:update"); const tenantId = useAuthStore((s) => s.tenant) ?? ""; const { namespace: currentNamespace } = useNamespace(tenantId); + const deviceSettings = device?.settings ?? {}; const existingSession = useTerminalStore((s) => s.sessions.find((sess) => sess.deviceUid === uid), ); @@ -57,11 +125,21 @@ export default function DeviceDetails() { } | null>(null); const [billingWarningOpen, setBillingWarningOpen] = useState(false); - // Auto-open connect drawer if ?connect=true (adjust during render) - const shouldAutoConnect = - searchParams.get("connect") === "true" && - device?.online && - !existingSession; + const updateDeviceSetting = async (settings: Partial) => { + if (!device) { + return; + } + + await updateSSH.mutateAsync({ + path: { uid: device.uid }, + body: settings, + }); + }; + + const shouldAutoConnect + = searchParams.get("connect") === "true" + && device?.online + && !existingSession; const [autoConnectDone, setAutoConnectDone] = useState(false); if (shouldAutoConnect && !autoConnectDone) { @@ -72,7 +150,6 @@ export default function DeviceDetails() { setAutoConnectDone(false); } - // Restore existing terminal session (side effect only, no setState) useEffect(() => { if ( searchParams.get("connect") === "true" && @@ -367,6 +444,47 @@ export default function DeviceDetails() { + {/* Settings */} +
+

+ + Settings +

+
+ {DEVICE_SSH_SETTINGS.map((setting) => { + const enabled = deviceSettings[setting.key] ?? true; + + return ( +
+
+ + + +
+

+ {setting.title} +

+

+ {setting.description} +

+
+
+
+ { + return updateDeviceSetting({ [setting.key]: checked }); + }} + /> +
+
+ ); + })} +
+
+ {/* Delete Dialog */} ; +type NamespaceSSHSettingKey = Exclude< + keyof NamespaceSettings, + "connection_announcement" | "device_auto_accept" | "session_record" +>; /* ─── Settings Card ─── */ @@ -96,6 +107,74 @@ function SettingsRow({ ); } +const NAMESPACE_SSH_SETTINGS: Array<{ + key: NamespaceSSHSettingKey; + title: string; + description: string; + permission: Action; +}> = [ + { + key: "allow_password", + title: "Allow Password Authentication", + description: "Allow SSH connections using password for all devices in this namespace", + permission: "namespace:updateAllowPassword", + }, + { + key: "allow_public_key", + title: "Allow Public Key Authentication", + description: "Allow SSH connections using public key for all devices in this namespace", + permission: "namespace:updateAllowPublicKey", + }, + { + key: "allow_root", + title: "Allow Root Login", + description: "Allow SSH connections to devices using the root user", + permission: "namespace:updateAllowRoot", + }, + { + key: "allow_empty_passwords", + title: "Allow Empty Passwords", + description: "Allow SSH logins with empty passwords for devices in this namespace", + permission: "namespace:updateAllowEmptyPasswords", + }, + { + key: "allow_tty", + title: "Allow TTY Allocation", + description: "Allow SSH sessions to allocate a TTY", + permission: "namespace:updateAllowTTY", + }, + { + key: "allow_tcp_forwarding", + title: "Allow TCP Forwarding", + description: "Allow SSH TCP port forwarding for devices in this namespace", + permission: "namespace:updateAllowTcpForwarding", + }, + { + key: "allow_web_endpoints", + title: "Allow Web Endpoints", + description: "Allow access to web endpoints through the HTTP proxy", + permission: "namespace:updateAllowWebEndpoints", + }, + { + key: "allow_sftp", + title: "Allow SFTP", + description: "Allow the SFTP subsystem for devices in this namespace", + permission: "namespace:updateAllowSFTP", + }, + { + key: "allow_agent_forwarding", + title: "Allow Agent Forwarding", + description: "Allow SSH agent forwarding for devices in this namespace", + permission: "namespace:updateAllowAgentForwarding", + }, +]; + +function ToggleStateIcon({ enabled }: { enabled: boolean }) { + return enabled + ? + : ; +} + /* ─── Edit Name Drawer ─── */ function EditNameDrawer({ @@ -385,48 +464,40 @@ function BannerPreview({ export default function Settings() { const { tenant: tenantId } = useAuthStore(); + const role = useAuthStore((s) => s.role); const { namespace: ns } = useNamespace(tenantId ?? ""); const editNs = useEditNamespace(); const setDeviceAutoAccept = useSetDeviceAutoAccept(); const [editNameOpen, setEditNameOpen] = useState(false); const [deleteOpen, setDeleteOpen] = useState(false); const [leaveOpen, setLeaveOpen] = useState(false); - const [togglingRecord, setTogglingRecord] = useState(false); + const [togglingSetting, setTogglingSetting] = useState(null); const [togglingAutoAccept, setTogglingAutoAccept] = useState(false); - const canRename = useHasPermission("namespace:rename"); - const canUpdateRecording = useHasPermission( - "namespace:updateSessionRecording", - ); - const canUpdateAutoAccept = useHasPermission( - "namespace:updateDeviceAutoAccept", - ); - const canEditBanner = useHasPermission("namespace:editBanner"); - const canDelete = useHasPermission("namespace:delete"); - - const settings = ns?.settings; - const sessionRecord = settings?.session_record ?? false; - const deviceAutoAccept = settings?.device_auto_accept ?? false; - const banner = settings?.connection_announcement ?? ""; - - const handleToggleRecord = async () => { - if (!tenantId || togglingRecord) return; - setTogglingRecord(true); + const can = (action: Action) => hasPermission(role, action); + const canRename = can("namespace:rename"); + const canUpdateRecording = can("namespace:updateSessionRecording"); + const canUpdateAutoAccept = can("namespace:updateDeviceAutoAccept"); + const canEditBanner = can("namespace:editBanner"); + const canDelete = can("namespace:delete"); + + const settings = normalizeNamespaceSettings(ns?.settings); + const sessionRecord = settings.session_record; + const deviceAutoAccept = settings.device_auto_accept; + const banner = settings.connection_announcement; + + const updateNamespaceSettings = async (patch: Partial, key: keyof NamespaceSettings) => { + if (!tenantId || togglingSetting) return; + setTogglingSetting(key); try { await editNs.mutateAsync({ path: { tenant: tenantId }, - body: { - settings: { - session_record: !sessionRecord, - connection_announcement: banner, - device_auto_accept: deviceAutoAccept, - }, - }, + body: { settings: normalizeNamespaceSettings({ ...settings, ...patch }) }, }); } catch { /* state didn't change */ } finally { - setTogglingRecord(false); + setTogglingSetting(null); } }; @@ -445,6 +516,14 @@ export default function Settings() { } }; + const handleToggleSetting = async ( + key: NamespaceSSHSettingKey | "session_record", + checked: boolean, + ) => { + await updateNamespaceSettings({ [key]: checked }, key); + }; + + const isUpdatingSetting = (key: keyof NamespaceSettings) => togglingSetting === key; if (!ns) { return ; } @@ -522,41 +601,37 @@ export default function Settings() { title="Session Recording" description="Record SSH sessions for audit and playback" > -
- - -
+ handleToggleSetting("session_record", checked)} + /> )} {/* SSH Banner */} + + {NAMESPACE_SSH_SETTINGS.map((setting) => { + const enabled = settings[setting.key] ?? true; + + return ( + } + title={setting.title} + description={setting.description} + > + handleToggleSetting(setting.key, checked)} + /> + + ); + })} {/* ── Devices ── */} @@ -566,36 +641,12 @@ export default function Settings() { title="Auto-Accept Devices" description="Automatically accept new devices when they connect for the first time" > -
- - -
+ handleToggleAutoAccept()} + /> diff --git a/ui/apps/console/src/pages/admin/namespaces/EditNamespaceDrawer.tsx b/ui/apps/console/src/pages/admin/namespaces/EditNamespaceDrawer.tsx index 92014431cd2..5841f8aec83 100644 --- a/ui/apps/console/src/pages/admin/namespaces/EditNamespaceDrawer.tsx +++ b/ui/apps/console/src/pages/admin/namespaces/EditNamespaceDrawer.tsx @@ -9,6 +9,7 @@ import CheckboxField from "@/components/common/fields/CheckboxField"; import { validateNamespaceName } from "@/utils/validation"; import type { Namespace } from "@/client"; import Spinner from "@/components/common/Spinner"; +import { normalizeNamespaceSettings } from "@/utils/namespaceSettings"; interface EditNamespaceDrawerProps { open: boolean; @@ -28,14 +29,12 @@ export default function EditNamespaceDrawer({ String(namespace?.max_devices ?? -1), ); const [sessionRecord, setSessionRecord] = useState(false); - const [deviceAutoAccept, setDeviceAutoAccept] = useState(false); const [error, setError] = useState(""); useResetOnOpen(open, () => { setName(namespace?.name ?? ""); setMaxDevices(String(namespace?.max_devices ?? -1)); setSessionRecord(namespace?.settings?.session_record ?? false); - setDeviceAutoAccept(namespace?.settings?.device_auto_accept ?? false); setError(""); }); @@ -51,17 +50,15 @@ export default function EditNamespaceDrawer({ try { await editNamespace.mutateAsync({ path: { tenantID: namespace.tenant_id }, - // The SDK types body as full Namespace; we spread the original - // to satisfy the type while only changing the editable fields. body: { ...namespace, name: name.trim(), max_devices: parseInt(maxDevices, 10), settings: { - connection_announcement: - namespace.settings?.connection_announcement ?? "", - session_record: sessionRecord, - device_auto_accept: deviceAutoAccept, + ...normalizeNamespaceSettings({ + ...namespace.settings, + session_record: sessionRecord, + }), }, }, }); @@ -137,14 +134,6 @@ export default function EditNamespaceDrawer({ checked={sessionRecord} onChange={setSessionRecord} /> - - - {error && (

{error} diff --git a/ui/apps/console/src/pages/admin/namespaces/__tests__/EditNamespaceDrawer.test.tsx b/ui/apps/console/src/pages/admin/namespaces/__tests__/EditNamespaceDrawer.test.tsx index 47ddfa8d1aa..9d455dc2e8d 100644 --- a/ui/apps/console/src/pages/admin/namespaces/__tests__/EditNamespaceDrawer.test.tsx +++ b/ui/apps/console/src/pages/admin/namespaces/__tests__/EditNamespaceDrawer.test.tsx @@ -15,7 +15,7 @@ vi.mock("@/components/common/Drawer", async () => ({ const mockMutateAsync = vi.fn(); -const mockNamespace: Namespace = { +const mockNamespace = { name: "my-namespace", owner: "owner-1", tenant_id: "tenant-abc", @@ -24,6 +24,8 @@ const mockNamespace: Namespace = { session_record: true, connection_announcement: "hello", device_auto_accept: false, + allow_password: true, + allow_public_key: true, }, max_devices: 10, created_at: "2024-01-01T00:00:00Z", @@ -31,7 +33,7 @@ const mockNamespace: Namespace = { devices_pending_count: 0, devices_accepted_count: 3, devices_rejected_count: 0, -}; +} as unknown as Namespace; beforeEach(() => { vi.clearAllMocks(); diff --git a/ui/apps/console/src/pages/devices/__tests__/DeviceDetails.test.tsx b/ui/apps/console/src/pages/devices/__tests__/DeviceDetails.test.tsx index e981ddd8b75..734552e3517 100644 --- a/ui/apps/console/src/pages/devices/__tests__/DeviceDetails.test.tsx +++ b/ui/apps/console/src/pages/devices/__tests__/DeviceDetails.test.tsx @@ -11,6 +11,7 @@ vi.mock("@/hooks/useDevice", () => ({ useDevice: vi.fn(), })); +const mockUpdateSSH = vi.fn(); const mockSetCustomField = vi.fn(); const mockDeleteCustomField = vi.fn(); @@ -19,6 +20,7 @@ vi.mock("@/hooks/useDeviceMutations", () => ({ useAddDeviceTag: () => ({ mutateAsync: vi.fn() }), useRemoveDeviceTag: () => ({ mutateAsync: vi.fn() }), useRemoveDevice: () => ({ mutateAsync: vi.fn() }), + useUpdateDeviceSSH: () => ({ mutateAsync: mockUpdateSSH }), useSetDeviceCustomField: () => ({ mutateAsync: mockSetCustomField }), useDeleteDeviceCustomField: () => ({ mutateAsync: mockDeleteCustomField }), })); @@ -124,6 +126,7 @@ function renderPage() { describe("DeviceDetails", () => { beforeEach(() => { + mockUpdateSSH.mockReset().mockResolvedValue({}); mockSetCustomField.mockReset().mockResolvedValue({}); mockDeleteCustomField.mockReset().mockResolvedValue({}); vi.mocked(useDevice).mockReturnValue({ @@ -284,7 +287,7 @@ describe("DeviceDetails", () => { expect(mockSetCustomField).toHaveBeenCalledWith( expect.objectContaining({ path: expect.objectContaining({ uid: "test-uid", key: "region" }), - body: { value: "us-east" }, + body: expect.objectContaining({ value: "us-east" }), }), ); }); diff --git a/ui/apps/console/src/utils/__tests__/namespaceSettings.test.ts b/ui/apps/console/src/utils/__tests__/namespaceSettings.test.ts new file mode 100644 index 00000000000..c4aa54c46ed --- /dev/null +++ b/ui/apps/console/src/utils/__tests__/namespaceSettings.test.ts @@ -0,0 +1,53 @@ +import { describe, expect, it } from "vitest"; +import { normalizeNamespaceSettings } from "../namespaceSettings"; + +describe("normalizeNamespaceSettings", () => { + it("fills all namespace settings defaults", () => { + expect(normalizeNamespaceSettings()).toEqual({ + session_record: false, + connection_announcement: "", + device_auto_accept: false, + allow_password: true, + allow_public_key: true, + allow_root: true, + allow_empty_passwords: true, + allow_tty: true, + allow_tcp_forwarding: true, + allow_web_endpoints: true, + allow_sftp: true, + allow_agent_forwarding: true, + }); + }); + + it("preserves provided values", () => { + expect( + normalizeNamespaceSettings({ + session_record: true, + connection_announcement: "hello", + device_auto_accept: true, + allow_password: false, + allow_public_key: false, + allow_root: false, + allow_empty_passwords: false, + allow_tty: false, + allow_tcp_forwarding: false, + allow_web_endpoints: false, + allow_sftp: false, + allow_agent_forwarding: false, + }), + ).toEqual({ + session_record: true, + connection_announcement: "hello", + device_auto_accept: true, + allow_password: false, + allow_public_key: false, + allow_root: false, + allow_empty_passwords: false, + allow_tty: false, + allow_tcp_forwarding: false, + allow_web_endpoints: false, + allow_sftp: false, + allow_agent_forwarding: false, + }); + }); +}); diff --git a/ui/apps/console/src/utils/__tests__/permission.test.ts b/ui/apps/console/src/utils/__tests__/permission.test.ts index d8012367717..e6fa87a1f69 100644 --- a/ui/apps/console/src/utils/__tests__/permission.test.ts +++ b/ui/apps/console/src/utils/__tests__/permission.test.ts @@ -22,6 +22,11 @@ const ADMINISTRATOR_ACTIONS: Action[] = [ "namespace:addMember", "namespace:editMember", "namespace:removeMember", "namespace:editInvitation", "namespace:cancelInvitation", "namespace:updateSessionRecording", "namespace:editBanner", + "namespace:updateAllowPassword", "namespace:updateAllowPublicKey", + "namespace:updateAllowRoot", "namespace:updateAllowEmptyPasswords", + "namespace:updateAllowTTY", "namespace:updateAllowTcpForwarding", + "namespace:updateAllowWebEndpoints", "namespace:updateAllowSFTP", + "namespace:updateAllowAgentForwarding", "publicKey:create", "publicKey:edit", "publicKey:remove", "firewall:create", "firewall:edit", "firewall:remove", "webEndpoint:create", "webEndpoint:delete", diff --git a/ui/apps/console/src/utils/namespaceSettings.ts b/ui/apps/console/src/utils/namespaceSettings.ts new file mode 100644 index 00000000000..851c5076be6 --- /dev/null +++ b/ui/apps/console/src/utils/namespaceSettings.ts @@ -0,0 +1,20 @@ +import type { NamespaceSettings } from "../client"; + +export function normalizeNamespaceSettings( + settings?: Partial | null, +): NamespaceSettings { + return { + session_record: settings?.session_record ?? false, + connection_announcement: settings?.connection_announcement ?? "", + device_auto_accept: settings?.device_auto_accept ?? false, + allow_password: settings?.allow_password ?? true, + allow_public_key: settings?.allow_public_key ?? true, + allow_root: settings?.allow_root ?? true, + allow_empty_passwords: settings?.allow_empty_passwords ?? true, + allow_tty: settings?.allow_tty ?? true, + allow_tcp_forwarding: settings?.allow_tcp_forwarding ?? true, + allow_web_endpoints: settings?.allow_web_endpoints ?? true, + allow_sftp: settings?.allow_sftp ?? true, + allow_agent_forwarding: settings?.allow_agent_forwarding ?? true, + }; +} diff --git a/ui/apps/console/src/utils/permission.ts b/ui/apps/console/src/utils/permission.ts index e8f3051c3c3..402a6222d8b 100644 --- a/ui/apps/console/src/utils/permission.ts +++ b/ui/apps/console/src/utils/permission.ts @@ -25,6 +25,7 @@ const permissions = { "device:accept": RoleLevel.OPERATOR, "device:reject": RoleLevel.OPERATOR, "device:rename": RoleLevel.OPERATOR, + "device:update": RoleLevel.OPERATOR, "device:customField:update": RoleLevel.OPERATOR, "device:remove": RoleLevel.ADMINISTRATOR, "device:choose": RoleLevel.OWNER, @@ -46,6 +47,15 @@ const permissions = { "namespace:cancelInvitation": RoleLevel.ADMINISTRATOR, "namespace:updateSessionRecording": RoleLevel.ADMINISTRATOR, "namespace:updateDeviceAutoAccept": RoleLevel.ADMINISTRATOR, + "namespace:updateAllowPassword": RoleLevel.ADMINISTRATOR, + "namespace:updateAllowPublicKey": RoleLevel.ADMINISTRATOR, + "namespace:updateAllowRoot": RoleLevel.ADMINISTRATOR, + "namespace:updateAllowEmptyPasswords": RoleLevel.ADMINISTRATOR, + "namespace:updateAllowTTY": RoleLevel.ADMINISTRATOR, + "namespace:updateAllowTcpForwarding": RoleLevel.ADMINISTRATOR, + "namespace:updateAllowWebEndpoints": RoleLevel.ADMINISTRATOR, + "namespace:updateAllowSFTP": RoleLevel.ADMINISTRATOR, + "namespace:updateAllowAgentForwarding": RoleLevel.ADMINISTRATOR, "namespace:delete": RoleLevel.OWNER, // Tags