Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ var db *gorm.DB

func Init(d *gorm.DB) {
db = d
err := AutoMigrate(new(model.Storage), new(model.User), new(model.Meta), new(model.SettingItem), new(model.SearchNode), new(model.TaskItem), new(model.SSHPublicKey), new(model.Role), new(model.Label), new(model.LabelFileBinding), new(model.ObjFile), new(model.Session))
err := AutoMigrate(new(model.Storage), new(model.User), new(model.Meta), new(model.SettingItem), new(model.SearchNode), new(model.TaskItem), new(model.SSHPublicKey), new(model.Role), new(model.Label), new(model.LabelFileBinding), new(model.ObjFile), new(model.Session), new(model.Share))
if err != nil {
log.Fatalf("failed migrate database: %s", err.Error())
}
Expand Down
124 changes: 124 additions & 0 deletions internal/db/share.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package db

import (
"time"

"github.com/alist-org/alist/v3/internal/model"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)

func GetShareByShareID(shareID string) (*model.Share, error) {
var share model.Share
if err := db.Where("share_id = ?", shareID).Take(&share).Error; err != nil {
return nil, err
}
return &share, nil
}

func GetShareByCreatorAndShareID(creatorID uint, shareID string) (*model.Share, error) {
var share model.Share
if err := db.Where("creator_id = ? AND share_id = ?", creatorID, shareID).Take(&share).Error; err != nil {
return nil, err
}
return &share, nil
}

func ShareIDExists(shareID string) (bool, error) {
var count int64
if err := db.Model(&model.Share{}).Where("share_id = ?", shareID).Count(&count).Error; err != nil {
return false, err
}
return count > 0, nil
}

func ShareIDExistsExceptID(shareID string, id uint) (bool, error) {
var count int64
if err := db.Model(&model.Share{}).Where("share_id = ? AND id <> ?", shareID, id).Count(&count).Error; err != nil {
return false, err
}
return count > 0, nil
}

func CreateShare(share *model.Share) error {
return db.Create(share).Error
}

func UpdateShare(share *model.Share) error {
return db.Save(share).Error
}

func GetSharesByCreator(creatorID uint, pageIndex, pageSize int) (shares []model.Share, count int64, err error) {
tx := db.Model(&model.Share{}).Where("creator_id = ?", creatorID)
err = tx.Count(&count).Error
if err != nil {
return nil, 0, err
}
err = tx.Order("created_at desc").Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&shares).Error
return
}

func DeleteShareByShareID(creatorID uint, shareID string) error {
return db.Where("creator_id = ? AND share_id = ?", creatorID, shareID).Delete(&model.Share{}).Error
}

func DisableShareByShareID(creatorID uint, shareID string) error {
return db.Model(&model.Share{}).
Where("creator_id = ? AND share_id = ?", creatorID, shareID).
Update("enabled", false).Error
}

func TouchShareView(shareID string) error {
now := time.Now()
return db.Model(&model.Share{}).
Where("share_id = ?", shareID).
UpdateColumns(map[string]interface{}{
"last_access_at": now,
"view_count": gorm.Expr("view_count + ?", 1),
}).Error
}

func TouchShareDownload(shareID string) error {
now := time.Now()
return db.Model(&model.Share{}).
Where("share_id = ?", shareID).
UpdateColumns(map[string]interface{}{
"last_access_at": now,
"download_count": gorm.Expr("download_count + ?", 1),
}).Error
}

func RecordShareAccess(shareID string) (*model.Share, error) {
var updated model.Share
err := db.Transaction(func(tx *gorm.DB) error {
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
Where("share_id = ?", shareID).
Take(&updated).Error; err != nil {
return err
}

now := time.Now()
updated.AccessCount++
updated.LastAccessAt = &now
updates := map[string]interface{}{
"access_count": updated.AccessCount,
"last_access_at": now,
}

limit := updated.EffectiveAccessLimit()
if limit > 0 && updated.AccessCount >= limit {
updated.Enabled = false
updated.ConsumedAt = &now
updates["enabled"] = false
updates["consumed_at"] = now
}

return tx.Model(&model.Share{}).
Where("id = ?", updated.ID).
Updates(updates).Error
})
if err != nil {
return nil, err
}
return &updated, nil
}
62 changes: 62 additions & 0 deletions internal/model/share.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package model

import "time"

type Share struct {
ID uint `json:"id" gorm:"primaryKey"`
ShareID string `json:"share_id" gorm:"uniqueIndex;size:32;not null"`
CreatorID uint `json:"creator_id" gorm:"index;not null"`
Name string `json:"name" gorm:"size:255;not null"`
RootPath string `json:"root_path" gorm:"size:4096;not null"`
IsDir bool `json:"is_dir"`
PasswordHash string `json:"-" gorm:"size:64"`
PasswordSalt string `json:"-" gorm:"size:32"`
BurnAfterRead bool `json:"burn_after_read" gorm:"default:false"`
AccessLimit int64 `json:"access_limit"`
AccessCount int64 `json:"access_count"`
AllowPreview bool `json:"allow_preview" gorm:"default:true"`
AllowDownload bool `json:"allow_download" gorm:"default:true"`
Enabled bool `json:"enabled" gorm:"default:true;index"`
ViewCount int64 `json:"view_count"`
DownloadCount int64 `json:"download_count"`
LastAccessAt *time.Time `json:"last_access_at"`
ConsumedAt *time.Time `json:"consumed_at"`
ExpiresAt *time.Time `json:"expires_at"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}

func (s Share) HasPassword() bool {
return s.PasswordHash != ""
}

func (s Share) EffectiveAccessLimit() int64 {
if s.AccessLimit > 0 {
return s.AccessLimit
}
if s.BurnAfterRead {
return 1
}
return 0
}

func (s Share) RemainingAccesses() int64 {
limit := s.EffectiveAccessLimit()
if limit <= 0 {
return 0
}
remaining := limit - s.AccessCount
if remaining < 0 {
return 0
}
return remaining
}

func (s Share) IsConsumed() bool {
limit := s.EffectiveAccessLimit()
return s.ConsumedAt != nil || (limit > 0 && s.AccessCount >= limit)
}

func (s Share) IsExpired(now time.Time) bool {
return s.ExpiresAt != nil && !s.ExpiresAt.After(now)
}
31 changes: 31 additions & 0 deletions internal/share/access.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package share

import (
"fmt"
"time"

"github.com/alist-org/alist/v3/internal/conf"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/setting"
signPkg "github.com/alist-org/alist/v3/pkg/sign"
)

func tokenPayload(share *model.Share) string {
updatedAt := int64(0)
if !share.UpdatedAt.IsZero() {
updatedAt = share.UpdatedAt.Unix()
}
return fmt.Sprintf("%s:%s:%d", share.ShareID, share.PasswordHash, updatedAt)
}

func signer() signPkg.Sign {
return signPkg.NewHMACSign([]byte(setting.GetStr(conf.Token) + "-share-access"))
}

func SignAccess(share *model.Share, d time.Duration) string {
return signer().Sign(tokenPayload(share), time.Now().Add(d).Unix())
}

func VerifyAccess(share *model.Share, token string) error {
return signer().Verify(tokenPayload(share), token)
}
Loading
Loading