diff --git a/src/CodeGen_Internal.cpp b/src/CodeGen_Internal.cpp index 0ac2e45d026c..07b2e0744c18 100644 --- a/src/CodeGen_Internal.cpp +++ b/src/CodeGen_Internal.cpp @@ -607,7 +607,6 @@ void get_target_options(const llvm::Module &module, llvm::TargetOptions &options bool use_soft_float_abi = get_modflag_bool(module, "halide_use_soft_float_abi"); std::string mabi = get_modflag_string(module, "halide_mabi"); - // FIXME: can this be migrated into `set_function_attributes_from_halide_target_options()`? bool per_instruction_fast_math_flags = get_modflag_bool(module, "halide_per_instruction_fast_math_flags"); options = llvm::TargetOptions(); @@ -615,7 +614,9 @@ void get_target_options(const llvm::Module &module, llvm::TargetOptions &options #if LLVM_VERSION < 230 options.NoInfsFPMath = !per_instruction_fast_math_flags; #endif +#if LLVM_VERSION < 230 options.NoNaNsFPMath = !per_instruction_fast_math_flags; +#endif options.HonorSignDependentRoundingFPMathOption = !per_instruction_fast_math_flags; options.NoZerosInBSS = false; options.GuaranteedTailCallOpt = false; @@ -720,6 +721,16 @@ void set_function_attributes_from_halide_target_options(llvm::Function &fn) { fn.addFnAttr(llvm::Attribute::getWithVScaleRangeArgs( module.getContext(), vscale_range, vscale_range)); } + + // When not using per-instruction fast-math flags (i.e., the whole module + // is in fast-math mode), propagate the fast-math assumptions as function + // attributes. In LLVM 23+, NoNaNsFPMath and NoInfsFPMath were removed from + // TargetOptions in favor of these per-function attributes. + bool per_instruction_fast_math_flags = get_modflag_bool(module, "halide_per_instruction_fast_math_flags"); + if (!per_instruction_fast_math_flags) { + fn.addFnAttr("no-nans-fp-math", "true"); + fn.addFnAttr("no-infs-fp-math", "true"); + } } void embed_bitcode(llvm::Module *M, const string &halide_command) { diff --git a/src/CodeGen_PTX_Dev.cpp b/src/CodeGen_PTX_Dev.cpp index 3deab616ecc5..638fc6188ab0 100644 --- a/src/CodeGen_PTX_Dev.cpp +++ b/src/CodeGen_PTX_Dev.cpp @@ -230,6 +230,11 @@ void CodeGen_PTX_Dev::init_module() { module = get_initial_module_for_ptx_device(target, context); + // Propagate the strict-float flag as a module flag so that + // set_function_attributes_from_halide_target_options can read it. + module->addModuleFlag(llvm::Module::Warning, "halide_per_instruction_fast_math_flags", + CodeGen_GPU_Dev::any_strict_float ? 1 : 0); + struct Intrinsic { const char *name; Type ret_type; @@ -631,7 +636,9 @@ vector CodeGen_PTX_Dev::compile_to_src() { #if LLVM_VERSION < 230 options.NoInfsFPMath = !CodeGen_GPU_Dev::any_strict_float; #endif +#if LLVM_VERSION < 230 options.NoNaNsFPMath = !CodeGen_GPU_Dev::any_strict_float; +#endif options.HonorSignDependentRoundingFPMathOption = !CodeGen_GPU_Dev::any_strict_float; options.NoZerosInBSS = false; options.GuaranteedTailCallOpt = false;