Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 94 additions & 53 deletions controller/passkey.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package controller

import (
"encoding/base64"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -40,26 +41,26 @@ func PasskeyRegisterBegin(c *gin.Context) {
return
}

credential, err := model.GetPasskeyByUserID(user.Id)
if err != nil && !errors.Is(err, model.ErrPasskeyNotFound) {
credentials, err := model.GetPasskeyByUserID(user.Id)
if err != nil {
common.ApiError(c, err)
return
}
if errors.Is(err, model.ErrPasskeyNotFound) {
credential = nil
}

wa, err := passkeysvc.BuildWebAuthn(c.Request)
if err != nil {
common.ApiError(c, err)
return
}

waUser := passkeysvc.NewWebAuthnUser(user, credential)
waUser := passkeysvc.NewWebAuthnUser(user, credentials...)
var options []webauthnlib.RegistrationOption
if credential != nil {
descriptor := credential.ToWebAuthnCredential().Descriptor()
options = append(options, webauthnlib.WithExclusions([]protocol.CredentialDescriptor{descriptor}))
if len(credentials) > 0 {
descriptors := make([]protocol.CredentialDescriptor, 0, len(credentials))
for _, cred := range credentials {
descriptors = append(descriptors, cred.ToWebAuthnCredential().Descriptor())
}
options = append(options, webauthnlib.WithExclusions(descriptors))
}

creation, sessionData, err := wa.BeginRegistration(waUser, options...)
Expand Down Expand Up @@ -110,22 +111,19 @@ func PasskeyRegisterFinish(c *gin.Context) {
return
}

credentialRecord, err := model.GetPasskeyByUserID(user.Id)
if err != nil && !errors.Is(err, model.ErrPasskeyNotFound) {
credentials, err := model.GetPasskeyByUserID(user.Id)
if err != nil {
common.ApiError(c, err)
return
}
if errors.Is(err, model.ErrPasskeyNotFound) {
credentialRecord = nil
}

sessionData, err := passkeysvc.PopSessionData(c, passkeysvc.RegistrationSessionKey)
if err != nil {
common.ApiError(c, err)
return
}

waUser := passkeysvc.NewWebAuthnUser(user, credentialRecord)
waUser := passkeysvc.NewWebAuthnUser(user, credentials...)
credential, err := wa.FinishRegistration(waUser, *sessionData, c.Request)
if err != nil {
common.ApiError(c, err)
Expand Down Expand Up @@ -159,11 +157,17 @@ func PasskeyDelete(c *gin.Context) {
return
}

credentialID := c.Param("credential_id")
if credentialID == "" {
common.ApiErrorMsg(c, "无效的凭证 ID")
return
}

if !requirePasskeyDeleteVerification(c, user.Id) {
return
}

if err := model.DeletePasskeyByUserID(user.Id); err != nil {
if err := model.DeletePasskeyByCredentialID(credentialID, user.Id); err != nil {
common.ApiError(c, err)
return
}
Expand All @@ -184,31 +188,43 @@ func PasskeyStatus(c *gin.Context) {
return
}

credential, err := model.GetPasskeyByUserID(user.Id)
if errors.Is(err, model.ErrPasskeyNotFound) {
credentials, err := model.GetPasskeyByUserID(user.Id)
if err != nil {
common.ApiError(c, err)
return
}

if len(credentials) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": gin.H{
"enabled": false,
"credentials": []gin.H{},
},
})
return
}
if err != nil {
common.ApiError(c, err)
return
}

data := gin.H{
"enabled": true,
"last_used_at": credential.LastUsedAt,
credList := make([]gin.H, 0, len(credentials))
for _, cred := range credentials {
credList = append(credList, gin.H{
"credential_id": cred.CredentialID,
"created_at": cred.CreatedAt,
"last_used_at": cred.LastUsedAt,
"backup_eligible": cred.BackupEligible,
"backup_state": cred.BackupState,
"attachment": cred.Attachment,
})
}

c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": data,
"data": gin.H{
"enabled": true,
"credentials": credList,
},
})
}

Expand Down Expand Up @@ -351,17 +367,18 @@ func AdminResetPasskey(c *gin.Context) {
return
}

if _, err := model.GetPasskeyByUserID(user.Id); err != nil {
if errors.Is(err, model.ErrPasskeyNotFound) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该用户尚未绑定 Passkey",
})
return
}
credentials, err := model.GetPasskeyByUserID(user.Id)
if err != nil {
common.ApiError(c, err)
return
}
if len(credentials) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该用户尚未绑定 Passkey",
})
return
}

if err := model.DeletePasskeyByUserID(user.Id); err != nil {
common.ApiError(c, err)
Expand Down Expand Up @@ -392,22 +409,29 @@ func PasskeyVerifyBegin(c *gin.Context) {
return
}

credential, err := model.GetPasskeyByUserID(user.Id)
credentials, err := model.GetPasskeyByUserID(user.Id)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该用户尚未绑定 Passkey",
})
return
}
if len(credentials) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该用户尚未绑定 Passkey",
})
return
}

