[WIP] Implements RNNT+MMI#1030
Conversation
| repeat_num = us_row_splits1_data[us_idx0 + 1] - | ||
| us_row_splits1_data[us_idx0]; | ||
|
|
||
| arc.score = -logf(1 - powf(1 - sampling_prob, repeat_num)); |
There was a problem hiding this comment.
I only include the "predictor" head output in C++ part, the other two scores (i.e. hybrid output and lm_output) will add on python part, it would be easier to enable autograd for hybrid output.
| a_value = getattr(lattice, "scores") | ||
| # Enable autograd for path_scores | ||
| b_value = index_select(path_scores.flatten(), arc_map) | ||
| value = a_value + b_value |
There was a problem hiding this comment.
path_scores here will contain hybrid_output and detached lm_output. I include the path_scores here and enable antograd to path_scores.
There was a problem hiding this comment.
Yes, OK. Right, we treat those as differentiable, but the negated sampling_prob is treated as just a constant.
| # index == 0 means the sampled symbol is blank | ||
| t_mask = index == 0 | ||
| # t_index = torch.where(t_mask, t_index + 1, t_index) | ||
| t_index = t_index + 1 |
There was a problem hiding this comment.
If we use regular RNN-T, it is possible to generate too many symbols for a specific frame, and that might be chances to generate a lattice containing cycles, which is not expected. I am not sure whether we will encounter such a issue at the very beginning of training.
There was a problem hiding this comment.
Hm, a valid point. Yes, computing forward backward scores would not work correctly if there are cycles. One possibility would be to augment the state with a sub-frame, i.e. instead of (ctx, t) it becomes (ctx, t, sub_t) with sub_t = (0, 1, 2, ..). That would prevent cycles, although it might prevent a small number of paths from recombining that might otherwise be able to recombine.
It runs normally in my self-constructed test case, not fully tested yet, though.
The sampled paths:
The corresponding lattice:

Note: There is an arc from state 2 to state 17 in the second lattice, because the last symbol of the second path of second sequence is sampled at frame 1, it is a simulation of reaching final frame.