@@ -218,7 +218,6 @@ namespace webnn::native ::dml {
218218 std::vector<UINT> newFilterDims (4 );
219219 switch (filterLayout) {
220220 case wnn::Conv2dFilterOperandLayout::Ohwi:
221- newFilterDims.resize (4 );
222221 newFilterDims[0 ] = filterDims[0 ];
223222 newFilterDims[1 ] = filterDims[3 ];
224223 newFilterDims[2 ] = filterDims[1 ];
@@ -243,6 +242,30 @@ namespace webnn::native ::dml {
243242 return newFilterDims;
244243 }
245244
245+ std::vector<UINT> transposeFilterDimensionsAsIohw (
246+ wnn::ConvTranspose2dFilterOperandLayout filterLayout,
247+ const std::vector<UINT>& filterDims) {
248+ std::vector<UINT> newFilterDims (4 );
249+ switch (filterLayout) {
250+ case wnn::ConvTranspose2dFilterOperandLayout::Hwoi:
251+ newFilterDims[0 ] = filterDims[3 ];
252+ newFilterDims[1 ] = filterDims[2 ];
253+ newFilterDims[2 ] = filterDims[0 ];
254+ newFilterDims[3 ] = filterDims[1 ];
255+ break ;
256+ case wnn::ConvTranspose2dFilterOperandLayout::Ohwi:
257+ newFilterDims[0 ] = filterDims[3 ];
258+ newFilterDims[1 ] = filterDims[0 ];
259+ newFilterDims[2 ] = filterDims[1 ];
260+ newFilterDims[3 ] = filterDims[2 ];
261+ break ;
262+ default :
263+ DAWN_ASSERT (0 );
264+ break ;
265+ }
266+ return newFilterDims;
267+ }
268+
246269 std::vector<UINT> transposeFilterStridesAsOihw (wnn::Conv2dFilterOperandLayout filterLayout,
247270 const std::vector<UINT>& filterDims) {
248271 UINT hStride = 0 , wStride = 0 , iStride = 0 , oStride = 0 ;
@@ -272,6 +295,30 @@ namespace webnn::native ::dml {
272295 return {oStride, iStride, hStride, wStride};
273296 }
274297
298+ std::vector<UINT> transposeFilterStridesAsIohw (
299+ wnn::ConvTranspose2dFilterOperandLayout filterLayout,
300+ const std::vector<UINT>& filterDims) {
301+ UINT hStride = 0 , wStride = 0 , iStride = 0 , oStride = 0 ;
302+ switch (filterLayout) {
303+ case wnn::ConvTranspose2dFilterOperandLayout::Hwoi:
304+ hStride = filterDims[1 ] * filterDims[2 ] * filterDims[3 ];
305+ wStride = filterDims[2 ] * filterDims[3 ];
306+ oStride = filterDims[3 ];
307+ iStride = 1 ;
308+ break ;
309+ case wnn::ConvTranspose2dFilterOperandLayout::Ohwi:
310+ oStride = filterDims[1 ] * filterDims[2 ] * filterDims[3 ];
311+ hStride = filterDims[2 ] * filterDims[3 ];
312+ wStride = filterDims[3 ];
313+ iStride = 1 ;
314+ break ;
315+ default :
316+ DAWN_ASSERT (0 );
317+ break ;
318+ }
319+ return {iStride, oStride, hStride, wStride};
320+ }
321+
275322 template <typename T>
276323 std::vector<UINT> ImplicitPadding (const T* options,
277324 const std::vector<UINT>& inputDims,
@@ -1671,7 +1718,161 @@ namespace webnn::native ::dml {
16711718 }
16721719
16731720 MaybeError Graph::AddConvTranspose2d (const op::ConvTranspose2d* convTranspose2d) {
1674- return DAWN_UNIMPLEMENTED_ERROR (" ConvTranspose2D has not been supported on DirectML." );
1721+ auto inputsOperand = convTranspose2d->Inputs ();
1722+ DAWN_ASSERT (inputsOperand.size () == 2 || inputsOperand.size () == 3 );
1723+ DAWN_ASSERT (mGraphEdgesMap .find (inputsOperand[0 ].Get ()) != mGraphEdgesMap .end ());
1724+ DAWN_ASSERT (mGraphEdgesMap .find (inputsOperand[1 ].Get ()) != mGraphEdgesMap .end ());
1725+
1726+ auto inputEdge = mGraphEdgesMap [inputsOperand[0 ].Get ()];
1727+ auto filterEdge = mGraphEdgesMap [inputsOperand[1 ].Get ()];
1728+
1729+ auto inputDims = ConvertDimensions (inputsOperand[0 ].Get ()->Shape ());
1730+ auto filterDims = ConvertDimensions (inputsOperand[1 ].Get ()->Shape ());
1731+ std::vector<UINT> newInputDims = inputDims, newFilterDims = filterDims, newInputStrides,
1732+ newFilterStrides;
1733+
1734+ const ConvTranspose2dOptions* options = convTranspose2d->GetOptions ();
1735+
1736+ DML_TENSOR_DESC inputTensorDesc = inputEdge->outputTensorDESC ;
1737+ if (options->inputLayout == wnn::InputOperandLayout::Nhwc) {
1738+ newInputDims = transposeDimensions (NhwcToNchw, inputDims);
1739+ newInputStrides = transposeStridesToNchw (inputDims, inputTensorDesc);
1740+
1741+ std::shared_ptr<DmlTensorDesc> inputDmlTensorDesc (new DmlTensorDesc);
1742+ if (!CreateDmlTensorDesc (mDmlTensorsDesc , inputDmlTensorDesc,
1743+ &inputEdge->outputTensorDESC , newInputDims, newInputStrides)) {
1744+ return DAWN_INTERNAL_ERROR (" Failed to create DML tensor description." );
1745+ }
1746+ inputTensorDesc = {DML_TENSOR_TYPE_BUFFER, &inputDmlTensorDesc->bufferDesc };
1747+ }
1748+
1749+ DML_TENSOR_DESC filterTensorDesc = filterEdge->outputTensorDESC ;
1750+ if (options->filterLayout != wnn::ConvTranspose2dFilterOperandLayout::Iohw) {
1751+ newFilterDims = transposeFilterDimensionsAsIohw (options->filterLayout , filterDims);
1752+ newFilterStrides = transposeFilterStridesAsIohw (options->filterLayout , filterDims);
1753+
1754+ std::shared_ptr<DmlTensorDesc> filterDmlTensorDesc (new DmlTensorDesc);
1755+ if (!CreateDmlTensorDesc (mDmlTensorsDesc , filterDmlTensorDesc,
1756+ &filterEdge->outputTensorDESC , newFilterDims,
1757+ newFilterStrides)) {
1758+ return DAWN_INTERNAL_ERROR (" Failed to create DML tensor description." );
1759+ }
1760+ filterTensorDesc = {DML_TENSOR_TYPE_BUFFER, &filterDmlTensorDesc->bufferDesc };
1761+ }
1762+
1763+ std::vector<std::shared_ptr<EdgeInfoBase>> inputEdges = {inputEdge, filterEdge};
1764+
1765+ const DML_TENSOR_DESC* biasTensorDescPtr = nullptr ;
1766+ DML_TENSOR_DESC newBiasTensorDesc = {};
1767+ if (options->bias != nullptr ) {
1768+ DAWN_ASSERT (mGraphEdgesMap .find (inputsOperand[2 ].Get ()) != mGraphEdgesMap .end ());
1769+ auto biasEdge = mGraphEdgesMap [inputsOperand[2 ].Get ()];
1770+ auto biasDims = ConvertDimensions (convTranspose2d->Inputs ()[2 ].Get ()->Shape ());
1771+ if (biasDims[0 ] != newFilterDims[0 ] || biasDims.size () != 1 ) {
1772+ return DAWN_INTERNAL_ERROR (
1773+ " The bias should be 1-D tensor with the shape of [output_channels]." );
1774+ }
1775+
1776+ // Reshape bias from 1-D to 4-D for NCHW layout.
1777+ std::vector<UINT> newBiasDims = {1 , biasDims[0 ], 1 , 1 };
1778+ std::shared_ptr<DmlTensorDesc> biasDmlTensorDesc (new DmlTensorDesc);
1779+ if (!CreateDmlTensorDesc (mDmlTensorsDesc , biasDmlTensorDesc,
1780+ &biasEdge->outputTensorDESC , newBiasDims)) {
1781+ return DAWN_INTERNAL_ERROR (" Failed to create DML tensor description." );
1782+ }
1783+ newBiasTensorDesc = {DML_TENSOR_TYPE_BUFFER, &biasDmlTensorDesc->bufferDesc };
1784+ biasTensorDescPtr = &newBiasTensorDesc;
1785+ inputEdges.push_back (biasEdge);
1786+ }
1787+
1788+ std::vector<UINT> outputDims (4 );
1789+ if (options->outputSizes != nullptr ) {
1790+ std::vector<UINT> outputSizes;
1791+ outputSizes.assign (options->outputSizes ,
1792+ options->outputSizes + options->outputSizesCount );
1793+ if (options->inputLayout == wnn::InputOperandLayout::Nchw) {
1794+ outputDims = {inputDims[0 ], newFilterDims[1 ], outputSizes[0 ], outputSizes[1 ]};
1795+ } else {
1796+ outputDims = {inputDims[0 ], outputSizes[0 ], outputSizes[1 ], newFilterDims[1 ]};
1797+ }
1798+ } else {
1799+ outputDims = ConvertDimensions (convTranspose2d->Outputs ()[0 ]->Shape ());
1800+ }
1801+ std::vector<UINT> newOutputDims = outputDims;
1802+ if (options->inputLayout == wnn::InputOperandLayout::Nhwc) {
1803+ newOutputDims = transposeDimensions (NhwcToNchw, outputDims);
1804+ }
1805+ std::shared_ptr<DmlTensorDesc> outputDmlTensorDesc (new DmlTensorDesc);
1806+ if (!CreateDmlTensorDesc (mDmlTensorsDesc , outputDmlTensorDesc, &inputEdge->outputTensorDESC ,
1807+ newOutputDims, {}, true )) {
1808+ return DAWN_INTERNAL_ERROR (" Failed to create DML tensor description." );
1809+ }
1810+ DML_TENSOR_DESC outputTensorDesc = {DML_TENSOR_TYPE_BUFFER,
1811+ &outputDmlTensorDesc->bufferDesc };
1812+
1813+ // FIXME(nhu): strides, dilations, padding should be uint32_t
1814+ // need to fix the spec.
1815+ std::vector<UINT> strides, dilations, outputPadding;
1816+ strides.assign (options->strides , options->strides + options->stridesCount );
1817+ dilations.assign (options->dilations , options->dilations + options->dilationsCount );
1818+ outputPadding.assign (options->outputPadding ,
1819+ options->outputPadding + options->outputPaddingCount );
1820+
1821+ std::vector<UINT> padding (4 );
1822+ if (options->autoPad == wnn::AutoPad::Explicit) {
1823+ padding = ExplicitPadding<ConvTranspose2dOptions>(options);
1824+ } else {
1825+ std::vector<UINT> inputSize = {inputDims[2 ], inputDims[3 ]};
1826+ std::vector<UINT> filterSize = {filterDims[2 ], filterDims[3 ]};
1827+ padding = webnn::native::utils::ComputeImplicitPaddingForConvTranspose2dAutoPad (
1828+ options, inputSize, filterSize);
1829+ }
1830+ std::vector<UINT> startPadding = {padding[0 ], padding[2 ]};
1831+ std::vector<UINT> endPadding = {padding[1 ], padding[3 ]};
1832+
1833+ DML_ACTIVATION_LINEAR_OPERATOR_DESC dmlActicationOperatorDesc{};
1834+ DML_OPERATOR_DESC dmlFusedOperatorDesc = {};
1835+ DML_OPERATOR_DESC* fusedActivation = CreateFusedOperator (
1836+ options->activation , dmlActicationOperatorDesc, dmlFusedOperatorDesc);
1837+
1838+ ComPtr<IDMLOperator> dmlOperator;
1839+ DML_CONVOLUTION_OPERATOR_DESC dmlSpecificOperatorDesc{};
1840+ dmlSpecificOperatorDesc.InputTensor = &inputTensorDesc;
1841+ dmlSpecificOperatorDesc.FilterTensor = &filterTensorDesc;
1842+ dmlSpecificOperatorDesc.BiasTensor = biasTensorDescPtr;
1843+ dmlSpecificOperatorDesc.OutputTensor = &outputTensorDesc;
1844+
1845+ dmlSpecificOperatorDesc.Mode = DML_CONVOLUTION_MODE_CONVOLUTION;
1846+ dmlSpecificOperatorDesc.Direction = DML_CONVOLUTION_DIRECTION_BACKWARD;
1847+ dmlSpecificOperatorDesc.DimensionCount = inputDims.size () - 2 ;
1848+ dmlSpecificOperatorDesc.Strides = strides.data ();
1849+ dmlSpecificOperatorDesc.Dilations = dilations.data ();
1850+ dmlSpecificOperatorDesc.StartPadding = startPadding.data ();
1851+ dmlSpecificOperatorDesc.EndPadding = endPadding.data ();
1852+ dmlSpecificOperatorDesc.OutputPadding = outputPadding.data ();
1853+ dmlSpecificOperatorDesc.GroupCount = static_cast <UINT>(options->groups );
1854+ dmlSpecificOperatorDesc.FusedActivation = fusedActivation;
1855+
1856+ DML_OPERATOR_DESC dmlOperatorDesc = {};
1857+ dmlOperatorDesc.Type = DML_OPERATOR_CONVOLUTION;
1858+ dmlOperatorDesc.Desc = &dmlSpecificOperatorDesc;
1859+ WEBNN_CHECK (mDevice ->CreateOperator (&dmlOperatorDesc, IID_PPV_ARGS (&dmlOperator)));
1860+
1861+ auto outputEdge = CreateEdgeFromThisNode (outputTensorDesc, mGraphDesc .NodeCount ());
1862+ AddNodeAndEdgesToGraphDesc (mGraphDesc , inputEdges, dmlOperator);
1863+
1864+ // Transpose output from nchw->nhwc.
1865+ if (options->inputLayout == wnn::InputOperandLayout::Nhwc) {
1866+ if (TransposeOutputToNhwc (outputEdge, newOutputDims).IsError ()) {
1867+ return DAWN_INTERNAL_ERROR (" Failed to transpose output from Nchw to Nhwc." );
1868+ };
1869+ }
1870+
1871+ if (EmulateFusedOperator (options->activation , outputEdge, outputDims).IsError ()) {
1872+ return DAWN_INTERNAL_ERROR (" Failed to emulate fused operator." );
1873+ }
1874+ mGraphEdgesMap [convTranspose2d->PrimaryOutput ()] = outputEdge;
1875+ return {};
16751876 }
16761877
16771878 MaybeError Graph::AddGru (const op::Gru* gru) {
0 commit comments