Skip to content

Commit b3931e0

Browse files
mergennachinclaude
andcommitted
Update torch pin to nightly 2.12.0.dev20260410
Bumps torch/torchvision/torchaudio/torchcodec to the 2026-04-10 nightly, updates the PyTorch commit pin and grafts the corresponding c10/headeronly headers, and restores the guarding_hint_or_throw fallback in exir/sym_util.py so shape_env.size_hint is only used when the newer API is unavailable. Co-authored-by: Claude <noreply@anthropic.com>
1 parent 8999865 commit b3931e0

11 files changed

Lines changed: 122 additions & 37 deletions

File tree

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
659af3c353e49b35c191cdd2dba3b3c79d0e6822
1+
7527b8c5c21e98eaa88ba6a1e86a56c7019beb9e

.ci/scripts/test_model_e2e.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ if [ "$AUDIO_URL" != "" ]; then
260260
elif [[ "$MODEL_NAME" == *whisper* ]] || [ "$MODEL_NAME" = "voxtral_realtime" ]; then
261261
conda install -y -c conda-forge "ffmpeg<8"
262262
pip install datasets soundfile
263-
pip install torchcodec==0.11.0.dev20260217 --extra-index-url https://download.pytorch.org/whl/nightly/cpu
263+
pip install torchcodec==0.12.0.dev20260410 --extra-index-url https://download.pytorch.org/whl/nightly/cpu
264264
python -c "from datasets import load_dataset;import soundfile as sf;sample = load_dataset('distil-whisper/librispeech_long', 'clean', split='validation')[0]['audio'];sf.write('${MODEL_DIR}/$AUDIO_FILE', sample['array'][:sample['sampling_rate']*30], sample['sampling_rate'])"
265265
fi
266266

examples/models/moshi/mimi/install_requirements.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
set -x
99

1010
sudo apt install ffmpeg -y
11-
pip install torchcodec==0.11.0.dev20260217 --extra-index-url https://download.pytorch.org/whl/nightly/cpu
11+
pip install torchcodec==0.12.0.dev20260410 --extra-index-url https://download.pytorch.org/whl/nightly/cpu
1212
pip install moshi==0.2.11
1313
pip install bitsandbytes soundfile einops
1414
# Run llama2/install requirements for torchao deps

examples/models/parakeet/export_parakeet_tdt.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -507,13 +507,11 @@ def _create_metal_partitioners(programs):
507507

508508
# Run decompositions for non-preprocessor programs
509509
updated_programs = {}
510+
decomp_table = torch.export.default_decompositions()
511+
decomp_table[torch.ops.aten.linear.default] = _linear_bias_decomposition
510512
for key, ep in programs.items():
511-
# print(f"Running decompositions for {key}")
512-
# print(ep.graph_module)
513513
if key != "preprocessor":
514-
updated_programs[key] = ep.run_decompositions(
515-
{torch.ops.aten.linear.default: _linear_bias_decomposition}
516-
)
514+
updated_programs[key] = ep.run_decompositions(decomp_table)
517515
else:
518516
updated_programs[key] = ep
519517

examples/models/voxtral_realtime/export_voxtral_rt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -402,10 +402,10 @@ def lower_to_executorch(programs, metadata, backend="xnnpack"):
402402

403403
# Run decompositions for Metal backend
404404
updated_programs = {}
405+
decomp_table = torch.export.default_decompositions()
406+
decomp_table[torch.ops.aten.linear.default] = _linear_bias_decomposition
405407
for key, ep in programs.items():
406-
updated_programs[key] = ep.run_decompositions(
407-
{torch.ops.aten.linear.default: _linear_bias_decomposition}
408-
)
408+
updated_programs[key] = ep.run_decompositions(decomp_table)
409409
programs = updated_programs
410410

411411
partitioner = {}

