Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions agent/installer.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func registerInstallerCommands(rootCmd *cobra.Command) {
installCmd.Flags().String("preferred-identity", "", "Preferred device identity")
installCmd.Flags().Uint("keepalive-interval", 30, "Keepalive interval in seconds")
installCmd.MarkFlagRequired("server-address") //nolint:errcheck
installCmd.MarkFlagRequired("tenant-id") //nolint:errcheck
installCmd.MarkFlagRequired("tenant-id") //nolint:errcheck

rootCmd.AddCommand(installCmd)

Expand Down Expand Up @@ -169,7 +169,7 @@ func writeAgentEnvFile(cfg installerConfig) error {
fmt.Fprintf(&buf, "SHELLHUB_KEEPALIVE_INTERVAL=%d\n", cfg.KeepaliveInterval)
}

return os.WriteFile(agentEnvFile, buf.Bytes(), 0600)
return os.WriteFile(agentEnvFile, buf.Bytes(), 0o600)
}

func writeAgentServiceFile(binaryPath string) error {
Expand All @@ -183,7 +183,7 @@ func writeAgentServiceFile(binaryPath string) error {
return err
}

return os.WriteFile(agentServiceFile, buf.Bytes(), 0644)
return os.WriteFile(agentServiceFile, buf.Bytes(), 0o644)
}

func agentUninstall() error {
Expand Down
51 changes: 51 additions & 0 deletions api/routes/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ const (
LookupDeviceURL = "/device/lookup"
UpdateDeviceStatusURL = "/devices/:uid/:status"
UpdateDevice = "/devices/:uid"
GetDeviceSettingsURL = "/devices/:uid/settings"
UpdateDeviceSettingsURL = "/devices/:uid/settings"
SetDeviceCustomFieldURL = "/devices/:uid/custom_fields/:key"
DeleteDeviceCustomFieldURL = "/devices/:uid/custom_fields/:key"
)
Expand Down Expand Up @@ -263,13 +265,62 @@ func (h *Handler) UpdateDevice(c gateway.Context) error {
return err
}

if c.Tenant() != nil {
req.TenantID = c.Tenant().ID
}

if err := h.service.UpdateDevice(c.Ctx(), req); err != nil {
return err
}

return c.NoContent(http.StatusOK)
}

func (h *Handler) GetDeviceSettings(c gateway.Context) error {
req := new(requests.DeviceGetSettings)

if err := c.Bind(req); err != nil {
return err
}

if c.Tenant() != nil {
req.TenantID = c.Tenant().ID
}

if err := c.Validate(req); err != nil {
return err
}

settings, err := h.service.GetDeviceSettings(c.Ctx(), req)
if err != nil {
return err
}

return c.JSON(http.StatusOK, settings)
}

func (h *Handler) UpdateDeviceSettings(c gateway.Context) error {
req := new(requests.DeviceUpdateSettings)

if err := c.Bind(req); err != nil {
return err
}

if c.Tenant() != nil {
req.TenantID = c.Tenant().ID
}

if err := c.Validate(req); err != nil {
return err
}

if err := h.service.UpdateDeviceSettings(c.Ctx(), req); err != nil {
return err
}

return c.NoContent(http.StatusOK)
}

func (h *Handler) SetDeviceCustomField(c gateway.Context) error {
req := new(requests.DeviceSetCustomField)

Expand Down
108 changes: 108 additions & 0 deletions api/routes/device_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -662,3 +662,111 @@ func TestUpdateDevice(t *testing.T) {
})
}
}

func TestGetDeviceSettings(t *testing.T) {
mock := new(mocks.Service)

settings := &models.SSHSettings{
AllowPassword: false,
AllowPublicKey: true,
}

mock.
On("GetDeviceSettings", gomock.Anything, &requests.DeviceGetSettings{
TenantID: "00000000-0000-4000-0000-000000000000",
DeviceParam: requests.DeviceParam{UID: "1234"},
}).
Return(settings, nil).
Once()

req := httptest.NewRequest(http.MethodGet, "/api/devices/1234/settings", nil)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Role", authorizer.RoleOwner.String())
req.Header.Set("X-Tenant-ID", "00000000-0000-4000-0000-000000000000")
rec := httptest.NewRecorder()

e := NewRouter(mock)
e.ServeHTTP(rec, req)

assert.Equal(t, http.StatusOK, rec.Result().StatusCode)

var body *models.SSHSettings
err := json.NewDecoder(rec.Result().Body).Decode(&body)
require.NoError(t, err)
assert.Equal(t, settings, body)
}

