Add support for conv1d, Phase 3#2941
Conversation
| // h' = hi * g + (ci % g**2) // g | ||
| // w' = wi * g + (ci % g) | ||
| // 3. Flatten (gW, gH, C) into idx_out = (c * g * h) * w' + (c) * h' + c' | ||
| // FIXME why are these interchanged??? |
There was a problem hiding this comment.
It seems like h and w were exchanged in the definitions of hOut and wOut
There was a problem hiding this comment.
oh no! @mdgrs could you link this into an issue? i don't want to block this but I'm also noticing an error in flattening the final output in the ISL string.
asraa
left a comment
There was a problem hiding this comment.
Thank you! I'll take a look - could you try rebasing? Something strange happened in the diff where it appears it deletes some recent changes
| // h' = hi * g + (ci % g**2) // g | ||
| // w' = wi * g + (ci % g) | ||
| // 3. Flatten (gW, gH, C) into idx_out = (c * g * h) * w' + (c) * h' + c' | ||
| // FIXME why are these interchanged??? |
| int64_t filterSize = filterType.getDimSize(2); | ||
| int64_t outputW = (dataSize + 2 * padding - filterSize) / stride + 1; | ||
|
|
||
| auto rowInterchangeRelation = get1dConvRowInterchangeRelation( |
There was a problem hiding this comment.
Sanity check: you want this because the cross-channel filters are stacked a certain way that, without the row interchange, requires more diagonals.
I admit it took us a while to derive the conv2d version, and I am struggling to remember how this interacted with the linalg canonicalization lowering it to a loop (which is also present in the PR for conv1d). @asraa didn't we end up having a problem with converting it to a loop, and didn't we ultimately resort to a dedicated lowering for the Conv2DNchwFchwOp as a whole?
There was a problem hiding this comment.
I added a test to highlights how the number of diagonals is reduced with the interchange
For the linalg canonicalization, I saw that currently we have for loops that loop over the extra dimensions of Conv2dNchwFchwOp and rewrite it as a series of Conv2dOp. I did the same for Conv1dNcwFcw.
There was a problem hiding this comment.
We don't want to lower to a loop anymore generally because preserving the multi-channel input/output let's us arrange each kernel into blocks of a larger toeplitz matrix - so I don't expect that we'll be wanting to rewrite as a loop of single convolutions often.
The other "problem" with the loop approach is that the conv_nd operations don't support strides so we can't use that to represent pooling / downsampling!
fix annotate overwrite
| }; | ||
|
|
||
| // Lower linalg.conv_1d_ncw_fcw to a loop of linalg.conv_1d operations. | ||
| struct LowerConv1DNcwFcw |
There was a problem hiding this comment.
see my comment earlier, but I think we might not actualyl want this since you already have the ncw_fcw kernels written. in the end to end pipeline, this pass runs very early in the pipeline and that means later passes will only see the linalg.conv_1d ops
| // h' = hi * g + (ci % g**2) // g | ||
| // w' = wi * g + (ci % g) | ||
| // 3. Flatten (gW, gH, C) into idx_out = (c * g * h) * w' + (c) * h' + c' | ||
| // FIXME why are these interchanged??? |
There was a problem hiding this comment.
oh no! @mdgrs could you link this into an issue? i don't want to block this but I'm also noticing an error in flattening the final output in the ISL string.
This builds on the previous work #2919 (and that previous work is the first commit)
It uses the work on Conv2d_nchw_fchw as a blueprint to add support for Con1d_ncw_fcw. In my tests, this is the linalg operation that my pytorch conv1d layer gets lowered to.
It makes conv1d turn green for #2923