Skip to content

Commit 1840799

Browse files
committed
Metadata inference working in pipeline mode.
Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
1 parent ca4d476 commit 1840799

38 files changed

Lines changed: 731 additions & 285 deletions

dali/operators/bbox/bb_flip.cc

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,7 @@ system, that is 0.0-1.0)code")
3737
1, true)
3838
.AddOptionalArg("vertical",
3939
R"code(Flip vertical dimension.)code",
40-
0, true)
41-
.OutputDType(0, [](const OpSpec &, span<const DALIDataType> in) { return in[0]; })
42-
.OutputNdim(0, [](const OpSpec &, span<const int> in) { return in[0]; })
43-
.OutputLayout(0, [](const OpSpec &, span<const TensorLayout> in) { return in[0]; });
40+
0, true);
4441

4542
void BbFlipCPU::RunImpl(Workspace &ws) {
4643
const auto &input = ws.Input<CPUBackend>(0);

dali/operators/decoder/image_decoder.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,9 @@ Please note that GPU acceleration for JPEG 2000 decoding is only available for C
192192
.NumOutput(1)
193193
.AddParent("ImageDecoderAttr")
194194
.AddParent("CachedDecoderAttr")
195-
.OutputDType(0, [](const OpSpec &, span<const DALIDataType>) { return DALI_UINT8; })
196-
.OutputNdim(0, [](const OpSpec &, span<const int>) { return 3; })
197-
.OutputLayout(0, [](const OpSpec &, span<const TensorLayout>) { return "HWC"; });
195+
.OutputDType(0, DALI_UINT8)
196+
.OutputNDim(0, 3)
197+
.OutputLayout(0, "HWC");
198198

199199
// Fused
200200

@@ -313,9 +313,9 @@ of the slice (s0, s1, s2, …).
313313
Integer coordinates are interpreted as absolute coordinates, while float coordinates can be
314314
interpreted as absolute or relative coordinates, depending on the value of
315315
`normalized_shape`.)code")
316-
.OutputDType(0, [](const OpSpec &, span<const DALIDataType>) { return DALI_UINT8; })
317-
.OutputNdim(0, [](const OpSpec &, span<const int>) { return 3; })
318-
.OutputLayout(0, [](const OpSpec &, span<const TensorLayout>) { return "HWC"; });
316+
.OutputDType(0, DALI_UINT8)
317+
.OutputNDim(0, 3)
318+
.OutputLayout(0, "HWC");
319319

320320

321321
// Deprecated aliases

dali/operators/generic/cast.cc

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,7 @@ DALI_SCHEMA(Cast)
7171
.NumOutput(1)
7272
.AllowSequences()
7373
.SupportVolumetric()
74-
.AddTypeArg("dtype", R"code(Output data type.)code")
75-
.OutputDType(0, [](const OpSpec &spec, span<const DALIDataType>) {
76-
return spec.GetArgument<DALIDataType>("dtype");
77-
})
78-
.OutputNdim(0, [](const OpSpec &, span<const int> in) { return in[0]; })
79-
.OutputLayout(0, [](const OpSpec &, span<const TensorLayout> in) { return in[0]; });
74+
.AddTypeArg("dtype", R"code(Output data type.)code");
8075

8176
DALI_SCHEMA(CastLike)
8277
.DocStr("Cast the first tensor to the type of the second tensor.")

dali/operators/generic/expand_dims.cc

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,47 @@ layout will be empty.")code")
4343
.AddOptionalArg("new_axis_names", R"code(Names of the new dimensions in the data layout.
4444
4545
The length of `new_axis_names` must match the length of `axes`.
46-
If argument isn't be provided, the layout will be cleared.)code", TensorLayout(""));
46+
If argument isn't be provided, the layout will be cleared.)code", TensorLayout(""))
47+
.OutputNDim(0, [](const OpSpec &spec)->std::optional<int> {
48+
auto &desc = spec.InputDesc(0);
49+
if (!desc.ndim)
50+
return std::nullopt;
51+
return *desc.ndim + spec.GetRepeatedArgument<int>("axes").size();
52+
})
53+
.OutputLayout(0, [](const OpSpec &spec)->std::optional<TensorLayout> {
54+
auto &desc = spec.InputDesc(0);
55+
if (!desc.layout)
56+
return std::nullopt;
57+
58+
auto axes = spec.GetRepeatedArgument<int>("axes");
59+
if (axes.empty())
60+
return desc.layout;
61+
62+
auto names = spec.GetArgument<TensorLayout>("axis_names");
63+
int num_new_axes = ssize(axes);
64+
if (num_new_axes != names.ndim())
65+
return "";
66+
67+
SmallVector<std::pair<int, char>, 6> ind_with_layout;
68+
for (size_t i = 0; i < axes.size(); i++) {
69+
ind_with_layout.push_back({ i, names[i] });
70+
}
71+
std::sort(ind_with_layout.begin(), ind_with_layout.end());
72+
73+
TensorLayout out_layout = "";
74+
int out_ndim = desc.layout->ndim() + names.ndim();
75+
int src_axis = 0;
76+
int new_axis = 0;
77+
for (int j = 0; j < out_ndim; j++) {
78+
if (new_axis < num_new_axes && axes[new_axis] == j) { // inserting new axis
79+
out_layout += names[new_axis++];
80+
} else {
81+
assert(src_axis < desc.layout->ndim());
82+
out_layout += (*desc.layout)[src_axis++];
83+
}
84+
}
85+
return out_layout;
86+
});
4787

4888
template <typename Backend>
4989
ExpandDims<Backend>::ExpandDims(const OpSpec &spec)

dali/operators/generic/flip.cc

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,7 @@ and depthwise).)code")
3434
.AddOptionalArg("depthwise", R"code(Flip the depthwise dimension.)code", 0, true)
3535
.InputLayout({"FDHWC", "FHWC", "DHWC", "HWC", "FCDHW", "FCHW", "CDHW", "CHW"})
3636
.AllowSequences()
37-
.SupportVolumetric()
38-
.OutputDType(0, [](const OpSpec &, span<const DALIDataType> in) { return in[0]; })
39-
.OutputNdim(0, [](const OpSpec &, span<const int> in) { return in[0]; })
40-
.OutputLayout(0, [](const OpSpec &, span<const TensorLayout> in) { return in[0]; });
37+
.SupportVolumetric();
4138

4239

4340
template <>

dali/operators/generic/join.cc

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,32 @@ constructed by inserting that character into the input layout at the position in
5858
For example, specifying ``axis = 0`` and ``axis_name = "C"`` with input layout "HW" will yield
5959
the output layout "CHW")", nullptr, false)
6060
.NumInput(1, 999)
61-
.NumOutput(1);
61+
.NumOutput(1)
62+
.OutputNDim(0, [](const OpSpec &spec)->std::optional<int> {
63+
std::optional<int> ndim;
64+
for (int i = 0; i < spec.NumInput(); i++)
65+
if (spec.InputDesc(i).ndim) { // any input will do - they must have the same ndim
66+
ndim = spec.InputDesc(i).ndim;
67+
break;
68+
}
69+
if (ndim)
70+
return *ndim + 1;
71+
else
72+
return std::nullopt;
73+
})
74+
.OutputLayout(0, [](const OpSpec &spec)->std::optional<TensorLayout> {
75+
std::string new_axis_name;
76+
if (!spec.TryGetArgument(new_axis_name, "axis_name") || new_axis_name.length() != 1)
77+
return std::nullopt;
78+
int axis = spec.GetArgument<int>("axis");
79+
for (int i = 0; i < spec.NumInput(); i++) {
80+
auto &desc = spec.InputDesc(i);
81+
if (!desc.layout || desc.layout->empty())
82+
continue;
83+
return desc.layout->sub(0, axis) + new_axis_name + desc.layout->sub(axis);
84+
}
85+
return std::nullopt;
86+
});
6287

6388
#define TENSOR_JOIN_TYPES (bool, uint8_t, int8_t, uint16_t, int16_t, uint32_t, int32_t, \
6489
uint64_t, int64_t, float16, float, double)

dali/operators/generic/reshape.cc

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2019-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -25,6 +25,16 @@
2525

