diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index 1193041d4..992403f51 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -43,6 +43,15 @@ Reference material for Zaparoo Core's architecture, APIs, and subsystems. For de - **Thread-safe**: `config.Instance` uses `syncutil.RWMutex` - Maintain backward compatibility — use migrations for breaking changes +## Profiles + +Device profiles are named buckets of preferences and limits, with no passwords or accounts. See `pkg/service/profiles/`. + +- **Active profile**: one per device, held as a snapshot in service state (`pkg/service/state/`) and persisted in the UserDB `DeviceState` table so it survives restarts. No active profile = pre-profiles behavior exactly. +- **Switching**: via API (`profiles.switch`, PIN-checked) or by scanning a card containing `**profile.switch:`. The switch ID is a word phrase (e.g. `corn-arm-truck`) generated from an embedded wordlist — a selector, never a credential. Card scans bypass the PIN: possession of the card is the authorization. PINs gate entry only; deactivating is always free. +- **Playtime limits**: profiles can override the global daily/session limits. `pkg/service/playtime.LimitsManager` reads limits through a `LimitsProvider`; the profile-aware resolver (`pkg/service/profiles.LimitsResolver`) layers the active profile's overrides over global config. Daily usage accounting is scoped to the active profile via the `ProfileID` column on `MediaHistory` (rows are attributed at launch time). Switching profiles resets the limit session. +- **Require-profile gate**: the `[profiles] require_for_launch` config setting blocks media launches while no profile is active (profile switch commands still run, so scanning a card unparks the device). + ## Reader Auto-Detection 10 reader types: acr122pcsc, externaldrive, file, libnfc, mqtt, opticaldrive, pn532, rs232barcode, simpleserial, tty2oled diff --git a/docs/api/methods.md b/docs/api/methods.md index 5e4e74720..1ba7a52e6 100644 --- a/docs/api/methods.md +++ b/docs/api/methods.md @@ -2447,6 +2447,124 @@ Returns `null` on success. } ``` +## Profiles + +Profiles are lightweight device profiles: named buckets of preferences and limits, with no passwords or accounts. One profile is active per device at a time, switched via the API or by scanning an NFC card containing the profile's switch ID (`**profile.switch:`). + +A profile may have an optional 4-8 digit PIN. Switching to a PIN-protected profile via the API requires the PIN; scanning the profile's physical card bypasses it (possession of the card is the authorization). Leaving a profile is always free — PINs gate entry only. To prevent a profile-less device from being an escape hatch, enable the `profilesRequireForLaunch` setting (see [settings](#settings)), which blocks media launches while no profile is active. + +When no profile is active the device behaves exactly as it did before profiles existed: global playtime limits apply and history is unattributed. + +##### Profile object + +| Key | Type | Required | Description | +| :------------ | :------ | :------- | :------------------------------------------------------------------------------------------------------- | +| profileId | string | Yes | Unique identifier of the profile. | +| name | string | Yes | Display name, e.g. "Dad" or "Kid A". | +| switchId | string | Yes | Word phrase written to profile switch cards, e.g. `corn-arm-truck`. A selector, not a secret. | +| hasPin | boolean | Yes | True when the profile has a PIN set. The PIN itself is never returned. | +| limitsEnabled | boolean | No | Playtime limits enabled override. Omitted = inherit the global setting. | +| dailyLimit | string | No | Daily playtime limit override as a duration string (e.g. `2h30m`). Omitted = inherit; `0` = unlimited. | +| sessionLimit | string | No | Session playtime limit override as a duration string. Omitted = inherit; `0` = unlimited. | +| createdAt | number | Yes | Unix timestamp of profile creation. | +| lastUpdatedAt | number | Yes | Unix timestamp of last modification. | + +### profiles + +List all profiles. + +#### Parameters + +None. + +#### Result + +| Key | Type | Required | Description | +| :------- | :--------------------------- | :------- | :---------------- | +| profiles | [Profile](#profile-object)[] | Yes | List of profiles. | + +### profiles.new + +Create a new profile. The switch ID is generated automatically and returned in the result — write it to a card as `**profile.switch:`. + +#### Parameters + +| Key | Type | Required | Description | +| :------------ | :------ | :------- | :---------------------------------------------------------------- | +| name | string | Yes | Display name. | +| pin | string | No | Optional 4-8 digit PIN required to switch to this profile via API. | +| limitsEnabled | boolean | No | Playtime limits enabled override. | +| dailyLimit | string | No | Daily limit duration override. | +| sessionLimit | string | No | Session limit duration override. | + +#### Result + +The created [profile object](#profile-object). + +### profiles.update + +Update a profile. Omitted fields are unchanged. If the updated profile is currently active, its limit changes apply immediately. + +#### Parameters + +| Key | Type | Required | Description | +| :----------------- | :------ | :------- | :---------------------------------------------------------------------- | +| profileId | string | Yes | Profile to update. | +| name | string | No | New display name. | +| pin | string | No | Set or replace the PIN. | +| clearPin | boolean | No | Remove the PIN. | +| limitsEnabled | boolean | No | Playtime limits enabled override. | +| dailyLimit | string | No | Daily limit duration override. | +| sessionLimit | string | No | Session limit duration override. | +| clearLimits | boolean | No | Reset all limit overrides back to inheriting the global config. | +| regenerateSwitchId | boolean | No | Issue a new switch ID (lost-card replacement). Old cards stop working. | + +#### Result + +The updated [profile object](#profile-object). + +### profiles.delete + +Delete a profile. If it is the active profile, the device deactivates. Past play history keeps its attribution to the deleted profile. + +#### Parameters + +| Key | Type | Required | Description | +| :-------- | :----- | :------- | :----------------- | +| profileId | string | Yes | Profile to delete. | + +#### Result + +Null. + +### profiles.active + +Get the device's currently active profile. + +#### Parameters + +None. + +#### Result + +The active profile (a subset of the [profile object](#profile-object) without `switchId` and timestamps), or null when no profile is active. + +### profiles.switch + +Switch the device's active profile. Switching to a PIN-protected profile requires its PIN, whether selected by `profileId` or `switchId` — only physical card scans bypass the PIN. Calling with neither `profileId` nor `switchId` deactivates the current profile, which never requires a PIN. + +#### Parameters + +| Key | Type | Required | Description | +| :-------- | :----- | :------- | :------------------------------------------------------ | +| profileId | string | No | Profile to activate, by ID. | +| switchId | string | No | Profile to activate, by switch ID. | +| pin | string | No | The profile's PIN, when one is set. | + +#### Result + +The new active profile, or null when deactivated. + ## Mappings Mappings are used to modify the contents of tokens before they're launched, based on different types of matching parameters. Stored mappings are queried before every launch and applied to the token if there's a match. This allows, for example, adding ZapScript to a read-only NFC tag based on its UID. diff --git a/docs/api/notifications.md b/docs/api/notifications.md index 06ade847c..5aca4f379 100644 --- a/docs/api/notifications.md +++ b/docs/api/notifications.md @@ -408,3 +408,35 @@ Sent when a new inbox message is added to the server. } } ``` + +## Profiles + +### profiles.active + +Sent when the device's active profile changes, including deactivation. + +#### Parameters + +| Key | Type | Required | Description | +| :------ | :----- | :------- | :------------------------------------------------------------------ | +| profile | object | Yes | The new active profile, or null when the device deactivated. | + +The profile object contains `profileId`, `name`, `hasPin` and any playtime limit overrides (`limitsEnabled`, `dailyLimit`, `sessionLimit`). + +#### Example + +```json +{ + "jsonrpc": "2.0", + "method": "profiles.active", + "params": { + "profile": { + "profileId": "1ad28b9a-7aef-11ef-9817-020304050607", + "name": "Kid A", + "hasPin": true, + "limitsEnabled": true, + "dailyLimit": "2h" + } + } +} +``` diff --git a/go.mod b/go.mod index 2e9ffe635..dbf04596f 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/DATA-DOG/go-sqlmock v1.5.2 github.com/Microsoft/go-winio v0.6.2 github.com/ZaparooProject/go-pn532 v0.22.1 - github.com/ZaparooProject/go-zapscript v0.13.0 + github.com/ZaparooProject/go-zapscript v0.14.0 github.com/adrg/xdg v0.5.3 github.com/andygrunwald/vdf v1.1.0 github.com/bendahl/uinput v1.7.0 diff --git a/go.sum b/go.sum index 2791c9f67..5a73bd076 100644 --- a/go.sum +++ b/go.sum @@ -16,8 +16,8 @@ github.com/TheTitanrain/w32 v0.0.0-20200114052255-2654d97dbd3d h1:2xp1BQbqcDDaik github.com/TheTitanrain/w32 v0.0.0-20200114052255-2654d97dbd3d/go.mod h1:peYoMncQljjNS6tZwI9WVyQB3qZS6u79/N3mBOcnd3I= github.com/ZaparooProject/go-pn532 v0.22.1 h1:Dtuc+sXYZtuNZP+8/DQv68V1MOHUxRAOMVYsqvVFAfQ= github.com/ZaparooProject/go-pn532 v0.22.1/go.mod h1:NwYx5IE0zAU70ZikNpoPiOF5MUlDn3fD8xImpZixW1k= -github.com/ZaparooProject/go-zapscript v0.13.0 h1:qiYhSoVzenFvmAeU+b0AwFYT5jAmiC7RMwNroYtBl2o= -github.com/ZaparooProject/go-zapscript v0.13.0/go.mod h1:Z3rFyQq/GA+ESpYUtCOA/2Xyftbygv4MfDCajOVDmag= +github.com/ZaparooProject/go-zapscript v0.14.0 h1:DJp4KsbqDN2My/mwH3DBc4MXBfmeabe8PWjurlRpRnM= +github.com/ZaparooProject/go-zapscript v0.14.0/go.mod h1:Z3rFyQq/GA+ESpYUtCOA/2Xyftbygv4MfDCajOVDmag= github.com/adrg/xdg v0.5.3 h1:xRnxJXne7+oWDatRhR1JLnvuccuIeCoBu2rtuLqQB78= github.com/adrg/xdg v0.5.3/go.mod h1:nlTsY+NNiCBGCK2tpm09vRqfVzrc2fLmXGpBLF0zlTQ= github.com/andygrunwald/vdf v1.1.0 h1:gmstp0R7DOepIZvWoSJY97ix7QOrsxpGPU6KusKXqvw= diff --git a/pkg/api/methods/profiles.go b/pkg/api/methods/profiles.go new file mode 100644 index 000000000..4feb7897e --- /dev/null +++ b/pkg/api/methods/profiles.go @@ -0,0 +1,212 @@ +// Zaparoo Core +// Copyright (c) 2026 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package methods + +import ( + "errors" + + "github.com/ZaparooProject/zaparoo-core/v2/pkg/api/models" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/api/models/requests" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/api/validation" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/database" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/profiles" + "github.com/rs/zerolog/log" +) + +// errProfilesUnavailable is returned when the profiles service was not +// wired into the request environment (should not happen in production). +var errProfilesUnavailable = errors.New("profiles service not available") + +func profileResponse(p *database.Profile) models.ProfileResponse { + return models.ProfileResponse{ + ProfileID: p.ProfileID, + Name: p.Name, + SwitchID: p.SwitchID, + HasPIN: p.PINHash != "", + LimitsEnabled: p.LimitsEnabled, + DailyLimit: p.DailyLimit, + SessionLimit: p.SessionLimit, + CreatedAt: p.CreatedAt, + LastUpdatedAt: p.UpdatedAt, + } +} + +// profileError maps service errors to client errors where the client is at +// fault (bad PIN, unknown profile), passing other errors through. +func profileError(err error) error { + switch { + case errors.Is(err, profiles.ErrPINRequired), + errors.Is(err, profiles.ErrPINIncorrect), + errors.Is(err, profiles.ErrPINRateLimited), + errors.Is(err, profiles.ErrInvalidPINFormat), + errors.Is(err, profiles.ErrNotFound): + return models.ClientErrf("%w", err) + default: + return err + } +} + +// HandleProfiles lists all profiles. +// +//nolint:gocritic // single-use parameter in API handler +func HandleProfiles(env requests.RequestEnv) (any, error) { + log.Info().Msg("received profiles list request") + if env.Profiles == nil { + return nil, errProfilesUnavailable + } + + list, err := env.Profiles.List() + if err != nil { + log.Error().Err(err).Msg("error listing profiles") + return nil, errors.New("error listing profiles") + } + + resp := models.ProfilesResponse{ + Profiles: make([]models.ProfileResponse, len(list)), + } + for i := range list { + resp.Profiles[i] = profileResponse(&list[i]) + } + return resp, nil +} + +// HandleProfilesNew creates a new profile. +// +//nolint:gocritic // single-use parameter in API handler +func HandleProfilesNew(env requests.RequestEnv) (any, error) { + log.Info().Msg("received profiles new request") + if env.Profiles == nil { + return nil, errProfilesUnavailable + } + + var params models.NewProfileParams + if err := validation.ValidateAndUnmarshal(env.Params, ¶ms); err != nil { + log.Warn().Err(err).Msg("invalid params") + return nil, models.ClientErrf("invalid params: %w", err) + } + + p, err := env.Profiles.Create(¶ms) + if err != nil { + return nil, profileError(err) + } + return profileResponse(p), nil +} + +// HandleProfilesUpdate updates an existing profile. +// +//nolint:gocritic // single-use parameter in API handler +func HandleProfilesUpdate(env requests.RequestEnv) (any, error) { + log.Info().Msg("received profiles update request") + if env.Profiles == nil { + return nil, errProfilesUnavailable + } + + var params models.UpdateProfileParams + if err := validation.ValidateAndUnmarshal(env.Params, ¶ms); err != nil { + log.Warn().Err(err).Msg("invalid params") + return nil, models.ClientErrf("invalid params: %w", err) + } + + p, err := env.Profiles.Update(¶ms) + if err != nil { + return nil, profileError(err) + } + return profileResponse(p), nil +} + +// HandleProfilesDelete removes a profile, deactivating it first if it is +// the active profile. +// +//nolint:gocritic // single-use parameter in API handler +func HandleProfilesDelete(env requests.RequestEnv) (any, error) { + log.Info().Msg("received profiles delete request") + if env.Profiles == nil { + return nil, errProfilesUnavailable + } + + var params models.DeleteProfileParams + if err := validation.ValidateAndUnmarshal(env.Params, ¶ms); err != nil { + log.Warn().Err(err).Msg("invalid params") + return nil, models.ClientErrf("invalid params: %w", err) + } + + if err := env.Profiles.Delete(params.ProfileID); err != nil { + return nil, profileError(err) + } + return NoContent{}, nil +} + +// HandleProfilesActive returns the active profile, or null when none. +// +//nolint:gocritic // single-use parameter in API handler +func HandleProfilesActive(env requests.RequestEnv) (any, error) { + log.Info().Msg("received profiles active request") + if env.Profiles == nil { + return nil, errProfilesUnavailable + } + return env.Profiles.Active(), nil +} + +// HandleProfilesSwitch switches the active profile. Switching to a +// PIN-protected profile requires its PIN, by profile ID or switch ID +// equally — only physical card scans bypass the PIN. Passing neither +// profileId nor switchId deactivates, which is always free (PINs gate +// entry only). +// +//nolint:gocritic // single-use parameter in API handler +func HandleProfilesSwitch(env requests.RequestEnv) (any, error) { + log.Info().Msg("received profiles switch request") + if env.Profiles == nil { + return nil, errProfilesUnavailable + } + + var params models.SwitchProfileParams + if len(env.Params) > 0 { + if err := validation.ValidateAndUnmarshal(env.Params, ¶ms); err != nil { + log.Warn().Err(err).Msg("invalid params") + return nil, models.ClientErrf("invalid params: %w", err) + } + } + + pin := "" + if params.PIN != nil { + pin = *params.PIN + } + + switch { + case params.ProfileID != nil && *params.ProfileID != "": + active, err := env.Profiles.ActivateByID(*params.ProfileID, pin) + if err != nil { + return nil, profileError(err) + } + return active, nil + case params.SwitchID != nil && *params.SwitchID != "": + active, err := env.Profiles.ActivateBySwitchIDChecked(*params.SwitchID, pin) + if err != nil { + return nil, profileError(err) + } + return active, nil + default: + if err := env.Profiles.Deactivate(); err != nil { + return nil, profileError(err) + } + return (*models.ActiveProfile)(nil), nil + } +} diff --git a/pkg/api/methods/profiles_test.go b/pkg/api/methods/profiles_test.go new file mode 100644 index 000000000..6cdc07d5f --- /dev/null +++ b/pkg/api/methods/profiles_test.go @@ -0,0 +1,246 @@ +// Zaparoo Core +// Copyright (c) 2026 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package methods + +import ( + "context" + "encoding/json" + "testing" + + "github.com/ZaparooProject/zaparoo-core/v2/pkg/api/models" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/api/models/requests" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/database" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/profiles" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/state" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/testing/helpers" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +// newProfilesEnv builds a RequestEnv with a real profiles service over a +// mock user DB. +func newProfilesEnv(t *testing.T) (env requests.RequestEnv, mockUserDB *helpers.MockUserDBI, st *state.State) { + t.Helper() + mockUserDB = helpers.NewMockUserDBI() + st, ns := state.NewState(nil, "boot") + t.Cleanup(func() { + for { + select { + case <-ns: + default: + return + } + } + }) + db := &database.Database{UserDB: mockUserDB, MediaDB: nil} + env = requests.RequestEnv{ + Context: context.Background(), + Database: db, + State: st, + Profiles: profiles.NewService(db, st), + } + return env, mockUserDB, st +} + +func testProfileRow(t *testing.T, pin string) *database.Profile { + t.Helper() + p := &database.Profile{ + ProfileID: "profile-1", + Name: "Kid A", + SwitchID: "corn-arm-truck", + CreatedAt: 1700000000, + UpdatedAt: 1700000000, + } + if pin != "" { + hash, err := profiles.HashPIN(pin) + require.NoError(t, err) + p.PINHash = hash + } + return p +} + +func TestHandleProfiles_List(t *testing.T) { + t.Parallel() + env, mockUserDB, _ := newProfilesEnv(t) + + mockUserDB.On("ListProfiles").Return([]database.Profile{*testProfileRow(t, "1234")}, nil) + + result, err := HandleProfiles(env) + require.NoError(t, err) + resp, ok := result.(models.ProfilesResponse) + require.True(t, ok) + require.Len(t, resp.Profiles, 1) + assert.Equal(t, "profile-1", resp.Profiles[0].ProfileID) + assert.Equal(t, "corn-arm-truck", resp.Profiles[0].SwitchID) + assert.True(t, resp.Profiles[0].HasPIN) + + // The PIN hash must never appear in the serialized response. + raw, err := json.Marshal(resp) + require.NoError(t, err) + assert.NotContains(t, string(raw), "pbkdf2") +} + +func TestHandleProfilesNew(t *testing.T) { + t.Parallel() + env, mockUserDB, _ := newProfilesEnv(t) + + mockUserDB.On("CreateProfile", mock.Anything).Return(nil) + + env.Params = json.RawMessage(`{"name": "Kid A", "pin": "1234", "dailyLimit": "2h"}`) + result, err := HandleProfilesNew(env) + require.NoError(t, err) + resp, ok := result.(models.ProfileResponse) + require.True(t, ok) + assert.Equal(t, "Kid A", resp.Name) + assert.NotEmpty(t, resp.SwitchID) + assert.True(t, resp.HasPIN) + require.NotNil(t, resp.DailyLimit) + assert.Equal(t, "2h", *resp.DailyLimit) +} + +func TestHandleProfilesNew_InvalidParams(t *testing.T) { + t.Parallel() + env, _, _ := newProfilesEnv(t) + + // Missing required name. + env.Params = json.RawMessage(`{"pin": "1234"}`) + _, err := HandleProfilesNew(env) + require.Error(t, err) + + // Non-numeric PIN. + env.Params = json.RawMessage(`{"name": "Kid A", "pin": "abcd"}`) + _, err = HandleProfilesNew(env) + require.Error(t, err) + + // Bad duration. + env.Params = json.RawMessage(`{"name": "Kid A", "dailyLimit": "2 hours"}`) + _, err = HandleProfilesNew(env) + require.Error(t, err) +} + +func TestHandleProfilesSwitch_PINFlow(t *testing.T) { + t.Parallel() + env, mockUserDB, st := newProfilesEnv(t) + + mockUserDB.On("GetProfile", "profile-1").Return(testProfileRow(t, "1234"), nil) + + // Missing PIN. + env.Params = json.RawMessage(`{"profileId": "profile-1"}`) + _, err := HandleProfilesSwitch(env) + require.Error(t, err) + assert.Contains(t, err.Error(), "PIN") + + // Wrong PIN. + env.Params = json.RawMessage(`{"profileId": "profile-1", "pin": "9999"}`) + _, err = HandleProfilesSwitch(env) + require.Error(t, err) + + assert.Nil(t, st.ActiveProfile()) + + // Correct PIN. + mockUserDB.On("SetDeviceState", database.DeviceStateKeyActiveProfile, "profile-1").Return(nil) + env.Params = json.RawMessage(`{"profileId": "profile-1", "pin": "1234"}`) + result, err := HandleProfilesSwitch(env) + require.NoError(t, err) + active, ok := result.(*models.ActiveProfile) + require.True(t, ok) + assert.Equal(t, "profile-1", active.ProfileID) + require.NotNil(t, st.ActiveProfile()) +} + +func TestHandleProfilesSwitch_BySwitchIDStillRequiresPIN(t *testing.T) { + t.Parallel() + env, mockUserDB, _ := newProfilesEnv(t) + + mockUserDB.On("GetProfileBySwitchID", "corn-arm-truck").Return(testProfileRow(t, "1234"), nil) + + // Knowing the switch ID is not possession of the card: the API path + // still enforces the PIN. + env.Params = json.RawMessage(`{"switchId": "corn-arm-truck"}`) + _, err := HandleProfilesSwitch(env) + require.Error(t, err) + assert.Contains(t, err.Error(), "PIN") +} + +func TestHandleProfilesSwitch_DeactivateIsFree(t *testing.T) { + t.Parallel() + env, mockUserDB, st := newProfilesEnv(t) + + st.SetActiveProfile(&models.ActiveProfile{ProfileID: "profile-1", Name: "Kid A", HasPIN: true}) + mockUserDB.On("DeleteDeviceState", database.DeviceStateKeyActiveProfile).Return(nil) + + // No params at all = deactivate; PINs gate entry only. + env.Params = nil + result, err := HandleProfilesSwitch(env) + require.NoError(t, err) + active, ok := result.(*models.ActiveProfile) + require.True(t, ok) + assert.Nil(t, active) + assert.Nil(t, st.ActiveProfile()) +} + +func TestHandleProfilesActive(t *testing.T) { + t.Parallel() + env, _, st := newProfilesEnv(t) + + result, err := HandleProfilesActive(env) + require.NoError(t, err) + assert.Nil(t, result.(*models.ActiveProfile)) + + st.SetActiveProfile(&models.ActiveProfile{ProfileID: "profile-1", Name: "Kid A"}) + result, err = HandleProfilesActive(env) + require.NoError(t, err) + active, ok := result.(*models.ActiveProfile) + require.True(t, ok) + assert.Equal(t, "profile-1", active.ProfileID) +} + +func TestHandleProfilesDelete(t *testing.T) { + t.Parallel() + env, mockUserDB, st := newProfilesEnv(t) + + st.SetActiveProfile(&models.ActiveProfile{ProfileID: "profile-1", Name: "Kid A"}) + mockUserDB.On("DeleteProfile", "profile-1").Return(nil) + + env.Params = json.RawMessage(`{"profileId": "profile-1"}`) + result, err := HandleProfilesDelete(env) + require.NoError(t, err) + assert.Equal(t, NoContent{}, result) + assert.Nil(t, st.ActiveProfile(), "deleting the active profile deactivates it") +} + +func TestHandleProfilesUpdate_ClearPIN(t *testing.T) { + t.Parallel() + env, mockUserDB, _ := newProfilesEnv(t) + + mockUserDB.On("GetProfile", "profile-1").Return(testProfileRow(t, "1234"), nil) + mockUserDB.On("UpdateProfile", mock.MatchedBy(func(p *database.Profile) bool { + return p.PINHash == "" + })).Return(nil) + + env.Params = json.RawMessage(`{"profileId": "profile-1", "clearPin": true}`) + result, err := HandleProfilesUpdate(env) + require.NoError(t, err) + resp, ok := result.(models.ProfileResponse) + require.True(t, ok) + assert.False(t, resp.HasPIN) + mockUserDB.AssertExpectations(t) +} diff --git a/pkg/api/methods/settings.go b/pkg/api/methods/settings.go index 6e522042a..89a1a58ef 100644 --- a/pkg/api/methods/settings.go +++ b/pkg/api/methods/settings.go @@ -74,6 +74,7 @@ func HandleSettings(env requests.RequestEnv) (any, error) { //nolint:gocritic // LaunchGuardTimeout: env.Config.LaunchGuardTimeout(), LaunchGuardDelay: env.Config.LaunchGuardDelay(), LaunchGuardRequireConfirm: env.Config.LaunchGuardRequireConfirm(), + ProfilesRequireForLaunch: env.Config.ProfilesRequireForLaunch(), } resp.ReadersScanIgnoreSystem = append(resp.ReadersScanIgnoreSystem, env.Config.ReadersScan().IgnoreSystem...) @@ -223,6 +224,11 @@ func HandleSettingsUpdate(env requests.RequestEnv) (any, error) { env.Config.SetLaunchGuardRequireConfirm(*params.LaunchGuardRequireConfirm) } + if params.ProfilesRequireForLaunch != nil { + log.Debug().Bool("profilesRequireForLaunch", *params.ProfilesRequireForLaunch).Msg("updating setting") + env.Config.SetProfilesRequireForLaunch(*params.ProfilesRequireForLaunch) + } + if params.ReadersConnect != nil { log.Debug().Int("count", len(*params.ReadersConnect)).Msg("updating readers.connect") connections := make([]config.ReadersConnect, 0, len(*params.ReadersConnect)) diff --git a/pkg/api/models/models.go b/pkg/api/models/models.go index c12f156ed..ca4dbda2a 100644 --- a/pkg/api/models/models.go +++ b/pkg/api/models/models.go @@ -39,6 +39,7 @@ const ( NotificationPlaytimeLimitWarning = "playtime.limit.warning" NotificationInboxAdded = "inbox.added" NotificationClientsPaired = "clients.paired" + NotificationProfilesActive = "profiles.active" ) const ( @@ -88,6 +89,12 @@ const ( MethodClientsDelete = "clients.delete" MethodClientsPairStart = "clients.pair.start" MethodClientsPairCancel = "clients.pair.cancel" + MethodProfiles = "profiles" + MethodProfilesNew = "profiles.new" + MethodProfilesUpdate = "profiles.update" + MethodProfilesDelete = "profiles.delete" + MethodProfilesActive = "profiles.active" + MethodProfilesSwitch = "profiles.switch" MethodSystems = "systems" MethodLaunchers = "launchers" MethodLaunchersRefresh = "launchers.refresh" diff --git a/pkg/api/models/params.go b/pkg/api/models/params.go index 8febe8998..cae42ef4e 100644 --- a/pkg/api/models/params.go +++ b/pkg/api/models/params.go @@ -130,6 +130,7 @@ type UpdateSettingsParams struct { LaunchGuardTimeout *float32 `json:"launchGuardTimeout" validate:"omitempty,gte=-1"` LaunchGuardDelay *float32 `json:"launchGuardDelay" validate:"omitempty,gte=0"` LaunchGuardRequireConfirm *bool `json:"launchGuardRequireConfirm"` + ProfilesRequireForLaunch *bool `json:"profilesRequireForLaunch"` } type UpdatePlaytimeLimitsParams struct { @@ -149,6 +150,45 @@ type DeleteClientParams struct { ID string `json:"id" validate:"required,min=1"` } +// NewProfileParams creates a profile. Nil limit fields inherit the global +// config; a "0" duration means explicitly unlimited. +type NewProfileParams struct { + PIN *string `json:"pin" validate:"omitempty,numeric,min=4,max=8"` + LimitsEnabled *bool `json:"limitsEnabled"` + DailyLimit *string `json:"dailyLimit" validate:"omitempty,duration"` + SessionLimit *string `json:"sessionLimit" validate:"omitempty,duration"` + Name string `json:"name" validate:"required,min=1,max=255"` +} + +// UpdateProfileParams updates a profile. Omitted fields are unchanged. +// ClearPIN removes the PIN; ClearLimits resets all limit overrides back to +// inheriting global config; RegenerateSwitchID issues a new switch ID +// (lost-card replacement). +type UpdateProfileParams struct { + Name *string `json:"name" validate:"omitempty,min=1,max=255"` + PIN *string `json:"pin" validate:"omitempty,numeric,min=4,max=8"` + LimitsEnabled *bool `json:"limitsEnabled"` + DailyLimit *string `json:"dailyLimit" validate:"omitempty,duration"` + SessionLimit *string `json:"sessionLimit" validate:"omitempty,duration"` + ProfileID string `json:"profileId" validate:"required,min=1"` + ClearPIN bool `json:"clearPin"` + ClearLimits bool `json:"clearLimits"` + RegenerateSwitchID bool `json:"regenerateSwitchId"` +} + +type DeleteProfileParams struct { + ProfileID string `json:"profileId" validate:"required,min=1"` +} + +// SwitchProfileParams switches the device's active profile. Exactly one of +// ProfileID or SwitchID selects the target; both omitted (or null) means +// deactivate. PIN is required when the target profile has one set. +type SwitchProfileParams struct { + ProfileID *string `json:"profileId"` + SwitchID *string `json:"switchId"` + PIN *string `json:"pin"` +} + type MediaStartedParams struct { SystemID string `json:"systemId" validate:"required"` SystemName string `json:"systemName" validate:"required"` diff --git a/pkg/api/models/requests/requests.go b/pkg/api/models/requests/requests.go index 5632c75be..dffa7e6fe 100644 --- a/pkg/api/models/requests/requests.go +++ b/pkg/api/models/requests/requests.go @@ -30,6 +30,7 @@ import ( "github.com/ZaparooProject/zaparoo-core/v2/pkg/helpers/syncutil" "github.com/ZaparooProject/zaparoo-core/v2/pkg/platforms" "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/playtime" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/profiles" "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/state" "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/tokens" ) @@ -45,6 +46,7 @@ type RequestEnv struct { State *state.State Database *database.Database LimitsManager *playtime.LimitsManager + Profiles *profiles.Service LauncherCache *helpers.LauncherCache Player audio.Player PlaybackManager audio.PlaybackManager diff --git a/pkg/api/models/responses.go b/pkg/api/models/responses.go index 14fdc72c4..3d24aab4e 100644 --- a/pkg/api/models/responses.go +++ b/pkg/api/models/responses.go @@ -97,6 +97,7 @@ type SettingsResponse struct { ErrorReporting bool `json:"errorReporting"` LaunchGuardEnabled bool `json:"launchGuardEnabled"` LaunchGuardRequireConfirm bool `json:"launchGuardRequireConfirm"` + ProfilesRequireForLaunch bool `json:"profilesRequireForLaunch"` } type PlaytimeLimitsResponse struct { @@ -585,6 +586,46 @@ type ClientsPairedNotification struct { ClientName string `json:"clientName"` } +// ProfileResponse represents a device profile in API responses. The PIN +// hash is never exposed — only whether a PIN is set. SwitchID is included: +// it is a card selector, not a credential (API switching by SwitchID still +// enforces the PIN; only physical card scans bypass it). +type ProfileResponse struct { + LimitsEnabled *bool `json:"limitsEnabled,omitempty"` + DailyLimit *string `json:"dailyLimit,omitempty"` + SessionLimit *string `json:"sessionLimit,omitempty"` + ProfileID string `json:"profileId"` + Name string `json:"name"` + SwitchID string `json:"switchId"` + CreatedAt int64 `json:"createdAt"` + LastUpdatedAt int64 `json:"lastUpdatedAt"` + HasPIN bool `json:"hasPin"` +} + +// ProfilesResponse is the response for the profiles RPC method. +type ProfilesResponse struct { + Profiles []ProfileResponse `json:"profiles"` +} + +// ActiveProfile is a snapshot of the device's active profile, held in +// service state and broadcast on the profiles.active notification. It +// carries the resolved limit overrides so the playtime hot path never +// touches the database. Nil limit fields mean "inherit global config". +type ActiveProfile struct { + LimitsEnabled *bool `json:"limitsEnabled,omitempty"` + DailyLimit *string `json:"dailyLimit,omitempty"` + SessionLimit *string `json:"sessionLimit,omitempty"` + ProfileID string `json:"profileId"` + Name string `json:"name"` + HasPIN bool `json:"hasPin"` +} + +// ProfilesActiveNotification is the payload for the profiles.active +// notification. Profile is null when the device has no active profile. +type ProfilesActiveNotification struct { + Profile *ActiveProfile `json:"profile"` +} + type SettingsAuthClaimResponse struct { Domains []string `json:"domains"` } diff --git a/pkg/api/notifications/notifications.go b/pkg/api/notifications/notifications.go index ca7c5726d..35de93b40 100644 --- a/pkg/api/notifications/notifications.go +++ b/pkg/api/notifications/notifications.go @@ -132,3 +132,9 @@ func InboxAdded(ns chan<- models.Notification, payload *models.InboxMessage) { func ClientsPaired(ns chan<- models.Notification, payload models.ClientsPairedNotification) { sendNotification(ns, models.NotificationClientsPaired, payload) } + +// ProfilesActiveChanged broadcasts a change of the device's active profile. +// The payload profile is null when the device deactivated to no profile. +func ProfilesActiveChanged(ns chan<- models.Notification, payload models.ProfilesActiveNotification) { + sendNotification(ns, models.NotificationProfilesActive, payload) +} diff --git a/pkg/api/server.go b/pkg/api/server.go index c807668f2..2262fe548 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -52,6 +52,7 @@ import ( "github.com/ZaparooProject/zaparoo-core/v2/pkg/platforms" "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/broker" "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/playtime" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/profiles" "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/state" "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/tokens" "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/updater" @@ -321,6 +322,13 @@ func NewMethodMap() *MethodMap { // clients (paired API clients) models.MethodClients: methods.HandleClients, models.MethodClientsDelete: methods.HandleClientsDelete, + + models.MethodProfiles: methods.HandleProfiles, + models.MethodProfilesNew: methods.HandleProfilesNew, + models.MethodProfilesUpdate: methods.HandleProfilesUpdate, + models.MethodProfilesDelete: methods.HandleProfilesDelete, + models.MethodProfilesActive: methods.HandleProfilesActive, + models.MethodProfilesSwitch: methods.HandleProfilesSwitch, // auth models.MethodSettingsAuthClaim: func(env requests.RequestEnv) (any, error) { return methods.HandleSettingsAuthClaim(env, zapscript.FetchWellKnown) @@ -965,6 +973,7 @@ func handleWSMessage( confirmQueue chan<- chan error, db *database.Database, limitsManager *playtime.LimitsManager, + profilesSvc *profiles.Service, player audio.Player, playbackManager audio.PlaybackManager, indexPauser *syncutil.Pauser, @@ -1075,6 +1084,7 @@ func handleWSMessage( State: st, Database: db, LimitsManager: limitsManager, + Profiles: profilesSvc, LauncherCache: helpers.GlobalLauncherCache, Player: player, PlaybackManager: playbackManager, @@ -1275,6 +1285,7 @@ func handlePostRequest( confirmQueue chan<- chan error, db *database.Database, limitsManager *playtime.LimitsManager, + profilesSvc *profiles.Service, player audio.Player, playbackManager audio.PlaybackManager, indexPauser *syncutil.Pauser, @@ -1330,6 +1341,7 @@ func handlePostRequest( State: st, Database: db, LimitsManager: limitsManager, + Profiles: profilesSvc, LauncherCache: helpers.GlobalLauncherCache, Player: player, PlaybackManager: playbackManager, @@ -1399,6 +1411,7 @@ func Start( confirmQueue chan<- chan error, db *database.Database, limitsManager *playtime.LimitsManager, + profilesSvc *profiles.Service, notifBroker *broker.Broker, mdnsHostname string, player audio.Player, @@ -1408,7 +1421,7 @@ func Start( tracker RequestTracker, ) error { return StartWithReady( - platform, cfg, st, inTokenQueue, confirmQueue, db, limitsManager, + platform, cfg, st, inTokenQueue, confirmQueue, db, limitsManager, profilesSvc, notifBroker, mdnsHostname, player, playbackManager, indexPauser, scrapePauser, tracker, nil, ) } @@ -1424,6 +1437,7 @@ func StartWithReady( confirmQueue chan<- chan error, db *database.Database, limitsManager *playtime.LimitsManager, + profilesSvc *profiles.Service, notifBroker *broker.Broker, mdnsHostname string, player audio.Player, @@ -1694,7 +1708,7 @@ func StartWithReady( postHandler := handlePostRequest( methodMap, platform, cfg, st, inTokenQueue, confirmQueue, - db, limitsManager, player, playbackManager, + db, limitsManager, profilesSvc, player, playbackManager, indexPauser, scrapePauser, tracker, ) r.Post("/api", postHandler) @@ -1737,7 +1751,7 @@ func StartWithReady( rateLimiter, handleWSMessage( methodMap, platform, cfg, st, inTokenQueue, confirmQueue, - db, limitsManager, player, playbackManager, indexPauser, scrapePauser, encGateway, + db, limitsManager, profilesSvc, player, playbackManager, indexPauser, scrapePauser, encGateway, lastSeenTracker, tracker, ), )) diff --git a/pkg/api/server_post_test.go b/pkg/api/server_post_test.go index 68bcb29b7..e6c58c709 100644 --- a/pkg/api/server_post_test.go +++ b/pkg/api/server_post_test.go @@ -117,7 +117,7 @@ func createTestPostHandler(t *testing.T) (http.HandlerFunc, *MethodMap, *fakeReq handler := handlePostRequest( methodMap, platform, cfg, st, tokenQueue, confirmQueue, db, - nil, nil, playbackManager, nil, nil, tracker, + nil, nil, nil, playbackManager, nil, nil, tracker, ) return handler, methodMap, tracker } @@ -145,7 +145,7 @@ func TestHandlePostRequest_InjectsPlaybackManager(t *testing.T) { confirmQueue := make(chan chan error, 1) handler := handlePostRequest( methodMap, platform, cfg, st, tokenQueue, confirmQueue, db, - nil, nil, playbackManager, nil, nil, nil, + nil, nil, nil, playbackManager, nil, nil, nil, ) reqBody := `{"jsonrpc":"2.0","id":"` + uuid.New().String() + `","method":"test.playback"}` diff --git a/pkg/api/server_startup_test.go b/pkg/api/server_startup_test.go index 50a88c890..fc1d8365d 100644 --- a/pkg/api/server_startup_test.go +++ b/pkg/api/server_startup_test.go @@ -78,7 +78,7 @@ func TestStartWithReadyReportsBindFailure(t *testing.T) { go func() { serverErr <- StartWithReady( platform, cfg, st, tokenQueue, nil, db, - nil, notifBroker, "", nil, nil, nil, nil, nil, ready, + nil, nil, notifBroker, "", nil, nil, nil, nil, nil, ready, ) }() @@ -141,7 +141,7 @@ func TestServerStartupConcurrency(t *testing.T) { defer close(serverDone) serverErr <- Start( platform, cfg, st, tokenQueue, nil, db, - nil, notifBroker, "", nil, nil, nil, nil, nil, + nil, nil, notifBroker, "", nil, nil, nil, nil, nil, ) }() // Cleanup: stop service first, then wait for server goroutine to fully exit @@ -213,7 +213,7 @@ func TestServerStartupImmediateConnection(t *testing.T) { serverErr := make(chan error, 1) go func() { defer close(serverDone) - serverErr <- Start(platform, cfg, st, tokenQueue, nil, db, nil, notifBroker, "", nil, nil, nil, nil, nil) + serverErr <- Start(platform, cfg, st, tokenQueue, nil, db, nil, nil, notifBroker, "", nil, nil, nil, nil, nil) }() // Cleanup: stop service first, then wait for server goroutine to fully exit defer func() { @@ -298,7 +298,7 @@ func TestServerListenContextCancellation(t *testing.T) { go func() { defer close(done) - serverErr <- Start(platform, cfg, st, tokenQueue, nil, db, nil, notifBroker, "", nil, nil, nil, nil, nil) + serverErr <- Start(platform, cfg, st, tokenQueue, nil, db, nil, nil, notifBroker, "", nil, nil, nil, nil, nil) }() // Wait for completion or timeout @@ -662,7 +662,9 @@ func TestServerBindFailureStopsService(t *testing.T) { server1Err := make(chan error, 1) go func() { defer close(server1Done) - server1Err <- Start(platform1, cfg1, st1, tokenQueue1, nil, db1, nil, notifBroker1, "", nil, nil, nil, nil, nil) + server1Err <- Start( + platform1, cfg1, st1, tokenQueue1, nil, db1, nil, nil, notifBroker1, "", nil, nil, nil, nil, nil, + ) }() // Wait for first server to be ready @@ -706,7 +708,9 @@ func TestServerBindFailureStopsService(t *testing.T) { server2Err := make(chan error, 1) go func() { defer close(server2Done) - server2Err <- Start(platform2, cfg2, st2, tokenQueue2, nil, db2, nil, notifBroker2, "", nil, nil, nil, nil, nil) + server2Err <- Start( + platform2, cfg2, st2, tokenQueue2, nil, db2, nil, nil, notifBroker2, "", nil, nil, nil, nil, nil, + ) }() // Wait for the second server's context to be cancelled (StopService called) @@ -979,7 +983,7 @@ func TestSSE_ReceivesNotifications(t *testing.T) { serverErr := make(chan error, 1) go func() { defer close(serverDone) - serverErr <- Start(platform, cfg, st, tokenQueue, nil, db, nil, notifBroker, "", nil, nil, nil, nil, nil) + serverErr <- Start(platform, cfg, st, tokenQueue, nil, db, nil, nil, notifBroker, "", nil, nil, nil, nil, nil) }() defer func() { st.StopService() diff --git a/pkg/api/server_ws_e2e_test.go b/pkg/api/server_ws_e2e_test.go index 8f6b6c431..3f449c874 100644 --- a/pkg/api/server_ws_e2e_test.go +++ b/pkg/api/server_ws_e2e_test.go @@ -155,7 +155,7 @@ func TestWSInjectsPlaybackManager(t *testing.T) { m := newWebSocketSession() m.HandleMessage(handleWSMessage( methodMap, platform, cfg, st, make(chan tokens.Token, 1), make(chan chan error, 1), db, - nil, nil, playbackManager, nil, nil, nil, nil, nil, + nil, nil, nil, playbackManager, nil, nil, nil, nil, nil, )) mux := http.NewServeMux() diff --git a/pkg/api/ws_dispatcher_test.go b/pkg/api/ws_dispatcher_test.go index 9f7a979b4..fa28ab82e 100644 --- a/pkg/api/ws_dispatcher_test.go +++ b/pkg/api/ws_dispatcher_test.go @@ -61,7 +61,7 @@ func startPriorityWSServer(t *testing.T, methodMap *MethodMap) (wsURL string, cl }) m.HandleMessage(handleWSMessage( methodMap, nil, cfg, st, nil, nil, - nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, )) diff --git a/pkg/config/config.go b/pkg/config/config.go index eed465d7c..35d00ab98 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -63,6 +63,7 @@ type Values struct { Service Service `toml:"service,omitempty"` Launchers Launchers `toml:"launchers,omitempty"` Playtime Playtime `toml:"playtime,omitempty"` + Profiles Profiles `toml:"profiles,omitempty"` Media Media `toml:"media,omitempty"` ZapScript ZapScript `toml:"zapscript,omitempty"` Mappings Mappings `toml:"mappings,omitempty"` diff --git a/pkg/config/configprofiles.go b/pkg/config/configprofiles.go new file mode 100644 index 000000000..c0b6bc853 --- /dev/null +++ b/pkg/config/configprofiles.go @@ -0,0 +1,45 @@ +// Zaparoo Core +// Copyright (c) 2026 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package config + +// Profiles configures device profile behavior. +type Profiles struct { + RequireForLaunch *bool `toml:"require_for_launch,omitempty"` +} + +// ProfilesRequireForLaunch returns true when media launches are blocked +// while no profile is active. Defaults to false: a profile-less device +// behaves exactly as before profiles existed. +func (c *Instance) ProfilesRequireForLaunch() bool { + c.mu.RLock() + defer c.mu.RUnlock() + if c.vals.Profiles.RequireForLaunch == nil { + return false + } + return *c.vals.Profiles.RequireForLaunch +} + +// SetProfilesRequireForLaunch enables or disables the require-profile +// launch gate. +func (c *Instance) SetProfilesRequireForLaunch(required bool) { + c.mu.Lock() + defer c.mu.Unlock() + c.vals.Profiles.RequireForLaunch = &required +} diff --git a/pkg/config/configprofiles_test.go b/pkg/config/configprofiles_test.go new file mode 100644 index 000000000..33e7df6ff --- /dev/null +++ b/pkg/config/configprofiles_test.go @@ -0,0 +1,54 @@ +// Zaparoo Core +// Copyright (c) 2026 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package config + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestProfilesRequireForLaunch(t *testing.T) { + t.Parallel() + + cfg, err := NewConfig(t.TempDir(), BaseDefaults) + require.NoError(t, err) + + // Default is false: profiles are purely additive. + assert.False(t, cfg.ProfilesRequireForLaunch()) + + cfg.SetProfilesRequireForLaunch(true) + assert.True(t, cfg.ProfilesRequireForLaunch()) + + cfg.SetProfilesRequireForLaunch(false) + assert.False(t, cfg.ProfilesRequireForLaunch()) +} + +func TestProfilesRequireForLaunch_TOML(t *testing.T) { + t.Parallel() + + cfg, err := NewConfig(t.TempDir(), BaseDefaults) + require.NoError(t, err) + + require.NoError(t, cfg.LoadTOML(`[profiles] +require_for_launch = true`)) + assert.True(t, cfg.ProfilesRequireForLaunch()) +} diff --git a/pkg/database/database.go b/pkg/database/database.go index d5ee5229f..e6e159cdb 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -74,6 +74,7 @@ type MediaHistoryEntry struct { EndTime *time.Time `json:"endTime,omitempty"` SyncedAt *time.Time `json:"syncedAt,omitempty"` DeviceID *string `json:"deviceId,omitempty"` + ProfileID *string `json:"profileId,omitempty"` BootUUID string `json:"bootUuid,omitempty"` ClockSource string `json:"clockSource,omitempty"` SystemID string `json:"systemId"` @@ -123,6 +124,28 @@ type InboxMessage struct { ProfileID int64 `json:"profileId"` } +// Profile represents a device profile: a named bucket of preferences and +// limits with no credentials. PINHash is hidden from JSON +// (API uses models.ProfileResponse instead). Nil limit fields mean +// "inherit the global config value"; a "0" duration string means +// "explicitly unlimited". +type Profile struct { + LimitsEnabled *bool `json:"limitsEnabled,omitempty"` + DailyLimit *string `json:"dailyLimit,omitempty"` + SessionLimit *string `json:"sessionLimit,omitempty"` + ProfileID string `json:"profileId"` + Name string `json:"name"` + SwitchID string `json:"switchId"` + PINHash string `json:"-"` + DBID int64 `json:"-"` + CreatedAt int64 `json:"createdAt"` + UpdatedAt int64 `json:"updatedAt"` +} + +// DeviceStateKeyActiveProfile is the DeviceState key holding the +// ProfileID of the device's active profile. +const DeviceStateKeyActiveProfile = "active_profile_id" + // Client represents a paired API client. AuthToken and PairingKey are // hidden from JSON (API uses models.PairedClient instead). type Client struct { @@ -584,6 +607,16 @@ type UserDBI interface { DeleteClient(clientID string) error UpdateClientLastSeen(authToken string, lastSeenAt int64) error CountClients() (int, error) + CreateProfile(p *Profile) error + GetProfile(profileID string) (*Profile, error) + GetProfileBySwitchID(switchID string) (*Profile, error) + ListProfiles() ([]Profile, error) + UpdateProfile(p *Profile) error + DeleteProfile(profileID string) error + GetMediaHistoryByProfile(profileID string, lastID int64, limit int) ([]MediaHistoryEntry, error) + SetDeviceState(key, value string) error + GetDeviceState(key string) (string, bool, error) + DeleteDeviceState(key string) error } type MediaDBI interface { diff --git a/pkg/database/userdb/media_history.go b/pkg/database/userdb/media_history.go index 0a42bb252..8f12f9113 100644 --- a/pkg/database/userdb/media_history.go +++ b/pkg/database/userdb/media_history.go @@ -61,7 +61,18 @@ func (db *UserDB) GetMediaHistory(systemIDs []string, lastID int64, limit int) ( if db.sql == nil { return nil, ErrNullSQL } - return sqlGetMediaHistory(db.ctx, db.sql, systemIDs, lastID, limit) + return sqlGetMediaHistory(db.ctx, db.sql, systemIDs, nil, lastID, limit) +} + +// GetMediaHistoryByProfile retrieves media history entries attributed to a +// specific profile, with pagination. +func (db *UserDB) GetMediaHistoryByProfile( + profileID string, lastID int64, limit int, +) ([]database.MediaHistoryEntry, error) { + if db.sql == nil { + return nil, ErrNullSQL + } + return sqlGetMediaHistory(db.ctx, db.sql, nil, &profileID, lastID, limit) } // GetLatestMediaHistory retrieves the most recent media history entry with no enrichment. @@ -118,8 +129,8 @@ func sqlAddMediaHistory(ctx context.Context, db *sql.DB, entry *database.MediaHi INSERT INTO MediaHistory( ID, StartTime, SystemID, SystemName, MediaPath, MediaName, LauncherID, PlayTime, BootUUID, MonotonicStart, DurationSec, WallDuration, TimeSkewFlag, - ClockReliable, ClockSource, CreatedAt, UpdatedAt, DeviceID - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?); + ClockReliable, ClockSource, CreatedAt, UpdatedAt, DeviceID, ProfileID + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?); `) if err != nil { return 0, fmt.Errorf("failed to prepare media history insert statement: %w", err) @@ -134,6 +145,10 @@ func sqlAddMediaHistory(ctx context.Context, db *sql.DB, entry *database.MediaHi if entry.DeviceID != nil { deviceID = *entry.DeviceID } + var profileID any + if entry.ProfileID != nil { + profileID = *entry.ProfileID + } result, err := stmt.ExecContext(ctx, entry.ID, @@ -154,6 +169,7 @@ func sqlAddMediaHistory(ctx context.Context, db *sql.DB, entry *database.MediaHi entry.CreatedAt.Unix(), entry.UpdatedAt.Unix(), deviceID, + profileID, ) if err != nil { return 0, fmt.Errorf("failed to execute media history insert: %w", err) @@ -215,7 +231,7 @@ func sqlCloseMediaHistory(ctx context.Context, db *sql.DB, dbid int64, endTime t } func sqlGetMediaHistory( - ctx context.Context, db *sql.DB, systemIDs []string, lastID int64, limit int, + ctx context.Context, db *sql.DB, systemIDs []string, profileID *string, lastID int64, limit int, ) ([]database.MediaHistoryEntry, error) { if limit <= 0 { limit = 25 @@ -232,7 +248,7 @@ func sqlGetMediaHistory( } conditions := []string{"DBID < ?"} - args := make([]any, 0, len(systemIDs)+2) + args := make([]any, 0, len(systemIDs)+3) args = append(args, lastID) if len(systemIDs) == 1 { @@ -247,6 +263,11 @@ func sqlGetMediaHistory( conditions = append(conditions, "SystemID IN ("+strings.Join(placeholders, ", ")+")") } + if profileID != nil { + conditions = append(conditions, "ProfileID = ?") + args = append(args, *profileID) + } + where := strings.Join(conditions, " AND ") args = append(args, limit) queryStarted := time.Now() @@ -257,7 +278,7 @@ func sqlGetMediaHistory( DBID, ID, StartTime, EndTime, SystemID, SystemName, MediaPath, MediaName, LauncherID, PlayTime, BootUUID, MonotonicStart, DurationSec, WallDuration, TimeSkewFlag, - ClockReliable, ClockSource, CreatedAt, UpdatedAt, DeviceID + ClockReliable, ClockSource, CreatedAt, UpdatedAt, DeviceID, ProfileID FROM MediaHistory WHERE %s ORDER BY DBID DESC @@ -290,7 +311,7 @@ func sqlGetMediaHistory( var endTimeUnix sql.NullInt64 var createdAtUnix, updatedAtUnix int64 var id, clockSource sql.NullString - var deviceID sql.NullString + var deviceID, rowProfileID sql.NullString err = rows.Scan( &entry.DBID, @@ -313,6 +334,7 @@ func sqlGetMediaHistory( &createdAtUnix, &updatedAtUnix, &deviceID, + &rowProfileID, ) if err != nil { return list, fmt.Errorf("failed to scan media history row: %w", err) @@ -328,6 +350,10 @@ func sqlGetMediaHistory( deviceStr := deviceID.String entry.DeviceID = &deviceStr } + if rowProfileID.Valid { + profileStr := rowProfileID.String + entry.ProfileID = &profileStr + } entry.StartTime = time.Unix(startTimeUnix, 0) if endTimeUnix.Valid { diff --git a/pkg/database/userdb/media_history_property_test.go b/pkg/database/userdb/media_history_property_test.go index 09a9cb1b7..a237cdfa6 100644 --- a/pkg/database/userdb/media_history_property_test.go +++ b/pkg/database/userdb/media_history_property_test.go @@ -237,7 +237,7 @@ func TestPropertyGetMediaHistoryLimitClamping(t *testing.T) { MediaPath TEXT, MediaName TEXT, LauncherID TEXT, PlayTime INTEGER, BootUUID TEXT, MonotonicStart INTEGER, DurationSec INTEGER, WallDuration INTEGER, TimeSkewFlag INTEGER, ClockReliable INTEGER, ClockSource TEXT, - CreatedAt INTEGER, UpdatedAt INTEGER, DeviceID TEXT + CreatedAt INTEGER, UpdatedAt INTEGER, DeviceID TEXT, ProfileID TEXT ) `) require.NoError(t, err) @@ -246,7 +246,7 @@ func TestPropertyGetMediaHistoryLimitClamping(t *testing.T) { limit := rapid.IntRange(-100, 200).Draw(t, "limit") // The function should clamp limit to valid range - entries, err := sqlGetMediaHistory(ctx, db, nil, 0, limit) + entries, err := sqlGetMediaHistory(ctx, db, nil, nil, 0, limit) require.NoError(t, err) // With empty table, we get empty results regardless of limit @@ -274,7 +274,7 @@ func TestPropertyGetMediaHistoryLastIDPagination(t *testing.T) { MediaPath TEXT, MediaName TEXT, LauncherID TEXT, PlayTime INTEGER, BootUUID TEXT, MonotonicStart INTEGER, DurationSec INTEGER, WallDuration INTEGER, TimeSkewFlag INTEGER, ClockReliable INTEGER, ClockSource TEXT, - CreatedAt INTEGER, UpdatedAt INTEGER, DeviceID TEXT + CreatedAt INTEGER, UpdatedAt INTEGER, DeviceID TEXT, ProfileID TEXT ) `) require.NoError(t, err) @@ -296,7 +296,7 @@ func TestPropertyGetMediaHistoryLastIDPagination(t *testing.T) { lastID := int64(rapid.IntRange(-10, 30).Draw(t, "lastID")) limit := rapid.IntRange(1, 100).Draw(t, "limit") - entries, err := sqlGetMediaHistory(ctx, db, nil, lastID, limit) + entries, err := sqlGetMediaHistory(ctx, db, nil, nil, lastID, limit) require.NoError(t, err) // Verify all returned entries have DBID < lastID (or lastID=0 means all) diff --git a/pkg/database/userdb/media_history_test.go b/pkg/database/userdb/media_history_test.go index 1d7b889ef..3a1921faf 100644 --- a/pkg/database/userdb/media_history_test.go +++ b/pkg/database/userdb/media_history_test.go @@ -83,6 +83,7 @@ func TestSqlAddMediaHistory_Success(t *testing.T) { entry.CreatedAt.Unix(), entry.UpdatedAt.Unix(), nil, + nil, ). WillReturnResult(sqlmock.NewResult(expectedDBID, 1)) @@ -140,6 +141,7 @@ func TestSqlAddMediaHistory_DatabaseError(t *testing.T) { entry.CreatedAt.Unix(), entry.UpdatedAt.Unix(), nil, + nil, ). WillReturnError(sqlmock.ErrCancelled) @@ -246,19 +248,19 @@ func TestSqlGetMediaHistory_Success(t *testing.T) { "DBID", "ID", "StartTime", "EndTime", "SystemID", "SystemName", "MediaPath", "MediaName", "LauncherID", "PlayTime", "BootUUID", "MonotonicStart", "DurationSec", "WallDuration", "TimeSkewFlag", - "ClockReliable", "ClockSource", "CreatedAt", "UpdatedAt", "DeviceID", + "ClockReliable", "ClockSource", "CreatedAt", "UpdatedAt", "DeviceID", "ProfileID", }). AddRow( int64(1), "uuid-1", startTime, endTime, "nes", "Nintendo Entertainment System", "/games/mario.nes", "Super Mario Bros.", "retroarch", 3600, "boot-1", int64(1000), 3600, 3600, false, - true, "system", startTime, startTime, nil, + true, "system", startTime, startTime, nil, nil, ). AddRow( int64(2), "uuid-2", startTime, endTime, "snes", "Super Nintendo", "/games/zelda.sfc", "The Legend of Zelda", "retroarch", 7200, "boot-1", int64(2000), 7200, 7200, false, - true, "system", startTime, startTime, nil, + true, "system", startTime, startTime, nil, nil, ) mock.ExpectPrepare(`SELECT.*FROM MediaHistory.*ORDER BY DBID DESC LIMIT`). @@ -266,7 +268,7 @@ func TestSqlGetMediaHistory_Success(t *testing.T) { WithArgs(int64(math.MaxInt64), limit). // lastID=0 becomes math.MaxInt64 in implementation WillReturnRows(rows) - entries, err := sqlGetMediaHistory(context.Background(), db, nil, lastID, limit) + entries, err := sqlGetMediaHistory(context.Background(), db, nil, nil, lastID, limit) require.NoError(t, err) assert.Len(t, entries, 2) assert.Equal(t, int64(1), entries[0].DBID) @@ -288,7 +290,7 @@ func TestSqlGetMediaHistory_EmptyResult(t *testing.T) { "DBID", "ID", "StartTime", "EndTime", "SystemID", "SystemName", "MediaPath", "MediaName", "LauncherID", "PlayTime", "BootUUID", "MonotonicStart", "DurationSec", "WallDuration", "TimeSkewFlag", - "ClockReliable", "ClockSource", "CreatedAt", "UpdatedAt", "DeviceID", + "ClockReliable", "ClockSource", "CreatedAt", "UpdatedAt", "DeviceID", "ProfileID", }) mock.ExpectPrepare(`SELECT.*FROM MediaHistory.*ORDER BY DBID DESC LIMIT`). @@ -296,7 +298,7 @@ func TestSqlGetMediaHistory_EmptyResult(t *testing.T) { WithArgs(int64(math.MaxInt64), limit). // lastID=0 becomes math.MaxInt64 in implementation WillReturnRows(rows) - entries, err := sqlGetMediaHistory(context.Background(), db, nil, lastID, limit) + entries, err := sqlGetMediaHistory(context.Background(), db, nil, nil, lastID, limit) require.NoError(t, err) assert.Empty(t, entries) assert.NoError(t, mock.ExpectationsWereMet()) @@ -384,7 +386,7 @@ func TestSqlGetMediaHistory_DatabaseError(t *testing.T) { mock.ExpectPrepare(`SELECT.*FROM MediaHistory.*ORDER BY DBID DESC LIMIT`). WillReturnError(sqlmock.ErrCancelled) - entries, err := sqlGetMediaHistory(context.Background(), db, nil, lastID, limit) + entries, err := sqlGetMediaHistory(context.Background(), db, nil, nil, lastID, limit) require.Error(t, err) assert.NotNil(t, entries) // Returns empty slice, not nil assert.Empty(t, entries) @@ -402,7 +404,7 @@ func TestSqlGetMediaHistory_SentinelUsesMaxInt64(t *testing.T) { "DBID", "ID", "StartTime", "EndTime", "SystemID", "SystemName", "MediaPath", "MediaName", "LauncherID", "PlayTime", "BootUUID", "MonotonicStart", "DurationSec", "WallDuration", "TimeSkewFlag", - "ClockReliable", "ClockSource", "CreatedAt", "UpdatedAt", "DeviceID", + "ClockReliable", "ClockSource", "CreatedAt", "UpdatedAt", "DeviceID", "ProfileID", }) // Verify that lastID=0 uses math.MaxInt64 as sentinel, not the old MaxInt32 @@ -411,7 +413,7 @@ func TestSqlGetMediaHistory_SentinelUsesMaxInt64(t *testing.T) { WithArgs(int64(math.MaxInt64), 10). WillReturnRows(rows) - entries, err := sqlGetMediaHistory(context.Background(), db, nil, 0, 10) + entries, err := sqlGetMediaHistory(context.Background(), db, nil, nil, 0, 10) require.NoError(t, err) assert.Empty(t, entries) assert.NoError(t, mock.ExpectationsWereMet()) @@ -431,12 +433,12 @@ func TestSqlGetMediaHistory_LargeLastID(t *testing.T) { "DBID", "ID", "StartTime", "EndTime", "SystemID", "SystemName", "MediaPath", "MediaName", "LauncherID", "PlayTime", "BootUUID", "MonotonicStart", "DurationSec", "WallDuration", "TimeSkewFlag", - "ClockReliable", "ClockSource", "CreatedAt", "UpdatedAt", "DeviceID", + "ClockReliable", "ClockSource", "CreatedAt", "UpdatedAt", "DeviceID", "ProfileID", }).AddRow( int64(math.MaxInt32)+50, "uuid-1", time.Now().Unix(), nil, "nes", "NES", "/games/mario.nes", "Mario", "retroarch", 100, "boot-1", int64(1000), 100, 100, false, - true, "system", time.Now().Unix(), time.Now().Unix(), nil, + true, "system", time.Now().Unix(), time.Now().Unix(), nil, nil, ) mock.ExpectPrepare(`SELECT.*FROM MediaHistory.*ORDER BY DBID DESC LIMIT`). @@ -444,7 +446,7 @@ func TestSqlGetMediaHistory_LargeLastID(t *testing.T) { WithArgs(lastID, limit). WillReturnRows(rows) - entries, err := sqlGetMediaHistory(context.Background(), db, nil, lastID, limit) + entries, err := sqlGetMediaHistory(context.Background(), db, nil, nil, lastID, limit) require.NoError(t, err) assert.Len(t, entries, 1) assert.Equal(t, int64(math.MaxInt32)+50, entries[0].DBID) @@ -465,13 +467,13 @@ func TestSqlGetMediaHistory_SingleSystemFilter(t *testing.T) { "DBID", "ID", "StartTime", "EndTime", "SystemID", "SystemName", "MediaPath", "MediaName", "LauncherID", "PlayTime", "BootUUID", "MonotonicStart", "DurationSec", "WallDuration", "TimeSkewFlag", - "ClockReliable", "ClockSource", "CreatedAt", "UpdatedAt", "DeviceID", + "ClockReliable", "ClockSource", "CreatedAt", "UpdatedAt", "DeviceID", "ProfileID", }). AddRow( int64(1), "uuid-1", startTime, endTime, "SNES", "Super Nintendo", "/games/zelda.sfc", "The Legend of Zelda", "retroarch", 3600, "boot-1", int64(1000), 3600, 3600, false, - true, "system", startTime, startTime, nil, + true, "system", startTime, startTime, nil, nil, ) mock.ExpectPrepare(`SELECT.*FROM MediaHistory.*WHERE DBID < \? AND SystemID = \?.*ORDER BY DBID DESC LIMIT`). @@ -479,7 +481,7 @@ func TestSqlGetMediaHistory_SingleSystemFilter(t *testing.T) { WithArgs(int64(math.MaxInt64), "SNES", 10). WillReturnRows(rows) - entries, err := sqlGetMediaHistory(context.Background(), db, []string{"SNES"}, 0, 10) + entries, err := sqlGetMediaHistory(context.Background(), db, []string{"SNES"}, nil, 0, 10) require.NoError(t, err) assert.Len(t, entries, 1) assert.Equal(t, "SNES", entries[0].SystemID) @@ -500,19 +502,19 @@ func TestSqlGetMediaHistory_MultipleSystemIDs(t *testing.T) { "DBID", "ID", "StartTime", "EndTime", "SystemID", "SystemName", "MediaPath", "MediaName", "LauncherID", "PlayTime", "BootUUID", "MonotonicStart", "DurationSec", "WallDuration", "TimeSkewFlag", - "ClockReliable", "ClockSource", "CreatedAt", "UpdatedAt", "DeviceID", + "ClockReliable", "ClockSource", "CreatedAt", "UpdatedAt", "DeviceID", "ProfileID", }). AddRow( int64(2), "uuid-2", startTime, endTime, "SNES", "Super Nintendo", "/games/zelda.sfc", "Zelda", "retroarch", 3600, "boot-1", int64(1000), 3600, 3600, false, - true, "system", startTime, startTime, nil, + true, "system", startTime, startTime, nil, nil, ). AddRow( int64(1), "uuid-1", startTime, endTime, "NES", "NES", "/games/mario.nes", "Mario", "retroarch", 1800, "boot-1", int64(2000), 1800, 1800, false, - true, "system", startTime, startTime, nil, + true, "system", startTime, startTime, nil, nil, ) mock.ExpectPrepare( @@ -522,7 +524,7 @@ func TestSqlGetMediaHistory_MultipleSystemIDs(t *testing.T) { WithArgs(int64(math.MaxInt64), "SNES", "NES", 10). WillReturnRows(rows) - entries, err := sqlGetMediaHistory(context.Background(), db, []string{"SNES", "NES"}, 0, 10) + entries, err := sqlGetMediaHistory(context.Background(), db, []string{"SNES", "NES"}, nil, 0, 10) require.NoError(t, err) assert.Len(t, entries, 2) assert.Equal(t, "SNES", entries[0].SystemID) @@ -544,13 +546,13 @@ func TestSqlGetMediaHistory_SystemFilterWithPagination(t *testing.T) { "DBID", "ID", "StartTime", "EndTime", "SystemID", "SystemName", "MediaPath", "MediaName", "LauncherID", "PlayTime", "BootUUID", "MonotonicStart", "DurationSec", "WallDuration", "TimeSkewFlag", - "ClockReliable", "ClockSource", "CreatedAt", "UpdatedAt", "DeviceID", + "ClockReliable", "ClockSource", "CreatedAt", "UpdatedAt", "DeviceID", "ProfileID", }). AddRow( int64(8), "uuid-8", startTime, endTime, "SNES", "Super Nintendo", "/games/zelda.sfc", "Zelda", "retroarch", 3600, "boot-1", int64(1000), 3600, 3600, false, - true, "system", startTime, startTime, nil, + true, "system", startTime, startTime, nil, nil, ) // lastID=10 + SystemID filter — both conditions in WHERE clause @@ -559,7 +561,7 @@ func TestSqlGetMediaHistory_SystemFilterWithPagination(t *testing.T) { WithArgs(int64(10), "SNES", 25). WillReturnRows(rows) - entries, err := sqlGetMediaHistory(context.Background(), db, []string{"SNES"}, 10, 25) + entries, err := sqlGetMediaHistory(context.Background(), db, []string{"SNES"}, nil, 10, 25) require.NoError(t, err) assert.Len(t, entries, 1) assert.Equal(t, int64(8), entries[0].DBID) diff --git a/pkg/database/userdb/migrations/20260612000000_create_profiles_table.sql b/pkg/database/userdb/migrations/20260612000000_create_profiles_table.sql new file mode 100644 index 000000000..f4c52af9f --- /dev/null +++ b/pkg/database/userdb/migrations/20260612000000_create_profiles_table.sql @@ -0,0 +1,36 @@ +-- +goose Up +-- +goose StatementBegin + +create table Profiles +( + DBID INTEGER PRIMARY KEY, + ProfileID text not null unique, + Name text not null, + SwitchID text not null unique, + PINHash text, + LimitsEnabled integer, + DailyLimit text, + SessionLimit text, + CreatedAt integer not null, + UpdatedAt integer not null +); + +create table DeviceState +( + Key text primary key, + Value text not null, + UpdatedAt integer not null +); + +ALTER TABLE MediaHistory ADD COLUMN ProfileID TEXT; +CREATE INDEX idx_media_history_profile ON MediaHistory (ProfileID) WHERE ProfileID IS NOT NULL; + +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +DROP INDEX IF EXISTS idx_media_history_profile; +ALTER TABLE MediaHistory DROP COLUMN ProfileID; +DROP TABLE IF EXISTS DeviceState; +DROP TABLE IF EXISTS Profiles; +-- +goose StatementEnd diff --git a/pkg/database/userdb/profiles.go b/pkg/database/userdb/profiles.go new file mode 100644 index 000000000..0599ad85f --- /dev/null +++ b/pkg/database/userdb/profiles.go @@ -0,0 +1,311 @@ +// Zaparoo Core +// Copyright (c) 2026 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package userdb + +import ( + "context" + "database/sql" + "errors" + "fmt" + "time" + + "github.com/ZaparooProject/zaparoo-core/v2/pkg/database" + "github.com/rs/zerolog/log" +) + +// ErrProfileNotFound is returned when a profile lookup matches no row. +var ErrProfileNotFound = errors.New("profile not found") + +func (db *UserDB) CreateProfile(p *database.Profile) error { + if db.sql == nil { + return ErrNullSQL + } + return sqlCreateProfile(db.ctx, db.sql, p) +} + +func (db *UserDB) GetProfile(profileID string) (*database.Profile, error) { + if db.sql == nil { + return nil, ErrNullSQL + } + return sqlGetProfile(db.ctx, db.sql, "ProfileID", profileID) +} + +func (db *UserDB) GetProfileBySwitchID(switchID string) (*database.Profile, error) { + if db.sql == nil { + return nil, ErrNullSQL + } + return sqlGetProfile(db.ctx, db.sql, "SwitchID", switchID) +} + +func (db *UserDB) ListProfiles() ([]database.Profile, error) { + if db.sql == nil { + return nil, ErrNullSQL + } + return sqlListProfiles(db.ctx, db.sql) +} + +func (db *UserDB) UpdateProfile(p *database.Profile) error { + if db.sql == nil { + return ErrNullSQL + } + return sqlUpdateProfile(db.ctx, db.sql, p) +} + +// DeleteProfile removes a profile. If the profile is the device's active +// profile, the active-profile device state is cleared in the same +// transaction. +func (db *UserDB) DeleteProfile(profileID string) error { + if db.sql == nil { + return ErrNullSQL + } + return sqlDeleteProfile(db.ctx, db.sql, profileID) +} + +func (db *UserDB) SetDeviceState(key, value string) error { + if db.sql == nil { + return ErrNullSQL + } + return sqlSetDeviceState(db.ctx, db.sql, key, value) +} + +// GetDeviceState returns the value for key and whether it exists. +func (db *UserDB) GetDeviceState(key string) (value string, found bool, err error) { + if db.sql == nil { + return "", false, ErrNullSQL + } + return sqlGetDeviceState(db.ctx, db.sql, key) +} + +func (db *UserDB) DeleteDeviceState(key string) error { + if db.sql == nil { + return ErrNullSQL + } + return sqlDeleteDeviceState(db.ctx, db.sql, key) +} + +/* + * Internal SQL functions + */ + +const profileColumns = `DBID, ProfileID, Name, SwitchID, PINHash, LimitsEnabled, + DailyLimit, SessionLimit, CreatedAt, UpdatedAt` + +func sqlCreateProfile(ctx context.Context, db *sql.DB, p *database.Profile) error { + var dbid int64 + err := db.QueryRowContext(ctx, ` + INSERT INTO Profiles (ProfileID, Name, SwitchID, PINHash, LimitsEnabled, + DailyLimit, SessionLimit, CreatedAt, UpdatedAt) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + RETURNING DBID; + `, p.ProfileID, p.Name, p.SwitchID, nullableString(p.PINHash), + nullableBool(p.LimitsEnabled), p.DailyLimit, p.SessionLimit, + p.CreatedAt, p.UpdatedAt).Scan(&dbid) + if err != nil { + return fmt.Errorf("failed to insert profile: %w", err) + } + p.DBID = dbid + return nil +} + +func sqlGetProfile(ctx context.Context, db *sql.DB, column, value string) (*database.Profile, error) { + //nolint:gosec // column is a hardcoded column name, not user input + row := db.QueryRowContext(ctx, ` + SELECT `+profileColumns+` + FROM Profiles + WHERE `+column+` = ?; + `, value) + + p, err := scanProfile(row.Scan) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, fmt.Errorf("%w: %s=%s", ErrProfileNotFound, column, value) + } + return nil, fmt.Errorf("failed to scan profile row: %w", err) + } + return p, nil +} + +func sqlListProfiles(ctx context.Context, db *sql.DB) ([]database.Profile, error) { + list := make([]database.Profile, 0) + + rows, err := db.QueryContext(ctx, ` + SELECT `+profileColumns+` + FROM Profiles + ORDER BY CreatedAt ASC; + `) + if err != nil { + return list, fmt.Errorf("failed to query profiles: %w", err) + } + defer func() { + if closeErr := rows.Close(); closeErr != nil { + log.Warn().Err(closeErr).Msg("failed to close sql rows") + } + }() + + for rows.Next() { + p, scanErr := scanProfile(rows.Scan) + if scanErr != nil { + return list, fmt.Errorf("failed to scan profile row: %w", scanErr) + } + list = append(list, *p) + } + + if err = rows.Err(); err != nil { + return list, fmt.Errorf("error iterating profile rows: %w", err) + } + return list, nil +} + +func sqlUpdateProfile(ctx context.Context, db *sql.DB, p *database.Profile) error { + result, err := db.ExecContext(ctx, ` + UPDATE Profiles + SET Name = ?, SwitchID = ?, PINHash = ?, LimitsEnabled = ?, + DailyLimit = ?, SessionLimit = ?, UpdatedAt = ? + WHERE ProfileID = ?; + `, p.Name, p.SwitchID, nullableString(p.PINHash), nullableBool(p.LimitsEnabled), + p.DailyLimit, p.SessionLimit, p.UpdatedAt, p.ProfileID) + if err != nil { + return fmt.Errorf("failed to execute profile update: %w", err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + if rowsAffected == 0 { + return fmt.Errorf("%w: %s", ErrProfileNotFound, p.ProfileID) + } + return nil +} + +func sqlDeleteProfile(ctx context.Context, db *sql.DB, profileID string) error { + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("failed to begin profile delete transaction: %w", err) + } + defer func() { + if rollbackErr := tx.Rollback(); rollbackErr != nil && !errors.Is(rollbackErr, sql.ErrTxDone) { + log.Warn().Err(rollbackErr).Msg("failed to rollback profile delete transaction") + } + }() + + result, err := tx.ExecContext(ctx, `DELETE FROM Profiles WHERE ProfileID = ?;`, profileID) + if err != nil { + return fmt.Errorf("failed to execute profile delete: %w", err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + if rowsAffected == 0 { + return fmt.Errorf("%w: %s", ErrProfileNotFound, profileID) + } + + _, err = tx.ExecContext(ctx, ` + DELETE FROM DeviceState WHERE Key = ? AND Value = ?; + `, database.DeviceStateKeyActiveProfile, profileID) + if err != nil { + return fmt.Errorf("failed to clear active profile device state: %w", err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit profile delete transaction: %w", err) + } + return nil +} + +func sqlSetDeviceState(ctx context.Context, db *sql.DB, key, value string) error { + _, err := db.ExecContext(ctx, ` + INSERT INTO DeviceState (Key, Value, UpdatedAt) + VALUES (?, ?, ?) + ON CONFLICT(Key) DO UPDATE SET Value = excluded.Value, UpdatedAt = excluded.UpdatedAt; + `, key, value, time.Now().Unix()) + if err != nil { + return fmt.Errorf("failed to set device state: %w", err) + } + return nil +} + +func sqlGetDeviceState(ctx context.Context, db *sql.DB, key string) (value string, found bool, err error) { + err = db.QueryRowContext(ctx, `SELECT Value FROM DeviceState WHERE Key = ?;`, key).Scan(&value) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return "", false, nil + } + return "", false, fmt.Errorf("failed to query device state: %w", err) + } + return value, true, nil +} + +func sqlDeleteDeviceState(ctx context.Context, db *sql.DB, key string) error { + _, err := db.ExecContext(ctx, `DELETE FROM DeviceState WHERE Key = ?;`, key) + if err != nil { + return fmt.Errorf("failed to delete device state: %w", err) + } + return nil +} + +// scanProfile reads a profile row using the given scan function, converting +// nullable columns to their pointer/empty-string Go representations. +func scanProfile(scan func(dest ...any) error) (*database.Profile, error) { + var p database.Profile + var pinHash sql.NullString + var limitsEnabled sql.NullBool + var dailyLimit, sessionLimit sql.NullString + + err := scan( + &p.DBID, &p.ProfileID, &p.Name, &p.SwitchID, &pinHash, + &limitsEnabled, &dailyLimit, &sessionLimit, &p.CreatedAt, &p.UpdatedAt, + ) + if err != nil { + return nil, err //nolint:wrapcheck // callers wrap with query context + } + + if pinHash.Valid { + p.PINHash = pinHash.String + } + if limitsEnabled.Valid { + p.LimitsEnabled = &limitsEnabled.Bool + } + if dailyLimit.Valid { + p.DailyLimit = &dailyLimit.String + } + if sessionLimit.Valid { + p.SessionLimit = &sessionLimit.String + } + return &p, nil +} + +// nullableString stores empty strings as NULL. +func nullableString(s string) any { + if s == "" { + return nil + } + return s +} + +// nullableBool stores nil as NULL and otherwise 0/1. +func nullableBool(b *bool) any { + if b == nil { + return nil + } + return *b +} diff --git a/pkg/database/userdb/profiles_test.go b/pkg/database/userdb/profiles_test.go new file mode 100644 index 000000000..012532a0f --- /dev/null +++ b/pkg/database/userdb/profiles_test.go @@ -0,0 +1,240 @@ +// Zaparoo Core +// Copyright (c) 2026 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package userdb + +import ( + "testing" + "time" + + "github.com/ZaparooProject/zaparoo-core/v2/pkg/database" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newTestProfile(profileID, switchID string) *database.Profile { + return &database.Profile{ + ProfileID: profileID, + Name: "Test Profile", + SwitchID: switchID, + CreatedAt: 1700000000, + UpdatedAt: 1700000000, + } +} + +func TestProfiles_CRUD_Integration(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + db, cleanup := setupTempUserDB(t) + defer cleanup() + + p := newTestProfile("profile-uuid-1", "corn-arm-truck") + limitsEnabled := true + daily := "2h30m" + p.LimitsEnabled = &limitsEnabled + p.DailyLimit = &daily + p.PINHash = "fake-pin-hash" + + require.NoError(t, db.CreateProfile(p)) + assert.Positive(t, p.DBID) + + got, err := db.GetProfile("profile-uuid-1") + require.NoError(t, err) + assert.Equal(t, "Test Profile", got.Name) + assert.Equal(t, "corn-arm-truck", got.SwitchID) + assert.Equal(t, "fake-pin-hash", got.PINHash) + require.NotNil(t, got.LimitsEnabled) + assert.True(t, *got.LimitsEnabled) + require.NotNil(t, got.DailyLimit) + assert.Equal(t, "2h30m", *got.DailyLimit) + assert.Nil(t, got.SessionLimit) + + bySwitch, err := db.GetProfileBySwitchID("corn-arm-truck") + require.NoError(t, err) + assert.Equal(t, "profile-uuid-1", bySwitch.ProfileID) + + list, err := db.ListProfiles() + require.NoError(t, err) + require.Len(t, list, 1) + + got.Name = "Renamed" + got.PINHash = "" + got.LimitsEnabled = nil + got.DailyLimit = nil + got.UpdatedAt = 1700000100 + require.NoError(t, db.UpdateProfile(got)) + + updated, err := db.GetProfile("profile-uuid-1") + require.NoError(t, err) + assert.Equal(t, "Renamed", updated.Name) + assert.Empty(t, updated.PINHash) + assert.Nil(t, updated.LimitsEnabled) + assert.Nil(t, updated.DailyLimit) + + require.NoError(t, db.DeleteProfile("profile-uuid-1")) + _, err = db.GetProfile("profile-uuid-1") + require.ErrorIs(t, err, ErrProfileNotFound) +} + +func TestProfiles_NotFoundErrors(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + db, cleanup := setupTempUserDB(t) + defer cleanup() + + _, err := db.GetProfile("missing") + require.ErrorIs(t, err, ErrProfileNotFound) + + _, err = db.GetProfileBySwitchID("missing-switch") + require.ErrorIs(t, err, ErrProfileNotFound) + + err = db.UpdateProfile(newTestProfile("missing", "a-b-c")) + require.ErrorIs(t, err, ErrProfileNotFound) + + err = db.DeleteProfile("missing") + require.ErrorIs(t, err, ErrProfileNotFound) +} + +func TestProfiles_SwitchIDUnique(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + db, cleanup := setupTempUserDB(t) + defer cleanup() + + require.NoError(t, db.CreateProfile(newTestProfile("p1", "same-switch-id"))) + err := db.CreateProfile(newTestProfile("p2", "same-switch-id")) + require.Error(t, err) + assert.Contains(t, err.Error(), "UNIQUE") +} + +func TestProfiles_DeleteClearsActiveState(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + db, cleanup := setupTempUserDB(t) + defer cleanup() + + require.NoError(t, db.CreateProfile(newTestProfile("p1", "switch-one"))) + require.NoError(t, db.CreateProfile(newTestProfile("p2", "switch-two"))) + require.NoError(t, db.SetDeviceState(database.DeviceStateKeyActiveProfile, "p1")) + + // Deleting a non-active profile keeps the active state. + require.NoError(t, db.DeleteProfile("p2")) + value, found, err := db.GetDeviceState(database.DeviceStateKeyActiveProfile) + require.NoError(t, err) + assert.True(t, found) + assert.Equal(t, "p1", value) + + // Deleting the active profile clears it. + require.NoError(t, db.DeleteProfile("p1")) + _, found, err = db.GetDeviceState(database.DeviceStateKeyActiveProfile) + require.NoError(t, err) + assert.False(t, found) +} + +func TestDeviceState_SetGetDelete(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + db, cleanup := setupTempUserDB(t) + defer cleanup() + + _, found, err := db.GetDeviceState("some_key") + require.NoError(t, err) + assert.False(t, found) + + require.NoError(t, db.SetDeviceState("some_key", "value1")) + value, found, err := db.GetDeviceState("some_key") + require.NoError(t, err) + assert.True(t, found) + assert.Equal(t, "value1", value) + + // Upsert overwrites. + require.NoError(t, db.SetDeviceState("some_key", "value2")) + value, found, err = db.GetDeviceState("some_key") + require.NoError(t, err) + assert.True(t, found) + assert.Equal(t, "value2", value) + + require.NoError(t, db.DeleteDeviceState("some_key")) + _, found, err = db.GetDeviceState("some_key") + require.NoError(t, err) + assert.False(t, found) +} + +func TestMediaHistory_ProfileAttribution(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + db, cleanup := setupTempUserDB(t) + defer cleanup() + + now := time.Now() + profileID := "profile-uuid-1" + + attributed := &database.MediaHistoryEntry{ + StartTime: now, + SystemID: "SNES", + SystemName: "Super Nintendo", + MediaPath: "snes/game.sfc", + MediaName: "Game", + LauncherID: "test", + PlayTime: 120, + CreatedAt: now, + UpdatedAt: now, + ProfileID: &profileID, + } + _, err := db.AddMediaHistory(attributed) + require.NoError(t, err) + + unattributed := &database.MediaHistoryEntry{ + StartTime: now, + SystemID: "NES", + SystemName: "Nintendo", + MediaPath: "nes/other.nes", + MediaName: "Other", + LauncherID: "test", + PlayTime: 60, + CreatedAt: now, + UpdatedAt: now, + } + _, err = db.AddMediaHistory(unattributed) + require.NoError(t, err) + + // Profile-scoped query returns only the attributed row. + scoped, err := db.GetMediaHistoryByProfile(profileID, 0, 10) + require.NoError(t, err) + require.Len(t, scoped, 1) + assert.Equal(t, "SNES", scoped[0].SystemID) + require.NotNil(t, scoped[0].ProfileID) + assert.Equal(t, profileID, *scoped[0].ProfileID) + + // Unscoped query returns everything (device-level accounting). + all, err := db.GetMediaHistory(nil, 0, 10) + require.NoError(t, err) + assert.Len(t, all, 2) + + // Unknown profile matches nothing. + none, err := db.GetMediaHistoryByProfile("unknown", 0, 10) + require.NoError(t, err) + assert.Empty(t, none) +} diff --git a/pkg/platforms/platforms.go b/pkg/platforms/platforms.go index 4943fea81..0c2d553ce 100644 --- a/pkg/platforms/platforms.go +++ b/pkg/platforms/platforms.go @@ -158,11 +158,24 @@ type CmdEnv struct { Unsafe bool } +// ProfileSwitchRequest asks the script runner to change the device's +// active profile. SwitchID selects a profile by its card switch ID; Clear +// deactivates the current profile instead. +type ProfileSwitchRequest struct { + SwitchID string + Clear bool +} + // CmdResult returns a summary of what global side effects may or may not have // happened as a result of a single ZapScript command running. type CmdResult struct { // Playlist is the result of the playlist change. Playlist *playlists.Playlist + // ProfileSwitch requests the active profile be changed. Commands return + // the request as intent; the service layer applies it (same pattern as + // Playlist). The scan path activates without a PIN check — possession + // of the card is the authorization. + ProfileSwitch *ProfileSwitchRequest // Strategy indicates which matching strategy was used for title-based launches. // Empty for non-title commands. Used for testing and debugging title resolution. Strategy string diff --git a/pkg/service/context.go b/pkg/service/context.go index 3b3524628..c14711cb9 100644 --- a/pkg/service/context.go +++ b/pkg/service/context.go @@ -27,6 +27,7 @@ import ( "github.com/ZaparooProject/zaparoo-core/v2/pkg/database" "github.com/ZaparooProject/zaparoo-core/v2/pkg/platforms" "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/playlists" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/profiles" "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/state" "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/tokens" ) @@ -38,6 +39,7 @@ type ServiceContext struct { Config *config.Instance State *state.State DB *database.Database + Profiles *profiles.Service PlaybackManager audio.PlaybackManager LaunchSoftwareQueue chan *tokens.Token PlaylistQueue chan *playlists.Playlist diff --git a/pkg/service/media_history_tracker.go b/pkg/service/media_history_tracker.go index bbf6b405d..d9c319e61 100644 --- a/pkg/service/media_history_tracker.go +++ b/pkg/service/media_history_tracker.go @@ -96,6 +96,12 @@ func (t *mediaHistoryTracker) listen(notificationChan <-chan models.Notification CreatedAt: now, UpdatedAt: now, } + // Attribution is fixed at launch time: the row keeps this + // profile even if the active profile switches mid-game. + if activeProfile := t.st.ActiveProfile(); activeProfile != nil { + profileID := activeProfile.ProfileID + entry.ProfileID = &profileID + } dbid, addErr := t.db.UserDB.AddMediaHistory(entry) if addErr != nil { log.Error().Err(addErr).Msg("failed to add media history entry") diff --git a/pkg/service/playtime/limits.go b/pkg/service/playtime/limits.go index 5725187c2..d791d93cb 100644 --- a/pkg/service/playtime/limits.go +++ b/pkg/service/playtime/limits.go @@ -88,6 +88,7 @@ type LimitsManager struct { db *database.Database notificationsSend chan<- models.Notification cfg *config.Instance + limits LimitsProvider player audio.Player cancel context.CancelFunc state SessionState @@ -128,6 +129,7 @@ func NewLimitsManager( db: db, platform: platform, cfg: cfg, + limits: globalProvider{cfg: cfg}, clock: clock, player: player, ctx: ctx, @@ -139,6 +141,12 @@ func NewLimitsManager( } } +// SetLimitsProvider replaces the source of limit values, e.g. with the +// profile-aware resolver. Must be called before Start. +func (tm *LimitsManager) SetLimitsProvider(limits LimitsProvider) { + tm.limits = limits +} + // Broker is the interface for subscribing to notifications. type Broker interface { Subscribe(bufferSize int) (<-chan models.Notification, int) @@ -182,9 +190,13 @@ func (tm *LimitsManager) Stop() { tm.wg.Wait() } -// SetEnabled enables or disables limit enforcement at runtime. -// When disabling, resets the session state completely (clears cooldown and cumulative time). -// When re-enabling, session starts fresh but daily usage from history is still enforced. +// SetEnabled records the runtime enabled state and, when disabling, resets +// the session completely (clears cooldown and cumulative time). Whether +// limits are actually enforced is decided by the LimitsProvider on every +// check (global config, possibly overridden by the active profile) — this +// flag exists for its session-reset side effect when the user toggles +// limits off, and is kept in sync with global config by the settings +// handler. func (tm *LimitsManager) SetEnabled(enabled bool) { tm.enabledMu.Lock() tm.enabled = enabled @@ -214,6 +226,47 @@ func (tm *LimitsManager) SetEnabled(enabled bool) { } } +// ResetSession starts a fresh limit session, called when the active +// profile changes: a different person is playing, so accumulated session +// time belongs to the previous profile. Daily usage is unaffected — it is +// recalculated from the (profile-attributed) history on every check. +// +// If a game is running, tracking restarts from now under the new profile's +// limits rather than stopping: the running game's already-played time was +// the previous profile's. +func (tm *LimitsManager) ResetSession() { + tm.mu.Lock() + defer tm.mu.Unlock() + + if tm.cooldownTimer != nil { + tm.cooldownTimer.Stop() + tm.cooldownTimer = nil + log.Debug().Msg("playtime: cancelled cooldown timer (profile switched)") + } + + switch tm.state { + case StateActive: + log.Info().Msg("playtime: profile switched mid-game, restarting session tracking") + now := tm.clock.Now() + tm.sessionStart = now + tm.sessionStartMono = time.Now() + tm.sessionStartReliable = helpers.IsClockReliable(now) + tm.sessionCumulativeTime = 0 + tm.warningsGiven = make(map[time.Duration]bool) + case StateCooldown: + log.Info().Msg("playtime: profile switched, resetting session state") + tm.transitionTo(StateReset) + tm.sessionStart = time.Time{} + tm.sessionStartMono = time.Time{} + tm.sessionCumulativeTime = 0 + tm.lastStopTime = time.Time{} + tm.sessionStartReliable = false + tm.warningsGiven = make(map[time.Duration]bool) + case StateReset: + tm.sessionCumulativeTime = 0 + } +} + // IsEnabled returns whether limits are currently enforced. func (tm *LimitsManager) IsEnabled() bool { tm.enabledMu.Lock() @@ -264,6 +317,8 @@ func (tm *LimitsManager) handleNotifications(notifChan <-chan models.Notificatio tm.OnMediaStarted() case models.NotificationStopped: tm.OnMediaStopped() + case models.NotificationProfilesActive: + tm.ResetSession() } case <-tm.ctx.Done(): @@ -274,7 +329,7 @@ func (tm *LimitsManager) handleNotifications(notifChan <-chan models.Notificatio // OnMediaStarted handles media.started events and begins time tracking. func (tm *LimitsManager) OnMediaStarted() { - if !tm.cfg.PlaytimeLimitsEnabled() { + if !tm.limits.PlaytimeLimitsEnabled() { return } @@ -436,7 +491,7 @@ func (tm *LimitsManager) checkLoop() { // checkLimits evaluates all rules and handles warnings/limits. func (tm *LimitsManager) checkLimits() { // Respect both config and runtime enabled state - if !tm.cfg.PlaytimeLimitsEnabled() || !tm.IsEnabled() { + if !tm.limits.PlaytimeLimitsEnabled() { return } @@ -582,10 +637,14 @@ func (tm *LimitsManager) buildRuleContext( } // calculateDailyUsage queries the database for today's total play time. +// When a profile is active, only history attributed to that profile is +// counted; otherwise all history counts (device-level accounting). func (tm *LimitsManager) calculateDailyUsage( todayStart time.Time, currentSessionDuration time.Duration, ) (time.Duration, error) { + profileID := tm.limits.ActiveProfileID() + // Query media history for today // Note: GetMediaHistory uses pagination, so we need to fetch all entries var totalUsage time.Duration @@ -593,7 +652,13 @@ func (tm *LimitsManager) calculateDailyUsage( limit := 100 for { - entries, err := tm.db.UserDB.GetMediaHistory(nil, lastID, limit) + var entries []database.MediaHistoryEntry + var err error + if profileID != "" { + entries, err = tm.db.UserDB.GetMediaHistoryByProfile(profileID, lastID, limit) + } else { + entries, err = tm.db.UserDB.GetMediaHistory(nil, lastID, limit) + } if err != nil { return 0, fmt.Errorf("failed to query media history: %w", err) } @@ -654,11 +719,11 @@ done: func (tm *LimitsManager) createRules() []Rule { rules := make([]Rule, 0, 2) - if limit := tm.cfg.SessionLimit(); limit > 0 { + if limit := tm.limits.SessionLimit(); limit > 0 { rules = append(rules, &SessionLimitRule{Limit: limit}) } - if limit := tm.cfg.DailyLimit(); limit > 0 { + if limit := tm.limits.DailyLimit(); limit > 0 { rules = append(rules, &DailyLimitRule{Limit: limit}) } @@ -667,7 +732,7 @@ func (tm *LimitsManager) createRules() []Rule { // handleWarnings checks if warnings should be emitted based on remaining time. func (tm *LimitsManager) handleWarnings(remaining time.Duration) { - intervals := tm.cfg.WarningIntervals() + intervals := tm.limits.WarningIntervals() // Sort intervals in descending order (largest first) sort.Slice(intervals, func(i, j int) bool { @@ -744,7 +809,7 @@ func (tm *LimitsManager) GetStatus() *StatusInfo { // Calculate daily usage/remaining even during reset - this data is valid // regardless of session state (the user has used time today and has // time remaining in their daily allowance) - dailyLimit := tm.cfg.DailyLimit() + dailyLimit := tm.limits.DailyLimit() if dailyLimit > 0 && helpers.IsClockReliable(now) { year, month, day := now.Date() todayStart := time.Date(year, month, day, 0, 0, 0, 0, now.Location()) @@ -776,8 +841,8 @@ func (tm *LimitsManager) GetStatus() *StatusInfo { // Calculate remaining times based on cumulative time var sessionRemaining time.Duration - sessionLimit := tm.cfg.SessionLimit() - dailyLimit := tm.cfg.DailyLimit() + sessionLimit := tm.limits.SessionLimit() + dailyLimit := tm.limits.DailyLimit() if sessionLimit > 0 { sessionRemaining = sessionLimit - cumulativeTime @@ -831,8 +896,8 @@ func (tm *LimitsManager) GetStatus() *StatusInfo { // Calculate session remaining time var sessionRemaining time.Duration - sessionLimit := tm.cfg.SessionLimit() - dailyLimit := tm.cfg.DailyLimit() + sessionLimit := tm.limits.SessionLimit() + dailyLimit := tm.limits.DailyLimit() if sessionLimit > 0 { sessionRemaining = sessionLimit - ctx.SessionDuration @@ -885,12 +950,12 @@ func (tm *LimitsManager) GetStatus() *StatusInfo { // trying to stop games immediately after they launch. func (tm *LimitsManager) CheckBeforeLaunch() error { // Check if limits are enabled (both config and runtime) - if !tm.cfg.PlaytimeLimitsEnabled() || !tm.IsEnabled() { + if !tm.limits.PlaytimeLimitsEnabled() { return nil } - dailyLimit := tm.cfg.DailyLimit() - sessionLimit := tm.cfg.SessionLimit() + dailyLimit := tm.limits.DailyLimit() + sessionLimit := tm.limits.SessionLimit() // If no limits configured, allow launch if dailyLimit == 0 && sessionLimit == 0 { diff --git a/pkg/service/playtime/profile_limits_test.go b/pkg/service/playtime/profile_limits_test.go new file mode 100644 index 000000000..7408db5ef --- /dev/null +++ b/pkg/service/playtime/profile_limits_test.go @@ -0,0 +1,211 @@ +// Zaparoo Core +// Copyright (c) 2026 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package playtime + +import ( + "testing" + "time" + + "github.com/ZaparooProject/zaparoo-core/v2/pkg/config" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/database" + testhelpers "github.com/ZaparooProject/zaparoo-core/v2/pkg/testing/helpers" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// stubProvider is a fixed-value LimitsProvider for tests. +type stubProvider struct { + profileID string + warnings []time.Duration + daily time.Duration + session time.Duration + enabled bool +} + +func (s stubProvider) PlaytimeLimitsEnabled() bool { return s.enabled } +func (s stubProvider) DailyLimit() time.Duration { return s.daily } +func (s stubProvider) SessionLimit() time.Duration { return s.session } +func (s stubProvider) WarningIntervals() []time.Duration { return s.warnings } +func (s stubProvider) ActiveProfileID() string { return s.profileID } + +func TestCalculateDailyUsage_ProfileScoped(t *testing.T) { + t.Parallel() + + now := time.Date(2026, 6, 12, 12, 0, 0, 0, time.UTC) + todayStart := time.Date(2026, 6, 12, 0, 0, 0, 0, time.UTC) + startTime := now.Add(-2 * time.Hour) + endTime := now.Add(-1 * time.Hour) + profileID := "kid-a" + + mockDB := testhelpers.NewMockUserDBI() + // Only the profile-scoped query may be used; an unscoped GetMediaHistory + // call would fail the mock expectations. + mockDB.On("GetMediaHistoryByProfile", profileID, int64(0), 100). + Return([]database.MediaHistoryEntry{ + { + DBID: 1, + StartTime: startTime, + EndTime: &endTime, + PlayTime: 3600, + ProfileID: &profileID, + }, + }, nil) + + tm := NewLimitsManager( + &database.Database{UserDB: mockDB}, nil, &config.Instance{}, + clockwork.NewFakeClockAt(now), newNoOpMockPlayer(), + ) + tm.SetLimitsProvider(stubProvider{enabled: true, daily: 2 * time.Hour, profileID: profileID}) + + usage, err := tm.calculateDailyUsage(todayStart, 0) + require.NoError(t, err) + assert.Equal(t, time.Hour, usage) + mockDB.AssertExpectations(t) +} + +func TestCalculateDailyUsage_NoProfile_SumsAllHistory(t *testing.T) { + t.Parallel() + + now := time.Date(2026, 6, 12, 12, 0, 0, 0, time.UTC) + todayStart := time.Date(2026, 6, 12, 0, 0, 0, 0, time.UTC) + startTime := now.Add(-2 * time.Hour) + endTime := now.Add(-1 * time.Hour) + + mockDB := testhelpers.NewMockUserDBI() + // Without an active profile, the unscoped query is used — device-level + // accounting is byte-identical to pre-profile behavior. + mockDB.On("GetMediaHistory", []string(nil), int64(0), 100). + Return([]database.MediaHistoryEntry{ + {DBID: 1, StartTime: startTime, EndTime: &endTime, PlayTime: 1800}, + }, nil) + + tm := NewLimitsManager( + &database.Database{UserDB: mockDB}, nil, &config.Instance{}, + clockwork.NewFakeClockAt(now), newNoOpMockPlayer(), + ) + + usage, err := tm.calculateDailyUsage(todayStart, 0) + require.NoError(t, err) + assert.Equal(t, 30*time.Minute, usage) + mockDB.AssertExpectations(t) +} + +func TestCheckBeforeLaunch_ProfileLimitsWithGlobalDisabled(t *testing.T) { + t.Parallel() + + now := time.Date(2026, 6, 12, 12, 0, 0, 0, time.UTC) + startTime := now.Add(-3 * time.Hour) + endTime := now.Add(-1 * time.Hour) + profileID := "kid-a" + + mockDB := testhelpers.NewMockUserDBI() + mockDB.On("GetMediaHistoryByProfile", profileID, int64(0), 100). + Return([]database.MediaHistoryEntry{ + { + DBID: 1, + StartTime: startTime, + EndTime: &endTime, + PlayTime: 7200, // 2 hours played today by this profile + ProfileID: &profileID, + }, + }, nil) + + // Global config has limits disabled — the profile override alone must + // enforce its 1 hour daily limit (the issue #883 use case). + cfg, err := config.NewConfig(t.TempDir(), config.BaseDefaults) + require.NoError(t, err) + cfg.SetPlaytimeLimitsEnabled(false) + + tm := NewLimitsManager( + &database.Database{UserDB: mockDB}, nil, cfg, + clockwork.NewFakeClockAt(now), newNoOpMockPlayer(), + ) + tm.SetLimitsProvider(stubProvider{enabled: true, daily: time.Hour, profileID: profileID}) + + err = tm.CheckBeforeLaunch() + require.Error(t, err) + assert.Contains(t, err.Error(), "daily playtime limit reached") +} + +func TestCheckBeforeLaunch_ProviderDisabledSkipsChecks(t *testing.T) { + t.Parallel() + + mockDB := testhelpers.NewMockUserDBI() + tm := NewLimitsManager( + &database.Database{UserDB: mockDB}, nil, &config.Instance{}, + clockwork.NewFakeClockAt(time.Date(2026, 6, 12, 12, 0, 0, 0, time.UTC)), + newNoOpMockPlayer(), + ) + tm.SetLimitsProvider(stubProvider{enabled: false, daily: time.Nanosecond}) + + require.NoError(t, tm.CheckBeforeLaunch()) + mockDB.AssertExpectations(t) +} + +func TestResetSession_ActiveGameRestartsTracking(t *testing.T) { + t.Parallel() + + now := time.Date(2026, 6, 12, 12, 0, 0, 0, time.UTC) + clock := clockwork.NewFakeClockAt(now) + tm := NewLimitsManager(&database.Database{}, nil, &config.Instance{}, clock, newNoOpMockPlayer()) + + // Simulate an active session with accumulated time. + tm.mu.Lock() + tm.state = StateActive + tm.sessionStart = now.Add(-time.Hour) + tm.sessionStartMono = time.Now().Add(-time.Hour) + tm.sessionCumulativeTime = 45 * time.Minute + tm.warningsGiven[5*time.Minute] = true + tm.mu.Unlock() + + tm.ResetSession() + + tm.mu.Lock() + defer tm.mu.Unlock() + assert.Equal(t, StateActive, tm.state, "running game keeps tracking under the new profile") + assert.True(t, tm.sessionStart.Equal(now), "session restarts from the switch moment") + assert.Equal(t, time.Duration(0), tm.sessionCumulativeTime) + assert.Empty(t, tm.warningsGiven) +} + +func TestResetSession_CooldownResets(t *testing.T) { + t.Parallel() + + now := time.Date(2026, 6, 12, 12, 0, 0, 0, time.UTC) + clock := clockwork.NewFakeClockAt(now) + tm := NewLimitsManager(&database.Database{}, nil, &config.Instance{}, clock, newNoOpMockPlayer()) + + tm.mu.Lock() + tm.state = StateCooldown + tm.sessionCumulativeTime = 30 * time.Minute + tm.lastStopTime = now.Add(-time.Minute) + tm.cooldownTimer = clock.NewTimer(20 * time.Minute) + tm.mu.Unlock() + + tm.ResetSession() + + tm.mu.Lock() + defer tm.mu.Unlock() + assert.Equal(t, StateReset, tm.state) + assert.Equal(t, time.Duration(0), tm.sessionCumulativeTime) + assert.Nil(t, tm.cooldownTimer) + assert.True(t, tm.lastStopTime.IsZero()) +} diff --git a/pkg/service/playtime/provider.go b/pkg/service/playtime/provider.go new file mode 100644 index 000000000..0d97329f9 --- /dev/null +++ b/pkg/service/playtime/provider.go @@ -0,0 +1,58 @@ +// Zaparoo Core +// Copyright (c) 2026 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package playtime + +import ( + "time" + + "github.com/ZaparooProject/zaparoo-core/v2/pkg/config" +) + +// LimitsProvider is the source of playtime limit values for the +// LimitsManager. The default implementation reads global config directly; +// the profiles service provides an implementation that layers the active +// profile's overrides over global config (see +// pkg/service/profiles.LimitsResolver). +type LimitsProvider interface { + // PlaytimeLimitsEnabled reports whether limits are enforced. + PlaytimeLimitsEnabled() bool + // DailyLimit returns the daily limit, or 0 for no limit. + DailyLimit() time.Duration + // SessionLimit returns the per-session limit, or 0 for no limit. + SessionLimit() time.Duration + // WarningIntervals returns the remaining-time warning thresholds. + WarningIntervals() []time.Duration + // ActiveProfileID returns the active profile's ID, or "" when no + // profile is active. Daily usage accounting is scoped to this + // profile's attributed history; "" sums all history (device-level). + ActiveProfileID() string +} + +// globalProvider is the default LimitsProvider: global config values, no +// profile scoping. +type globalProvider struct { + cfg *config.Instance +} + +func (g globalProvider) PlaytimeLimitsEnabled() bool { return g.cfg.PlaytimeLimitsEnabled() } +func (g globalProvider) DailyLimit() time.Duration { return g.cfg.DailyLimit() } +func (g globalProvider) SessionLimit() time.Duration { return g.cfg.SessionLimit() } +func (g globalProvider) WarningIntervals() []time.Duration { return g.cfg.WarningIntervals() } +func (globalProvider) ActiveProfileID() string { return "" } diff --git a/pkg/service/profiles/pin.go b/pkg/service/profiles/pin.go new file mode 100644 index 000000000..abd8d11ab --- /dev/null +++ b/pkg/service/profiles/pin.go @@ -0,0 +1,111 @@ +// Zaparoo Core +// Copyright (c) 2026 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package profiles + +import ( + "crypto/pbkdf2" + "crypto/rand" + "crypto/sha256" + "crypto/subtle" + "encoding/base64" + "errors" + "fmt" + "strconv" + "strings" +) + +// Profile PINs are short convenience barriers, not credentials. They are +// still stored only as PBKDF2-SHA256 hashes, with +// brute force limited by the in-memory rate limiter in Service. +const ( + pinMinLen = 4 + pinMaxLen = 8 + pinIterations = 600_000 + pinSaltLen = 16 + pinKeyLen = 32 + pinHashPrefix = "pbkdf2-sha256" +) + +// ErrInvalidPINFormat is returned when a PIN is not 4-8 digits. +var ErrInvalidPINFormat = errors.New("PIN must be 4 to 8 digits") + +// HashPIN validates and hashes a PIN for storage. The encoded form is +// "pbkdf2-sha256$$$". +func HashPIN(pin string) (string, error) { + if err := validatePIN(pin); err != nil { + return "", err + } + + salt := make([]byte, pinSaltLen) + if _, err := rand.Read(salt); err != nil { + return "", fmt.Errorf("failed to generate PIN salt: %w", err) + } + + key, err := pbkdf2.Key(sha256.New, pin, salt, pinIterations, pinKeyLen) + if err != nil { + return "", fmt.Errorf("failed to derive PIN hash: %w", err) + } + + encoded := pinHashPrefix + "$" + + strconv.Itoa(pinIterations) + "$" + + base64.RawStdEncoding.EncodeToString(salt) + "$" + + base64.RawStdEncoding.EncodeToString(key) + return encoded, nil +} + +// VerifyPIN reports whether pin matches the encoded hash. Malformed hashes +// verify as false. +func VerifyPIN(pin, encoded string) bool { + parts := strings.Split(encoded, "$") + if len(parts) != 4 || parts[0] != pinHashPrefix { + return false + } + + iterations, err := strconv.Atoi(parts[1]) + if err != nil || iterations <= 0 { + return false + } + salt, err := base64.RawStdEncoding.DecodeString(parts[2]) + if err != nil { + return false + } + expected, err := base64.RawStdEncoding.DecodeString(parts[3]) + if err != nil { + return false + } + + key, err := pbkdf2.Key(sha256.New, pin, salt, iterations, len(expected)) + if err != nil { + return false + } + return subtle.ConstantTimeCompare(key, expected) == 1 +} + +func validatePIN(pin string) error { + if len(pin) < pinMinLen || len(pin) > pinMaxLen { + return ErrInvalidPINFormat + } + for _, r := range pin { + if r < '0' || r > '9' { + return ErrInvalidPINFormat + } + } + return nil +} diff --git a/pkg/service/profiles/pin_test.go b/pkg/service/profiles/pin_test.go new file mode 100644 index 000000000..2d499b334 --- /dev/null +++ b/pkg/service/profiles/pin_test.go @@ -0,0 +1,70 @@ +// Zaparoo Core +// Copyright (c) 2026 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package profiles + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHashPIN_RoundTrip(t *testing.T) { + t.Parallel() + + hash, err := HashPIN("1234") + require.NoError(t, err) + assert.True(t, strings.HasPrefix(hash, "pbkdf2-sha256$")) + assert.True(t, VerifyPIN("1234", hash)) + assert.False(t, VerifyPIN("4321", hash)) + assert.False(t, VerifyPIN("", hash)) +} + +func TestHashPIN_SaltsDiffer(t *testing.T) { + t.Parallel() + + hash1, err := HashPIN("12345678") + require.NoError(t, err) + hash2, err := HashPIN("12345678") + require.NoError(t, err) + assert.NotEqual(t, hash1, hash2) + assert.True(t, VerifyPIN("12345678", hash1)) + assert.True(t, VerifyPIN("12345678", hash2)) +} + +func TestHashPIN_RejectsInvalidFormats(t *testing.T) { + t.Parallel() + + for _, pin := range []string{"", "123", "123456789", "12a4", "12.4", "abcd"} { + _, err := HashPIN(pin) + require.ErrorIs(t, err, ErrInvalidPINFormat, "pin %q", pin) + } +} + +func TestVerifyPIN_MalformedHash(t *testing.T) { + t.Parallel() + + assert.False(t, VerifyPIN("1234", "")) + assert.False(t, VerifyPIN("1234", "not-a-hash")) + assert.False(t, VerifyPIN("1234", "pbkdf2-sha256$abc$def$ghi")) + assert.False(t, VerifyPIN("1234", "pbkdf2-sha256$0$AAAA$AAAA")) + assert.False(t, VerifyPIN("1234", "other-scheme$600000$AAAA$AAAA")) +} diff --git a/pkg/service/profiles/profiles.go b/pkg/service/profiles/profiles.go new file mode 100644 index 000000000..e537ac4c3 --- /dev/null +++ b/pkg/service/profiles/profiles.go @@ -0,0 +1,422 @@ +// Zaparoo Core +// Copyright (c) 2026 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +// Package profiles implements device profiles: named buckets of preferences +// and limits with no credentials. One profile is active per +// device at a time. Switching via the API enforces an optional per-profile +// PIN; switching by scanning a profile's physical card bypasses the PIN +// (possession of the card is the authorization). +package profiles + +import ( + "errors" + "fmt" + "strings" + "time" + + "github.com/ZaparooProject/zaparoo-core/v2/pkg/api/models" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/database" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/database/userdb" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/helpers/syncutil" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/state" + "github.com/google/uuid" + "github.com/rs/zerolog/log" +) + +const ( + // pinAttemptWindow and pinAttemptLimit bound PIN guesses per profile: + // at most pinAttemptLimit failures within pinAttemptWindow. + pinAttemptWindow = time.Minute + pinAttemptLimit = 5 + + // switchIDRetries bounds regeneration attempts on the (vanishingly + // unlikely) unique-constraint collision of a generated switch ID. + switchIDRetries = 5 +) + +var ( + // ErrPINRequired is returned when switching to a PIN-protected profile + // without supplying a PIN. + ErrPINRequired = errors.New("profile requires a PIN") + // ErrPINIncorrect is returned when the supplied PIN does not match. + ErrPINIncorrect = errors.New("incorrect PIN") + // ErrPINRateLimited is returned when too many failed PIN attempts have + // been made against a profile. + ErrPINRateLimited = errors.New("too many PIN attempts, try again later") + // ErrNotFound is returned when a profile does not exist. + ErrNotFound = userdb.ErrProfileNotFound +) + +// Service owns the device's profile lifecycle: CRUD, the active-profile +// state, and PIN-checked switching. All activation paths (API, ZapScript +// card scans, boot restore) go through here. +type Service struct { + db *database.Database + st *state.State + now func() time.Time + pinAttempts map[string][]time.Time + mu syncutil.Mutex +} + +// NewService creates a profiles service backed by the user database and +// service state. +func NewService(db *database.Database, st *state.State) *Service { + return &Service{ + db: db, + st: st, + now: time.Now, + pinAttempts: make(map[string][]time.Time), + } +} + +// List returns all profiles. +func (s *Service) List() ([]database.Profile, error) { + profiles, err := s.db.UserDB.ListProfiles() + if err != nil { + return nil, fmt.Errorf("failed to list profiles: %w", err) + } + return profiles, nil +} + +// Get returns a profile by its profile ID. +func (s *Service) Get(profileID string) (*database.Profile, error) { + p, err := s.db.UserDB.GetProfile(profileID) + if err != nil { + return nil, fmt.Errorf("failed to get profile: %w", err) + } + return p, nil +} + +// Create creates a new profile with a generated profile ID and switch ID, +// hashing the PIN if one is given. Limit duration strings must already be +// validated by the caller (API layer) or empty. +func (s *Service) Create(params *models.NewProfileParams) (*database.Profile, error) { + if err := validateLimitDurations(params.DailyLimit, params.SessionLimit); err != nil { + return nil, err + } + + p := &database.Profile{ + ProfileID: uuid.New().String(), + Name: params.Name, + LimitsEnabled: params.LimitsEnabled, + DailyLimit: params.DailyLimit, + SessionLimit: params.SessionLimit, + CreatedAt: s.now().Unix(), + UpdatedAt: s.now().Unix(), + } + + if params.PIN != nil && *params.PIN != "" { + hash, err := HashPIN(*params.PIN) + if err != nil { + return nil, err + } + p.PINHash = hash + } + + if err := s.insertWithSwitchID(p); err != nil { + return nil, err + } + + return p, nil +} + +// Update applies an update to a profile. If the profile is currently +// active, the in-memory snapshot is refreshed so changed limits apply +// immediately. +func (s *Service) Update(params *models.UpdateProfileParams) (*database.Profile, error) { + if err := validateLimitDurations(params.DailyLimit, params.SessionLimit); err != nil { + return nil, err + } + + p, err := s.db.UserDB.GetProfile(params.ProfileID) + if err != nil { + return nil, fmt.Errorf("failed to get profile: %w", err) + } + + if params.Name != nil { + p.Name = *params.Name + } + switch { + case params.ClearPIN: + p.PINHash = "" + case params.PIN != nil && *params.PIN != "": + hash, hashErr := HashPIN(*params.PIN) + if hashErr != nil { + return nil, hashErr + } + p.PINHash = hash + } + if params.ClearLimits { + p.LimitsEnabled = nil + p.DailyLimit = nil + p.SessionLimit = nil + } else { + if params.LimitsEnabled != nil { + p.LimitsEnabled = params.LimitsEnabled + } + if params.DailyLimit != nil { + p.DailyLimit = params.DailyLimit + } + if params.SessionLimit != nil { + p.SessionLimit = params.SessionLimit + } + } + p.UpdatedAt = s.now().Unix() + + if params.RegenerateSwitchID { + if regenErr := s.updateWithNewSwitchID(p); regenErr != nil { + return nil, regenErr + } + } else if updateErr := s.db.UserDB.UpdateProfile(p); updateErr != nil { + return nil, fmt.Errorf("failed to update profile: %w", updateErr) + } + + // Refresh the active snapshot if this profile is active so limit + // changes take effect without a re-switch. + if active := s.st.ActiveProfile(); active != nil && active.ProfileID == p.ProfileID { + s.st.SetActiveProfile(snapshot(p)) + } + + return p, nil +} + +// Delete removes a profile. If it is the active profile, the device +// deactivates (the persisted active state is cleared transactionally by +// the database layer). +func (s *Service) Delete(profileID string) error { + if err := s.db.UserDB.DeleteProfile(profileID); err != nil { + return fmt.Errorf("failed to delete profile: %w", err) + } + + if active := s.st.ActiveProfile(); active != nil && active.ProfileID == profileID { + s.st.SetActiveProfile(nil) + } + + return nil +} + +// ActivateByID switches the device to a profile, enforcing its PIN if one +// is set. This is the API path; card scans use ActivateBySwitchID. +func (s *Service) ActivateByID(profileID, pin string) (*models.ActiveProfile, error) { + p, err := s.db.UserDB.GetProfile(profileID) + if err != nil { + return nil, fmt.Errorf("failed to get profile: %w", err) + } + if err := s.checkPIN(p, pin); err != nil { + return nil, err + } + return s.activate(p) +} + +// ActivateBySwitchIDChecked switches to a profile selected by switch ID, +// enforcing its PIN. Used when a switch ID arrives over the API rather +// than from a physical scan. +func (s *Service) ActivateBySwitchIDChecked(switchID, pin string) (*models.ActiveProfile, error) { + p, err := s.db.UserDB.GetProfileBySwitchID(switchID) + if err != nil { + return nil, fmt.Errorf("failed to get profile by switch ID: %w", err) + } + if err := s.checkPIN(p, pin); err != nil { + return nil, err + } + return s.activate(p) +} + +// ActivateBySwitchID switches to a profile selected by switch ID without a +// PIN check. This is the physical card-scan path: possession of the card +// is the authorization. +func (s *Service) ActivateBySwitchID(switchID string) (*models.ActiveProfile, error) { + p, err := s.db.UserDB.GetProfileBySwitchID(switchID) + if err != nil { + return nil, fmt.Errorf("failed to get profile by switch ID: %w", err) + } + return s.activate(p) +} + +// Deactivate clears the active profile. Leaving a profile is always free +// (PINs gate entry only); restricting what a profile-less device can do is +// handled by the require-profile launch setting. +func (s *Service) Deactivate() error { + if err := s.db.UserDB.DeleteDeviceState(database.DeviceStateKeyActiveProfile); err != nil { + return fmt.Errorf("failed to clear active profile state: %w", err) + } + s.st.SetActiveProfile(nil) + return nil +} + +// Active returns the active profile snapshot, or nil when none is active. +func (s *Service) Active() *models.ActiveProfile { + return s.st.ActiveProfile() +} + +// RestoreOnBoot restores the persisted active profile into service state. +// A dangling reference to a deleted profile is cleaned up silently. +func (s *Service) RestoreOnBoot() error { + profileID, found, err := s.db.UserDB.GetDeviceState(database.DeviceStateKeyActiveProfile) + if err != nil { + return fmt.Errorf("failed to read active profile state: %w", err) + } + if !found { + return nil + } + + p, err := s.db.UserDB.GetProfile(profileID) + if err != nil { + if errors.Is(err, ErrNotFound) { + log.Warn().Str("profileId", profileID). + Msg("persisted active profile no longer exists, clearing") + if delErr := s.db.UserDB.DeleteDeviceState(database.DeviceStateKeyActiveProfile); delErr != nil { + return fmt.Errorf("failed to clear dangling active profile state: %w", delErr) + } + return nil + } + return fmt.Errorf("failed to get persisted active profile: %w", err) + } + + s.st.SetActiveProfile(snapshot(p)) + log.Info().Str("profileId", p.ProfileID).Str("name", p.Name). + Msg("restored active profile") + return nil +} + +func (s *Service) activate(p *database.Profile) (*models.ActiveProfile, error) { + if err := s.db.UserDB.SetDeviceState(database.DeviceStateKeyActiveProfile, p.ProfileID); err != nil { + return nil, fmt.Errorf("failed to persist active profile: %w", err) + } + snap := snapshot(p) + s.st.SetActiveProfile(snap) + log.Info().Str("profileId", p.ProfileID).Str("name", p.Name). + Msg("switched active profile") + return snap, nil +} + +// checkPIN enforces a profile's PIN with per-profile rate limiting. A +// profile without a PIN always passes. +func (s *Service) checkPIN(p *database.Profile, pin string) error { + if p.PINHash == "" { + return nil + } + if pin == "" { + return ErrPINRequired + } + + s.mu.Lock() + now := s.now() + attempts := s.pinAttempts[p.ProfileID] + recent := attempts[:0] + for _, at := range attempts { + if now.Sub(at) < pinAttemptWindow { + recent = append(recent, at) + } + } + if len(recent) >= pinAttemptLimit { + s.pinAttempts[p.ProfileID] = recent + s.mu.Unlock() + return ErrPINRateLimited + } + s.mu.Unlock() + + if !VerifyPIN(pin, p.PINHash) { + s.mu.Lock() + s.pinAttempts[p.ProfileID] = append(recent, now) + s.mu.Unlock() + return ErrPINIncorrect + } + + s.mu.Lock() + delete(s.pinAttempts, p.ProfileID) + s.mu.Unlock() + return nil +} + +// insertWithSwitchID inserts a profile, generating a fresh switch ID and +// retrying on the unlikely unique-constraint collision. +func (s *Service) insertWithSwitchID(p *database.Profile) error { + for range switchIDRetries { + switchID, err := GenerateSwitchID() + if err != nil { + return err + } + p.SwitchID = switchID + err = s.db.UserDB.CreateProfile(p) + if err == nil { + return nil + } + if !isSwitchIDConflict(err) { + return fmt.Errorf("failed to create profile: %w", err) + } + } + return errors.New("failed to generate a unique switch ID") +} + +// updateWithNewSwitchID updates a profile with a regenerated switch ID, +// retrying on collision. +func (s *Service) updateWithNewSwitchID(p *database.Profile) error { + for range switchIDRetries { + switchID, err := GenerateSwitchID() + if err != nil { + return err + } + p.SwitchID = switchID + err = s.db.UserDB.UpdateProfile(p) + if err == nil { + return nil + } + if !isSwitchIDConflict(err) { + return fmt.Errorf("failed to update profile: %w", err) + } + } + return errors.New("failed to generate a unique switch ID") +} + +// isSwitchIDConflict detects a unique-constraint violation on SwitchID. +// SQLite reports these as "UNIQUE constraint failed: Profiles.SwitchID". +func isSwitchIDConflict(err error) bool { + return err != nil && + strings.Contains(err.Error(), "UNIQUE") && + strings.Contains(err.Error(), "SwitchID") +} + +// snapshot builds the in-memory active-profile snapshot from a profile +// row. It carries resolved limit overrides so the playtime hot path never +// touches the database. +func snapshot(p *database.Profile) *models.ActiveProfile { + return &models.ActiveProfile{ + ProfileID: p.ProfileID, + Name: p.Name, + HasPIN: p.PINHash != "", + LimitsEnabled: p.LimitsEnabled, + DailyLimit: p.DailyLimit, + SessionLimit: p.SessionLimit, + } +} + +// validateLimitDurations rejects unparseable limit duration strings. An +// empty string is allowed (it means "clear to inherit" on update). +func validateLimitDurations(durations ...*string) error { + for _, d := range durations { + if d == nil || *d == "" { + continue + } + if _, err := time.ParseDuration(*d); err != nil { + return fmt.Errorf("invalid limit duration %q: %w", *d, err) + } + } + return nil +} diff --git a/pkg/service/profiles/profiles_test.go b/pkg/service/profiles/profiles_test.go new file mode 100644 index 000000000..6df9ca917 --- /dev/null +++ b/pkg/service/profiles/profiles_test.go @@ -0,0 +1,287 @@ +// Zaparoo Core +// Copyright (c) 2026 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package profiles + +import ( + "strings" + "testing" + "time" + + "github.com/ZaparooProject/zaparoo-core/v2/pkg/api/models" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/database" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/database/userdb" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/state" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/testing/helpers" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func newTestService(t *testing.T) (svc *Service, mockDB *helpers.MockUserDBI, st *state.State) { + t.Helper() + mockDB = helpers.NewMockUserDBI() + st, ns := state.NewState(nil, "boot") + t.Cleanup(func() { + for { + select { + case <-ns: + default: + return + } + } + }) + svc = NewService(&database.Database{UserDB: mockDB, MediaDB: nil}, st) + return svc, mockDB, st +} + +func pinProfile(t *testing.T, pin string) *database.Profile { + t.Helper() + p := &database.Profile{ + ProfileID: "profile-1", + Name: "Kid A", + SwitchID: "corn-arm-truck", + } + if pin != "" { + hash, err := HashPIN(pin) + require.NoError(t, err) + p.PINHash = hash + } + return p +} + +func TestActivateByID_NoPIN(t *testing.T) { + t.Parallel() + svc, mockDB, st := newTestService(t) + + mockDB.On("GetProfile", "profile-1").Return(pinProfile(t, ""), nil) + mockDB.On("SetDeviceState", database.DeviceStateKeyActiveProfile, "profile-1").Return(nil) + + snap, err := svc.ActivateByID("profile-1", "") + require.NoError(t, err) + assert.Equal(t, "Kid A", snap.Name) + assert.False(t, snap.HasPIN) + + active := st.ActiveProfile() + require.NotNil(t, active) + assert.Equal(t, "profile-1", active.ProfileID) + mockDB.AssertExpectations(t) +} + +func TestActivateByID_PINEnforced(t *testing.T) { + t.Parallel() + svc, mockDB, st := newTestService(t) + + mockDB.On("GetProfile", "profile-1").Return(pinProfile(t, "1234"), nil) + + _, err := svc.ActivateByID("profile-1", "") + require.ErrorIs(t, err, ErrPINRequired) + + _, err = svc.ActivateByID("profile-1", "9999") + require.ErrorIs(t, err, ErrPINIncorrect) + + assert.Nil(t, st.ActiveProfile()) + + mockDB.On("SetDeviceState", database.DeviceStateKeyActiveProfile, "profile-1").Return(nil) + snap, err := svc.ActivateByID("profile-1", "1234") + require.NoError(t, err) + assert.True(t, snap.HasPIN) + require.NotNil(t, st.ActiveProfile()) +} + +func TestActivateByID_RateLimited(t *testing.T) { + t.Parallel() + svc, mockDB, _ := newTestService(t) + + now := time.Date(2026, 6, 12, 12, 0, 0, 0, time.UTC) + svc.now = func() time.Time { return now } + + mockDB.On("GetProfile", "profile-1").Return(pinProfile(t, "1234"), nil) + + for range pinAttemptLimit { + _, err := svc.ActivateByID("profile-1", "9999") + require.ErrorIs(t, err, ErrPINIncorrect) + } + + // Even the correct PIN is rejected while rate limited. + _, err := svc.ActivateByID("profile-1", "1234") + require.ErrorIs(t, err, ErrPINRateLimited) + + // After the window passes, attempts work again. + now = now.Add(pinAttemptWindow + time.Second) + mockDB.On("SetDeviceState", database.DeviceStateKeyActiveProfile, "profile-1").Return(nil) + _, err = svc.ActivateByID("profile-1", "1234") + require.NoError(t, err) +} + +func TestActivateBySwitchID_BypassesPIN(t *testing.T) { + t.Parallel() + svc, mockDB, st := newTestService(t) + + mockDB.On("GetProfileBySwitchID", "corn-arm-truck").Return(pinProfile(t, "1234"), nil) + mockDB.On("SetDeviceState", database.DeviceStateKeyActiveProfile, "profile-1").Return(nil) + + snap, err := svc.ActivateBySwitchID("corn-arm-truck") + require.NoError(t, err) + assert.Equal(t, "profile-1", snap.ProfileID) + require.NotNil(t, st.ActiveProfile()) +} + +func TestActivateBySwitchIDChecked_EnforcesPIN(t *testing.T) { + t.Parallel() + svc, mockDB, _ := newTestService(t) + + mockDB.On("GetProfileBySwitchID", "corn-arm-truck").Return(pinProfile(t, "1234"), nil) + + _, err := svc.ActivateBySwitchIDChecked("corn-arm-truck", "") + require.ErrorIs(t, err, ErrPINRequired) +} + +func TestDeactivate(t *testing.T) { + t.Parallel() + svc, mockDB, st := newTestService(t) + + st.SetActiveProfile(&models.ActiveProfile{ProfileID: "profile-1", Name: "Kid A"}) + mockDB.On("DeleteDeviceState", database.DeviceStateKeyActiveProfile).Return(nil) + + require.NoError(t, svc.Deactivate()) + assert.Nil(t, st.ActiveProfile()) +} + +func TestRestoreOnBoot_Restores(t *testing.T) { + t.Parallel() + svc, mockDB, st := newTestService(t) + + mockDB.On("GetDeviceState", database.DeviceStateKeyActiveProfile).Return("profile-1", true, nil) + mockDB.On("GetProfile", "profile-1").Return(pinProfile(t, ""), nil) + + require.NoError(t, svc.RestoreOnBoot()) + active := st.ActiveProfile() + require.NotNil(t, active) + assert.Equal(t, "profile-1", active.ProfileID) +} + +func TestRestoreOnBoot_NothingPersisted(t *testing.T) { + t.Parallel() + svc, mockDB, st := newTestService(t) + + mockDB.On("GetDeviceState", database.DeviceStateKeyActiveProfile).Return("", false, nil) + + require.NoError(t, svc.RestoreOnBoot()) + assert.Nil(t, st.ActiveProfile()) +} + +func TestRestoreOnBoot_DanglingCleared(t *testing.T) { + t.Parallel() + svc, mockDB, st := newTestService(t) + + mockDB.On("GetDeviceState", database.DeviceStateKeyActiveProfile).Return("deleted-profile", true, nil) + mockDB.On("GetProfile", "deleted-profile").Return(nil, userdb.ErrProfileNotFound) + mockDB.On("DeleteDeviceState", database.DeviceStateKeyActiveProfile).Return(nil) + + require.NoError(t, svc.RestoreOnBoot()) + assert.Nil(t, st.ActiveProfile()) + mockDB.AssertExpectations(t) +} + +func TestCreate_GeneratesSwitchIDAndHashesPIN(t *testing.T) { + t.Parallel() + svc, mockDB, _ := newTestService(t) + + pin := "1234" + mockDB.On("CreateProfile", mock.MatchedBy(func(p *database.Profile) bool { + return p.ProfileID != "" && + len(strings.Split(p.SwitchID, "-")) == switchIDWords && + strings.HasPrefix(p.PINHash, "pbkdf2-sha256$") + })).Return(nil) + + p, err := svc.Create(&models.NewProfileParams{Name: "Kid A", PIN: &pin}) + require.NoError(t, err) + assert.Equal(t, "Kid A", p.Name) + assert.True(t, VerifyPIN("1234", p.PINHash)) + mockDB.AssertExpectations(t) +} + +func TestCreate_RejectsBadDuration(t *testing.T) { + t.Parallel() + svc, _, _ := newTestService(t) + + bad := "2 hours" + _, err := svc.Create(&models.NewProfileParams{Name: "Kid A", DailyLimit: &bad}) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid limit duration") +} + +func TestUpdate_RefreshesActiveSnapshot(t *testing.T) { + t.Parallel() + svc, mockDB, st := newTestService(t) + + existing := pinProfile(t, "") + st.SetActiveProfile(&models.ActiveProfile{ProfileID: existing.ProfileID, Name: existing.Name}) + + mockDB.On("GetProfile", "profile-1").Return(existing, nil) + mockDB.On("UpdateProfile", mock.Anything).Return(nil) + + daily := "2h" + enabled := true + _, err := svc.Update(&models.UpdateProfileParams{ + ProfileID: "profile-1", + DailyLimit: &daily, + LimitsEnabled: &enabled, + }) + require.NoError(t, err) + + active := st.ActiveProfile() + require.NotNil(t, active) + require.NotNil(t, active.DailyLimit) + assert.Equal(t, "2h", *active.DailyLimit) + require.NotNil(t, active.LimitsEnabled) + assert.True(t, *active.LimitsEnabled) +} + +func TestUpdate_ClearLimits(t *testing.T) { + t.Parallel() + svc, mockDB, _ := newTestService(t) + + existing := pinProfile(t, "") + enabled := true + daily := "1h" + existing.LimitsEnabled = &enabled + existing.DailyLimit = &daily + + mockDB.On("GetProfile", "profile-1").Return(existing, nil) + mockDB.On("UpdateProfile", mock.MatchedBy(func(p *database.Profile) bool { + return p.LimitsEnabled == nil && p.DailyLimit == nil && p.SessionLimit == nil + })).Return(nil) + + _, err := svc.Update(&models.UpdateProfileParams{ProfileID: "profile-1", ClearLimits: true}) + require.NoError(t, err) + mockDB.AssertExpectations(t) +} + +func TestDelete_ActiveProfileDeactivates(t *testing.T) { + t.Parallel() + svc, mockDB, st := newTestService(t) + + st.SetActiveProfile(&models.ActiveProfile{ProfileID: "profile-1", Name: "Kid A"}) + mockDB.On("DeleteProfile", "profile-1").Return(nil) + + require.NoError(t, svc.Delete("profile-1")) + assert.Nil(t, st.ActiveProfile()) +} diff --git a/pkg/service/profiles/resolver.go b/pkg/service/profiles/resolver.go new file mode 100644 index 000000000..794e8e49d --- /dev/null +++ b/pkg/service/profiles/resolver.go @@ -0,0 +1,103 @@ +// Zaparoo Core +// Copyright (c) 2026 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package profiles + +import ( + "time" + + "github.com/ZaparooProject/zaparoo-core/v2/pkg/config" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/state" +) + +// LimitsResolver layers the active profile's playtime limit overrides over +// the global config. It satisfies the playtime.LimitsProvider interface. +// Reads come from the in-memory active-profile snapshot, never the +// database, so it is safe on the limit-check hot path. +// +// An explicit (non-nil) profile field wins; a nil field inherits the +// global config value. A "0" duration string means explicitly unlimited. +// Warning intervals are device UX, not per-person policy, and always come +// from global config. +type LimitsResolver struct { + cfg *config.Instance + st *state.State +} + +// NewLimitsResolver creates a resolver over the global config and the +// service state holding the active profile. +func NewLimitsResolver(cfg *config.Instance, st *state.State) *LimitsResolver { + return &LimitsResolver{cfg: cfg, st: st} +} + +// PlaytimeLimitsEnabled returns the active profile's enabled override, or +// the global config value when no profile is active or it has no override. +func (r *LimitsResolver) PlaytimeLimitsEnabled() bool { + if p := r.st.ActiveProfile(); p != nil && p.LimitsEnabled != nil { + return *p.LimitsEnabled + } + return r.cfg.PlaytimeLimitsEnabled() +} + +// DailyLimit returns the active profile's daily limit override, or the +// global config value. Returns 0 for "no limit". +func (r *LimitsResolver) DailyLimit() time.Duration { + if p := r.st.ActiveProfile(); p != nil && p.DailyLimit != nil { + return parseLimit(*p.DailyLimit) + } + return r.cfg.DailyLimit() +} + +// SessionLimit returns the active profile's session limit override, or the +// global config value. Returns 0 for "no limit". +func (r *LimitsResolver) SessionLimit() time.Duration { + if p := r.st.ActiveProfile(); p != nil && p.SessionLimit != nil { + return parseLimit(*p.SessionLimit) + } + return r.cfg.SessionLimit() +} + +// WarningIntervals always returns the global config warning intervals. +func (r *LimitsResolver) WarningIntervals() []time.Duration { + return r.cfg.WarningIntervals() +} + +// ActiveProfileID returns the active profile's ID, or "" when no profile +// is active. The playtime limits manager uses this to scope daily usage +// accounting to the active profile's attributed history. +func (r *LimitsResolver) ActiveProfileID() string { + if p := r.st.ActiveProfile(); p != nil { + return p.ProfileID + } + return "" +} + +// parseLimit parses a stored limit duration string. Strings are validated +// at write time; an unparseable value degrades to 0 (no limit), matching +// the config accessors' behavior. +func parseLimit(s string) time.Duration { + if s == "" { + return 0 + } + d, err := time.ParseDuration(s) + if err != nil { + return 0 + } + return d +} diff --git a/pkg/service/profiles/resolver_test.go b/pkg/service/profiles/resolver_test.go new file mode 100644 index 000000000..920710ac6 --- /dev/null +++ b/pkg/service/profiles/resolver_test.go @@ -0,0 +1,136 @@ +// Zaparoo Core +// Copyright (c) 2026 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package profiles + +import ( + "testing" + "time" + + "github.com/ZaparooProject/zaparoo-core/v2/pkg/api/models" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/config" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/state" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/testing/helpers" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newResolver(t *testing.T) (resolver *LimitsResolver, cfg *config.Instance, st *state.State) { + t.Helper() + fs := helpers.NewMemoryFS() + cfg, err := helpers.NewTestConfig(fs, t.TempDir()) + require.NoError(t, err) + st, ns := state.NewState(nil, "boot") + t.Cleanup(func() { + for { + select { + case <-ns: + default: + return + } + } + }) + return NewLimitsResolver(cfg, st), cfg, st +} + +func strPtr(s string) *string { return &s } + +func boolPtr(b bool) *bool { return &b } + +func TestLimitsResolver_NoProfileInheritsGlobal(t *testing.T) { + t.Parallel() + resolver, cfg, _ := newResolver(t) + + cfg.SetPlaytimeLimitsEnabled(true) + require.NoError(t, cfg.SetDailyLimit("2h")) + require.NoError(t, cfg.SetSessionLimit("45m")) + + assert.True(t, resolver.PlaytimeLimitsEnabled()) + assert.Equal(t, 2*time.Hour, resolver.DailyLimit()) + assert.Equal(t, 45*time.Minute, resolver.SessionLimit()) + assert.Empty(t, resolver.ActiveProfileID()) +} + +func TestLimitsResolver_ProfileOverrides(t *testing.T) { + t.Parallel() + resolver, cfg, st := newResolver(t) + + cfg.SetPlaytimeLimitsEnabled(false) + require.NoError(t, cfg.SetDailyLimit("8h")) + + st.SetActiveProfile(&models.ActiveProfile{ + ProfileID: "kid-a", + Name: "Kid A", + LimitsEnabled: boolPtr(true), + DailyLimit: strPtr("1h30m"), + SessionLimit: strPtr("30m"), + }) + + assert.True(t, resolver.PlaytimeLimitsEnabled()) + assert.Equal(t, 90*time.Minute, resolver.DailyLimit()) + assert.Equal(t, 30*time.Minute, resolver.SessionLimit()) + assert.Equal(t, "kid-a", resolver.ActiveProfileID()) +} + +func TestLimitsResolver_PartialOverridesInheritRest(t *testing.T) { + t.Parallel() + resolver, cfg, st := newResolver(t) + + cfg.SetPlaytimeLimitsEnabled(true) + require.NoError(t, cfg.SetDailyLimit("4h")) + require.NoError(t, cfg.SetSessionLimit("1h")) + + // Profile overrides only the daily limit; everything else inherits. + st.SetActiveProfile(&models.ActiveProfile{ + ProfileID: "kid-b", + Name: "Kid B", + DailyLimit: strPtr("2h"), + }) + + assert.True(t, resolver.PlaytimeLimitsEnabled()) + assert.Equal(t, 2*time.Hour, resolver.DailyLimit()) + assert.Equal(t, time.Hour, resolver.SessionLimit()) +} + +func TestLimitsResolver_ZeroMeansUnlimited(t *testing.T) { + t.Parallel() + resolver, cfg, st := newResolver(t) + + cfg.SetPlaytimeLimitsEnabled(true) + require.NoError(t, cfg.SetDailyLimit("2h")) + + // "0" is an explicit override to unlimited, distinct from nil/inherit. + st.SetActiveProfile(&models.ActiveProfile{ + ProfileID: "parent", + Name: "Parent", + DailyLimit: strPtr("0"), + }) + + assert.Equal(t, time.Duration(0), resolver.DailyLimit()) +} + +func TestLimitsResolver_WarningsAlwaysGlobal(t *testing.T) { + t.Parallel() + resolver, cfg, st := newResolver(t) + + require.NoError(t, cfg.SetWarningIntervals([]string{"10m", "1m"})) + st.SetActiveProfile(&models.ActiveProfile{ProfileID: "kid-a", Name: "Kid A"}) + + assert.Equal(t, []time.Duration{10 * time.Minute, time.Minute}, resolver.WarningIntervals()) +} diff --git a/pkg/service/profiles/switchid.go b/pkg/service/profiles/switchid.go new file mode 100644 index 000000000..af07b7543 --- /dev/null +++ b/pkg/service/profiles/switchid.go @@ -0,0 +1,59 @@ +// Zaparoo Core +// Copyright (c) 2026 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package profiles + +import ( + "crypto/rand" + _ "embed" + "fmt" + "math/big" + "strings" +) + +// wordlistRaw is derived from the EFF short wordlist #1 +// (https://www.eff.org/dice, CC-BY 3.0), with hyphenated words removed so +// every word is a plain lowercase token. Switch IDs are selectors, not +// secrets — the list only needs to be large enough to avoid accidental +// collisions and easy to read aloud or print on a card. +// +//go:embed wordlist.txt +var wordlistRaw string + +// switchIDWords is the number of words in a generated switch ID. +const switchIDWords = 3 + +//nolint:gochecknoglobals // immutable parsed copy of the embedded wordlist +var wordlist = strings.Fields(wordlistRaw) + +// GenerateSwitchID returns a new random word-phrase switch ID, e.g. +// "corn-arm-truck". Uniqueness is enforced by the database; callers should +// retry on a unique-constraint conflict. +func GenerateSwitchID() (string, error) { + parts := make([]string, switchIDWords) + maxIndex := big.NewInt(int64(len(wordlist))) + for i := range parts { + n, err := rand.Int(rand.Reader, maxIndex) + if err != nil { + return "", fmt.Errorf("failed to generate switch ID word: %w", err) + } + parts[i] = wordlist[n.Int64()] + } + return strings.Join(parts, "-"), nil +} diff --git a/pkg/service/profiles/switchid_test.go b/pkg/service/profiles/switchid_test.go new file mode 100644 index 000000000..d283aa7a9 --- /dev/null +++ b/pkg/service/profiles/switchid_test.go @@ -0,0 +1,53 @@ +// Zaparoo Core +// Copyright (c) 2026 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package profiles + +import ( + "regexp" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestWordlist_Loaded(t *testing.T) { + t.Parallel() + + assert.Len(t, wordlist, 1295) + wordRe := regexp.MustCompile(`^[a-z]+$`) + for _, w := range wordlist { + assert.True(t, wordRe.MatchString(w), "word %q must be plain lowercase", w) + } +} + +func TestGenerateSwitchID_Format(t *testing.T) { + t.Parallel() + + for range 50 { + id, err := GenerateSwitchID() + require.NoError(t, err) + parts := strings.Split(id, "-") + require.Len(t, parts, switchIDWords, "switch ID %q", id) + for _, part := range parts { + assert.Contains(t, wordlist, part) + } + } +} diff --git a/pkg/service/profiles/wordlist.txt b/pkg/service/profiles/wordlist.txt new file mode 100644 index 000000000..028ca87e7 --- /dev/null +++ b/pkg/service/profiles/wordlist.txt @@ -0,0 +1,1295 @@ +acid +acorn +acre +acts +afar +affix +aged +agent +agile +aging +agony +ahead +aide +aids +aim +ajar +alarm +alias +alibi +alien +alike +alive +aloe +aloft +aloha +alone +amend +amino +ample +amuse +angel +anger +angle +ankle +apple +april +apron +aqua +area +arena +argue +arise +armed +armor +army +aroma +array +arson +art +ashen +ashes +atlas +atom +attic +audio +avert +avoid +awake +award +awoke +axis +bacon +badge +bagel +baggy +baked +baker +balmy +banjo +barge +barn +bash +basil +bask +batch +bath +baton +bats +blade +blank +blast +blaze +bleak +blend +bless +blimp +blink +bloat +blob +blog +blot +blunt +blurt +blush +boast +boat +body +boil +bok +bolt +boned +boney +bonus +bony +book +booth +boots +boss +botch +both +boxer +breed +bribe +brick +bride +brim +bring +brink +brisk +broad +broil +broke +brook +broom +brush +buck +bud +buggy +bulge +bulk +bully +bunch +bunny +bunt +bush +bust +busy +buzz +cable +cache +cadet +cage +cake +calm +cameo +canal +candy +cane +canon +cape +card +cargo +carol +carry +carve +case +cash +cause +cedar +chain +chair +chant +chaos +charm +chase +cheek +cheer +chef +chess +chest +chew +chief +chili +chill +chip +chomp +chop +chow +chuck +chump +chunk +churn +chute +cider +cinch +city +civic +civil +clad +claim +clamp +clap +clash +clasp +class +claw +clay +clean +clear +cleat +cleft +clerk +click +cling +clink +clip +cloak +clock +clone +cloth +cloud +clump +coach +coast +coat +cod +coil +coke +cola +cold +colt +coma +come +comic +comma +cone +cope +copy +coral +cork +cost +cot +couch +cough +cover +cozy +craft +cramp +crane +crank +crate +crave +crawl +crazy +creme +crepe +crept +crib +cried +crisp +crook +crop +cross +crowd +crown +crumb +crush +crust +cub +cult +cupid +cure +curl +curry +curse +curve +curvy +cushy +cut +cycle +dab +dad +daily +dairy +daisy +dance +dandy +darn +dart +dash +data +date +dawn +deaf +deal +dean +debit +debt +debug +decaf +decal +decay +deck +decor +decoy +deed +delay +denim +dense +dent +depth +derby +desk +dial +diary +dice +dig +dill +dime +dimly +diner +dingy +disco +dish +disk +ditch +ditzy +dizzy +dock +dodge +doing +doll +dome +donor +donut +dose +dot +dove +down +dowry +doze +drab +drama +drank +draw +dress +dried +drift +drill +drive +drone +droop +drove +drown +drum +dry +duck +duct +dude +dug +duke +duo +dusk +dust +duty +dwarf +dwell +eagle +early +earth +easel +east +eaten +eats +ebay +ebony +ebook +echo +edge +eel +eject +elbow +elder +elf +elk +elm +elope +elude +elves +email +emit +empty +emu +enter +entry +envoy +equal +erase +error +erupt +essay +etch +evade +even +evict +evil +evoke +exact +exit +fable +faced +fact +fade +fall +false +fancy +fang +fax +feast +feed +femur +fence +fend +ferry +fetal +fetch +fever +fiber +fifth +fifty +film +filth +final +finch +fit +five +flag +flaky +flame +flap +flask +fled +flick +fling +flint +flip +flirt +float +flock +flop +floss +flyer +foam +foe +fog +foil +folic +folk +food +fool +found +fox +foyer +frail +frame +fray +fresh +fried +frill +frisk +from +front +frost +froth +frown +froze +fruit +gag +gains +gala +game +gap +gas +gave +gear +gecko +geek +gem +genre +gift +gig +gills +given +giver +glad +glass +glide +gloss +glove +glow +glue +goal +going +golf +gong +good +gooey +goofy +gore +gown +grab +grain +grant +grape +graph +grasp +grass +grave +gravy +gray +green +greet +grew +grid +grief +grill +grip +grit +groom +grope +growl +grub +grunt +guide +gulf +gulp +gummy +guru +gush +gut +guy +habit +half +halo +halt +happy +harm +hash +hasty +hatch +hate +haven +hazel +hazy +heap +heat +heave +hedge +hefty +help +herbs +hers +hub +hug +hula +hull +human +humid +hump +hung +hunk +hunt +hurry +hurt +hush +hut +ice +icing +icon +icy +igloo +image +ion +iron +islam +issue +item +ivory +ivy +jab +jam +jaws +jazz +jeep +jelly +jet +jiffy +job +jog +jolly +jolt +jot +joy +judge +juice +juicy +july +jumbo +jump +junky +juror +jury +keep +keg +kept +kick +kilt +king +kite +kitty +kiwi +knee +knelt +koala +kung +ladle +lady +lair +lake +lance +land +lapel +large +lash +lasso +last +latch +late +lazy +left +legal +lemon +lend +lens +lent +level +lever +lid +life +lift +lilac +lily +limb +limes +line +lint +lion +lip +list +lived +liver +lunar +lunch +lung +lurch +lure +lurk +lying +lyric +mace +maker +malt +mama +mango +manor +many +map +march +mardi +marry +mash +match +mate +math +moan +mocha +moist +mold +mom +moody +mop +morse +most +motor +motto +mount +mouse +mousy +mouth +move +movie +mower +mud +mug +mulch +mule +mull +mumbo +mummy +mural +muse +music +musky +mute +nacho +nag +nail +name +nanny +nap +navy +near +neat +neon +nerd +nest +net +next +niece +ninth +nutty +oak +oasis +oat +ocean +oil +old +olive +omen +onion +only +ooze +opal +open +opera +opt +otter +ouch +ounce +outer +oval +oven +owl +ozone +pace +pagan +pager +palm +panda +panic +pants +panty +paper +park +party +pasta +patch +path +patio +payer +pecan +penny +pep +perch +perky +perm +pest +petal +petri +petty +photo +plank +plant +plaza +plead +plot +plow +pluck +plug +plus +poach +pod +poem +poet +pogo +point +poise +poker +polar +polio +polka +polo +pond +pony +poppy +pork +poser +pouch +pound +pout +power +prank +press +print +prior +prism +prize +probe +prong +proof +props +prude +prune +pry +pug +pull +pulp +pulse +puma +punch +punk +pupil +puppy +purr +purse +push +putt +quack +quake +query +quiet +quill +quilt +quit +quota +quote +rabid +race +rack +radar +radio +raft +rage +raid +rail +rake +rally +ramp +ranch +range +rank +rant +rash +raven +reach +react +ream +rebel +recap +relax +relay +relic +remix +repay +repel +reply +rerun +reset +rhyme +rice +rich +ride +rigid +rigor +rinse +riot +ripen +rise +risk +ritzy +rival +river +roast +robe +robin +rock +rogue +roman +romp +rope +rover +royal +ruby +rug +ruin +rule +runny +rush +rust +rut +sadly +sage +said +saint +salad +salon +salsa +salt +same +sandy +santa +satin +sauna +saved +savor +sax +say +scale +scam +scan +scare +scarf +scary +scoff +scold +scoop +scoot +scope +score +scorn +scout +scowl +scrap +scrub +scuba +scuff +sect +sedan +self +send +sepia +serve +set +seven +shack +shade +shady +shaft +shaky +sham +shape +share +sharp +shed +sheep +sheet +shelf +shell +shine +shiny +ship +shirt +shock +shop +shore +shout +shove +shown +showy +shred +shrug +shun +shush +shut +shy +sift +silk +silly +silo +sip +siren +sixth +size +skate +skew +skid +skier +skies +skip +skirt +skit +sky +slab +slack +slain +slam +slang +slash +slate +slaw +sled +sleek +sleep +sleet +slept +slice +slick +slimy +sling +slip +slit +slob +slot +slug +slum +slurp +slush +small +smash +smell +smile +smirk +smog +snack +snap +snare +snarl +sneak +sneer +sniff +snore +snort +snout +snowy +snub +snuff +speak +speed +spend +spent +spew +spied +spill +spiny +spoil +spoke +spoof +spool +spoon +sport +spot +spout +spray +spree +spur +squad +squat +squid +stack +staff +stage +stain +stall +stamp +stand +stank +stark +start +stash +state +stays +steam +steep +stem +step +stew +stick +sting +stir +stock +stole +stomp +stony +stood +stool +stoop +stop +storm +stout +stove +straw +stray +strut +stuck +stud +stuff +stump +stung +stunt +suds +sugar +sulk +surf +sushi +swab +swan +swarm +sway +swear +sweat +sweep +swell +swept +swim +swing +swipe +swirl +swoop +swore +syrup +tacky +taco +tag +take +tall +talon +tamer +tank +taper +taps +tarot +tart +task +taste +tasty +taunt +thank +thaw +theft +theme +thigh +thing +think +thong +thorn +those +throb +thud +thumb +thump +thus +tiara +tidal +tidy +tiger +tile +tilt +tint +tiny +trace +track +trade +train +trait +trap +trash +tray +treat +tree +trek +trend +trial +tribe +trick +trio +trout +truce +truck +trump +trunk +try +tug +tulip +tummy +turf +tusk +tutor +tutu +tux +tweak +tweet +twice +twine +twins +twirl +twist +uncle +uncut +undo +unify +union +unit +untie +upon +upper +urban +used +user +usher +utter +value +vapor +vegan +venue +verse +vest +veto +vice +video +view +viral +virus +visa +visor +vixen +vocal +voice +void +volt +voter +vowel +wad +wafer +wager +wages +wagon +wake +walk +wand +wasp +watch +water +wavy +wheat +whiff +whole +whoop +wick +widen +widow +width +wife +wifi +wilt +wimp +wind +wing +wink +wipe +wired +wiry +wise +wish +wispy +wok +wolf +womb +wool +woozy +word +work +worry +wound +woven +wrath +wreck +wrist +xerox +yahoo +yam +yard +year +yeast +yelp +yield +yodel +yoga +yoyo +yummy +zebra +zero +zesty +zippy +zone +zoom diff --git a/pkg/service/queues.go b/pkg/service/queues.go index 6c41cd273..b1f11a93a 100644 --- a/pkg/service/queues.go +++ b/pkg/service/queues.go @@ -233,6 +233,12 @@ func runTokenZapScript( } } + if result.ProfileSwitch != nil { + if profileErr := applyProfileSwitch(svc, result.ProfileSwitch); profileErr != nil { + return profileErr + } + } + if result.Unsafe { log.Warn().Msg("token has been flagged as unsafe") token.Unsafe = true @@ -249,6 +255,25 @@ func runTokenZapScript( return nil } +// applyProfileSwitch applies a profile switch requested by a ZapScript +// command. This is the physical-scan path, so activation bypasses any +// profile PIN — possession of the card is the authorization. +func applyProfileSwitch(svc *ServiceContext, req *platforms.ProfileSwitchRequest) error { + if svc.Profiles == nil { + return errors.New("profiles service not available") + } + if req.Clear { + if err := svc.Profiles.Deactivate(); err != nil { + return fmt.Errorf("failed to clear active profile: %w", err) + } + return nil + } + if _, err := svc.Profiles.ActivateBySwitchID(req.SwitchID); err != nil { + return fmt.Errorf("failed to switch profile: %w", err) + } + return nil +} + func stopNativePlaybackBeforePrimaryCommand( svc *ServiceContext, cmd gozapscript.Command, @@ -587,6 +612,24 @@ func processTokenQueue( // Check if any command in the script launches media hasMediaLaunchCmd := parseErr == nil && scriptHasMediaLaunchingCommand(&script) + // When require_for_launch is enabled, media launches are blocked + // until a profile is active (profile switch commands still run — + // scanning a profile card is how the device gets unparked). + if hasMediaLaunchCmd && svc.Config.ProfilesRequireForLaunch() && svc.State.ActiveProfile() == nil { + log.Warn().Msg("profiles: launch blocked, no active profile and require_for_launch is set") + + path, enabled := svc.Config.FailSoundPath(helpers.DataDir(svc.Platform)) + helpers.PlayConfiguredSound(player, path, enabled, assets.FailSound, "fail") + + he.Success = false + if histErr := svc.DB.UserDB.AddHistory(&he); histErr != nil { + log.Error().Err(histErr).Msgf("error adding history") + } + + // Skip launch + continue + } + // Only check playtime limits if the script contains media-launching commands if hasMediaLaunchCmd { if limitErr := limitsManager.CheckBeforeLaunch(); limitErr != nil { diff --git a/pkg/service/require_profile_test.go b/pkg/service/require_profile_test.go new file mode 100644 index 000000000..0fd4067bf --- /dev/null +++ b/pkg/service/require_profile_test.go @@ -0,0 +1,114 @@ +// Zaparoo Core +// Copyright (c) 2026 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package service + +import ( + "testing" + "time" + + "github.com/ZaparooProject/zaparoo-core/v2/pkg/api/models" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/database" + "github.com/stretchr/testify/require" +) + +func TestScanBehavior_RequireProfile_BlocksLaunchWithoutProfile(t *testing.T) { + t.Parallel() + env := setupScanBehavior(t, "tap", 0) + + env.cfg.SetProfilesRequireForLaunch(true) + + env.sendGameScan("card1", env.gamePath("game1.gba")) + env.expectNoLaunch(t) +} + +func TestScanBehavior_RequireProfile_AllowsLaunchWithActiveProfile(t *testing.T) { + t.Parallel() + env := setupScanBehavior(t, "tap", 0) + + env.cfg.SetProfilesRequireForLaunch(true) + env.st.SetActiveProfile(&models.ActiveProfile{ProfileID: "profile-1", Name: "Dad"}) + + env.sendGameScan("card1", env.gamePath("game1.gba")) + env.waitForLaunch(t) +} + +func TestScanBehavior_RequireProfile_OffByDefault(t *testing.T) { + t.Parallel() + env := setupScanBehavior(t, "tap", 0) + + // No profile active and require_for_launch unset: launches work exactly + // as they did before profiles existed. + env.sendGameScan("card1", env.gamePath("game1.gba")) + env.waitForLaunch(t) +} + +// TestScanBehavior_ProfileSwitchCard covers the signature card interaction: +// scanning a **profile.switch token activates the profile with no PIN +// check, and **profile.clear deactivates it. +func TestScanBehavior_ProfileSwitchCard(t *testing.T) { + t.Parallel() + env := setupScanBehavior(t, "tap", 0) + + profile := &database.Profile{ + ProfileID: "profile-1", + Name: "Kid A", + SwitchID: "corn-arm-truck", + PINHash: "pbkdf2-sha256$1$AAAA$AAAA", // PIN set, but card scans bypass it + } + env.userDB.On("GetProfileBySwitchID", "corn-arm-truck").Return(profile, nil) + env.userDB.On("SetDeviceState", database.DeviceStateKeyActiveProfile, "profile-1").Return(nil) + env.userDB.On("DeleteDeviceState", database.DeviceStateKeyActiveProfile).Return(nil) + + env.sendCommandScan("switch-card", "**profile.switch:corn-arm-truck") + env.waitForActiveProfile(t, "profile-1") + + env.sendCommandScan("clear-card", "**profile.clear") + env.waitForNoActiveProfile(t) +} + +func (env *scanBehaviorEnv) waitForActiveProfile(t *testing.T, profileID string) { + t.Helper() + deadline := time.After(behaviorTimeout) + for { + if active := env.st.ActiveProfile(); active != nil && active.ProfileID == profileID { + return + } + select { + case <-deadline: + require.FailNow(t, "timed out waiting for active profile", "profileID=%s", profileID) + case <-time.After(time.Millisecond): + } + } +} + +func (env *scanBehaviorEnv) waitForNoActiveProfile(t *testing.T) { + t.Helper() + deadline := time.After(behaviorTimeout) + for { + if env.st.ActiveProfile() == nil { + return + } + select { + case <-deadline: + require.FailNow(t, "timed out waiting for profile deactivation") + case <-time.After(time.Millisecond): + } + } +} diff --git a/pkg/service/scan_behavior_test.go b/pkg/service/scan_behavior_test.go index e6805d3d8..8de636771 100644 --- a/pkg/service/scan_behavior_test.go +++ b/pkg/service/scan_behavior_test.go @@ -33,6 +33,7 @@ import ( "github.com/ZaparooProject/zaparoo-core/v2/pkg/readers" "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/playlists" "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/playtime" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/profiles" "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/state" "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/tokens" testhelpers "github.com/ZaparooProject/zaparoo-core/v2/pkg/testing/helpers" @@ -52,6 +53,7 @@ const ( type scanBehaviorEnv struct { st *state.State cfg *config.Instance + userDB *testhelpers.MockUserDBI scanQueue chan readers.Scan clock *clockwork.FakeClock launchCh chan string @@ -168,6 +170,7 @@ mode = "unrestricted"`)) Config: cfg, State: st, DB: db, + Profiles: profiles.NewService(db, st), LaunchSoftwareQueue: lsq, PlaylistQueue: plq, } @@ -199,6 +202,7 @@ mode = "unrestricted"`)) return &scanBehaviorEnv{ st: st, cfg: cfg, + userDB: mockUserDB, scanQueue: scanQueue, clock: fakeClock, romsDir: romsDir, diff --git a/pkg/service/service.go b/pkg/service/service.go index 10abda521..b41173e88 100644 --- a/pkg/service/service.go +++ b/pkg/service/service.go @@ -46,6 +46,7 @@ import ( "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/inbox" "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/playlists" "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/playtime" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/profiles" "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/state" "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/tokens" "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/updater" @@ -240,11 +241,21 @@ func Start( log.Info().Msg("initializing inbox service") st.SetInbox(inbox.NewService(db.UserDB, st.Notifications)) + // Initialize profiles and restore the persisted active profile before + // the limits manager starts, so limit checks see the right profile. + log.Info().Msg("initializing profiles service") + profilesSvc := profiles.NewService(db, st) + if restoreErr := profilesSvc.RestoreOnBoot(); restoreErr != nil { + log.Error().Err(restoreErr).Msg("error restoring active profile") + } + // Initialize playtime limits system (always create for runtime enable/disable) log.Info().Msg("initializing playtime limits") limitsManager := playtime.NewLimitsManager(db, pl, cfg, clockwork.NewRealClock(), player) + limitsResolver := profiles.NewLimitsResolver(cfg, st) + limitsManager.SetLimitsProvider(limitsResolver) limitsManager.Start(notifBroker, st.Notifications) - if cfg.PlaytimeLimitsEnabled() { + if limitsResolver.PlaytimeLimitsEnabled() { limitsManager.SetEnabled(true) } @@ -253,6 +264,7 @@ func Start( Config: cfg, State: st, DB: db, + Profiles: profilesSvc, PlaybackManager: playbackManager, LaunchSoftwareQueue: lsq, PlaylistQueue: plq, @@ -315,7 +327,7 @@ func Start( apiDone := make(chan error, 1) go func() { apiDone <- api.StartWithReady( - pl, cfg, st, itq, cfq, db, limitsManager, + pl, cfg, st, itq, cfq, db, limitsManager, profilesSvc, notifBroker, discoveryService.InstanceName(), player, playbackManager, indexPauser, scrapePauser, idleSched, apiReady, ) diff --git a/pkg/service/state/active_profile_test.go b/pkg/service/state/active_profile_test.go new file mode 100644 index 000000000..cc3dca44e --- /dev/null +++ b/pkg/service/state/active_profile_test.go @@ -0,0 +1,87 @@ +// Zaparoo Core +// Copyright (c) 2026 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package state_test + +import ( + "encoding/json" + "testing" + + "github.com/ZaparooProject/zaparoo-core/v2/pkg/api/models" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/state" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSetActiveProfile_StoresCopy(t *testing.T) { + t.Parallel() + + st, _ := state.NewState(nil, "boot") + assert.Nil(t, st.ActiveProfile()) + + profile := &models.ActiveProfile{ + ProfileID: "profile-1", + Name: "Kid A", + HasPIN: true, + } + st.SetActiveProfile(profile) + + // Mutating the caller's struct must not affect stored state. + profile.Name = "mutated" + + got := st.ActiveProfile() + require.NotNil(t, got) + assert.Equal(t, "Kid A", got.Name) + assert.Equal(t, "profile-1", got.ProfileID) + assert.True(t, got.HasPIN) + + // Mutating the returned copy must not affect stored state either. + got.Name = "also mutated" + assert.Equal(t, "Kid A", st.ActiveProfile().Name) +} + +func TestSetActiveProfile_Notifications(t *testing.T) { + t.Parallel() + + st, ns := state.NewState(nil, "boot") + + st.SetActiveProfile(&models.ActiveProfile{ProfileID: "profile-1", Name: "Kid A"}) + activated := <-ns + assert.Equal(t, models.NotificationProfilesActive, activated.Method) + var payload models.ProfilesActiveNotification + require.NoError(t, json.Unmarshal(activated.Params, &payload)) + require.NotNil(t, payload.Profile) + assert.Equal(t, "profile-1", payload.Profile.ProfileID) + + st.SetActiveProfile(nil) + deactivated := <-ns + assert.Equal(t, models.NotificationProfilesActive, deactivated.Method) + var nilPayload models.ProfilesActiveNotification + require.NoError(t, json.Unmarshal(deactivated.Params, &nilPayload)) + assert.Nil(t, nilPayload.Profile) + assert.Nil(t, st.ActiveProfile()) + + // Duplicate deactivation emits no second notification. + st.SetActiveProfile(nil) + select { + case n := <-ns: + t.Fatalf("unexpected notification: %s", n.Method) + default: + } +} diff --git a/pkg/service/state/state.go b/pkg/service/state/state.go index 1ee780a72..49c0a1412 100644 --- a/pkg/service/state/state.go +++ b/pkg/service/state/state.go @@ -69,6 +69,7 @@ type State struct { Notifications chan<- models.Notification activeMedia *models.ActiveMedia backgroundMedia *models.ActiveMedia + activeProfile *models.ActiveProfile activePlaylist *playlists.Playlist backgroundPlaylist *playlists.Playlist activeMediaReadyCh chan struct{} @@ -145,6 +146,49 @@ func (s *State) GetActiveCard() tokens.Token { return s.activeToken } +// SetActiveProfile sets or clears (nil) the device's active profile and +// broadcasts a profiles.active notification. The snapshot is stored by +// value internally so callers cannot mutate state through the pointer. +func (s *State) SetActiveProfile(profile *models.ActiveProfile) { + s.mu.Lock() + + if profile == nil && s.activeProfile == nil { + // ignore duplicate deactivations + s.mu.Unlock() + return + } + + var stored *models.ActiveProfile + if profile != nil { + profileCopy := *profile + stored = &profileCopy + } + s.activeProfile = stored + + // Prepare notification payload inside lock, send outside + var payload *models.ActiveProfile + if stored != nil { + payloadCopy := *stored + payload = &payloadCopy + } + + s.mu.Unlock() + + notifications.ProfilesActiveChanged(s.Notifications, models.ProfilesActiveNotification{Profile: payload}) +} + +// ActiveProfile returns a copy of the device's active profile snapshot, or +// nil when no profile is active. +func (s *State) ActiveProfile() *models.ActiveProfile { + s.mu.RLock() + defer s.mu.RUnlock() + if s.activeProfile == nil { + return nil + } + profileCopy := *s.activeProfile + return &profileCopy +} + func (s *State) GetLastScanned() tokens.Token { s.mu.RLock() defer s.mu.RUnlock() diff --git a/pkg/testing/helpers/db_mocks.go b/pkg/testing/helpers/db_mocks.go index 8283f1c6c..f64ebf6a3 100644 --- a/pkg/testing/helpers/db_mocks.go +++ b/pkg/testing/helpers/db_mocks.go @@ -493,6 +493,114 @@ func (m *MockUserDBI) CountClients() (int, error) { return count, nil } +func (m *MockUserDBI) CreateProfile(p *database.Profile) error { + args := m.Called(p) + if err := args.Error(0); err != nil { + return fmt.Errorf("mock UserDBI create profile failed: %w", err) + } + return nil +} + +func (m *MockUserDBI) GetProfile(profileID string) (*database.Profile, error) { + args := m.Called(profileID) + if result, ok := args.Get(0).(*database.Profile); ok { + if err := args.Error(1); err != nil { + return nil, fmt.Errorf("mock UserDBI get profile failed: %w", err) + } + return result, nil + } + if err := args.Error(1); err != nil { + return nil, fmt.Errorf("mock UserDBI get profile failed: %w", err) + } + return nil, nil //nolint:nilnil // mock returns nil when no profile is configured +} + +func (m *MockUserDBI) GetProfileBySwitchID(switchID string) (*database.Profile, error) { + args := m.Called(switchID) + if result, ok := args.Get(0).(*database.Profile); ok { + if err := args.Error(1); err != nil { + return nil, fmt.Errorf("mock UserDBI get profile by switch ID failed: %w", err) + } + return result, nil + } + if err := args.Error(1); err != nil { + return nil, fmt.Errorf("mock UserDBI get profile by switch ID failed: %w", err) + } + return nil, nil //nolint:nilnil // mock returns nil when no profile is configured +} + +func (m *MockUserDBI) ListProfiles() ([]database.Profile, error) { + args := m.Called() + if profiles, ok := args.Get(0).([]database.Profile); ok { + if err := args.Error(1); err != nil { + return profiles, fmt.Errorf("mock UserDBI list profiles failed: %w", err) + } + return profiles, nil + } + if err := args.Error(1); err != nil { + return nil, fmt.Errorf("mock UserDBI list profiles failed: %w", err) + } + return nil, nil +} + +func (m *MockUserDBI) UpdateProfile(p *database.Profile) error { + args := m.Called(p) + if err := args.Error(0); err != nil { + return fmt.Errorf("mock UserDBI update profile failed: %w", err) + } + return nil +} + +func (m *MockUserDBI) DeleteProfile(profileID string) error { + args := m.Called(profileID) + if err := args.Error(0); err != nil { + return fmt.Errorf("mock UserDBI delete profile failed: %w", err) + } + return nil +} + +func (m *MockUserDBI) GetMediaHistoryByProfile( + profileID string, lastID int64, limit int, +) ([]database.MediaHistoryEntry, error) { + args := m.Called(profileID, lastID, limit) + history, ok := args.Get(0).([]database.MediaHistoryEntry) + if !ok { + history = []database.MediaHistoryEntry{} + } + if err := args.Error(1); err != nil { + return history, fmt.Errorf("mock UserDBI get media history by profile failed: %w", err) + } + return history, nil +} + +func (m *MockUserDBI) SetDeviceState(key, value string) error { + args := m.Called(key, value) + if err := args.Error(0); err != nil { + return fmt.Errorf("mock UserDBI set device state failed: %w", err) + } + return nil +} + +func (m *MockUserDBI) GetDeviceState(key string) (value string, found bool, err error) { + args := m.Called(key) + if v, ok := args.Get(0).(string); ok { + value = v + } + found = args.Bool(1) + if err := args.Error(2); err != nil { + return value, found, fmt.Errorf("mock UserDBI get device state failed: %w", err) + } + return value, found, nil +} + +func (m *MockUserDBI) DeleteDeviceState(key string) error { + args := m.Called(key) + if err := args.Error(0); err != nil { + return fmt.Errorf("mock UserDBI delete device state failed: %w", err) + } + return nil +} + // MockMediaDBI is a mock implementation of the MediaDBI interface using testify/mock type MockMediaDBI struct { mock.Mock diff --git a/pkg/zapscript/commands.go b/pkg/zapscript/commands.go index e3f117a59..c3e2591ef 100644 --- a/pkg/zapscript/commands.go +++ b/pkg/zapscript/commands.go @@ -109,6 +109,9 @@ func lookupCmd(name string) (cmdFunc, bool) { zapscript.ZapScriptCmdControl: cmdControl, zapscript.ZapScriptCmdScreenshot: cmdScreenshot, + zapscript.ZapScriptCmdProfileSwitch: cmdProfileSwitch, + zapscript.ZapScriptCmdProfileClear: cmdProfileClear, + zapscript.ZapScriptCmdMisterINI: forwardCmd, zapscript.ZapScriptCmdMisterCore: forwardCmd, zapscript.ZapScriptCmdMisterScript: forwardCmd, diff --git a/pkg/zapscript/profile.go b/pkg/zapscript/profile.go new file mode 100644 index 000000000..f34bc096a --- /dev/null +++ b/pkg/zapscript/profile.go @@ -0,0 +1,65 @@ +// Zaparoo Core +// Copyright (c) 2026 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package zapscript + +import ( + "errors" + "fmt" + + "github.com/ZaparooProject/zaparoo-core/v2/pkg/platforms" +) + +// cmdProfileSwitch handles **profile.switch: — the card-scan path +// for changing the device's active profile. The switch ID is resolved here +// so an unknown card fails the script (and plays the fail sound); the +// actual activation is applied by the service layer from the returned +// CmdResult. No PIN is checked on this path: possession of the card is the +// authorization. +// +//nolint:gocritic // single-use parameter in command handler +func cmdProfileSwitch(_ platforms.Platform, env platforms.CmdEnv) (platforms.CmdResult, error) { + if len(env.Cmd.Args) != 1 || env.Cmd.Args[0] == "" { + return platforms.CmdResult{}, ErrArgCount + } + switchID := env.Cmd.Args[0] + + if env.Database == nil || env.Database.UserDB == nil { + return platforms.CmdResult{}, errors.New("user database not available") + } + if _, err := env.Database.UserDB.GetProfileBySwitchID(switchID); err != nil { + return platforms.CmdResult{}, fmt.Errorf("unknown profile switch ID: %w", err) + } + + return platforms.CmdResult{ + ProfileSwitch: &platforms.ProfileSwitchRequest{SwitchID: switchID}, + }, nil +} + +// cmdProfileClear handles **profile.clear — deactivates the active profile. +// +//nolint:gocritic // single-use parameter in command handler +func cmdProfileClear(_ platforms.Platform, env platforms.CmdEnv) (platforms.CmdResult, error) { + if len(env.Cmd.Args) > 0 { + return platforms.CmdResult{}, ErrArgCount + } + return platforms.CmdResult{ + ProfileSwitch: &platforms.ProfileSwitchRequest{Clear: true}, + }, nil +} diff --git a/pkg/zapscript/profile_test.go b/pkg/zapscript/profile_test.go new file mode 100644 index 000000000..6c6bd4824 --- /dev/null +++ b/pkg/zapscript/profile_test.go @@ -0,0 +1,108 @@ +// Zaparoo Core +// Copyright (c) 2026 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package zapscript + +import ( + "testing" + + "github.com/ZaparooProject/go-zapscript" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/database" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/database/userdb" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/platforms" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/testing/helpers" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func profileCmdEnv(mockDB *helpers.MockUserDBI, name string, args []string) platforms.CmdEnv { + return platforms.CmdEnv{ + Cmd: zapscript.Command{ + Name: name, + Args: args, + }, + Database: &database.Database{UserDB: mockDB, MediaDB: nil}, + } +} + +func TestCmdProfileSwitch_Success(t *testing.T) { + t.Parallel() + + mockDB := helpers.NewMockUserDBI() + mockDB.On("GetProfileBySwitchID", "corn-arm-truck"). + Return(&database.Profile{ProfileID: "p1", Name: "Kid A", SwitchID: "corn-arm-truck"}, nil) + + result, err := cmdProfileSwitch(nil, profileCmdEnv(mockDB, "profile.switch", []string{"corn-arm-truck"})) + require.NoError(t, err) + require.NotNil(t, result.ProfileSwitch) + assert.Equal(t, "corn-arm-truck", result.ProfileSwitch.SwitchID) + assert.False(t, result.ProfileSwitch.Clear) + assert.False(t, result.MediaChanged, "profile switch must not count as a media change") +} + +func TestCmdProfileSwitch_UnknownSwitchID(t *testing.T) { + t.Parallel() + + mockDB := helpers.NewMockUserDBI() + mockDB.On("GetProfileBySwitchID", "no-such-card").Return(nil, userdb.ErrProfileNotFound) + + _, err := cmdProfileSwitch(nil, profileCmdEnv(mockDB, "profile.switch", []string{"no-such-card"})) + require.Error(t, err) + assert.Contains(t, err.Error(), "unknown profile switch ID") +} + +func TestCmdProfileSwitch_ArgValidation(t *testing.T) { + t.Parallel() + + mockDB := helpers.NewMockUserDBI() + + _, err := cmdProfileSwitch(nil, profileCmdEnv(mockDB, "profile.switch", nil)) + require.ErrorIs(t, err, ErrArgCount) + + _, err = cmdProfileSwitch(nil, profileCmdEnv(mockDB, "profile.switch", []string{""})) + require.ErrorIs(t, err, ErrArgCount) + + _, err = cmdProfileSwitch(nil, profileCmdEnv(mockDB, "profile.switch", []string{"a", "b"})) + require.ErrorIs(t, err, ErrArgCount) +} + +func TestCmdProfileClear(t *testing.T) { + t.Parallel() + + mockDB := helpers.NewMockUserDBI() + + result, err := cmdProfileClear(nil, profileCmdEnv(mockDB, "profile.clear", nil)) + require.NoError(t, err) + require.NotNil(t, result.ProfileSwitch) + assert.True(t, result.ProfileSwitch.Clear) + + _, err = cmdProfileClear(nil, profileCmdEnv(mockDB, "profile.clear", []string{"extra"})) + require.ErrorIs(t, err, ErrArgCount) +} + +func TestProfileCommands_NotMediaLaunching(t *testing.T) { + t.Parallel() + + // Profile switching must never be blocked by playtime limits — a kid + // who has hit their limit can still hand the device to a parent card. + assert.False(t, IsMediaLaunchingCommand(zapscript.ZapScriptCmdProfileSwitch)) + assert.False(t, IsMediaLaunchingCommand(zapscript.ZapScriptCmdProfileClear)) + assert.False(t, IsMediaDisruptingCommand(zapscript.ZapScriptCmdProfileSwitch)) + assert.False(t, IsMediaDisruptingCommand(zapscript.ZapScriptCmdProfileClear)) +}