Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions easy_rec/python/layers/input_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,12 +303,15 @@ def single_call_input_layer(self,
group_columns, group_seq_columns = feature_group.select_columns(
self._fc_parser)
cols_to_output_tensors = OrderedDict()
output_features = feature_column.input_layer(
features,
group_columns,
cols_to_output_tensors=cols_to_output_tensors,
feature_name_to_output_tensors=feature_name_to_output_tensors,
is_training=self._is_training)
if group_columns:
output_features = feature_column.input_layer(
features,
group_columns,
cols_to_output_tensors=cols_to_output_tensors,
feature_name_to_output_tensors=feature_name_to_output_tensors,
is_training=self._is_training)
else:
output_features = None

embedding_reg_lst = []
builder = feature_column._LazyBuilder(features)
Expand Down Expand Up @@ -349,13 +352,14 @@ def single_call_input_layer(self,
cols_to_output_tensors[column] = cnn_feature
else:
raise NotImplementedError
all_features = ([output_features] if output_features is not None else []) \
+ seq_features
if self._variational_dropout_config is not None:
features_dimension = OrderedDict([
(k.raw_name, int(v.shape[-1]))
for k, v in cols_to_output_tensors.items()
])
concat_features = array_ops.concat(
[output_features] + seq_features, axis=-1)
concat_features = array_ops.concat(all_features, axis=-1)
variational_dropout = variational_dropout_layer.VariationalDropoutLayer(
self._variational_dropout_config,
features_dimension,
Expand All @@ -365,8 +369,7 @@ def single_call_input_layer(self,
group_features = tf.split(
concat_features, list(features_dimension.values()), axis=-1)
else:
concat_features = array_ops.concat(
[output_features] + seq_features, axis=-1)
concat_features = array_ops.concat(all_features, axis=-1)
group_features = [cols_to_output_tensors[x] for x in group_columns] + \
[cols_to_output_tensors[x] for x in group_seq_columns]

Expand Down
Loading