Skip to content

Commit cbb4355

Browse files
authored
fix bug of empty deep column (#552)
* fix bug of empty deep column
1 parent f27e6df commit cbb4355

1 file changed

Lines changed: 13 additions & 10 deletions

File tree

easy_rec/python/layers/input_layer.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -303,12 +303,15 @@ def single_call_input_layer(self,
303303
group_columns, group_seq_columns = feature_group.select_columns(
304304
self._fc_parser)
305305
cols_to_output_tensors = OrderedDict()
306-
output_features = feature_column.input_layer(
307-
features,
308-
group_columns,
309-
cols_to_output_tensors=cols_to_output_tensors,
310-
feature_name_to_output_tensors=feature_name_to_output_tensors,
311-
is_training=self._is_training)
306+
if group_columns:
307+
output_features = feature_column.input_layer(
308+
features,
309+
group_columns,
310+
cols_to_output_tensors=cols_to_output_tensors,
311+
feature_name_to_output_tensors=feature_name_to_output_tensors,
312+
is_training=self._is_training)
313+
else:
314+
output_features = None
312315

313316
embedding_reg_lst = []
314317
builder = feature_column._LazyBuilder(features)
@@ -349,13 +352,14 @@ def single_call_input_layer(self,
349352
cols_to_output_tensors[column] = cnn_feature
350353
else:
351354
raise NotImplementedError
355+
all_features = ([output_features] if output_features is not None else []) \
356+
+ seq_features
352357
if self._variational_dropout_config is not None:
353358
features_dimension = OrderedDict([
354359
(k.raw_name, int(v.shape[-1]))
355360
for k, v in cols_to_output_tensors.items()
356361
])
357-
concat_features = array_ops.concat(
358-
[output_features] + seq_features, axis=-1)
362+
concat_features = array_ops.concat(all_features, axis=-1)
359363
variational_dropout = variational_dropout_layer.VariationalDropoutLayer(
360364
self._variational_dropout_config,
361365
features_dimension,
@@ -365,8 +369,7 @@ def single_call_input_layer(self,
365369
group_features = tf.split(
366370
concat_features, list(features_dimension.values()), axis=-1)
367371
else:
368-
concat_features = array_ops.concat(
369-
[output_features] + seq_features, axis=-1)
372+
concat_features = array_ops.concat(all_features, axis=-1)
370373
group_features = [cols_to_output_tensors[x] for x in group_columns] + \
371374
[cols_to_output_tensors[x] for x in group_seq_columns]
372375

0 commit comments

Comments
 (0)