func TestUpdateDeviceSettings(t *testing.T) {
mock := new(mocks.Service)

cases := []struct {
description string
req requests.DeviceUpdateSettings
requiredMocks func()
expectedStatus int
}{
{
description: "fails when device settings update cannot find device",
req: requests.DeviceUpdateSettings{
TenantID: "00000000-0000-4000-0000-000000000000",
DeviceParam: requests.DeviceParam{
UID: "1234",
},
},
requiredMocks: func() {
mock.On("UpdateDeviceSettings", gomock.Anything, &requests.DeviceUpdateSettings{
TenantID: "00000000-0000-4000-0000-000000000000",
DeviceParam: requests.DeviceParam{UID: "1234"},
}).Return(svc.ErrNotFound).Once()
},
expectedStatus: http.StatusNotFound,
},
{
description: "success when updating a device setting",
req: requests.DeviceUpdateSettings{
TenantID: "00000000-0000-4000-0000-000000000000",
DeviceParam: requests.DeviceParam{
UID: "1234",
},
SSHSettingsUpdate: requests.SSHSettingsUpdate{
AllowPassword: func() *bool {
v := false

return &v
}(),
},
},
requiredMocks: func() {
v := false
mock.On("UpdateDeviceSettings", gomock.Anything, &requests.DeviceUpdateSettings{
TenantID: "00000000-0000-4000-0000-000000000000",
DeviceParam: requests.DeviceParam{UID: "1234"},
SSHSettingsUpdate: requests.SSHSettingsUpdate{
AllowPassword: &v,
},
}).Return(nil).Once()
},
expectedStatus: http.StatusOK,
},
}

for _, tc := range cases {
t.Run(tc.description, func(t *testing.T) {
tc.requiredMocks()

jsonData, err := json.Marshal(tc.req)
require.NoError(t, err)

req := httptest.NewRequest(http.MethodPatch, fmt.Sprintf("/api/devices/%s/settings", tc.req.UID), strings.NewReader(string(jsonData)))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Role", authorizer.RoleOwner.String())
req.Header.Set("X-Tenant-ID", "00000000-0000-4000-0000-000000000000")
rec := httptest.NewRecorder()

e := NewRouter(mock)
e.ServeHTTP(rec, req)

assert.Equal(t, tc.expectedStatus, rec.Result().StatusCode)
})
}
}
2 changes: 2 additions & 0 deletions api/routes/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,10 @@ func NewRouter(service services.Service, opts ...Option) *echo.Echo {

publicAPI.GET(GetDeviceListURL, routesmiddleware.Authorize(gateway.Handler(handler.GetDeviceList)))
publicAPI.GET(GetDeviceURL, routesmiddleware.Authorize(gateway.Handler(handler.GetDevice)))
publicAPI.GET(GetDeviceSettingsURL, routesmiddleware.Authorize(gateway.Handler(handler.GetDeviceSettings)))
publicAPI.GET(ResolveDeviceURL, routesmiddleware.Authorize(gateway.Handler(handler.ResolveDevice)))
publicAPI.PUT(UpdateDevice, gateway.Handler(handler.UpdateDevice), routesmiddleware.RequiresPermission(authorizer.DeviceUpdate))
publicAPI.PATCH(UpdateDeviceSettingsURL, gateway.Handler(handler.UpdateDeviceSettings), routesmiddleware.RequiresPermission(authorizer.DeviceUpdate))
publicAPI.PATCH(RenameDeviceURL, gateway.Handler(handler.RenameDevice), routesmiddleware.RequiresPermission(authorizer.DeviceRename))
publicAPI.PATCH(UpdateDeviceStatusURL, gateway.Handler(handler.UpdateDeviceStatus), routesmiddleware.RequiresPermission(authorizer.DeviceAccept)) // TODO: DeviceWrite
publicAPI.DELETE(DeleteDeviceURL, gateway.Handler(handler.DeleteDevice), routesmiddleware.RequiresPermission(authorizer.DeviceRemove))
Expand Down
64 changes: 64 additions & 0 deletions api/services/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ type DeviceService interface {
OfflineDevice(ctx context.Context, uid models.UID) error

UpdateDevice(ctx context.Context, req *requests.DeviceUpdate) error
GetDeviceSettings(ctx context.Context, req *requests.DeviceGetSettings) (*models.SSHSettings, error)
UpdateDeviceSettings(ctx context.Context, req *requests.DeviceUpdateSettings) error
// UpdateDeviceStatus updates a device's status. Devices that are already accepted cannot change their status.
//
// When accepting, if a device with the same MAC address is already accepted within the same namespace, it
Expand Down Expand Up @@ -398,6 +400,64 @@ func (s *service) UpdateDevice(ctx context.Context, req *requests.DeviceUpdate)
return nil
}

func (s *service) GetDeviceSettings(ctx context.Context, req *requests.DeviceGetSettings) (*models.SSHSettings, error) {
device, err := s.store.DeviceResolve(ctx, store.DeviceUIDResolver, req.UID, s.store.Options().InNamespace(req.TenantID))
if err != nil {
return nil, NewErrDeviceNotFound(models.UID(req.UID), err)
}

if device.SSH == nil {
return models.DefaultSSHSettings(), nil
}

return device.SSH, nil
}

func (s *service) UpdateDeviceSettings(ctx context.Context, req *requests.DeviceUpdateSettings) error {
device, err := s.store.DeviceResolve(ctx, store.DeviceUIDResolver, req.UID, s.store.Options().InNamespace(req.TenantID))
if err != nil {
return NewErrDeviceNotFound(models.UID(req.UID), err)
}

if device.SSH == nil {
device.SSH = models.DefaultSSHSettings()
}

if req.AllowPassword != nil {
device.SSH.AllowPassword = *req.AllowPassword
}
if req.AllowPublicKey != nil {
device.SSH.AllowPublicKey = *req.AllowPublicKey
}
if req.AllowRoot != nil {
device.SSH.AllowRoot = *req.AllowRoot
}
if req.AllowEmptyPasswords != nil {
device.SSH.AllowEmptyPasswords = *req.AllowEmptyPasswords
}
if req.AllowTTY != nil {
device.SSH.AllowTTY = *req.AllowTTY
}
if req.AllowTCPForwarding != nil {
device.SSH.AllowTCPForwarding = *req.AllowTCPForwarding
}
if req.AllowWebEndpoints != nil {
device.SSH.AllowWebEndpoints = *req.AllowWebEndpoints
}
if req.AllowSFTP != nil {
device.SSH.AllowSFTP = *req.AllowSFTP
}
if req.AllowAgentForwarding != nil {
device.SSH.AllowAgentForwarding = *req.AllowAgentForwarding
}

if err := s.store.DeviceUpdateSettings(ctx, req.UID, device.SSH); err != nil {
return err
}

return nil
}

// maxCustomFieldsPerDevice is the upper bound on the number of custom_fields entries
// per device. Enforced server-side to prevent storage abuse.
const maxCustomFieldsPerDevice = 20
Expand Down Expand Up @@ -451,6 +511,10 @@ func (s *service) mergeDevice(ctx context.Context, tenantID string, oldDevice *m
}

log.WithFields(logFields).Debug("updating new device name to preserve old device identity")
if oldDevice.SSH != nil {
newDevice.SSH = oldDevice.SSH
}

newDevice.Name = oldDevice.Name
if err := s.store.DeviceUpdate(ctx, newDevice); err != nil {
log.WithError(err).WithFields(logFields).Error("failed to update new device name")
Expand Down
Loading
Loading