exir/sym_util.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ def eval_expr(symint: Union[int, torch.SymInt]) -> Optional[int]:
2525
shape_env = node.shape_env
2626
expr = node.expr
2727
try:
28-
output = shape_env.size_hint(expr)
28+
if hasattr(shape_env, "guarding_hint_or_throw"):
29+
output = shape_env.guarding_hint_or_throw(expr)
30+
else:
31+
# size_hint is deprecated, delete this code path.
32+
output = shape_env.size_hint(expr)
2933
except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
3034
return None
3135
return int(output)

install_requirements.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def install_optional_example_requirements(use_pytorch_nightly):
119119
print("Installing torch domain libraries")
120120
DOMAIN_LIBRARIES = [
121121
(
122-
f"torchvision==0.26.0.{NIGHTLY_VERSION}"
122+
f"torchvision==0.27.0.{NIGHTLY_VERSION}"
123123
if use_pytorch_nightly
124124
else "torchvision"
125125
),

runtime/core/portable_type/c10/c10/util/complex_math.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,41 @@ C10_HOST_DEVICE inline c10::complex<T> pow(
8686
#endif
8787
}
8888

89+
// Regression in ROCm 7.2. See https://github.com/ROCm/rocm-libraries/pull/3836.
90+
// Specialized version for complex<float> on AMD GPUs to use FMA-based
91+
// multiplication
92+
#if defined(__HIPCC__)
93+
namespace detail {
94+
// FMA-aware complex multiplication for float precision on AMD GPUs.
95+
// This prevents SLP vectorizer from breaking FMA formation, which causes
96+
// numerical precision loss in complex arithmetic.
97+
// The issue occurs when vectorizer packs scalar multiplies before backend
98+
// can form FMA instructions, resulting in double rounding instead of single.
99+
C10_HOST_DEVICE inline thrust::complex<float> complex_mul_fma(
100+
thrust::complex<float> a,
101+
thrust::complex<float> b) {
102+
// Complex multiplication: (a.r + a.i*i) * (b.r + b.i*i)
103+
// = (a.r*b.r - a.i*b.i) + (a.r*b.i + a.i*b.r)*i
104+
// Using __builtin_fmaf ensures FMA at source level:
105+
// real: a.r*b.r + (-(a.i*b.i)) = FMA(a.r, b.r, -(a.i*b.i))
106+
// imag: a.i*b.r + a.r*b.i = FMA(a.r, b.i, a.i*b.r)
107+
float real_part = __builtin_fmaf(a.real(), b.real(), -(a.imag() * b.imag()));
108+
float imag_part = __builtin_fmaf(a.real(), b.imag(), a.imag() * b.real());
109+
return thrust::complex<float>(real_part, imag_part);
110+
}
111+
} // namespace detail
112+
113+
template <>
114+
C10_HOST_DEVICE inline c10::complex<float> pow(
115+
const c10::complex<float>& x,
116+
const c10::complex<float>& y) {
117+
auto log_x = thrust::log(static_cast<thrust::complex<float>>(x));
118+
auto y_log_x =
119+
detail::complex_mul_fma(static_cast<thrust::complex<float>>(y), log_x);
120+
return static_cast<c10::complex<float>>(thrust::exp(y_log_x));
121+
}
122+
#endif
123+
89124
template <typename T>
90125
C10_HOST_DEVICE inline c10::complex<T> pow(
91126
const c10::complex<T>& x,

runtime/core/portable_type/c10/torch/headeronly/macros/Macros.h

Lines changed: 63 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -325,41 +325,88 @@ constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256;
325325
#define C10_HIP_HOST_DEVICE
326326
#endif
327327

328-
#if defined(USE_ROCM)
329328
// C10_WARP_SIZE is only allowed for device code.
330-
// Host code _must_ use at::cuda::warp_size()
329+
// Host code dynamically-sized launch configs _must_ use at::cuda::warp_size().
330+
// Host or device statically-sized arrays _must_ use either
331+
// C10_WARP_SIZE_UPPER_BOUND or C10_WARP_SIZE_LOWER_BOUND, as needed.
332+
//
331333
// HIP header used to define warpSize as a constexpr that was either 32 or 64
332334
// depending on the target device, and then always set it to 64 for host code.
333-
// Host pass of HIP compiler needs C10_WARP_SIZE defined to _something_ so we
334-
// set it to something unreasonable to trigger obvious host code errors.
335-
335+
// For a time, that allowed C10_WARP_SIZE to be defined like so:
336+
//
337+
// #ifdef USE_ROCM
338+
// #define C10_WARP_SIZE warpSize
339+
// #else
340+
// #define C10_WARP_SIZE 32
341+
// #endif
342+
//
343+
// In ROCm 7, warpSize is no longer constexpr, matching CUDA behavior.
344+
// We can now only use warpSize for C10_WARP_SIZE in device code and this is
345+
// enforced by using __device__ in its definition. In host code where
346+
// C10_WARP_SIZE was previously used as a compile-time constant, this will now
347+
// cause a compile-time error.
348+
//
349+
// If an array was previously expected to be sized at compile-time using
350+
// C10_WARP_SIZE, users must now use either C10_WARP_SIZE_UPPER_BOUND or
351+
// C10_WARP_SIZE_LOWER_BOUND depending on the situation.
352+
//
353+
// If C10_WARP_SIZE was previously used to determine kernel launch sizes, users
354+
// must now use at::cuda::warp_size() for the dynamic runtime query.
355+
//
356+
// Unfortunately, C10_WARP_SIZE has been public and available for both host and
357+
// device since approximately 2019, so forcing it to be device-only would break
358+
// existing code in the wild.
359+
#if defined(USE_ROCM)
336360
namespace at::cuda {
337361
TORCH_CUDA_CPP_API int warp_size();
338362
}
339-
#ifdef __HIPCC__
340-
static inline int __host__ C10_WARP_SIZE_INTERNAL() {
363+
#if defined(__HIPCC__)
364+
static __host__ inline int C10_WARP_SIZE_INTERNAL() {
341365
return at::cuda::warp_size();
342366
}
343-
344-
static inline constexpr int __device__ C10_WARP_SIZE_INTERNAL() {
367+
// NOTE: __device__ C10_WARP_SIZE_INTERNAL
368+
// For __SPIRV__, we must use dynamic warpSize. When not targeting __SPIRV__,
369+
// we can use constexpr. This matches prior behavior. We preserve this for
370+
// backward compatibility instead of forcing old code to use dynamic warpSize
371+
// and losing constexpr. However, compiling for --offload-arch=amdgcnspirv
372+
// could expose where C10_WARP_SIZE was used incorrectly where the dynamic
373+
// warpSize is not allowed.
374+
#if defined(__SPIRV__)
375+
static __device__ inline int C10_WARP_SIZE_INTERNAL() {
376+
return warpSize;
377+
}
378+
#else // __SPIRV__
379+
static __device__ inline constexpr int C10_WARP_SIZE_INTERNAL() {
345380
#if defined(__GFX9__)
346381
return 64;
347382
#else // __GFX9__
348383
return 32;
349384
#endif // __GFX9__
350385
}
351-
#else // __HIPCC__
386+
#endif // __SPIRV__
387+
#if defined(__SPIRV__)
388+
#define C10_WARP_SIZE_LOWER_BOUND 32
389+
#define C10_WARP_SIZE_UPPER_BOUND 64
390+
#elif defined(__GFX9__)
391+
#define C10_WARP_SIZE_LOWER_BOUND 64
392+
#define C10_WARP_SIZE_UPPER_BOUND 64
393+
#else
394+
#define C10_WARP_SIZE_LOWER_BOUND 32
395+
#define C10_WARP_SIZE_UPPER_BOUND 32
396+
#endif
397+
#else // !__HIPCC__
352398
static inline int C10_WARP_SIZE_INTERNAL() {
353399
return at::cuda::warp_size();
354400
}
401+
#define C10_WARP_SIZE_LOWER_BOUND 32
402+
#define C10_WARP_SIZE_UPPER_BOUND 64
355403
#endif // __HIPCC__
356-
357404
#define C10_WARP_SIZE (C10_WARP_SIZE_INTERNAL())
358-
#define C10_WARP_SIZE_STATIC 64
359-
360-
#else // defined(USE_ROCM)
405+
#else // !USE_ROCM
361406
#define C10_WARP_SIZE 32
362-
#endif
407+
#define C10_WARP_SIZE_LOWER_BOUND 32
408+
#define C10_WARP_SIZE_UPPER_BOUND 32
409+
#endif // USE_ROCM
363410

364411
#if defined(_MSC_VER) && _MSC_VER <= 1900
365412
#define __func__ __FUNCTION__
@@ -629,7 +676,7 @@ __host__ __device__
629676
// This macro is used to find older C++ compilers
630677
// that don't support move optimization for return values.
631678

632-
#if (defined(__GNUC__) && __GNUC__ < 13) || \
679+
#if (defined(__GNUC__) && __GNUC__ < 13 && __cplusplus < 202002L) || \
633680
(defined(__clang_major__) && __clang_major__ < 13)
634681
#define C10_RETURN_MOVE_IF_OLD_COMPILER 1
635682
#else

runtime/core/portable_type/c10/torch/headeronly/util/BFloat16.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#include <iosfwd>
1313
#include <ostream>
1414

15-
#if defined(__CUDACC__) && !defined(USE_ROCM)
15+
#if defined(__CUDACC__) && (!defined(USE_ROCM) || (TORCH_HIP_VERSION >= 702))
1616
#include <cuda_bf16.h>
1717
#endif
1818

@@ -46,7 +46,7 @@ struct alignas(2) BFloat16 {
4646
/* implicit */ inline C10_HOST_DEVICE BFloat16(float value);
4747
inline C10_HOST_DEVICE operator float() const;
4848

49-
#if defined(__CUDACC__) && !defined(USE_ROCM)
49+
#if defined(__CUDACC__) && (!defined(USE_ROCM) || (TORCH_HIP_VERSION >= 702))
5050
inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value);
5151
explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const;
5252
#endif
@@ -124,8 +124,9 @@ C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
124124
/// Constructors
125125
inline C10_HOST_DEVICE BFloat16::BFloat16(float value)
126126
:
127-
#if defined(__CUDACC__) && !defined(USE_ROCM) && defined(__CUDA_ARCH__) && \
128-
__CUDA_ARCH__ >= 800
127+
#if defined(__CUDACC__) && \
128+
(!defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 || \
129+
defined(USE_ROCM) && (TORCH_HIP_VERSION >= 702))
129130
x(__bfloat16_as_ushort(__float2bfloat16(value)))
130131
#elif defined(__SYCL_DEVICE_ONLY__) && \
131132
defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
@@ -139,7 +140,7 @@ inline C10_HOST_DEVICE BFloat16::BFloat16(float value)
139140

140141
/// Implicit conversions
141142
inline C10_HOST_DEVICE BFloat16::operator float() const {
142-
#if defined(__CUDACC__) && !defined(USE_ROCM)
143+
#if defined(__CUDACC__) && (!defined(USE_ROCM) || (TORCH_HIP_VERSION >= 702))
143144
return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&x));
144145
#elif defined(__SYCL_DEVICE_ONLY__) && \
145146
defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
@@ -149,7 +150,7 @@ inline C10_HOST_DEVICE BFloat16::operator float() const {
149150
#endif
150151
}
151152

152-
#if defined(__CUDACC__) && !defined(USE_ROCM)
153+
#if defined(__CUDACC__) && (!defined(USE_ROCM) || (TORCH_HIP_VERSION >= 702))
153154
inline C10_HOST_DEVICE BFloat16::BFloat16(const __nv_bfloat16& value) {
154155
x = *reinterpret_cast<const unsigned short*>(&value);
155156
}

0 commit comments

Comments
 (0)