wa, err := passkeysvc.BuildWebAuthn(c.Request)
if err != nil {
common.ApiError(c, err)
return
}

waUser := passkeysvc.NewWebAuthnUser(user, credential)
waUser := passkeysvc.NewWebAuthnUser(user, credentials...)
assertion, sessionData, err := wa.BeginLogin(waUser)
if err != nil {
common.ApiError(c, err)
Expand Down Expand Up @@ -452,34 +476,51 @@ func PasskeyVerifyFinish(c *gin.Context) {
return
}

credential, err := model.GetPasskeyByUserID(user.Id)
credentials, err := model.GetPasskeyByUserID(user.Id)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该用户尚未绑定 Passkey",
})
return
}
if len(credentials) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该用户尚未绑定 Passkey",
})
return
}

sessionData, err := passkeysvc.PopSessionData(c, passkeysvc.VerifySessionKey)
if err != nil {
common.ApiError(c, err)
return
}

waUser := passkeysvc.NewWebAuthnUser(user, credential)
_, err = wa.FinishLogin(waUser, *sessionData, c.Request)
waUser := passkeysvc.NewWebAuthnUser(user, credentials...)
credential, err := wa.FinishLogin(waUser, *sessionData, c.Request)
if err != nil {
common.ApiError(c, err)
return
}

// 更新凭证的最后使用时间
now := time.Now()
credential.LastUsedAt = &now
if err := model.UpsertPasskeyCredential(credential); err != nil {
common.ApiError(c, err)
return
// 更新匹配的凭证的最后使用时间
credIDStr := base64.StdEncoding.EncodeToString(credential.ID)
var matched *model.PasskeyCredential
for _, cred := range credentials {
if cred.CredentialID == credIDStr {
matched = cred
break
}
}
if matched != nil {
now := time.Now()
matched.LastUsedAt = &now
if err := model.UpsertPasskeyCredential(matched); err != nil {
common.ApiError(c, err)
return
}
}

session := sessions.Default(c)
Expand Down Expand Up @@ -540,18 +581,18 @@ func requirePasskeyDeleteVerification(c *gin.Context, userID int) bool {
return requireSecureVerificationMethod(c, secureVerificationMethod2FA)
}

_, err = model.GetPasskeyByUserID(userID)
credentials, err := model.GetPasskeyByUserID(userID)
if err != nil {
if errors.Is(err, model.ErrPasskeyNotFound) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该用户尚未绑定 Passkey",
})
return false
}
common.ApiError(c, err)
return false
}
if len(credentials) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该用户尚未绑定 Passkey",
})
return false
}

return requireSecureVerificationMethod(c, secureVerificationMethodPasskey)
}
Expand Down
4 changes: 2 additions & 2 deletions controller/secure_verification.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ func UniversalVerify(c *gin.Context) {
twoFA, _ := model.GetTwoFAByUserId(userId)
has2FA := twoFA != nil && twoFA.IsEnabled

passkey, passkeyErr := model.GetPasskeyByUserID(userId)
hasPasskey := passkeyErr == nil && passkey != nil
passkeys, passkeyErr := model.GetPasskeyByUserID(userId)
hasPasskey := passkeyErr == nil && len(passkeys) > 0

if !has2FA && !hasPasskey {
common.ApiError(c, fmt.Errorf("用户未启用2FA或Passkey"))
Expand Down
26 changes: 26 additions & 0 deletions model/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ func migrateDB() error {
if err := migrateTokenModelLimitsToText(); err != nil {
return err
}
// Migrate passkey_credentials: drop old unique index on user_id to allow multiple passkeys per user
migratePasskeyUserIDUniqueIndex()

err := DB.AutoMigrate(
&Channel{},
Expand Down Expand Up @@ -704,3 +706,27 @@ func PingDB() error {
common.SysLog("Database pinged successfully")
return nil
}


func migratePasskeyUserIDUniqueIndex() {
m := DB.Migrator()
indexes, err := m.GetIndexes(&PasskeyCredential{})
if err != nil {
common.SysLog(fmt.Sprintf("failed to get indexes for passkey_credentials: %v", err))
return
}
for _, idx := range indexes {
unique, ok := idx.Unique()
if !ok || !unique {
continue
}
cols := idx.Columns()
if len(cols) == 1 && cols[0] == "user_id" {
if err := m.DropIndex(&PasskeyCredential{}, idx.Name()); err != nil {
common.SysLog(fmt.Sprintf("failed to drop index %s: %v", idx.Name(), err))
} else {
common.SysLog(fmt.Sprintf("dropped old unique index %s on passkey_credentials.user_id", idx.Name()))
}
}
}
}
Loading