Skip to content

Commit d4318ce

Browse files
authored
Merge pull request #271 from miaobin/op-convTranspose2d-dml
[DML] Add convTranspose2d op for DML backend
2 parents b030a61 + c54b1ea commit d4318ce

1 file changed

Lines changed: 203 additions & 2 deletions

File tree

src/webnn/native/dml/GraphDML.cpp

Lines changed: 203 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,6 @@ namespace webnn::native ::dml {
218218
std::vector<UINT> newFilterDims(4);
219219
switch (filterLayout) {
220220
case wnn::Conv2dFilterOperandLayout::Ohwi:
221-
newFilterDims.resize(4);
222221
newFilterDims[0] = filterDims[0];
223222
newFilterDims[1] = filterDims[3];
224223
newFilterDims[2] = filterDims[1];
@@ -243,6 +242,30 @@ namespace webnn::native ::dml {
243242
return newFilterDims;
244243
}
245244

245+
std::vector<UINT> transposeFilterDimensionsAsIohw(
246+
wnn::ConvTranspose2dFilterOperandLayout filterLayout,
247+
const std::vector<UINT>& filterDims) {
248+
std::vector<UINT> newFilterDims(4);
249+
switch (filterLayout) {
250+
case wnn::ConvTranspose2dFilterOperandLayout::Hwoi:
251+
newFilterDims[0] = filterDims[3];
252+
newFilterDims[1] = filterDims[2];
253+
newFilterDims[2] = filterDims[0];
254+
newFilterDims[3] = filterDims[1];
255+
break;
256+
case wnn::ConvTranspose2dFilterOperandLayout::Ohwi:
257+
newFilterDims[0] = filterDims[3];
258+
newFilterDims[1] = filterDims[0];
259+
newFilterDims[2] = filterDims[1];
260+
newFilterDims[3] = filterDims[2];
261+
break;
262+
default:
263+
DAWN_ASSERT(0);
264+
break;
265+
}
266+
return newFilterDims;
267+
}
268+
246269
std::vector<UINT> transposeFilterStridesAsOihw(wnn::Conv2dFilterOperandLayout filterLayout,
247270
const std::vector<UINT>& filterDims) {
248271
UINT hStride = 0, wStride = 0, iStride = 0, oStride = 0;
@@ -272,6 +295,30 @@ namespace webnn::native ::dml {
272295
return {oStride, iStride, hStride, wStride};
273296
}
274297

298+
std::vector<UINT> transposeFilterStridesAsIohw(
299+
wnn::ConvTranspose2dFilterOperandLayout filterLayout,
300+
const std::vector<UINT>& filterDims) {
301+
UINT hStride = 0, wStride = 0, iStride = 0, oStride = 0;
302+
switch (filterLayout) {
303+
case wnn::ConvTranspose2dFilterOperandLayout::Hwoi:
304+
hStride = filterDims[1] * filterDims[2] * filterDims[3];
305+
wStride = filterDims[2] * filterDims[3];
306+
oStride = filterDims[3];
307+
iStride = 1;
308+
break;
309+
case wnn::ConvTranspose2dFilterOperandLayout::Ohwi:
310+
oStride = filterDims[1] * filterDims[2] * filterDims[3];
311+
hStride = filterDims[2] * filterDims[3];
312+
wStride = filterDims[3];
313+
iStride = 1;
314+
break;
315+
default:
316+
DAWN_ASSERT(0);
317+
break;
318+
}
319+
return {iStride, oStride, hStride, wStride};
320+
}
321+
275322
template <typename T>
276323
std::vector<UINT> ImplicitPadding(const T* options,
277324
const std::vector<UINT>& inputDims,
@@ -1671,7 +1718,161 @@ namespace webnn::native ::dml {
16711718
}
16721719

16731720
MaybeError Graph::AddConvTranspose2d(const op::ConvTranspose2d* convTranspose2d) {
1674-
return DAWN_UNIMPLEMENTED_ERROR("ConvTranspose2D has not been supported on DirectML.");
1721+
auto inputsOperand = convTranspose2d->Inputs();
1722+
DAWN_ASSERT(inputsOperand.size() == 2 || inputsOperand.size() == 3);
1723+
DAWN_ASSERT(mGraphEdgesMap.find(inputsOperand[0].Get()) != mGraphEdgesMap.end());
1724+
DAWN_ASSERT(mGraphEdgesMap.find(inputsOperand[1].Get()) != mGraphEdgesMap.end());
1725+
1726+
auto inputEdge = mGraphEdgesMap[inputsOperand[0].Get()];
1727+
auto filterEdge = mGraphEdgesMap[inputsOperand[1].Get()];
1728+
1729+
auto inputDims = ConvertDimensions(inputsOperand[0].Get()->Shape());
1730+
auto filterDims = ConvertDimensions(inputsOperand[1].Get()->Shape());
1731+
std::vector<UINT> newInputDims = inputDims, newFilterDims = filterDims, newInputStrides,
1732+
newFilterStrides;
1733+
1734+
const ConvTranspose2dOptions* options = convTranspose2d->GetOptions();
1735+
1736+
DML_TENSOR_DESC inputTensorDesc = inputEdge->outputTensorDESC;
1737+
if (options->inputLayout == wnn::InputOperandLayout::Nhwc) {
1738+
newInputDims = transposeDimensions(NhwcToNchw, inputDims);
1739+
newInputStrides = transposeStridesToNchw(inputDims, inputTensorDesc);
1740+
1741+
std::shared_ptr<DmlTensorDesc> inputDmlTensorDesc(new DmlTensorDesc);
1742+
if (!CreateDmlTensorDesc(mDmlTensorsDesc, inputDmlTensorDesc,
1743+
&inputEdge->outputTensorDESC, newInputDims, newInputStrides)) {
1744+
return DAWN_INTERNAL_ERROR("Failed to create DML tensor description.");
1745+
}
1746+
inputTensorDesc = {DML_TENSOR_TYPE_BUFFER, &inputDmlTensorDesc->bufferDesc};
1747+
}
1748+
1749+
DML_TENSOR_DESC filterTensorDesc = filterEdge->outputTensorDESC;
1750+
if (options->filterLayout != wnn::ConvTranspose2dFilterOperandLayout::Iohw) {
1751+
newFilterDims = transposeFilterDimensionsAsIohw(options->filterLayout, filterDims);
1752+
newFilterStrides = transposeFilterStridesAsIohw(options->filterLayout, filterDims);
1753+
1754+
std::shared_ptr<DmlTensorDesc> filterDmlTensorDesc(new DmlTensorDesc);
1755+
if (!CreateDmlTensorDesc(mDmlTensorsDesc, filterDmlTensorDesc,
1756+
&filterEdge->outputTensorDESC, newFilterDims,
1757+
newFilterStrides)) {
1758+
return DAWN_INTERNAL_ERROR("Failed to create DML tensor description.");
1759+
}
1760+
filterTensorDesc = {DML_TENSOR_TYPE_BUFFER, &filterDmlTensorDesc->bufferDesc};
1761+
}
1762+
1763+
std::vector<std::shared_ptr<EdgeInfoBase>> inputEdges = {inputEdge, filterEdge};
1764+
1765+
const DML_TENSOR_DESC* biasTensorDescPtr = nullptr;
1766+
DML_TENSOR_DESC newBiasTensorDesc = {};
1767+
if (options->bias != nullptr) {
1768+
DAWN_ASSERT(mGraphEdgesMap.find(inputsOperand[2].Get()) != mGraphEdgesMap.end());
1769+
auto biasEdge = mGraphEdgesMap[inputsOperand[2].Get()];
1770+
auto biasDims = ConvertDimensions(convTranspose2d->Inputs()[2].Get()->Shape());
1771+
if (biasDims[0] != newFilterDims[0] || biasDims.size() != 1) {
1772+
return DAWN_INTERNAL_ERROR(
1773+
"The bias should be 1-D tensor with the shape of [output_channels].");
1774+
}
1775+
1776+
// Reshape bias from 1-D to 4-D for NCHW layout.
1777+
std::vector<UINT> newBiasDims = {1, biasDims[0], 1, 1};
1778+
std::shared_ptr<DmlTensorDesc> biasDmlTensorDesc(new DmlTensorDesc);
1779+
if (!CreateDmlTensorDesc(mDmlTensorsDesc, biasDmlTensorDesc,
1780+
&biasEdge->outputTensorDESC, newBiasDims)) {
1781+
return DAWN_INTERNAL_ERROR("Failed to create DML tensor description.");
1782+
}
1783+
newBiasTensorDesc = {DML_TENSOR_TYPE_BUFFER, &biasDmlTensorDesc->bufferDesc};
1784+
biasTensorDescPtr = &newBiasTensorDesc;
1785+
inputEdges.push_back(biasEdge);
1786+
}
1787+
1788+
std::vector<UINT> outputDims(4);
1789+
if (options->outputSizes != nullptr) {
1790+
std::vector<UINT> outputSizes;
1791+
outputSizes.assign(options->outputSizes,
1792+
options->outputSizes + options->outputSizesCount);
1793+
if (options->inputLayout == wnn::InputOperandLayout::Nchw) {
1794+
outputDims = {inputDims[0], newFilterDims[1], outputSizes[0], outputSizes[1]};
1795+
} else {
1796+
outputDims = {inputDims[0], outputSizes[0], outputSizes[1], newFilterDims[1]};
1797+
}
1798+
} else {
1799+
outputDims = ConvertDimensions(convTranspose2d->Outputs()[0]->Shape());
1800+
}
1801+
std::vector<UINT> newOutputDims = outputDims;
1802+
if (options->inputLayout == wnn::InputOperandLayout::Nhwc) {
1803+
newOutputDims = transposeDimensions(NhwcToNchw, outputDims);
1804+
}
1805+
std::shared_ptr<DmlTensorDesc> outputDmlTensorDesc(new DmlTensorDesc);
1806+
if (!CreateDmlTensorDesc(mDmlTensorsDesc, outputDmlTensorDesc, &inputEdge->outputTensorDESC,
1807+
newOutputDims, {}, true)) {
1808+
return DAWN_INTERNAL_ERROR("Failed to create DML tensor description.");
1809+
}
1810+
DML_TENSOR_DESC outputTensorDesc = {DML_TENSOR_TYPE_BUFFER,
1811+
&outputDmlTensorDesc->bufferDesc};
1812+
1813+
// FIXME(nhu): strides, dilations, padding should be uint32_t
1814+
// need to fix the spec.
1815+
std::vector<UINT> strides, dilations, outputPadding;
1816+
strides.assign(options->strides, options->strides + options->stridesCount);
1817+
dilations.assign(options->dilations, options->dilations + options->dilationsCount);
1818+
outputPadding.assign(options->outputPadding,
1819+
options->outputPadding + options->outputPaddingCount);
1820+
1821+
std::vector<UINT> padding(4);
1822+
if (options->autoPad == wnn::AutoPad::Explicit) {
1823+
padding = ExplicitPadding<ConvTranspose2dOptions>(options);
1824+
} else {
1825+
std::vector<UINT> inputSize = {inputDims[2], inputDims[3]};
1826+
std::vector<UINT> filterSize = {filterDims[2], filterDims[3]};
1827+
padding = webnn::native::utils::ComputeImplicitPaddingForConvTranspose2dAutoPad(
1828+
options, inputSize, filterSize);
1829+
}
1830+
std::vector<UINT> startPadding = {padding[0], padding[2]};
1831+
std::vector<UINT> endPadding = {padding[1], padding[3]};
1832+
1833+
DML_ACTIVATION_LINEAR_OPERATOR_DESC dmlActicationOperatorDesc{};
1834+
DML_OPERATOR_DESC dmlFusedOperatorDesc = {};
1835+
DML_OPERATOR_DESC* fusedActivation = CreateFusedOperator(
1836+
options->activation, dmlActicationOperatorDesc, dmlFusedOperatorDesc);
1837+
1838+
ComPtr<IDMLOperator> dmlOperator;
1839+
DML_CONVOLUTION_OPERATOR_DESC dmlSpecificOperatorDesc{};
1840+
dmlSpecificOperatorDesc.InputTensor = &inputTensorDesc;
1841+
dmlSpecificOperatorDesc.FilterTensor = &filterTensorDesc;
1842+
dmlSpecificOperatorDesc.BiasTensor = biasTensorDescPtr;
1843+
dmlSpecificOperatorDesc.OutputTensor = &outputTensorDesc;
1844+
1845+
dmlSpecificOperatorDesc.Mode = DML_CONVOLUTION_MODE_CONVOLUTION;
1846+
dmlSpecificOperatorDesc.Direction = DML_CONVOLUTION_DIRECTION_BACKWARD;
1847+
dmlSpecificOperatorDesc.DimensionCount = inputDims.size() - 2;
1848+
dmlSpecificOperatorDesc.Strides = strides.data();
1849+
dmlSpecificOperatorDesc.Dilations = dilations.data();
1850+
dmlSpecificOperatorDesc.StartPadding = startPadding.data();
1851+
dmlSpecificOperatorDesc.EndPadding = endPadding.data();
1852+
dmlSpecificOperatorDesc.OutputPadding = outputPadding.data();
1853+
dmlSpecificOperatorDesc.GroupCount = static_cast<UINT>(options->groups);
1854+
dmlSpecificOperatorDesc.FusedActivation = fusedActivation;
1855+
1856+
DML_OPERATOR_DESC dmlOperatorDesc = {};
1857+
dmlOperatorDesc.Type = DML_OPERATOR_CONVOLUTION;
1858+
dmlOperatorDesc.Desc = &dmlSpecificOperatorDesc;
1859+
WEBNN_CHECK(mDevice->CreateOperator(&dmlOperatorDesc, IID_PPV_ARGS(&dmlOperator)));
1860+
1861+
auto outputEdge = CreateEdgeFromThisNode(outputTensorDesc, mGraphDesc.NodeCount());
1862+
AddNodeAndEdgesToGraphDesc(mGraphDesc, inputEdges, dmlOperator);
1863+
1864+
// Transpose output from nchw->nhwc.
1865+
if (options->inputLayout == wnn::InputOperandLayout::Nhwc) {
1866+
if (TransposeOutputToNhwc(outputEdge, newOutputDims).IsError()) {
1867+
return DAWN_INTERNAL_ERROR("Failed to transpose output from Nchw to Nhwc.");
1868+
};
1869+
}
1870+
1871+
if (EmulateFusedOperator(options->activation, outputEdge, outputDims).IsError()) {
1872+
return DAWN_INTERNAL_ERROR("Failed to emulate fused operator.");
1873+
}
1874+
mGraphEdgesMap[convTranspose2d->PrimaryOutput()] = outputEdge;
1875+
return {};
16751876
}
16761877

16771878
MaybeError Graph::AddGru(const op::Gru* gru) {

0 commit comments

Comments
 (0)