Skip to content

Commit 4221fbb

Browse files
BradPepersAMDclaude
andcommitted
Add miopenGetTensorDescriptorV2 and fix build errors
Add size_t-based tensor descriptor getter to the public API, matching the existing miopenSetTensorDescriptorV2. This avoids silent truncation of strides that exceed INT_MAX when using the int-based getter. - Add miopenGetTensorDescriptorV2() to miopen.h (beta API) and implement in tensor_api.cpp - Update tensor_utils::GetLengths/GetStrides to use V2 API and return vector<size_t> - Fix tensor_layout.hpp: add missing #include <numeric> - Fix tensor_holder.hpp: add size_t delegating constructor for layout+dims+strides - Fix dropout_gpu_emulator.hpp: define MAX_PRNG_STATE locally after removing miopen/dropout.hpp internal include - Add miopenHandle_t handle parameter to RNN/GRU/LSTM backward verify functions, needed by RunDropoutBackwardEmulator which now uses miopenGetDropoutDescriptor (public API) instead of miopen::deref() Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
1 parent 1cb4151 commit 4221fbb

11 files changed

Lines changed: 84 additions & 15 deletions

File tree

projects/miopen/common_utils/include/common_utils/tensor_utils.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,19 @@ inline miopenDataType_t GetType(miopenTensorDescriptor_t desc)
3232
return dt;
3333
}
3434

35-
inline std::vector<int> GetLengths(miopenTensorDescriptor_t desc)
35+
inline std::vector<size_t> GetLengths(miopenTensorDescriptor_t desc)
3636
{
3737
int ndim = GetNumDims(desc);
38-
std::vector<int> lens(ndim);
39-
miopenGetTensorDescriptor(desc, nullptr, lens.data(), nullptr);
38+
std::vector<size_t> lens(ndim);
39+
miopenGetTensorDescriptorV2(desc, nullptr, lens.data(), nullptr);
4040
return lens;
4141
}
4242

43-
inline std::vector<int> GetStrides(miopenTensorDescriptor_t desc)
43+
inline std::vector<size_t> GetStrides(miopenTensorDescriptor_t desc)
4444
{
4545
int ndim = GetNumDims(desc);
46-
std::vector<int> strides(ndim);
47-
miopenGetTensorDescriptor(desc, nullptr, nullptr, strides.data());
46+
std::vector<size_t> strides(ndim);
47+
miopenGetTensorDescriptorV2(desc, nullptr, nullptr, strides.data());
4848
return strides;
4949
}
5050

projects/miopen/driver/dropout_gpu_emulator.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@
3636
#include <cmath>
3737
#include <vector>
3838

39+
// Maximum PRNG states for dropout emulation (matches kernel definition).
40+
#ifndef MAX_PRNG_STATE
41+
#define MAX_PRNG_STATE (256 * 64)
42+
#endif
43+
3944
// disable __device__ qualifiers
4045
#ifdef FQUALIFIERS
4146
#error rocrand FQUALIFIERS defined externally, probably one of rocrand device header included prior to this

