Skip to content

[WIP] Implements RNNT+MMI#1030

Draft
pkufool wants to merge 9 commits into
k2-fsa:masterfrom
pkufool:rnnt_mmi
Draft

[WIP] Implements RNNT+MMI#1030
pkufool wants to merge 9 commits into
k2-fsa:masterfrom
pkufool:rnnt_mmi

Conversation

@pkufool
Copy link
Copy Markdown
Collaborator

@pkufool pkufool commented Aug 9, 2022

It runs normally in my self-constructed test case, not fully tested yet, though.

The sampled paths:

sampled_paths = torch.tensor([ [ [ 3, 5, 0, 4, 6, 0, 2, 1 ],
                                 [ 2, 0, 5, 4, 0, 6, 1, 2 ],
                                 [ 3, 5, 2, 0, 0, 1, 6, 4 ]],
                               [ [ 7, 0, 4, 0, 6, 0, 3, 0 ],
                                 [ 0, 7, 3, 0, 2, 0, 4, 5 ],
                                 [ 7, 0, 3, 4, 0, 1, 2, 0 ]]], dtype=torch.int32)

The corresponding lattice:
image

image

Note: There is an arc from state 2 to state 17 in the second lattice, because the last symbol of the second path of second sequence is sampled at frame 1, it is a simulation of reaching final frame.

Copy link
Copy Markdown
Collaborator Author

@pkufool pkufool left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danpovey Do you have any good idea to test this function, I can only think of constructing simple test cases.

Comment thread k2/csrc/fsa_algo.cu Outdated
repeat_num = us_row_splits1_data[us_idx0 + 1] -
us_row_splits1_data[us_idx0];

arc.score = -logf(1 - powf(1 - sampling_prob, repeat_num));
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only include the "predictor" head output in C++ part, the other two scores (i.e. hybrid output and lm_output) will add on python part, it would be easier to enable autograd for hybrid output.

Comment thread k2/python/k2/fsa_algo.py Outdated
a_value = getattr(lattice, "scores")
# Enable autograd for path_scores
b_value = index_select(path_scores.flatten(), arc_map)
value = a_value + b_value
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

path_scores here will contain hybrid_output and detached lm_output. I include the path_scores here and enable antograd to path_scores.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, OK. Right, we treat those as differentiable, but the negated sampling_prob is treated as just a constant.

# index == 0 means the sampled symbol is blank
t_mask = index == 0
# t_index = torch.where(t_mask, t_index + 1, t_index)
t_index = t_index + 1
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use regular RNN-T, it is possible to generate too many symbols for a specific frame, and that might be chances to generate a lattice containing cycles, which is not expected. I am not sure whether we will encounter such a issue at the very beginning of training.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, a valid point. Yes, computing forward backward scores would not work correctly if there are cycles. One possibility would be to augment the state with a sub-frame, i.e. instead of (ctx, t) it becomes (ctx, t, sub_t) with sub_t = (0, 1, 2, ..). That would prevent cycles, although it might prevent a small number of paths from recombining that might otherwise be able to recombine.

@pkufool pkufool marked this pull request as draft August 17, 2022 02:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants