1- // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+ // Copyright (c) 2021-2024 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22//
33// Licensed under the Apache License, Version 2.0 (the "License");
44// you may not use this file except in compliance with the License.
@@ -23,6 +23,7 @@ DALI_SCHEMA(GetProperty)
2323
2424The type of the output will depend on the ``key`` of the requested property.)code" )
2525 .NumInput(1 )
26+ .InputDevice(0 , InputDevice::Metadata)
2627 .NumOutput(1 )
2728 .AddArg(" key" ,
2829 R"code( Specifies, which property is requested.
@@ -38,6 +39,110 @@ The following properties are supported:
3839)code" ,
3940 DALI_STRING);
4041
42+ template <typename Backend, typename SampleShapeFunc, typename CopySampleFunc>
43+ void GetPerSample (TensorList<CPUBackend> &out, const TensorList<Backend> &in,
44+ SampleShapeFunc &&sample_shape, CopySampleFunc &©_sample) {
45+ int N = in.num_samples ();
46+ TensorListShape<> tls;
47+ for (int i = 0 ; i < N; i++) {
48+ auto shape = sample_shape (in, i);
49+ if (i == 0 )
50+ tls.resize (N, shape.sample_dim ());
51+ tls.set_tensor_shape (i, shape);
52+ }
53+ out.Resize (tls, DALI_UINT8);
54+ for (int i = 0 ; i < N; i++) {
55+ copy_sample (out, in, i);
56+ }
57+ }
58+
59+ template <typename Backend>
60+ void SourceInfoToTL (TensorList<CPUBackend> &out, const TensorList<Backend> &in) {
61+ GetPerSample (out, in,
62+ [](auto &in, int idx) {
63+ auto &info = in.GetMeta (idx).GetSourceInfo ();
64+ return TensorShape<1 >(info.length ());
65+ },
66+ [](auto &out, auto &in, int idx) {
67+ auto &info = in.GetMeta (idx).GetSourceInfo ();
68+ std::memcpy (out.raw_mutable_tensor (idx), info.c_str (), info.length ());
69+ });
70+ }
71+
72+ template <typename Backend>
73+ void SourceInfoToTL (TensorList<GPUBackend> &out, const TensorList<Backend> &in) {
74+ TensorList<CPUBackend> tmp;
75+ tmp.set_pinned (true );
76+ SourceInfoToTL (tmp, in);
77+ tmp.set_order (out.order ());
78+ out.Copy (tmp);
79+ }
80+
81+ template <typename OutputBackend>
82+ void SourceInfoToTL (TensorList<OutputBackend> &out, const Workspace &ws) {
83+ ws.Output <OutputBackend>(0 ).set_order (ws.output_order ());
84+ if (ws.InputIsType <CPUBackend>(0 ))
85+ return SourceInfoToTL (out, ws.Input <CPUBackend>(0 ));
86+ else if (ws.InputIsType <GPUBackend>(0 ))
87+ return SourceInfoToTL (out, ws.Input <GPUBackend>(0 ));
88+ else
89+ DALI_FAIL (" Internal error - input 0 is neither CPU nor GPU." );
90+ }
91+
92+ template <typename Backend>
93+ void RepeatTensor (TensorList<Backend> &tl, const Tensor<Backend> &t, int N) {
94+ tl.Reset ();
95+ tl.set_device_id (t.device_id ());
96+ tl.SetSize (N);
97+ tl.set_sample_dim (t.ndim ());
98+ tl.set_type (t.type ());
99+ tl.SetLayout (t.GetLayout ());
100+ for (int i = 0 ; i < N; i++)
101+ tl.SetSample (i, t);
102+ }
103+
104+ template <typename Backend>
105+ void RepeatFirstSample (TensorList<Backend> &tl, int N) {
106+ Tensor<Backend> t;
107+ TensorShape<> shape = tl[0 ].shape ();
108+ t.ShareData (unsafe_sample_owner (tl, 0 ), shape.num_elements (), tl.is_pinned (),
109+ shape, tl.type (), tl.device_id (), tl.order ());
110+ t.SetMeta (tl.GetMeta (0 ));
111+ RepeatTensor (tl, t, N);
112+ }
113+
114+ void LayoutToTL (TensorList<CPUBackend> &out, const Workspace &ws) {
115+ TensorLayout l = ws.GetInputLayout (0 );
116+ out.Resize (uniform_list_shape (1 , { l.size () }), DALI_UINT8);
117+ memcpy (out.raw_mutable_tensor (0 ), l.data (), l.size ());
118+ RepeatFirstSample (out, ws.GetInputBatchSize (0 ));
119+ }
120+
121+ void LayoutToTL (TensorList<GPUBackend> &out, const Workspace &ws) {
122+ TensorLayout l = ws.GetInputLayout (0 );
123+ Tensor<CPUBackend> tmp_cpu;
124+ Tensor<GPUBackend> tmp_gpu;
125+ tmp_cpu.Resize (TensorShape<1 >(l.size ()), DALI_UINT8);
126+ memcpy (tmp_cpu.raw_mutable_data (), l.data (), l.size ());
127+ tmp_cpu.set_order (ws.output_order ());
128+ tmp_gpu.set_order (ws.output_order ());
129+ tmp_gpu.Copy (tmp_cpu);
130+
131+ RepeatTensor (out, tmp_gpu, ws.GetInputBatchSize (0 ));
132+ }
133+
134+ template <typename Backend>
135+ auto GetProperty<Backend>::GetPropertyReader(std::string_view key) -> PropertyReader {
136+ if (key == " source_info" ) {
137+ return static_cast <PropertyReaderFunc &>(SourceInfoToTL<Backend>);
138+ } else if (key == " layout" ) {
139+ return static_cast <PropertyReaderFunc &>(LayoutToTL);
140+ } else {
141+ DALI_FAIL (make_string (" Unsupported property key: " , key));
142+ }
143+ }
144+
145+
41146DALI_REGISTER_OPERATOR (GetProperty, GetProperty<CPUBackend>, CPU)
42147DALI_REGISTER_OPERATOR (GetProperty, GetProperty<GPUBackend>, GPU)
43148
0 commit comments