Skip to content

Commit 1b5fe40

Browse files
authored
Fix DeepFM dense-only loading, DNN save tracking, and FNN logits (#555)
- Avoid empty FM branches in dense-only DeepFM models - Build DNN activation layers without mutating tracked lists - Remove the WDL-style linear logit from FNN - Replace test asserts flagged by Codacy - Add targeted regression tests
1 parent a827877 commit 1b5fe40

8 files changed

Lines changed: 65 additions & 26 deletions

File tree

deepctr/estimator/models/fnn.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
"""
99
import tensorflow as tf
1010

11-
from ..feature_column import get_linear_logit, input_from_feature_columns
11+
from ..feature_column import input_from_feature_columns
1212
from ..utils import deepctr_model_fn, DNN_SCOPE_NAME, variable_scope
1313
from ...layers.core import DNN
1414
from ...layers.utils import combined_dnn_input
@@ -20,11 +20,11 @@ def FNNEstimator(linear_feature_columns, dnn_feature_columns, dnn_hidden_units=(
2020
dnn_optimizer='Adagrad', training_chief_hooks=None):
2121
"""Instantiates the Factorization-supported Neural Network architecture.
2222
23-
:param linear_feature_columns: An iterable containing all the features used by linear part of the model.
23+
:param linear_feature_columns: An iterable containing features kept for API compatibility.
2424
:param dnn_feature_columns: An iterable containing all the features used by deep part of the model.
2525
:param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of deep net
2626
:param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector
27-
:param l2_reg_linear: float. L2 regularizer strength applied to linear weight
27+
:param l2_reg_linear: float. Kept for API compatibility.
2828
:param l2_reg_dnn: float . L2 regularizer strength applied to DNN
2929
:param seed: integer ,to use as random seed.
3030
:param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate.
@@ -47,8 +47,6 @@ def FNNEstimator(linear_feature_columns, dnn_feature_columns, dnn_hidden_units=(
4747
def _model_fn(features, labels, mode, config):
4848
train_flag = (mode == tf.estimator.ModeKeys.TRAIN)
4949

50-
linear_logits = get_linear_logit(features, linear_feature_columns, l2_reg_linear=l2_reg_linear)
51-
5250
with variable_scope(DNN_SCOPE_NAME):
5351
sparse_embedding_list, dense_value_list = input_from_feature_columns(features, dnn_feature_columns,
5452
l2_reg_embedding=l2_reg_embedding)
@@ -57,9 +55,7 @@ def _model_fn(features, labels, mode, config):
5755
dnn_logit = tf.keras.layers.Dense(
5856
1, use_bias=False, kernel_initializer=tf.keras.initializers.glorot_normal(seed))(deep_out)
5957

60-
logits = linear_logits + dnn_logit
61-
62-
return deepctr_model_fn(features, mode, logits, labels, task, linear_optimizer, dnn_optimizer,
58+
return deepctr_model_fn(features, mode, dnn_logit, labels, task, linear_optimizer, dnn_optimizer,
6359
training_chief_hooks=training_chief_hooks)
6460

6561
return tf.estimator.Estimator(_model_fn, model_dir=model_dir, config=config)

deepctr/layers/core.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,10 @@ def build(self, input_shape):
179179
self.dropout_layers = [Dropout(self.dropout_rate, seed=self.seed + i) for i in
180180
range(len(self.hidden_units))]
181181

182-
self.activation_layers = [activation_layer(self.activation) for _ in range(len(self.hidden_units))]
183-
184-
if self.output_activation:
185-
self.activation_layers[-1] = activation_layer(self.output_activation)
182+
self.activation_layers = [
183+
activation_layer(
184+
self.output_activation if i == len(self.hidden_units) - 1 and self.output_activation else self.activation)
185+
for i in range(len(self.hidden_units))]
186186

187187
super(DNN, self).build(input_shape) # Be sure to call this somewhere!
188188

deepctr/layers/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,14 +142,14 @@ def build(self, input_shape):
142142
trainable=True)
143143
if self.mode == 1:
144144
self.kernel = self.add_weight(
145-
'linear_kernel',
145+
name='linear_kernel',
146146
shape=[int(input_shape[-1]), 1],
147147
initializer=glorot_normal(self.seed),
148148
regularizer=l2(self.l2_reg),
149149
trainable=True)
150150
elif self.mode == 2:
151151
self.kernel = self.add_weight(
152-
'linear_kernel',
152+
name='linear_kernel',
153153
shape=[int(input_shape[1][-1]), 1],
154154
initializer=glorot_normal(self.seed),
155155
regularizer=l2(self.l2_reg),

deepctr/models/deepfm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,15 @@ def DeepFM(linear_feature_columns, dnn_feature_columns, fm_group=(DEFAULT_GROUP_
5050
group_embedding_dict, dense_value_list = input_from_feature_columns(features, dnn_feature_columns, l2_reg_embedding,
5151
seed, support_group=True)
5252

53-
fm_logit = add_func([FM()(concat_func(v, axis=1))
54-
for k, v in group_embedding_dict.items() if k in fm_group])
53+
fm_logit_list = [FM()(concat_func(v, axis=1))
54+
for k, v in group_embedding_dict.items() if k in fm_group]
5555

5656
dnn_input = combined_dnn_input(list(chain.from_iterable(
5757
group_embedding_dict.values())), dense_value_list)
5858
dnn_output = DNN(dnn_hidden_units, dnn_activation, l2_reg_dnn, dnn_dropout, dnn_use_bn, seed=seed)(dnn_input)
5959
dnn_logit = Dense(1, use_bias=False)(dnn_output)
6060

61-
final_logit = add_func([linear_logit, fm_logit, dnn_logit])
61+
final_logit = add_func([linear_logit, dnn_logit] + fm_logit_list)
6262

6363
output = PredictionLayer(task)(final_logit)
6464
model = Model(inputs=inputs_list, outputs=output)

deepctr/models/fnn.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,21 @@
99
from tensorflow.keras.models import Model
1010
from tensorflow.keras.layers import Dense
1111

12-
from ..feature_column import build_input_features, get_linear_logit, input_from_feature_columns
12+
from ..feature_column import build_input_features, input_from_feature_columns
1313
from ..layers.core import PredictionLayer, DNN
14-
from ..layers.utils import add_func, combined_dnn_input
14+
from ..layers.utils import combined_dnn_input
1515

1616

1717
def FNN(linear_feature_columns, dnn_feature_columns, dnn_hidden_units=(256, 128, 64),
1818
l2_reg_embedding=1e-5, l2_reg_linear=1e-5, l2_reg_dnn=0, seed=1024, dnn_dropout=0,
1919
dnn_activation='relu', task='binary'):
2020
"""Instantiates the Factorization-supported Neural Network architecture.
2121
22-
:param linear_feature_columns: An iterable containing all the features used by linear part of the model.
22+
:param linear_feature_columns: An iterable containing features kept for API compatibility.
2323
:param dnn_feature_columns: An iterable containing all the features used by deep part of the model.
2424
:param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of deep net
2525
:param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector
26-
:param l2_reg_linear: float. L2 regularizer strength applied to linear weight
26+
:param l2_reg_linear: float. Kept for API compatibility.
2727
:param l2_reg_dnn: float . L2 regularizer strength applied to DNN
2828
:param seed: integer ,to use as random seed.
2929
:param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate.
@@ -36,18 +36,14 @@ def FNN(linear_feature_columns, dnn_feature_columns, dnn_hidden_units=(256, 128,
3636

3737
inputs_list = list(features.values())
3838

39-
linear_logit = get_linear_logit(features, linear_feature_columns, seed=seed, prefix='linear',
40-
l2_reg=l2_reg_linear)
41-
4239
sparse_embedding_list, dense_value_list = input_from_feature_columns(features, dnn_feature_columns,
4340
l2_reg_embedding, seed)
4441

4542
dnn_input = combined_dnn_input(sparse_embedding_list, dense_value_list)
4643
deep_out = DNN(dnn_hidden_units, dnn_activation, l2_reg_dnn, dnn_dropout, False, seed=seed)(dnn_input)
4744
dnn_logit = Dense(1, use_bias=False)(deep_out)
48-
final_logit = add_func([dnn_logit, linear_logit])
4945

50-
output = PredictionLayer(task)(final_logit)
46+
output = PredictionLayer(task)(dnn_logit)
5147

5248
model = Model(inputs=inputs_list, outputs=output)
5349
return model

tests/layers/core_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,15 @@ def test_DNN(hidden_units, use_bn):
4343
BATCH_SIZE, EMBEDDING_SIZE))
4444

4545

46+
def test_DNN_output_activation():
47+
with CustomObjectScope({'DNN': layers.DNN}):
48+
x = tf.keras.layers.Input(shape=(EMBEDDING_SIZE,))
49+
y = layers.DNN((10,), output_activation='sigmoid')(x)
50+
model = tf.keras.models.Model(x, y)
51+
if model.output_shape != (None, 10):
52+
raise AssertionError("Unexpected DNN output shape")
53+
54+
4655
@pytest.mark.parametrize(
4756
'task,use_bias',
4857
[(task, use_bias)

tests/models/DeepFM_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
import pytest
2+
import numpy as np
3+
import os
4+
import tempfile
5+
from tensorflow.keras.models import load_model, save_model
26

7+
from deepctr.feature_column import DenseFeat
8+
from deepctr.layers import custom_objects
39
from deepctr.models import DeepFM
410
from ..utils import check_model, get_test_data, SAMPLE_SIZE, get_test_data_estimator, check_estimator, TEST_Estimator
511

@@ -21,6 +27,27 @@ def test_DeepFM(hidden_size, sparse_feature_num):
2127
check_model(model, model_name, x, y)
2228

2329

30+
def test_DeepFM_dense_only_model_io():
31+
sample_size = SAMPLE_SIZE
32+
feature_columns = [DenseFeat('dense_feature_' + str(i), 1) for i in range(2)]
33+
x = {fc.name: np.random.random(sample_size) for fc in feature_columns}
34+
y = np.random.randint(0, 2, (sample_size, 1))
35+
36+
model = DeepFM(feature_columns, feature_columns, dnn_hidden_units=(4,), dnn_dropout=0)
37+
model.compile('adam', 'binary_crossentropy',
38+
metrics=['binary_crossentropy'])
39+
model.fit(x, y, batch_size=4, epochs=1, validation_split=0.5)
40+
41+
fd = tempfile.NamedTemporaryFile(suffix='.h5', delete=False)
42+
model_path = fd.name
43+
fd.close()
44+
try:
45+
save_model(model, model_path)
46+
load_model(model_path, custom_objects)
47+
finally:
48+
os.remove(model_path)
49+
50+
2451
@pytest.mark.parametrize(
2552
'hidden_size,sparse_feature_num',
2653
[

tests/models/FNN_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
import tensorflow as tf
33

4+
from deepctr.feature_column import DenseFeat, SparseFeat
45
from deepctr.models import FNN
56
from ..utils import check_model, get_test_data, SAMPLE_SIZE, get_test_data_estimator, check_estimator, TEST_Estimator
67

@@ -23,6 +24,16 @@ def test_FNN(sparse_feature_num, dense_feature_num):
2324
check_model(model, model_name, x, y)
2425

2526

27+
def test_FNN_does_not_add_wide_linear_logit():
28+
feature_columns = [SparseFeat('sparse_feature', 4, embedding_dim=4),
29+
DenseFeat('dense_feature', 1)]
30+
31+
model = FNN(feature_columns, feature_columns, dnn_hidden_units=(4,), dnn_dropout=0)
32+
33+
if not all(layer.__class__.__name__ != 'Linear' for layer in model.layers):
34+
raise AssertionError("FNN should not include a wide Linear layer")
35+
36+
2637
# @pytest.mark.parametrize(
2738
# 'sparse_feature_num,dense_feature_num',
2839
# [(0, 1), (1, 0)

0 commit comments

Comments
 (0)