2626
namespace dali {
2727

28+
inline std::optional<int> ReshapeNDimFunc(const OpSpec &spec) {
29+
std::vector<int> shape;
30+
if (spec.TryGetRepeatedArgument(shape, "shape"))
31+
return shape.size();
32+
std::vector<float> rel_shape;
33+
if (spec.TryGetRepeatedArgument(rel_shape, "rel_shape"))
34+
return rel_shape.size();
35+
return std::nullopt;
36+
}
37+
2838
DALI_SCHEMA(Reshape)
2939
.DocStr(R"code(Treats content of the input as if it had a different shape and/or layout.
3040
@@ -94,7 +104,8 @@ extents in `rel_shape` describe to the target dimensions. In the example above,
94104
``rel_shape = [-1, 0.5, 2]`` would result in the output shape ``[1, 100, 600]``.
95105
96106
All indices must be in the range of valid dimensions of the input, or -1.)code",
97-
nullptr, true);
107+
nullptr, true)
108+
.OutputNDim(0, ReshapeNDimFunc);
98109

99110
DALI_SCHEMA(Reinterpret)
100111
.DocStr(R"(Treats content of the input as if it had a different type, shape, and/or layout.

dali/operators/generic/shapes.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2019-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -26,7 +26,9 @@ DALI_SCHEMA(Shapes)
2626
.AddOptionalTypeArg("dtype", "Data type to which the sizes are converted.", DALI_INT64)
2727
.DeprecateArgInFavorOf("type", "dtype", "0.27")
2828
.MakeDocHidden()
29-
.Deprecate("1.44", "", "Use :meth:`nvidia.dali.pipeline.DataNode.shape` instead.");
29+
.Deprecate("1.44", "", "Use :meth:`nvidia.dali.pipeline.DataNode.shape` instead.")
30+
.OutputNDim(0, 1)
31+
.OutputLayout(0, std::nullopt);
3032

3133
DALI_SCHEMA(_Shape)
3234
.DocStr(R"(Returns the shapes of tensors in the input batch.
@@ -39,7 +41,9 @@ INTERNAL ONLY; used by DataNode.shape()
3941
.AllowSequences()
4042
.SupportVolumetric()
4143
.MakeDocHidden()
42-
.AddOptionalTypeArg("dtype", "Data type to which the sizes are converted.", DALI_INT64);
44+
.AddOptionalTypeArg("dtype", "Data type to which the sizes are converted.", DALI_INT64)
45+
.OutputNDim(0, 1)
46+
.OutputLayout(0, std::nullopt);
4347

4448
DALI_REGISTER_OPERATOR(Shapes, Shapes<CPUBackend>, CPU);
4549
DALI_REGISTER_OPERATOR(Shapes, Shapes<GPUBackend>, GPU);

dali/operators/generic/slice/subscript.cc

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2021-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -64,7 +64,20 @@ DALI_SCHEMA(_TensorSubscript)
6464
.INDEX_ARGS(28)
6565
.INDEX_ARGS(29)
6666
.INDEX_ARGS(30)
67-
.INDEX_ARGS(31);
67+
.INDEX_ARGS(31)
68+
.OutputNDim(0, [](const OpSpec &spec)->std::optional<int> {
69+
auto &input_desc = spec.InputDesc(0);
70+
if (!input_desc.ndim.has_value())
71+
return std::nullopt;
72+
int ndim = *input_desc.ndim;
73+
for (int i = 0; i < kMaxSubscripts; i++) {
74+
if (spec.ArgumentDefined(make_string("at_", i)))
75+
ndim--;
76+
}
77+
if (ndim < 0)
78+
return std::nullopt;
79+
return ndim;
80+
});
6881

6982
template <>
7083
template <int ndim, int element_size>

dali/operators/image/color/brightness_contrast.cc

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,7 @@ This operator can also change the type of data.)code")
8989
.NumOutput(1)
9090
.AllowSequences()
9191
.SupportVolumetric()
92-
.InputLayout({"FHWC", "DHWC", "HWC"})
93-
.OutputDType(0, [](const OpSpec &spec, span<const DALIDataType> in) {
94-
DALIDataType dtype;
95-
if (spec.TryGetArgument(dtype, "dtype"))
96-
return dtype;
97-
return in[0];
98-
})
99-
.OutputNdim(0, [](const OpSpec &, span<const int> in) { return in[0]; })
100-
.OutputLayout(0, [](const OpSpec &, span<const TensorLayout> in) { return in[0]; });
92+
.InputLayout({"FHWC", "DHWC", "HWC"});
10193

10294
DALI_REGISTER_OPERATOR(BrightnessContrast, BrightnessContrastCpu, CPU)
10395
DALI_REGISTER_OPERATOR(Brightness, BrightnessContrastCpu, CPU);

0 commit comments

Comments
 (0)