Skip to content

Commit 8dc77ea

Browse files
committed
HYPERFLEET-995 - refactor: remove pointer-to-interface on SessionFactory
Go interfaces are already reference types. Storing *db.SessionFactory and dereferencing with (*d.sessionFactory).New(ctx) adds unnecessary indirection. Change all DAOs to store db.SessionFactory directly.
1 parent 8561b0b commit 8dc77ea

9 files changed

Lines changed: 48 additions & 48 deletions

File tree

pkg/dao/adapter_status.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,15 @@ type AdapterStatusDao interface {
2929
var _ AdapterStatusDao = &sqlAdapterStatusDao{}
3030

3131
type sqlAdapterStatusDao struct {
32-
sessionFactory *db.SessionFactory
32+
sessionFactory db.SessionFactory
3333
}
3434

35-
func NewAdapterStatusDao(sessionFactory *db.SessionFactory) AdapterStatusDao {
35+
func NewAdapterStatusDao(sessionFactory db.SessionFactory) AdapterStatusDao {
3636
return &sqlAdapterStatusDao{sessionFactory: sessionFactory}
3737
}
3838

3939
func (d *sqlAdapterStatusDao) Get(ctx context.Context, id string) (*api.AdapterStatus, error) {
40-
g2 := (*d.sessionFactory).New(ctx)
40+
g2 := d.sessionFactory.New(ctx)
4141
var adapterStatus api.AdapterStatus
4242
if err := g2.Take(&adapterStatus, "id = ?", id).Error; err != nil {
4343
return nil, err
@@ -48,7 +48,7 @@ func (d *sqlAdapterStatusDao) Get(ctx context.Context, id string) (*api.AdapterS
4848
func (d *sqlAdapterStatusDao) Create(
4949
ctx context.Context, adapterStatus *api.AdapterStatus,
5050
) (*api.AdapterStatus, error) {
51-
g2 := (*d.sessionFactory).New(ctx)
51+
g2 := d.sessionFactory.New(ctx)
5252
if err := g2.Omit(clause.Associations).Create(adapterStatus).Error; err != nil {
5353
db.MarkForRollback(ctx, err)
5454
return nil, err
@@ -59,7 +59,7 @@ func (d *sqlAdapterStatusDao) Create(
5959
func (d *sqlAdapterStatusDao) Upsert(
6060
ctx context.Context, adapterStatus *api.AdapterStatus, existing *api.AdapterStatus,
6161
) (*api.AdapterStatus, error) {
62-
g2 := (*d.sessionFactory).New(ctx)
62+
g2 := d.sessionFactory.New(ctx)
6363

6464
if existing != nil {
6565
updateResult := g2.Model(&api.AdapterStatus{}).
@@ -104,7 +104,7 @@ func (d *sqlAdapterStatusDao) Upsert(
104104

105105
// Delete permanently removes the adapter status row from the database.
106106
func (d *sqlAdapterStatusDao) Delete(ctx context.Context, id string) error {
107-
g2 := (*d.sessionFactory).New(ctx)
107+
g2 := d.sessionFactory.New(ctx)
108108
adapterStatus := &api.AdapterStatus{Meta: api.Meta{ID: id}}
109109
if err := g2.Omit(clause.Associations).Delete(adapterStatus).Error; err != nil {
110110
db.MarkForRollback(ctx, err)
@@ -114,7 +114,7 @@ func (d *sqlAdapterStatusDao) Delete(ctx context.Context, id string) error {
114114
}
115115

116116
func (d *sqlAdapterStatusDao) DeleteByResource(ctx context.Context, resourceType, resourceID string) error {
117-
g2 := (*d.sessionFactory).New(ctx)
117+
g2 := d.sessionFactory.New(ctx)
118118
if err := g2.Where("resource_type = ? AND resource_id = ?", resourceType, resourceID).
119119
Delete(&api.AdapterStatus{}).Error; err != nil {
120120
db.MarkForRollback(ctx, err)
@@ -126,7 +126,7 @@ func (d *sqlAdapterStatusDao) DeleteByResource(ctx context.Context, resourceType
126126
func (d *sqlAdapterStatusDao) FindByResource(
127127
ctx context.Context, resourceType, resourceID string,
128128
) (api.AdapterStatusList, error) {
129-
g2 := (*d.sessionFactory).New(ctx)
129+
g2 := d.sessionFactory.New(ctx)
130130
statuses := api.AdapterStatusList{}
131131
query := g2.Where("resource_type = ? AND resource_id = ?", resourceType, resourceID)
132132
if err := query.Find(&statuses).Error; err != nil {
@@ -138,7 +138,7 @@ func (d *sqlAdapterStatusDao) FindByResource(
138138
func (d *sqlAdapterStatusDao) FindByResourceIDs(
139139
ctx context.Context, resourceType string, resourceIDs []string,
140140
) (api.AdapterStatusList, error) {
141-
g2 := (*d.sessionFactory).New(ctx)
141+
g2 := d.sessionFactory.New(ctx)
142142
statuses := api.AdapterStatusList{}
143143
if len(resourceIDs) == 0 {
144144
return statuses, nil
@@ -153,7 +153,7 @@ func (d *sqlAdapterStatusDao) FindByResourceIDs(
153153
func (d *sqlAdapterStatusDao) FindByResourcePaginated(
154154
ctx context.Context, resourceType, resourceID string, offset, limit int,
155155
) (api.AdapterStatusList, int64, error) {
156-
g2 := (*d.sessionFactory).New(ctx)
156+
g2 := d.sessionFactory.New(ctx)
157157
statuses := api.AdapterStatusList{}
158158
var total int64
159159

@@ -176,7 +176,7 @@ func (d *sqlAdapterStatusDao) FindByResourcePaginated(
176176
func (d *sqlAdapterStatusDao) FindByResourceAndAdapter(
177177
ctx context.Context, resourceType, resourceID, adapter string,
178178
) (*api.AdapterStatus, error) {
179-
g2 := (*d.sessionFactory).New(ctx)
179+
g2 := d.sessionFactory.New(ctx)
180180
var adapterStatus api.AdapterStatus
181181
query := g2.Where("resource_type = ? AND resource_id = ? AND adapter = ?", resourceType, resourceID, adapter)
182182
if err := query.Take(&adapterStatus).Error; err != nil {
@@ -186,7 +186,7 @@ func (d *sqlAdapterStatusDao) FindByResourceAndAdapter(
186186
}
187187

188188
func (d *sqlAdapterStatusDao) All(ctx context.Context) (api.AdapterStatusList, error) {
189-
g2 := (*d.sessionFactory).New(ctx)
189+
g2 := d.sessionFactory.New(ctx)
190190
statuses := api.AdapterStatusList{}
191191
if err := g2.Find(&statuses).Error; err != nil {
192192
return nil, err

pkg/dao/cluster.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@ type ClusterDao interface {
2323
var _ ClusterDao = &sqlClusterDao{}
2424

2525
type sqlClusterDao struct {
26-
sessionFactory *db.SessionFactory
26+
sessionFactory db.SessionFactory
2727
}
2828

29-
func NewClusterDao(sessionFactory *db.SessionFactory) ClusterDao {
29+
func NewClusterDao(sessionFactory db.SessionFactory) ClusterDao {
3030
return &sqlClusterDao{sessionFactory: sessionFactory}
3131
}
3232

3333
func (d *sqlClusterDao) Get(ctx context.Context, id string) (*api.Cluster, error) {
34-
g2 := (*d.sessionFactory).New(ctx)
34+
g2 := d.sessionFactory.New(ctx)
3535
var cluster api.Cluster
3636
if err := g2.Take(&cluster, "id = ?", id).Error; err != nil {
3737
return nil, err
@@ -40,7 +40,7 @@ func (d *sqlClusterDao) Get(ctx context.Context, id string) (*api.Cluster, error
4040
}
4141

4242
func (d *sqlClusterDao) GetForUpdate(ctx context.Context, id string) (*api.Cluster, error) {
43-
g2 := (*d.sessionFactory).New(ctx)
43+
g2 := d.sessionFactory.New(ctx)
4444
var cluster api.Cluster
4545
if err := g2.Clauses(clause.Locking{Strength: "UPDATE"}).Take(&cluster, "id = ?", id).Error; err != nil {
4646
return nil, err
@@ -49,7 +49,7 @@ func (d *sqlClusterDao) GetForUpdate(ctx context.Context, id string) (*api.Clust
4949
}
5050

5151
func (d *sqlClusterDao) Create(ctx context.Context, cluster *api.Cluster) (*api.Cluster, error) {
52-
g2 := (*d.sessionFactory).New(ctx)
52+
g2 := d.sessionFactory.New(ctx)
5353
if err := g2.Omit(clause.Associations).Create(cluster).Error; err != nil {
5454
db.MarkForRollback(ctx, err)
5555
return nil, err
@@ -58,7 +58,7 @@ func (d *sqlClusterDao) Create(ctx context.Context, cluster *api.Cluster) (*api.
5858
}
5959

6060
func (d *sqlClusterDao) Save(ctx context.Context, cluster *api.Cluster) error {
61-
g2 := (*d.sessionFactory).New(ctx)
61+
g2 := d.sessionFactory.New(ctx)
6262
if err := g2.Omit(clause.Associations).Save(cluster).Error; err != nil {
6363
db.MarkForRollback(ctx, err)
6464
return err
@@ -67,7 +67,7 @@ func (d *sqlClusterDao) Save(ctx context.Context, cluster *api.Cluster) error {
6767
}
6868

6969
func (d *sqlClusterDao) SaveStatusConditions(ctx context.Context, id string, statusConditions []byte) error {
70-
g2 := (*d.sessionFactory).New(ctx)
70+
g2 := d.sessionFactory.New(ctx)
7171
result := g2.Model(&api.Cluster{}).Where("id = ?", id).Update("status_conditions", statusConditions)
7272
if result.Error != nil {
7373
db.MarkForRollback(ctx, result.Error)
@@ -77,7 +77,7 @@ func (d *sqlClusterDao) SaveStatusConditions(ctx context.Context, id string, sta
7777
}
7878

7979
func (d *sqlClusterDao) Delete(ctx context.Context, id string) error {
80-
g2 := (*d.sessionFactory).New(ctx)
80+
g2 := d.sessionFactory.New(ctx)
8181
if err := g2.Omit(clause.Associations).Delete(&api.Cluster{Meta: api.Meta{ID: id}}).Error; err != nil {
8282
db.MarkForRollback(ctx, err)
8383
return err
@@ -86,7 +86,7 @@ func (d *sqlClusterDao) Delete(ctx context.Context, id string) error {
8686
}
8787

8888
func (d *sqlClusterDao) FindByIDs(ctx context.Context, ids []string) (api.ClusterList, error) {
89-
g2 := (*d.sessionFactory).New(ctx)
89+
g2 := d.sessionFactory.New(ctx)
9090
clusters := api.ClusterList{}
9191
if err := g2.Where("id in (?)", ids).Find(&clusters).Error; err != nil {
9292
return nil, err
@@ -95,7 +95,7 @@ func (d *sqlClusterDao) FindByIDs(ctx context.Context, ids []string) (api.Cluste
9595
}
9696

9797
func (d *sqlClusterDao) All(ctx context.Context) (api.ClusterList, error) {
98-
g2 := (*d.sessionFactory).New(ctx)
98+
g2 := d.sessionFactory.New(ctx)
9999
clusters := api.ClusterList{}
100100
if err := g2.Find(&clusters).Error; err != nil {
101101
return nil, err

pkg/dao/generic.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ type GenericDao interface {
4141
var _ GenericDao = &sqlGenericDao{}
4242

4343
type sqlGenericDao struct {
44-
sessionFactory *db.SessionFactory
44+
sessionFactory db.SessionFactory
4545
g2 *gorm.DB
4646
}
4747

@@ -54,14 +54,14 @@ type TableRelation struct {
5454
ForeignColumnName string
5555
}
5656

57-
func NewGenericDao(sessionFactory *db.SessionFactory) GenericDao {
57+
func NewGenericDao(sessionFactory db.SessionFactory) GenericDao {
5858
return &sqlGenericDao{sessionFactory: sessionFactory}
5959
}
6060

6161
func (d *sqlGenericDao) GetInstanceDao(ctx context.Context, model interface{}) GenericDao {
6262
return &sqlGenericDao{
6363
sessionFactory: d.sessionFactory,
64-
g2: (*d.sessionFactory).New(ctx).Model(model),
64+
g2: d.sessionFactory.New(ctx).Model(model),
6565
}
6666
}
6767

pkg/dao/node_pool.go

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@ type NodePoolDao interface {
2727
var _ NodePoolDao = &sqlNodePoolDao{}
2828

2929
type sqlNodePoolDao struct {
30-
sessionFactory *db.SessionFactory
30+
sessionFactory db.SessionFactory
3131
}
3232

33-
func NewNodePoolDao(sessionFactory *db.SessionFactory) NodePoolDao {
33+
func NewNodePoolDao(sessionFactory db.SessionFactory) NodePoolDao {
3434
return &sqlNodePoolDao{sessionFactory: sessionFactory}
3535
}
3636

3737
func (d *sqlNodePoolDao) Get(ctx context.Context, id string) (*api.NodePool, error) {
38-
g2 := (*d.sessionFactory).New(ctx)
38+
g2 := d.sessionFactory.New(ctx)
3939
var nodePool api.NodePool
4040
if err := g2.Take(&nodePool, "id = ?", id).Error; err != nil {
4141
return nil, err
@@ -44,7 +44,7 @@ func (d *sqlNodePoolDao) Get(ctx context.Context, id string) (*api.NodePool, err
4444
}
4545

4646
func (d *sqlNodePoolDao) GetByIDAndOwner(ctx context.Context, id string, ownerID string) (*api.NodePool, error) {
47-
g2 := (*d.sessionFactory).New(ctx)
47+
g2 := d.sessionFactory.New(ctx)
4848
var nodePool api.NodePool
4949
if err := g2.Take(&nodePool, "id = ? AND owner_id = ?", id, ownerID).Error; err != nil {
5050
return nil, err
@@ -53,7 +53,7 @@ func (d *sqlNodePoolDao) GetByIDAndOwner(ctx context.Context, id string, ownerID
5353
}
5454

5555
func (d *sqlNodePoolDao) GetForUpdate(ctx context.Context, id string) (*api.NodePool, error) {
56-
g2 := (*d.sessionFactory).New(ctx)
56+
g2 := d.sessionFactory.New(ctx)
5757
var nodePool api.NodePool
5858
if err := g2.Clauses(clause.Locking{Strength: "UPDATE"}).Take(&nodePool, "id = ?", id).Error; err != nil {
5959
return nil, err
@@ -62,7 +62,7 @@ func (d *sqlNodePoolDao) GetForUpdate(ctx context.Context, id string) (*api.Node
6262
}
6363

6464
func (d *sqlNodePoolDao) SaveStatusConditions(ctx context.Context, id string, statusConditions []byte) error {
65-
g2 := (*d.sessionFactory).New(ctx)
65+
g2 := d.sessionFactory.New(ctx)
6666
result := g2.Model(&api.NodePool{}).Where("id = ?", id).Update("status_conditions", statusConditions)
6767
if result.Error != nil {
6868
db.MarkForRollback(ctx, result.Error)
@@ -72,7 +72,7 @@ func (d *sqlNodePoolDao) SaveStatusConditions(ctx context.Context, id string, st
7272
}
7373

7474
func (d *sqlNodePoolDao) Create(ctx context.Context, nodePool *api.NodePool) (*api.NodePool, error) {
75-
g2 := (*d.sessionFactory).New(ctx)
75+
g2 := d.sessionFactory.New(ctx)
7676
if err := g2.Omit(clause.Associations).Create(nodePool).Error; err != nil {
7777
db.MarkForRollback(ctx, err)
7878
return nil, err
@@ -81,7 +81,7 @@ func (d *sqlNodePoolDao) Create(ctx context.Context, nodePool *api.NodePool) (*a
8181
}
8282

8383
func (d *sqlNodePoolDao) Save(ctx context.Context, nodePool *api.NodePool) error {
84-
g2 := (*d.sessionFactory).New(ctx)
84+
g2 := d.sessionFactory.New(ctx)
8585
if err := g2.Omit(clause.Associations).Save(nodePool).Error; err != nil {
8686
db.MarkForRollback(ctx, err)
8787
return err
@@ -90,7 +90,7 @@ func (d *sqlNodePoolDao) Save(ctx context.Context, nodePool *api.NodePool) error
9090
}
9191

9292
func (d *sqlNodePoolDao) Delete(ctx context.Context, id string) error {
93-
g2 := (*d.sessionFactory).New(ctx)
93+
g2 := d.sessionFactory.New(ctx)
9494
if err := g2.Omit(clause.Associations).Delete(&api.NodePool{Meta: api.Meta{ID: id}}).Error; err != nil {
9595
db.MarkForRollback(ctx, err)
9696
return err
@@ -99,7 +99,7 @@ func (d *sqlNodePoolDao) Delete(ctx context.Context, id string) error {
9999
}
100100

101101
func (d *sqlNodePoolDao) FindByIDs(ctx context.Context, ids []string) (api.NodePoolList, error) {
102-
g2 := (*d.sessionFactory).New(ctx)
102+
g2 := d.sessionFactory.New(ctx)
103103
nodePools := api.NodePoolList{}
104104
if err := g2.Where("id in (?)", ids).Find(&nodePools).Error; err != nil {
105105
return nil, err
@@ -108,7 +108,7 @@ func (d *sqlNodePoolDao) FindByIDs(ctx context.Context, ids []string) (api.NodeP
108108
}
109109

110110
func (d *sqlNodePoolDao) FindByOwner(ctx context.Context, ownerID string) (api.NodePoolList, error) {
111-
g2 := (*d.sessionFactory).New(ctx)
111+
g2 := d.sessionFactory.New(ctx)
112112
var nodePools api.NodePoolList
113113
if err := g2.Where("owner_id = ?", ownerID).Find(&nodePools).Error; err != nil {
114114
return nil, err
@@ -120,7 +120,7 @@ func (d *sqlNodePoolDao) SaveAll(ctx context.Context, nodePools api.NodePoolList
120120
if len(nodePools) == 0 {
121121
return nil
122122
}
123-
g2 := (*d.sessionFactory).New(ctx)
123+
g2 := d.sessionFactory.New(ctx)
124124
if err := g2.Omit(clause.Associations).Save(nodePools).Error; err != nil {
125125
db.MarkForRollback(ctx, err)
126126
return err
@@ -129,7 +129,7 @@ func (d *sqlNodePoolDao) SaveAll(ctx context.Context, nodePools api.NodePoolList
129129
}
130130

131131
func (d *sqlNodePoolDao) ExistsByOwner(ctx context.Context, ownerID string) (bool, error) {
132-
g2 := (*d.sessionFactory).New(ctx)
132+
g2 := d.sessionFactory.New(ctx)
133133
var count int64
134134
if err := g2.Model(&api.NodePool{}).Where("owner_id = ?", ownerID).Limit(1).Count(&count).Error; err != nil {
135135
return false, err
@@ -138,7 +138,7 @@ func (d *sqlNodePoolDao) ExistsByOwner(ctx context.Context, ownerID string) (boo
138138
}
139139

140140
func (d *sqlNodePoolDao) All(ctx context.Context) (api.NodePoolList, error) {
141-
g2 := (*d.sessionFactory).New(ctx)
141+
g2 := d.sessionFactory.New(ctx)
142142
nodePools := api.NodePoolList{}
143143
if err := g2.Find(&nodePools).Error; err != nil {
144144
return nil, err

pkg/services/generic_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ func TestSQLTranslation(t *testing.T) {
2222
var dbFactory db.SessionFactory = dbmocks.NewMockSessionFactory()
2323
defer dbFactory.Close() //nolint:errcheck
2424

25-
g := dao.NewGenericDao(&dbFactory)
25+
g := dao.NewGenericDao(dbFactory)
2626
genericService := sqlGenericService{genericDao: g}
2727

2828
// ill-formatted search or disallowed fields should be rejected

plugins/adapterStatus/plugin.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ type ServiceLocator func() services.AdapterStatusService
1313
func NewServiceLocator(env *environments.Env) ServiceLocator {
1414
return func() services.AdapterStatusService {
1515
return services.NewAdapterStatusService(
16-
dao.NewAdapterStatusDao(&env.Database.SessionFactory),
16+
dao.NewAdapterStatusDao(env.Database.SessionFactory),
1717
)
1818
}
1919
}

plugins/clusters/plugin.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ type ServiceLocator func() services.ClusterService
2424
func NewServiceLocator(env *environments.Env) ServiceLocator {
2525
return func() services.ClusterService {
2626
return services.NewClusterService(
27-
dao.NewClusterDao(&env.Database.SessionFactory),
28-
dao.NewNodePoolDao(&env.Database.SessionFactory),
27+
dao.NewClusterDao(env.Database.SessionFactory),
28+
dao.NewNodePoolDao(env.Database.SessionFactory),
2929
nodePools.Service(&env.Services),
30-
dao.NewAdapterStatusDao(&env.Database.SessionFactory),
30+
dao.NewAdapterStatusDao(env.Database.SessionFactory),
3131
env.Config.Adapters,
3232
)
3333
}

plugins/generic/plugin.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ type ServiceLocator func() services.GenericService
1212

1313
func NewServiceLocator(env *environments.Env) ServiceLocator {
1414
return func() services.GenericService {
15-
return services.NewGenericService(dao.NewGenericDao(&env.Database.SessionFactory))
15+
return services.NewGenericService(dao.NewGenericDao(env.Database.SessionFactory))
1616
}
1717
}
1818

plugins/nodePools/plugin.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ type ServiceLocator func() services.NodePoolService
2222
func NewServiceLocator(env *environments.Env) ServiceLocator {
2323
return func() services.NodePoolService {
2424
return services.NewNodePoolService(
25-
dao.NewNodePoolDao(&env.Database.SessionFactory),
26-
dao.NewClusterDao(&env.Database.SessionFactory),
27-
dao.NewAdapterStatusDao(&env.Database.SessionFactory),
25+
dao.NewNodePoolDao(env.Database.SessionFactory),
26+
dao.NewClusterDao(env.Database.SessionFactory),
27+
dao.NewAdapterStatusDao(env.Database.SessionFactory),
2828
env.Config.Adapters,
2929
generic.Service(&env.Services),
3030
)

0 commit comments

Comments
 (0)