I'm considering to rewrite the whole search concept on returnn-common side.
I think this is actually not too difficult.
Every function should treat unhandled dims just like the batch dim. So having an additional dim for the beam should be no problem at all. So we don't need to merge the beam into the batch dim. This simplifies many things and also makes it much cleaner. If there is any function which does not behave this way, we really should fix this, in any case, independent of this proposal here. This is a fundamental concept of the building blocks of returnn-common.
We now have the very generic top_k function (#140, #143) which can operate over multiple axes together.
So consider an input [batch, beam, classes], top_k can operate on [beam, classes]. I then returns two indices, one for the (source) beam, one for the classes.
The output is of shape [batch, k]. We can set k as the new beam.
Combining [batch, old_beam, dim] + [batch, new_beam, dim] would not directly work like it works in RETURNN. But we can introduce an explicit function (translate_beam or so, analogue to how it is done in RETURNN), and then the user needs to take care about this explicitly.
If this is in a loop, and you probably later want to traceback through the beam indices to construct the K output sequences. This needs some further manual explicit logic.
I'm considering to rewrite the whole search concept on returnn-common side.
I think this is actually not too difficult.
Every function should treat unhandled dims just like the batch dim. So having an additional dim for the beam should be no problem at all. So we don't need to merge the beam into the batch dim. This simplifies many things and also makes it much cleaner. If there is any function which does not behave this way, we really should fix this, in any case, independent of this proposal here. This is a fundamental concept of the building blocks of returnn-common.
We now have the very generic
top_kfunction (#140, #143) which can operate over multiple axes together.So consider an input [batch, beam, classes],
top_kcan operate on [beam, classes]. I then returns two indices, one for the (source) beam, one for the classes.The output is of shape [batch, k]. We can set k as the new beam.
Combining [batch, old_beam, dim] + [batch, new_beam, dim] would not directly work like it works in RETURNN. But we can introduce an explicit function (
translate_beamor so, analogue to how it is done in RETURNN), and then the user needs to take care about this explicitly.If this is in a loop, and you probably later want to traceback through the beam indices to construct the K output sequences. This needs some further manual explicit logic.