Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified docs/images/qrcode/dinggroup1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/images/qrcode/dinggroup2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 2 additions & 0 deletions docs/source/feature/feature.rst
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ Sequence类特征格式一般为“XX\|XX\|XX”,如用户行为序列特征
- embedding_dim: embedding的dimension
- hash_bucket_size: 同离散值特征
- sub_feature_type: 用于描述序列特征里子特征的类型,目前支持 IdFeature 和 RawFeature 两种形式,默认为 IdFeature
- pad_sequence_length: 序列补齐或截断的长度
- max_seq_len: 最大序列长度,超过该长度的序列会被截断;当配置了`pad_sequence_length`时,`max_seq_len`会被忽略
- NOTE:SequenceFeature一般用在DIN算法或者BST算法里面。

在模型中可支持对序列特征使用Target Attention(DIN),方法如下:
Expand Down
3 changes: 1 addition & 2 deletions easy_rec/python/builders/loss_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,7 @@ def build_kd_loss(kds, prediction_dict, label_dict, feature_dict):
loss_dict = {}
for kd in kds:
assert kd.pred_name in prediction_dict, 'invalid predict_name: %s available ones: %s' % (
kd.pred_name,
','.join(prediction_dict.keys()))
kd.pred_name, ','.join(prediction_dict.keys()))

loss_name = kd.loss_name
if not loss_name:
Expand Down
8 changes: 8 additions & 0 deletions easy_rec/python/feature_column/feature_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,11 @@ def _cmp_embed_config(a, b):
config = self._share_embed_infos[embed_name]
max_seq_len = config.max_seq_len if config.HasField(
'max_seq_len') else -1
pad_sequence_length = config.pad_sequence_length if config.HasField(
'pad_sequence_length') else -1
for fc in share_embed_fcs:
fc.max_seq_length = max_seq_len
fc.pad_sequence_length = pad_sequence_length
self._deep_share_embed_columns[embed_name] = share_embed_fcs

# for handling wide share embedding columns
Expand All @@ -182,8 +185,11 @@ def _cmp_embed_config(a, b):
config = self._share_embed_infos[embed_name]
max_seq_len = config.max_seq_len if config.HasField(
'max_seq_len') else -1
pad_sequence_length = config.pad_sequence_length if config.HasField(
'pad_sequence_length') else -1
for fc in share_embed_fcs:
fc.max_seq_length = max_seq_len
fc.pad_sequence_length = pad_sequence_length
self._wide_share_embed_columns[embed_name] = share_embed_fcs

for fc_name in self._deep_columns:
Expand Down Expand Up @@ -647,6 +653,8 @@ def _add_deep_embedding_column(self, fc, config):
ev_params=ev_params)
fc.max_seq_length = config.max_seq_len if config.HasField(
'max_seq_len') else -1
fc.pad_sequence_length = config.pad_sequence_length if config.HasField(
'pad_sequence_length') else -1

if config.feature_type != config.SequenceFeature:
self._deep_columns[feature_name] = fc
Expand Down
6 changes: 5 additions & 1 deletion easy_rec/python/layers/input_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,11 @@ def get_sequence_feature(self, features, group_name):
with variable_scope.variable_scope('input_layer/' +
fc.categorical_column.name):
tmp_embedding, tmp_seq_len = fc._get_sequence_dense_tensor(builder)
if fc.max_seq_length > 0:
# If pad_sequence_length is set, pad or truncate to fixed length
if fc.pad_sequence_length > 0:
tmp_embedding, tmp_seq_len = shape_utils.pad_or_truncate_sequence(
tmp_embedding, tmp_seq_len, fc.pad_sequence_length)
elif fc.max_seq_length > 0:
tmp_embedding, tmp_seq_len = shape_utils.truncate_sequence(
tmp_embedding, tmp_seq_len, fc.max_seq_length)
seq_features.append((tmp_embedding, tmp_seq_len))
Expand Down
2 changes: 2 additions & 0 deletions easy_rec/python/protos/feature_config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ message FeatureConfig {
optional string seq_multi_sep = 101;
// truncate sequence data to max_seq_len
optional uint32 max_seq_len = 102;
// pad or truncate sequence to fixed length (enables padding for shorter sequences)
optional uint32 pad_sequence_length = 103;

optional string vocab_file = 11;
repeated string vocab_list = 12;
Expand Down
2 changes: 1 addition & 1 deletion easy_rec/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.

__version__ = '0.8.7'
__version__ = '0.8.8'
Loading