Skip to content

Commit 198636e

Browse files
committed
refactor: rename SampleWeights to Weights in ctr.Dataset
1 parent f07f4e1 commit 198636e

2 files changed

Lines changed: 9 additions & 9 deletions

File tree

model/ctr/data.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ type Dataset struct {
152152
PositiveCount int
153153
NegativeCount int
154154
// Weight support
155-
SampleWeights []float32 // Computed weight for each sample (set by tasks)
155+
Weights []float32 // Computed weight for each sample (set by tasks)
156156
}
157157

158158
// CountUsers returns the number of users.
@@ -259,8 +259,8 @@ func (dataset *Dataset) Get(i int) ([]int32, []float32, [][]float32, float32) {
259259
// GetWeight returns the weight for the i-th sample.
260260
// Returns 1.0 if no weight is set (default behavior).
261261
func (dataset *Dataset) GetWeight(i int) float32 {
262-
if dataset.SampleWeights != nil && i < len(dataset.SampleWeights) {
263-
return dataset.SampleWeights[i]
262+
if dataset.Weights != nil && i < len(dataset.Weights) {
263+
return dataset.Weights[i]
264264
}
265265
return 1.0
266266
}
@@ -368,8 +368,8 @@ func (dataset *Dataset) Split(ratio float32, seed int64) (*Dataset, *Dataset) {
368368
}
369369
testSet.Target = append(testSet.Target, dataset.Target[i])
370370

371-
if dataset.SampleWeights != nil {
372-
testSet.SampleWeights = append(testSet.SampleWeights, dataset.SampleWeights[i])
371+
if dataset.Weights != nil {
372+
testSet.Weights = append(testSet.Weights, dataset.Weights[i])
373373
}
374374
if dataset.Target[i] > 0 {
375375
testSet.PositiveCount++
@@ -385,8 +385,8 @@ func (dataset *Dataset) Split(ratio float32, seed int64) (*Dataset, *Dataset) {
385385
}
386386
trainSet.Target = append(trainSet.Target, dataset.Target[i])
387387

388-
if dataset.SampleWeights != nil {
389-
trainSet.SampleWeights = append(trainSet.SampleWeights, dataset.SampleWeights[i])
388+
if dataset.Weights != nil {
389+
trainSet.Weights = append(trainSet.Weights, dataset.Weights[i])
390390
}
391391
if dataset.Target[i] > 0 {
392392
trainSet.PositiveCount++

model/ctr/data_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ func TestDataset_GetWeight(t *testing.T) {
200200

201201
t.Run("with weights", func(t *testing.T) {
202202
dataset := &Dataset{
203-
SampleWeights: []float32{1.0, 2.0, 3.0},
203+
Weights: []float32{1.0, 2.0, 3.0},
204204
}
205205
assert.Equal(t, float32(1.0), dataset.GetWeight(0))
206206
assert.Equal(t, float32(2.0), dataset.GetWeight(1))
@@ -209,7 +209,7 @@ func TestDataset_GetWeight(t *testing.T) {
209209

210210
t.Run("out of range returns 1.0", func(t *testing.T) {
211211
dataset := &Dataset{
212-
SampleWeights: []float32{1.0},
212+
Weights: []float32{1.0},
213213
}
214214
assert.Equal(t, float32(1.0), dataset.GetWeight(100))
215215
})

0 commit comments

Comments
 (0)