projects/miopen/driver/gru_verify_gemm.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,7 @@ void RunGRUBackwardDataGEMMCPUVerify(std::vector<Tref>& din_host,
889889
std::vector<Tref>& wkspace_host,
890890
bool use_dropout,
891891
miopenDropoutDescriptor_t dropoutDesc,
892+
miopenHandle_t handle,
892893
bool hx_is_null = false,
893894
bool dhy_is_null = false)
894895
{

projects/miopen/driver/lstm_verify_gemm.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,7 @@ void RunLSTMBackwardDataGEMMCPUVerify(
727727
std::vector<Tref>& wkspace_host,
728728
bool use_dropout,
729729
miopenDropoutDescriptor_t dropoutDesc,
730+
miopenHandle_t handle,
730731
bool cx_is_null = false,
731732
bool dhy_is_null = false,
732733
bool dcy_is_null = false)

projects/miopen/driver/rnn_driver.hpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1433,7 +1433,8 @@ int RNNDriver<Tgpu, Tref>::RunBackwardDataCPU()
14331433
reservespace_host,
14341434
workspace_host,
14351435
bool(inflags.GetValueInt("use_dropout")),
1436-
DropoutDesc);
1436+
DropoutDesc,
1437+
GetHandle());
14371438
}
14381439
else if(mode == miopenLSTM)
14391440
{
@@ -1461,7 +1462,8 @@ int RNNDriver<Tgpu, Tref>::RunBackwardDataCPU()
14611462
reservespace_host,
14621463
workspace_host,
14631464
bool(inflags.GetValueInt("use_dropout")),
1464-
DropoutDesc);
1465+
DropoutDesc,
1466+
GetHandle());
14651467
}
14661468
else if(mode == miopenGRU)
14671469
{
@@ -1486,7 +1488,8 @@ int RNNDriver<Tgpu, Tref>::RunBackwardDataCPU()
14861488
reservespace_host,
14871489
workspace_host,
14881490
bool(inflags.GetValueInt("use_dropout")),
1489-
DropoutDesc);
1491+
DropoutDesc,
1492+
GetHandle());
14901493
}
14911494
else
14921495
{

projects/miopen/driver/rnn_seq_driver.hpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1644,7 +1644,8 @@ int RNNSeqDriver<Tgpu, Tref>::RunBackwardDataCPU()
16441644
reservespace_host,
16451645
workspace_host,
16461646
bool(inflags.GetValueInt("use_dropout")),
1647-
DropoutDesc);
1647+
DropoutDesc,
1648+
GetHandle());
16481649
}
16491650
else if(mode == miopenLSTM)
16501651
{
@@ -1672,7 +1673,8 @@ int RNNSeqDriver<Tgpu, Tref>::RunBackwardDataCPU()
16721673
reservespace_host,
16731674
workspace_host,
16741675
bool(inflags.GetValueInt("use_dropout")),
1675-
DropoutDesc);
1676+
DropoutDesc,
1677+
GetHandle());
16761678
}
16771679
else if(mode == miopenGRU)
16781680
{
@@ -1697,7 +1699,8 @@ int RNNSeqDriver<Tgpu, Tref>::RunBackwardDataCPU()
16971699
reservespace_host,
16981700
workspace_host,
16991701
bool(inflags.GetValueInt("use_dropout")),
1700-
DropoutDesc);
1702+
DropoutDesc,
1703+
GetHandle());
17011704
}
17021705
else
17031706
{

projects/miopen/driver/rnn_verify_gemm.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,7 @@ void RunRNNBackwardDataGEMMCPUVerify(std::vector<Tref>& din_host,
600600
std::vector<Tref>& wkspace_host,
601601
bool use_dropout,
602602
miopenDropoutDescriptor_t dropoutDesc,
603+
miopenHandle_t handle,
603604
bool dhy_is_null = false)
604605
{
605606
/*

projects/miopen/include/miopen/miopen.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,25 @@ MIOPEN_EXPORT miopenStatus_t miopenGetTensorDescriptor(miopenTensorDescriptor_t
825825
int* dimsA,
826826
int* stridesA);
827827

828+
#ifdef MIOPEN_BETA_API
829+
/*! @brief Get the details of the tensor descriptor
830+
*
831+
* Returns the same information as miopenGetTensorDescriptor() but uses size_t
832+
* arrays, matching miopenSetTensorDescriptorV2(). This avoids truncation for
833+
* tensors whose strides exceed INT_MAX.
834+
*
835+
* @param tensorDesc Tensor descriptor (input)
836+
* @param dataType MIOpen datatype (output)
837+
* @param dimsA Array containing the size of dimensions (output)
838+
* @param stridesA Array containing the size of stride (output)
839+
* @return miopenStatus_t
840+
*/
841+
MIOPEN_EXPORT miopenStatus_t miopenGetTensorDescriptorV2(miopenTensorDescriptor_t tensorDesc,
842+
miopenDataType_t* dataType,
843+
size_t* dimsA,
844+
size_t* stridesA);
845+
#endif
846+
828847
/*! @brief Destroys the tensor descriptor
829848
*
830849
* @param tensorDesc Tensor descriptor (input)

projects/miopen/miopen_utils/include/miopen_utils/tensor_holder.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,15 @@ struct tensor
180180

181181
template <class X>
182182
tensor(miopenTensorLayout_t layout, const std::vector<X>& dims, const std::vector<X>& strides)
183+
: tensor(layout,
184+
std::vector<std::size_t>(dims.begin(), dims.end()),
185+
std::vector<std::size_t>(strides.begin(), strides.end()))
186+
{
187+
}
188+
189+
tensor(miopenTensorLayout_t layout,
190+
const std::vector<std::size_t>& dims,
191+
const std::vector<std::size_t>& strides)
183192
: desc(miopen_type<T>{}, layout, dims, strides), data(desc.GetElementSpace())
184193
{
185194
assert(dims.size() == strides.size());

projects/miopen/src/include/miopen/tensor_layout.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,12 @@
2727
#define GUARD_TENSOR_LAYOUT_HPP
2828

2929
#include <miopen/errors.hpp>
30-
#include <map>
3130
#include <algorithm>
32-
#include <vector>
33-
#include <string>
3431
#include <iterator>
32+
#include <map>
33+
#include <numeric>
34+
#include <string>
35+
#include <vector>
3536

3637
namespace miopen {
3738

0 commit comments

Comments
 (0)